Linear Recurrence (Constant Coefficient)

Calculates nth element of linearly recurrent sequence {an}, given initial values and coefficients.

Time complexity: O(k³logn), where k is the length of coefficients.

const M: u64 = 998244353;
const M32: u32 = M as u32;
fn recurrence(init: impl AsRef<[u32]>, coef: impl AsRef<[u32]>, n: usize) -> u32 {
    let init = init.as_ref();
    let coef = coef.as_ref();
    let k = init.len();
    assert_eq!(k, coef.len());
    if n < k {
        return init[n];
    }
    let mut acc = vec![0; k * k];
    for i in 0..k {
        acc[i * k + i] = 1;
    }
    let mut mult = vec![0; k * (k - 1)];
    for i in 1..k {
        mult[(i - 1) * k + i] = 1;
    }
    mult.extend_from_slice(coef);
    let mut aux = vec![0; k * k];
    let t = (n - k + 1..=n).min_by_key(|t| t.count_ones()).unwrap();
    for s in 0..64 - t.leading_zeros() {
        if t & (1 << s) != 0 {
            multiply(k, &acc, &mult, &mut aux);
            (acc, aux) = (aux, acc);
        }
        multiply(k, &mult, &mult, &mut aux);
        (mult, aux) = (aux, mult);
    }
    let mut result = 0;
    for (&y, &x) in init.iter().zip(&acc[(n - t) * k..]) {
        result += (y as u64 * x as u64 % M) as u32;
        if result >= M32 {
            result -= M32;
        }
    }
    result
}

fn multiply(n: usize, a: &[u32], b: &[u32], c: &mut [u32]) {
    for i in 0..n {
        for j in 0..n {
            c[i * n + j] = 0;
            for k in 0..n {
                c[i * n + j] += (a[i * n + k] as u64 * b[k * n + j] as u64 % M) as u32;
                if c[i * n + j] >= M32 {
                    c[i * n + j] -= M32;
                }
            }
        }
    }
}