Number Theoretic Transform

struct NttParam {
    m: ModU64,
    u: u64,
    ui: u64,
    s: u32,
}

static NTT: &[NttParam] = &[NttParam {
    m: ModU64::new(998244353),
    u: 15311432,
    ui: 469870224,
    s: 23,
}, NttParam {
    m: ModU64::new(1107296257),
    u: 1087287097,
    ui: 623044540,
    s: 25,
}, NttParam {
    m: ModU64::new(1711276033),
    u: 969788637,
    ui: 1790856,
    s: 25,
}];

impl NttParam {
    fn run(&self, a: &mut [u32], inv: bool) {
        let n = a.len();
        let s = n.leading_zeros() + 1;
        for i in 0..n {
            let r = i.reverse_bits() >> s;
            if i < r {
                a.swap(i, r);
            }
        }
        let u = if inv { self.ui } else { self.u };
        for k in 1..=n.trailing_zeros() {
            let mut wlen = u;
            for _ in k..self.s {
                wlen = self.m.rem(wlen * wlen);
            }
            let kh = 1 << (k - 1);
            for i in 0..(n + (1 << k) - 1) >> k {
                let i = i << k;
                let mut w = 1;
                for j in 0..kh {
                    let u = a[i + j] as u64;
                    let v = self.m.rem(a[i + j + kh] as u64 * w);
                    let mut s = u + v;
                    if s >= self.m.1 {
                        s -= self.m.1;
                    }
                    a[i + j] = s as u32;
                    let mut d = u + self.m.1 - v;
                    if d >= self.m.1 {
                        d -= self.m.1;
                    }
                    a[i + j + kh] = d as u32;
                    w = self.m.rem(w * wlen);
                }
            }
        }
        if inv {
            let p = self.m.1 as i64;
            let ni = ((egcd(n as i64, p).1 % p + p) % p) as u64;
            for x in a {
                *x = self.m.rem(*x as u64 * ni) as u32;
            }
        }
    }
}

fn egcd(mut a: i64, mut b: i64) -> (i64, i64, i64) {
    let (mut x, mut y, mut x1, mut y1) = (1, 0, 0, 1);
    while b != 0 {
        let q = a / b;
        (x, x1) = (x1, x - q * x1);
        (y, y1) = (y1, y - q * y1);
        (a, b) = (b, a - q * b);
    }
    (a, x, y)
}

#[derive(Copy, Clone)]
struct ModU64(u128, u64);

impl ModU64 {
    const fn new(div: u64) -> Self {
        Self((!0u128 / div as u128).wrapping_add(1), div)
    }
    fn multop(a: u128, b: u64) -> u64 {
        let mut bottom = (a as u64 as u128) * b as u128;
        bottom >>= 64;
        let top = (a >> 64) * b as u128;
        ((bottom + top) >> 64) as u64
    }
    fn rem(&self, a: u64) -> u64 {
        let low = self.0.wrapping_mul(a as u128);
        Self::multop(low, self.1)
    }
}