Lagrange Interpolation for Polynomial Construction and Evaluation
Given a polynomial (f) of degree (k) and (k+1) data points ((x_i, y_i)) where (f(x_i) = y_i), the goal is to reconstruct (f). While Gaussian elimination solves this in (O(k^3)), Lagrange inteprolation provides an (O(k^2)) construction:
[ f(x) = \sum_{i=1}^{k+1} y_i \cdot \prod_{\substack{j=1 \ j \neq i}}^{k+1} \frac{x - x_j}{x_i - x_j} ]
Proof: Substituting (x = x_i), the product term for index (i) becomes 1, while all other terms vanish because they contain a factor (x_i - x_i = 0) in the numerator.
This formula allows polynomial evaluation at any point in (O(k^2)). Below is an implementation for a modular arithmetic setting (modulo 998244353):
#include <iostream>
using namespace std;
const int MOD = 998244353;
const int MAXN = 2005;
int x_vals[MAXN], y_vals[MAXN];
int num_points, query;
int mod_pow(int base, int exp) {
int result = 1;
while (exp > 0) {
if (exp & 1) result = (1LL * result * base) % MOD;
base = (1LL * base * base) % MOD;
exp >>= 1;
}
return result;
}
int lagrange_eval(int x) {
int total = 0;
for (int i = 0; i < num_points; ++i) {
int term = y_vals[i];
for (int j = 0; j < num_points; ++j) {
if (i == j) continue;
int numerator = (x - x_vals[j] + MOD) % MOD;
int denominator = (x_vals[i] - x_vals[j] + MOD) % MOD;
term = (1LL * term * numerator) % MOD;
term = (1LL * term * mod_pow(denominator, MOD - 2)) % MOD;
}
total = (total + term) % MOD;
}
return total;
}
int main() {
cin >> num_points >> query;
query %= MOD;
if (query < 0) query += MOD;
for (int i = 0; i < num_points; ++i) {
cin >> x_vals[i] >> y_vals[i];
x_vals[i] %= MOD; y_vals[i] %= MOD;
if (x_vals[i] < 0) x_vals[i] += MOD;
if (y_vals[i] < 0) y_vals[i] += MOD;
}
cout << lagrange_eval(query) << endl;
return 0;
}
Application: Sum of Powers
Problem: Compute (S_k(n) = \sum_{i=1}^{n} i^k) for (n \leq 10^9, k \leq 10^6).
Solution: (S_k(n)) is a polynomial of degree (k+1). This can be proven via finite differences or approximated by intergation as (\frac{n^{k+1} - 1}{k+1}). To evaluate it efficiently, precompute (S_k(m)) for (m = 0, 1, \dots, k+1) and apply Lagrange interpolation. Since the (x_i) values are consecutive integers, preprocessing factorials allows (O(k)) evaluation.
#include <iostream>
using namespace std;
const int MOD = 1e9 + 7;
const int MAXK = 1e6 + 5;
int fact[MAXK], inv_fact[MAXK];
int neg_fact[MAXK], inv_neg_fact[MAXK];
int prefix_sum[MAXK], suffix_prod[MAXK];
int n, k;
int mod_pow(int base, int exp) {
int result = 1;
while (exp > 0) {
if (exp & 1) result = (1LL * result * base) % MOD;
base = (1LL * base * base) % MOD;
exp >>= 1;
}
return result;
}
void precompute() {
fact[0] = neg_fact[0] = 1;
for (int i = 1; i <= k + 1; ++i) {
fact[i] = (1LL * fact[i - 1] * i) % MOD;
neg_fact[i] = (1LL * neg_fact[i - 1] * (MOD - i)) % MOD;
}
prefix_sum[0] = 0;
for (int i = 1; i <= k + 1; ++i) {
prefix_sum[i] = (prefix_sum[i - 1] + mod_pow(i, k)) % MOD;
}
inv_fact[k + 1] = mod_pow(fact[k + 1], MOD - 2);
inv_neg_fact[k + 1] = mod_pow(neg_fact[k + 1], MOD - 2);
for (int i = k; i >= 0; --i) {
inv_fact[i] = (1LL * inv_fact[i + 1] * (i + 1)) % MOD;
inv_neg_fact[i] = (1LL * inv_neg_fact[i + 1] * (MOD - i - 1)) % MOD;
}
suffix_prod[0] = n % MOD;
for (int i = 1; i <= k + 1; ++i) {
suffix_prod[i] = (1LL * suffix_prod[i - 1] * ((n - i + MOD) % MOD)) % MOD;
}
}
int lagrange_term(int idx, int x) {
if (x <= k + 1) {
if (idx != x) return 0;
int numerator = (idx == 0) ? suffix_prod[k + 1] : (1LL * suffix_prod[idx - 1] * neg_fact[k + 1 - idx]) % MOD;
numerator = (1LL * numerator * inv_neg_fact[k + 1 - idx]) % MOD;
numerator = (1LL * numerator * inv_fact[idx]) % MOD;
return numerator;
}
int numerator = suffix_prod[k + 1];
numerator = (1LL * numerator * mod_pow((x - idx + MOD) % MOD, MOD - 2)) % MOD;
numerator = (1LL * numerator * inv_neg_fact[k + 1 - idx]) % MOD;
numerator = (1LL * numerator * inv_fact[idx]) % MOD;
return numerator;
}
int main() {
cin >> n >> k;
precompute();
int answer = 0;
for (int i = 0; i <= k + 1; ++i) {
int term = (1LL * prefix_sum[i] * lagrange_term(i, n)) % MOD;
answer = (answer + term) % MOD;
}
cout << answer << endl;
return 0;
}