Bitset

struct Bitset(Vec<std::arch::x86_64::__m256i>, usize);

impl Bitset {
    #[target_feature(enable = "avx2")]
    unsafe fn new(n: usize) -> Self {
        use std::arch::x86_64::*;
        let mut v = vec![];
        v.resize_with((n + 255) / 256, || unsafe { _mm256_setzero_si256() });
        Self(v, n)
    }
    fn len(&self) -> usize {
        self.1
    }
    fn get(&self, i: usize) -> bool {
        let b64 =
            unsafe { std::slice::from_raw_parts(self.0.as_ptr() as *const u64, self.0.len() * 4) };
        b64[i / 64] & (1 << (i % 64)) != 0
    }
    fn set(&mut self, i: usize, v: bool) {
        let b64 = unsafe {
            std::slice::from_raw_parts_mut(self.0.as_mut_ptr() as *mut u64, self.0.len() * 4)
        };
        if v {
            b64[i / 64] |= 1 << (i % 64);
        } else {
            b64[i / 64] &= !(1 << (i % 64));
        }
    }
    fn flip(&mut self, i: usize) {
        let b64 = unsafe {
            std::slice::from_raw_parts_mut(self.0.as_mut_ptr() as *mut u64, self.0.len() * 4)
        };
        b64[i / 64] ^= 1 << (i % 64);
    }
    fn shl(&self, x: usize) -> Self {
        let mut shl = unsafe { Self::new(self.len() + x) };
        let shl_b64 = unsafe {
            std::slice::from_raw_parts_mut(shl.0.as_mut_ptr().cast::<u64>(), (shl.len() + 63) / 64)
        };
        let b64 = unsafe {
            std::slice::from_raw_parts(self.0.as_ptr().cast::<u64>(), (self.len() + 63) / 64)
        };
        let x_chunk = x / 64;
        let x_inside = (x % 64) as u32;
        if x_inside == 0 {
            for (dest, &src) in shl_b64.iter_mut().skip(x_chunk).zip(b64) {
                *dest = src.wrapping_shl(x_inside);
            }
        } else {
            let mut low = 0;
            for (dest, &src) in shl_b64.iter_mut().skip(x_chunk).zip(b64) {
                *dest = src.wrapping_shl(x_inside) | low;
                low = src >> (64 - x_inside);
            }
            if shl_b64.len() > b64.len() + x_chunk {
                *shl_b64.last_mut().unwrap() = low;
            }
        }
        shl
    }
    #[target_feature(enable = "avx2")]
    unsafe fn or(&mut self, other: &Self) {
        for (a, &b) in self.0.iter_mut().zip(&other.0) {
            *a = std::arch::x86_64::_mm256_or_si256(*a, b);
        }
    }
    fn split_range(mut begin: usize, end: usize) -> Option<(usize, usize)> {
        let back = end & !255;
        if begin & 255 != 0 {
            begin += 256 - (begin & 255);
        }
        if begin < back {
            Some((begin, back))
        } else {
            None
        }
    }
    #[target_feature(enable = "avx2")]
    unsafe fn find_first_unset(&self, l: usize, r: usize) -> Option<usize> {
        let b64 =
            unsafe { std::slice::from_raw_parts(self.0.as_ptr() as *const u64, self.0.len() * 4) };
        if let Some((le, rb)) = Self::split_range(l, r) {
            for i in l..le {
                if b64[i >> 6] & (1 << (i & 63)) == 0 {
                    return Some(i);
                }
            }
            let mut mask = 0;
            if let Some(mut i) = self.0[le >> 8..rb >> 8].iter().position(|p| {
                use std::arch::x86_64::*;
                let v256 = unsafe { _mm256_load_si256(p) };
                let zero = unsafe { _mm256_set1_epi8(-1) };
                let nonzero = unsafe { _mm256_cmpeq_epi8(zero, v256) };
                mask = unsafe { _mm256_movemask_epi8(nonzero) } as u32;
                mask != !0
            }) {
                i += le >> 8;
                let j = (0..32).position(|s| mask & (1 << s) == 0).unwrap();
                let b8 = unsafe {
                    std::slice::from_raw_parts(self.0.as_ptr() as *const u8, self.0.len() * 32)
                };
                let k = b8[i * 32 + j].trailing_ones() as usize;
                let begin = (i * 32 + j) * 8 + k;
                return Some(begin);
            }
            for i in rb..r {
                if b64[i >> 6] & (1 << (i & 63)) == 0 {
                    return Some(i);
                }
            }
            None
        } else {
            for i in l..r {
                if b64[i >> 6] & (1 << (i & 63)) == 0 {
                    return Some(i);
                }
            }
            None
        }
    }
    #[target_feature(enable = "avx2")]
    unsafe fn find_first_set(&self, l: usize, r: usize) -> Option<usize> {
        let b64 =
            unsafe { std::slice::from_raw_parts(self.0.as_ptr() as *const u64, self.0.len() * 4) };
        if let Some((le, rb)) = Self::split_range(l, r) {
            for i in l..le {
                if b64[i >> 6] & (1 << (i & 63)) != 0 {
                    return Some(i);
                }
            }
            let mut mask = 0;
            if let Some(mut i) = self.0[le >> 8..rb >> 8].iter().position(|p| {
                use std::arch::x86_64::*;
                let v256 = unsafe { _mm256_load_si256(p) };
                let zero = unsafe { _mm256_setzero_si256() };
                let nonzero = unsafe { _mm256_cmpeq_epi8(zero, v256) };
                mask = unsafe { _mm256_movemask_epi8(nonzero) } as u32;
                mask != !0
            }) {
                i += le >> 8;
                let j = (0..32).position(|s| mask & (1 << s) == 0).unwrap();
                let b8 = unsafe {
                    std::slice::from_raw_parts(self.0.as_ptr() as *const u8, self.0.len() * 32)
                };
                let k = b8[i * 32 + j].trailing_zeros() as usize;
                let begin = (i * 32 + j) * 8 + k;
                return Some(begin);
            }
            for i in rb..r {
                if b64[i >> 6] & (1 << (i & 63)) != 0 {
                    return Some(i);
                }
            }
            None
        } else {
            for i in l..r {
                if b64[i >> 6] & (1 << (i & 63)) != 0 {
                    return Some(i);
                }
            }
            None
        }
    }
    #[target_feature(enable = "avx2")]
    unsafe fn flip_range(&mut self, l: usize, r: usize) {
        let b64 = std::slice::from_raw_parts_mut(self.0.as_mut_ptr() as *mut u64, self.0.len() * 4);
        if let Some((le, rb)) = Self::split_range(l, r) {
            use std::arch::x86_64::*;
            for i in l..le {
                b64[i >> 6] ^= 1 << (i & 63);
            }
            let one = _mm256_set1_epi8(-1);
            for v in &mut self.0[le >> 8..rb >> 8] {
                let load = _mm256_load_si256(v);
                let xor = _mm256_xor_si256(load, one);
                _mm256_store_si256(v, xor);
            }
            for i in rb..r {
                b64[i >> 6] ^= 1 << (i & 63);
            }
        } else {
            for i in l..r {
                b64[i >> 6] ^= 1 << (i & 63);
            }
        }
    }
    #[target_feature(enable = "avx2")]
    unsafe fn count_ones(&mut self, l: usize, r: usize) -> u32 {
        let b64 = std::slice::from_raw_parts_mut(self.0.as_mut_ptr() as *mut u64, self.0.len() * 4);
        let mut count = 0;
        if let Some((le, rb)) = Self::split_range(l, r) {
            use std::arch::x86_64::*;
            for i in l..le {
                if b64[i >> 6] & 1 << (i & 63) != 0 {
                    count += 1;
                }
            }
            let b0 = _mm256_set1_epi32(0x55555555);
            let b1 = _mm256_set1_epi32(0x33333333);
            let b2 = _mm256_set1_epi32(0x0f0f0f0f);
            let b3 = _mm256_set1_epi32(0x01010101);
            let mut vcount = _mm256_setzero_si256();
            for v in &mut self.0[le >> 8..rb >> 8] {
                let load = _mm256_load_si256(v);
                let c1 = _mm256_sub_epi32(load, _mm256_and_si256(b0, _mm256_srli_epi32::<1>(load)));
                let c2 = _mm256_add_epi32(
                    _mm256_and_si256(c1, b1),
                    _mm256_and_si256(_mm256_srli_epi32::<2>(c1), b1),
                );
                let c3 = _mm256_srli_epi32::<24>(_mm256_mullo_epi32(
                    _mm256_and_si256(_mm256_add_epi32(c2, _mm256_srli_epi32::<4>(c2)), b2),
                    b3,
                ));
                vcount = _mm256_add_epi32(vcount, c3);
            }
            let [a, b, c, d, e, f, g, h]: [u32; 8] = std::mem::transmute(vcount);
            count += a + b + c + d + e + f + g + h;
            for i in rb..r {
                if b64[i >> 6] & 1 << (i & 63) != 0 {
                    count += 1;
                }
            }
        } else {
            for i in l..r {
                if b64[i >> 6] & 1 << (i & 63) != 0 {
                    count += 1;
                }
            }
        }
        count
    }
}

impl std::fmt::Debug for Bitset {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        let b64 =
            unsafe { std::slice::from_raw_parts(self.0.as_ptr().cast::<u64>(), self.0.len() * 4) };
        let mut iter = b64.iter().rev().skip_while(|&&b| b == 0);
        if let Some(&first) = iter.next() {
            write!(f, "{first:b}")?;
            for &b in iter {
                write!(f, "{b:064b}")?;
            }
            Ok(())
        } else {
            write!(f, "0")
        }
    }
}