use libc::{cpu_set_t, sched_getaffinity, sched_setaffinity, CPU_ISSET, CPU_SET, CPU_ZERO};
use log::{debug, error, info};
use std::fs;
use std::mem;
use std::sync::{Arc, RwLock};
use std::thread;
use std::time::Duration;

const MAX_CPUS: usize = 256;
const SLEEP_TIME_MS: u64 = 5_000;

pub struct AffinityManager {
    rules: Arc<RwLock<Vec<(String, Vec<usize>)>>>,
    available_cores: Vec<usize>,
    running: Arc<std::sync::atomic::AtomicBool>,
}

impl AffinityManager {
    pub fn new() -> Self {
        let available_cores = Self::get_process_affinity();
        info!("Available cores: {:?}", available_cores);

        Self {
            rules: Arc::new(RwLock::new(Vec::new())),
            available_cores,
            running: Arc::new(std::sync::atomic::AtomicBool::new(false)),
        }
    }

    /// Get the cores that the C parent process allowed us to use
    fn get_process_affinity() -> Vec<usize> {
        unsafe {
            let mut cpuset: cpu_set_t = mem::zeroed();

            if sched_getaffinity(0, mem::size_of::<cpu_set_t>(), &mut cpuset) == 0 {
                (0..MAX_CPUS) // CPU_SETSIZE is usually 1024, but 128 is reasonable
                    .filter(|&i| CPU_ISSET(i, &cpuset))
                    .collect()
            } else {
                // Fallback: assume all cores available
                (0..num_cpus::get()).collect()
            }
        }
    }

    /// Add rule for a given pattern and list of cores
    pub fn add_rule(&self, name_pattern: &str, core_ids: &Vec<usize>) {
        if !core_ids.is_empty() {
            info!(
                "Pinning threads matching '{}' to cores {:?}",
                name_pattern, core_ids
            );
            let mut rules = self.rules.write().unwrap();
            if rules.iter().any(|(pattern, _)| pattern == name_pattern) {
                panic!("Failed to add duplicate affinity rule: '{}'", name_pattern);
            }
            rules.push((name_pattern.to_string(), core_ids.clone()));
        }
    }

    pub fn available_cores(&self) -> &[usize] {
        &self.available_cores
    }

    pub fn start(&self) {
        self.running
            .store(true, std::sync::atomic::Ordering::Relaxed);
        let rules = self.rules.clone();
        let running = self.running.clone();

        thread::Builder::new()
            .name("affinity-manager".to_string())
            .spawn(move || {
                while running.load(std::sync::atomic::Ordering::Relaxed) {
                    apply_affinity_rules(&rules);
                    thread::sleep(Duration::from_millis(SLEEP_TIME_MS));
                }
            })
            .unwrap();
    }

    pub fn stop(&self) {
        self.running
            .store(false, std::sync::atomic::Ordering::Relaxed);
    }
}

#[derive(Debug)]
pub enum AffinityError {
    ThreadNoLongActive,
    MalformedStatus(String),
}

impl std::fmt::Display for AffinityError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            AffinityError::ThreadNoLongActive => {
                write!(f, "AffinityError: Thread no longer active")
            }
            AffinityError::MalformedStatus(msg) => {
                write!(f, "AffinityError: Malformed status file: {}", msg)
            }
        }
    }
}

fn get_kernel_tid(ns_tid: &str) -> Result<i32, AffinityError> {
    let status_path = format!("/proc/self/task/{}/status", ns_tid);
    let Ok(status) = fs::read_to_string(&status_path) else {
        return Err(AffinityError::ThreadNoLongActive);
    };
    for line in status.lines() {
        if line.starts_with("NSpid") {
            let parts: Vec<&str> = line.split_whitespace().collect();
            match parts.len() {
                2 => {
                    // When sandbox=false, format is:
                    // NSpid: <tid>
                    return Ok(parts[1].parse::<i32>().unwrap());
                }
                4 => {
                    // When sandbox=true, format is:
                    // NSpid: <ns_tid> <kernel_parent_tid> <kernel_tid>
                    return Ok(parts[3].parse::<i32>().unwrap());
                }
                _ => {
                    let msg = format!(
                        "Unexpected NSpid line format for thread {}: '{}'",
                        ns_tid, line
                    );
                    return Err(AffinityError::MalformedStatus(msg));
                }
            }
        }
    }
    let msg = format!(
        "Failed to find NSpid in status file for thread {}: {}",
        ns_tid, status_path
    );
    Err(AffinityError::MalformedStatus(msg))
}

fn apply_affinity_rules(rules: &RwLock<Vec<(String, Vec<usize>)>>) {
    let rules = rules.read().unwrap();
    if rules.is_empty() {
        return;
    }

    let Ok(entries) = fs::read_dir("/proc/self/task") else {
        return;
    };

    for entry in entries.flatten() {
        let tid = entry.file_name().to_string_lossy().to_string();
        let ntid = match get_kernel_tid(&tid) {
            Ok(ntid) => ntid,
            Err(AffinityError::ThreadNoLongActive) => {
                debug!(
                    "Thread with ns_tid {} is no longer active; skipping affinity set",
                    tid
                );
                continue;
            }
            Err(e) => {
                error!("Failed to get kernel tid for ns_tid {}: {}", tid, e);
                continue;
            }
        };

        let comm_path = format!("/proc/self/task/{}/comm", tid);
        let Ok(name) = fs::read_to_string(&comm_path) else {
            continue;
        };
        let name = name.trim();
        let aff = get_thread_affinity(ntid);

        for (pattern, core_ids) in rules.iter() {
            if matches_pattern(name, pattern) {
                if core_ids == &aff {
                    break;
                }
                match set_thread_affinity(ntid, core_ids) {
                    Ok(()) => {
                        debug!("Successfully set affinity for thread {} ('{}') from {:?} to {:?}, ns_tid: {}", ntid, name, aff, core_ids, tid);
                    }
                    Err(e) => {
                        error!("Error setting affinity for thread {} ('{}') from {:?} to {:?}: {}, ns_tid: {}", ntid, name, aff, core_ids, e, tid);
                    }
                }
                break;
            }
        }
    }
}

fn matches_pattern(name: &str, pattern: &str) -> bool {
    if pattern.contains('*') {
        let regex_pattern = pattern.replace('.', "\\.").replace('*', ".*");
        regex::Regex::new(&format!("^{}$", regex_pattern))
            .map(|re| re.is_match(name))
            .unwrap_or(false)
    } else {
        name == pattern
    }
}

pub fn get_thread_affinity(tid: i32) -> Vec<usize> {
    unsafe {
        let mut cpuset: cpu_set_t = mem::zeroed();

        if sched_getaffinity(tid, mem::size_of::<cpu_set_t>(), &mut cpuset) == 0 {
            (0..MAX_CPUS).filter(|&i| CPU_ISSET(i, &cpuset)).collect()
        } else {
            vec![]
        }
    }
}

fn set_thread_affinity(tid: i32, core_ids: &[usize]) -> Result<(), std::io::Error> {
    unsafe {
        let mut cpuset: cpu_set_t = mem::zeroed();

        CPU_ZERO(&mut cpuset);

        for &core_id in core_ids {
            CPU_SET(core_id, &mut cpuset);
        }

        let ret = sched_setaffinity(tid, mem::size_of::<cpu_set_t>(), &cpuset);
        if ret != 0 {
            return Err(std::io::Error::last_os_error());
        }
        Ok(())
    }
}
