//! Simple Bloom Filter

use {
    bv::BitVec,
    fnv::FnvHasher,
    rand::{self, Rng},
    serde::{Deserialize, Serialize},
    solana_sanitize::{Sanitize, SanitizeError},
    solana_time_utils::AtomicInterval,
    std::{
        cmp, fmt,
        hash::Hasher,
        marker::PhantomData,
        ops::Deref,
        sync::atomic::{AtomicU64, Ordering},
    },
};

/// Generate a stable hash of `self` for each `hash_index`
/// Best effort can be made for uniqueness of each hash.
pub trait BloomHashIndex {
    fn hash_at_index(&self, hash_index: u64) -> u64;
}

#[cfg_attr(feature = "frozen-abi", derive(AbiExample))]
#[derive(Serialize, Deserialize, Default, Clone, PartialEq, Eq)]
pub struct Bloom<T: BloomHashIndex> {
    pub keys: Vec<u64>,
    pub bits: BitVec<u64>,
    num_bits_set: u64,
    _phantom: PhantomData<T>,
}

impl<T: BloomHashIndex> fmt::Debug for Bloom<T> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(
            f,
            "Bloom {{ keys.len: {} bits.len: {} num_set: {} bits: ",
            self.keys.len(),
            self.bits.len(),
            self.num_bits_set
        )?;
        const MAX_PRINT_BITS: u64 = 10;
        for i in 0..std::cmp::min(MAX_PRINT_BITS, self.bits.len()) {
            if self.bits.get(i) {
                write!(f, "1")?;
            } else {
                write!(f, "0")?;
            }
        }
        if self.bits.len() > MAX_PRINT_BITS {
            write!(f, "..")?;
        }
        write!(f, " }}")
    }
}

impl<T: BloomHashIndex> Sanitize for Bloom<T> {
    fn sanitize(&self) -> Result<(), SanitizeError> {
        // Avoid division by zero in self.pos(...).
        if self.bits.is_empty() {
            Err(SanitizeError::InvalidValue)
        } else {
            Ok(())
        }
    }
}

impl<T: BloomHashIndex> Bloom<T> {
    pub fn new(num_bits: usize, keys: Vec<u64>) -> Self {
        let bits = BitVec::new_fill(false, num_bits as u64);
        Bloom {
            keys,
            bits,
            num_bits_set: 0,
            _phantom: PhantomData,
        }
    }
    /// Create filter optimal for num size given the `FALSE_RATE`.
    ///
    /// The keys are randomized for picking data out of a collision resistant hash of size
    /// `keysize` bytes.
    ///
    /// See <https://hur.st/bloomfilter/>.
    pub fn random(num_items: usize, false_rate: f64, max_bits: usize) -> Self {
        let m = Self::num_bits(num_items as f64, false_rate);
        let num_bits = cmp::max(1, cmp::min(m as usize, max_bits));
        let num_keys = Self::num_keys(num_bits as f64, num_items as f64) as usize;
        let keys: Vec<u64> = (0..num_keys).map(|_| rand::thread_rng().gen()).collect();
        Self::new(num_bits, keys)
    }
    fn num_bits(num_items: f64, false_rate: f64) -> f64 {
        let n = num_items;
        let p = false_rate;
        ((n * p.ln()) / (1f64 / 2f64.powf(2f64.ln())).ln()).ceil()
    }
    fn num_keys(num_bits: f64, num_items: f64) -> f64 {
        let n = num_items;
        let m = num_bits;
        // infinity as usize is zero in rust 1.43 but 2^64-1 in rust 1.45; ensure it's zero here
        if n == 0.0 {
            0.0
        } else {
            1f64.max(((m / n) * 2f64.ln()).round())
        }
    }
    fn pos(&self, key: &T, k: u64) -> u64 {
        key.hash_at_index(k)
            .checked_rem(self.bits.len())
            .unwrap_or(0)
    }
    pub fn clear(&mut self) {
        self.bits = BitVec::new_fill(false, self.bits.len());
        self.num_bits_set = 0;
    }
    pub fn add(&mut self, key: &T) {
        for k in &self.keys {
            let pos = self.pos(key, *k);
            if !self.bits.get(pos) {
                self.num_bits_set = self.num_bits_set.saturating_add(1);
                self.bits.set(pos, true);
            }
        }
    }
    pub fn contains(&self, key: &T) -> bool {
        for k in &self.keys {
            let pos = self.pos(key, *k);
            if !self.bits.get(pos) {
                return false;
            }
        }
        true
    }
}

fn slice_hash(slice: &[u8], hash_index: u64) -> u64 {
    let mut hasher = FnvHasher::with_key(hash_index);
    hasher.write(slice);
    hasher.finish()
}

impl<T: AsRef<[u8]>> BloomHashIndex for T {
    fn hash_at_index(&self, hash_index: u64) -> u64 {
        slice_hash(self.as_ref(), hash_index)
    }
}

/// Bloom filter that can be used concurrently.
/// Concurrent reads/writes are safe, but are not atomic at the struct level,
/// this means that reads may see partial writes.
pub struct ConcurrentBloom<T> {
    num_bits: u64,
    keys: Vec<u64>,
    bits: Vec<AtomicU64>,
    _phantom: PhantomData<T>,
}

impl<T: BloomHashIndex> From<Bloom<T>> for ConcurrentBloom<T> {
    fn from(bloom: Bloom<T>) -> Self {
        ConcurrentBloom {
            num_bits: bloom.bits.len(),
            keys: bloom.keys,
            bits: bloom
                .bits
                .into_boxed_slice()
                .iter()
                .map(|&x| AtomicU64::new(x))
                .collect(),
            _phantom: PhantomData,
        }
    }
}

impl<T: BloomHashIndex> ConcurrentBloom<T> {
    fn pos(&self, key: &T, hash_index: u64) -> (usize, u64) {
        let pos = key
            .hash_at_index(hash_index)
            .checked_rem(self.num_bits)
            .unwrap_or(0);
        // Divide by 64 to figure out which of the
        // AtomicU64 bit chunks we need to modify.
        let index = pos.wrapping_shr(6);
        // (pos & 63) is equivalent to mod 64 so that we can find
        // the index of the bit within the AtomicU64 to modify.
        let mask = 1u64.wrapping_shl(u32::try_from(pos & 63).unwrap());
        (index as usize, mask)
    }

    /// Adds an item to the bloom filter and returns true if the item
    /// was not in the filter before.
    pub fn add(&self, key: &T) -> bool {
        let mut added = false;
        for k in &self.keys {
            let (index, mask) = self.pos(key, *k);
            let prev_val = self.bits[index].fetch_or(mask, Ordering::Relaxed);
            added = added || prev_val & mask == 0u64;
        }
        added
    }

    pub fn contains(&self, key: &T) -> bool {
        self.keys.iter().all(|k| {
            let (index, mask) = self.pos(key, *k);
            let bit = self.bits[index].load(Ordering::Relaxed) & mask;
            bit != 0u64
        })
    }

    pub fn clear(&self) {
        self.bits.iter().for_each(|bit| {
            bit.store(0u64, Ordering::Relaxed);
        });
    }
}

impl<T: BloomHashIndex> From<ConcurrentBloom<T>> for Bloom<T> {
    fn from(atomic_bloom: ConcurrentBloom<T>) -> Self {
        let bits: Vec<_> = atomic_bloom
            .bits
            .into_iter()
            .map(AtomicU64::into_inner)
            .collect();
        let num_bits_set = bits.iter().map(|x| x.count_ones() as u64).sum();
        let mut bits: BitVec<u64> = bits.into();
        bits.truncate(atomic_bloom.num_bits);
        Bloom {
            keys: atomic_bloom.keys,
            bits,
            num_bits_set,
            _phantom: PhantomData,
        }
    }
}

/// Wrapper around `ConcurrentBloom` and `AtomicInterval` so the bloom filter
/// can be cleared periodically.
pub struct ConcurrentBloomInterval<T: BloomHashIndex> {
    interval: AtomicInterval,
    bloom: ConcurrentBloom<T>,
}

// Directly allow all methods of `AtomicBloom` to be called on `AtomicBloomInterval`.
impl<T: BloomHashIndex> Deref for ConcurrentBloomInterval<T> {
    type Target = ConcurrentBloom<T>;
    fn deref(&self) -> &Self::Target {
        &self.bloom
    }
}

impl<T: BloomHashIndex> ConcurrentBloomInterval<T> {
    /// Create a new filter with the given parameters.
    /// See `Bloom::random` for details.
    pub fn new(num_items: usize, false_positive_rate: f64, max_bits: usize) -> Self {
        let bloom = Bloom::random(num_items, false_positive_rate, max_bits);
        Self {
            interval: AtomicInterval::default(),
            bloom: ConcurrentBloom::from(bloom),
        }
    }

    /// Reset the filter if the reset interval has elapsed.
    pub fn maybe_reset(&self, reset_interval_ms: u64) {
        if self.interval.should_update(reset_interval_ms) {
            self.bloom.clear();
        }
    }
}

#[cfg(test)]
mod test {
    use {super::*, rayon::prelude::*, solana_hash::Hash, solana_sha256_hasher::hash};

    #[test]
    fn test_bloom_filter() {
        //empty
        let bloom: Bloom<Hash> = Bloom::random(0, 0.1, 100);
        assert_eq!(bloom.keys.len(), 0);
        assert_eq!(bloom.bits.len(), 1);

        //normal
        let bloom: Bloom<Hash> = Bloom::random(10, 0.1, 100);
        assert_eq!(bloom.keys.len(), 3);
        assert_eq!(bloom.bits.len(), 48);

        //saturated
        let bloom: Bloom<Hash> = Bloom::random(100, 0.1, 100);
        assert_eq!(bloom.keys.len(), 1);
        assert_eq!(bloom.bits.len(), 100);
    }
    #[test]
    fn test_add_contains() {
        let mut bloom: Bloom<Hash> = Bloom::random(100, 0.1, 100);
        //known keys to avoid false positives in the test
        bloom.keys = vec![0, 1, 2, 3];

        let key = hash(b"hello");
        assert!(!bloom.contains(&key));
        bloom.add(&key);
        assert!(bloom.contains(&key));

        let key = hash(b"world");
        assert!(!bloom.contains(&key));
        bloom.add(&key);
        assert!(bloom.contains(&key));
    }
    #[test]
    fn test_random() {
        let mut b1: Bloom<Hash> = Bloom::random(10, 0.1, 100);
        let mut b2: Bloom<Hash> = Bloom::random(10, 0.1, 100);
        b1.keys.sort_unstable();
        b2.keys.sort_unstable();
        assert_ne!(b1.keys, b2.keys);
    }
    // Bloom filter math in python
    // n number of items
    // p false rate
    // m number of bits
    // k number of keys
    //
    // n = ceil(m / (-k / log(1 - exp(log(p) / k))))
    // p = pow(1 - exp(-k / (m / n)), k)
    // m = ceil((n * log(p)) / log(1 / pow(2, log(2))));
    // k = round((m / n) * log(2));
    #[test]
    fn test_filter_math() {
        assert_eq!(Bloom::<Hash>::num_bits(100f64, 0.1f64) as u64, 480u64);
        assert_eq!(Bloom::<Hash>::num_bits(100f64, 0.01f64) as u64, 959u64);
        assert_eq!(Bloom::<Hash>::num_keys(1000f64, 50f64) as u64, 14u64);
        assert_eq!(Bloom::<Hash>::num_keys(2000f64, 50f64) as u64, 28u64);
        assert_eq!(Bloom::<Hash>::num_keys(2000f64, 25f64) as u64, 55u64);
        //ensure min keys is 1
        assert_eq!(Bloom::<Hash>::num_keys(20f64, 1000f64) as u64, 1u64);
    }

    #[test]
    fn test_debug() {
        let mut b: Bloom<Hash> = Bloom::new(3, vec![100]);
        b.add(&Hash::default());
        assert_eq!(
            format!("{b:?}"),
            "Bloom { keys.len: 1 bits.len: 3 num_set: 1 bits: 001 }"
        );

        let mut b: Bloom<Hash> = Bloom::new(1000, vec![100]);
        b.add(&Hash::default());
        b.add(&hash(&[1, 2]));
        assert_eq!(
            format!("{b:?}"),
            "Bloom { keys.len: 1 bits.len: 1000 num_set: 2 bits: 0000000000.. }"
        );
    }

    fn generate_random_hash() -> Hash {
        let mut rng = rand::thread_rng();
        let mut hash = [0u8; solana_hash::HASH_BYTES];
        rng.fill(&mut hash);
        Hash::new_from_array(hash)
    }

    #[test]
    fn test_atomic_bloom() {
        let hash_values: Vec<_> = std::iter::repeat_with(generate_random_hash)
            .take(1200)
            .collect();
        let bloom: ConcurrentBloom<_> = Bloom::<Hash>::random(1287, 0.1, 7424).into();
        assert_eq!(bloom.keys.len(), 3);
        assert_eq!(bloom.num_bits, 6168);
        assert_eq!(bloom.bits.len(), 97);
        hash_values.par_iter().for_each(|v| {
            bloom.add(v);
        });
        let bloom: Bloom<Hash> = bloom.into();
        assert_eq!(bloom.keys.len(), 3);
        assert_eq!(bloom.bits.len(), 6168);
        assert!(bloom.num_bits_set > 2000);
        for hash_value in hash_values {
            assert!(bloom.contains(&hash_value));
        }
        let false_positive = std::iter::repeat_with(generate_random_hash)
            .take(10_000)
            .filter(|hash_value| bloom.contains(hash_value))
            .count();
        assert!(false_positive < 2_000, "false_positive: {false_positive}");
    }

    #[test]
    fn test_atomic_bloom_round_trip() {
        let mut rng = rand::thread_rng();
        let keys: Vec<_> = std::iter::repeat_with(|| rng.gen()).take(5).collect();
        let mut bloom = Bloom::<Hash>::new(9731, keys.clone());
        let hash_values: Vec<_> = std::iter::repeat_with(generate_random_hash)
            .take(1000)
            .collect();
        for hash_value in &hash_values {
            bloom.add(hash_value);
        }
        let num_bits_set = bloom.num_bits_set;
        assert!(num_bits_set > 2000, "# bits set: {num_bits_set}");
        // Round-trip with no inserts.
        let bloom: ConcurrentBloom<_> = bloom.into();
        assert_eq!(bloom.num_bits, 9731);
        assert_eq!(bloom.bits.len(), 9731_usize.div_ceil(64));
        for hash_value in &hash_values {
            assert!(bloom.contains(hash_value));
        }
        let bloom: Bloom<_> = bloom.into();
        assert_eq!(bloom.num_bits_set, num_bits_set);
        for hash_value in &hash_values {
            assert!(bloom.contains(hash_value));
        }
        // Round trip, re-inserting the same hash values.
        let bloom: ConcurrentBloom<_> = bloom.into();
        hash_values.par_iter().for_each(|v| {
            bloom.add(v);
        });
        for hash_value in &hash_values {
            assert!(bloom.contains(hash_value));
        }
        let bloom: Bloom<_> = bloom.into();
        assert_eq!(bloom.num_bits_set, num_bits_set);
        assert_eq!(bloom.bits.len(), 9731);
        for hash_value in &hash_values {
            assert!(bloom.contains(hash_value));
        }
        // Round trip, inserting new hash values.
        let more_hash_values: Vec<_> = std::iter::repeat_with(generate_random_hash)
            .take(1000)
            .collect();
        let bloom: ConcurrentBloom<_> = bloom.into();
        assert_eq!(bloom.num_bits, 9731);
        assert_eq!(bloom.bits.len(), 9731_usize.div_ceil(64));
        more_hash_values.par_iter().for_each(|v| {
            bloom.add(v);
        });
        for hash_value in &hash_values {
            assert!(bloom.contains(hash_value));
        }
        for hash_value in &more_hash_values {
            assert!(bloom.contains(hash_value));
        }
        let false_positive = std::iter::repeat_with(generate_random_hash)
            .take(10_000)
            .filter(|hash_value| bloom.contains(hash_value))
            .count();
        assert!(false_positive < 2000, "false_positive: {false_positive}");
        let bloom: Bloom<_> = bloom.into();
        assert_eq!(bloom.bits.len(), 9731);
        assert!(bloom.num_bits_set > num_bits_set);
        assert!(
            bloom.num_bits_set > 4000,
            "# bits set: {}",
            bloom.num_bits_set
        );
        for hash_value in &hash_values {
            assert!(bloom.contains(hash_value));
        }
        for hash_value in &more_hash_values {
            assert!(bloom.contains(hash_value));
        }
        let false_positive = std::iter::repeat_with(generate_random_hash)
            .take(10_000)
            .filter(|hash_value| bloom.contains(hash_value))
            .count();
        assert!(false_positive < 2000, "false_positive: {false_positive}");
        // Assert that the bits vector precisely match if no atomic ops were
        // used.
        let bits = bloom.bits;
        let mut bloom = Bloom::<Hash>::new(9731, keys);
        for hash_value in &hash_values {
            bloom.add(hash_value);
        }
        for hash_value in &more_hash_values {
            bloom.add(hash_value);
        }
        assert_eq!(bits, bloom.bits);
    }
}
