Number Theoretic Transform: Fast Polynomial Multiplication over Finite Fields
In a finite field modulo a prime $p$, a primitive root of unity of order $n$ exists if $n$ divides $p-1$. Let $g$ be a primitive root modulo $p$. The $n$-th root of unity $\zeta_n$ is defined as: $$\zeta_n = g^{(p-1)/n} \mod p$$
Key properties analogous to complex roots of unity:
- $\zeta_{2n}^2 = \zeta_n$
- $\zeta_n^{k + n/2} = -\zeta_n^k$
- $\zeta_n^{k + n} = \zeta_n^k$
Consider a polynomial $P(x)$ of degree less than $n$: $$P(x) = \sum_{j=0}^{n-1} p_j x^j$$ The NTT computes the evaluation of $P(x)$ at the powers of $\zeta_n$: $$P_k = P(\zeta_n^k) = \sum_{j=0}^{n-1} p_j \zeta_n^{jk}$$
To compute the NTT efficiently, split $P(x)$ in to even and odd parts: $$P_{\text{even}}(x) = p_0 + p_2 x + p_4 x^2 + \cdots$$ $$P_{\text{odd}}(x) = p_1 + p_3 x + p_5 x^2 + \cdots$$ Then, $$P(x) = P_{\text{even}}(x^2) + x P_{\text{odd}}(x^2)$$ Evaluating at $\zeta_n^k$: $$P(\zeta_n^k) = P_{\text{even}}(\zeta_n^{2k}) + \zeta_n^k P_{\text{odd}}(\zeta_n^{2k})$$ Since $\zeta_n^{2k} = \zeta_{n/2}^k$, this becomes: $$P(\zeta_n^k) = P_{\text{even}}(\zeta_{n/2}^k) + \zeta_n^k P_{\text{odd}}(\zeta_{n/2}^k)$$ Similar, for the negative index: $$P(\zeta_n^{k + n/2}) = P_{\text{even}}(\zeta_{n/2}^k) - \zeta_n^k P_{\text{odd}}(\zeta_{n/2}^k)$$
Iterative NTT Implementasion (Modulo 998244353, g=3)
using ll = long long;
const ll MOD = 998244353;
const ll PRIMITIVE_ROOT = 3;
ll mod_exp(ll base, ll exp, ll mod) {
ll res = 1;
base %= mod;
while (exp) {
if (exp & 1) res = res * base % mod;
base = base * base % mod;
exp >>= 1;
}
return res;
}
void ntt_transform(vector<ll>& poly, bool invert) {
int n = poly.size();
for (int idx = 1, rev_idx = 0; idx < n; ++idx) {
int bit = n >> 1;
while (rev_idx & bit) {
rev_idx ^= bit;
bit >>= 1;
}
rev_idx |= bit;
if (idx < rev_idx) swap(poly[idx], poly[rev_idx]);
}
for (int segment_len = 2; segment_len <= n; segment_len <<= 1) {
ll root = mod_exp(PRIMITIVE_ROOT, (MOD - 1) / segment_len, MOD);
if (invert) root = mod_exp(root, MOD - 2, MOD);
for (int start = 0; start < n; start += segment_len) {
ll current_factor = 1;
for (int pos = 0; pos < segment_len / 2; ++pos) {
ll even_part = poly[start + pos];
ll odd_part = current_factor * poly[start + pos + segment_len / 2] % MOD;
poly[start + pos] = (even_part + odd_part) % MOD;
poly[start + pos + segment_len / 2] = (even_part - odd_part + MOD) % MOD;
current_factor = current_factor * root % MOD;
}
}
}
if (invert) {
ll inv_n = mod_exp(n, MOD - 2, MOD);
for (ll& coeff : poly) coeff = coeff * inv_n % MOD;
}
}
Polynomial Multiplication
vector<ll> multiply_polynomials(const vector<ll>& a, const vector<ll>& b) {
vector<ll> poly_a = a, poly_b = b;
int size = 1;
while (size < (int)a.size() + (int)b.size()) size <<= 1;
poly_a.resize(size);
poly_b.resize(size);
ntt_transform(poly_a, false);
ntt_transform(poly_b, false);
for (int i = 0; i < size; ++i) poly_a[i] = poly_a[i] * poly_b[i] % MOD;
ntt_transform(poly_a, true);
return poly_a;
}