「Learning Notes」多项式的多点求值与快速插值

本文会随着笔者的水平提升持续更新.
若发现文中有叙述不严谨之处,欢迎指出。

多项式的多点求值

Description

给出一个多项式 f(x) n 个点 x_{1},x_{2},...,x_{n} ,求

f\left(x_{1}\right),f\left(x_{2}\right),...,f\left(x_{n}\right)

Method

考虑使用分治来将问题规模减半.

将给定的点分为两部分:

\begin{aligned} X_{0}&=\left\{x_{1},x_{2},...,x_{\left\lfloor\frac{n}{2}\right\rfloor}\right\}\\ X_{1}&=\left\{x_{\left\lfloor\frac{n}{2}\right\rfloor+1},x_{\left\lfloor\frac{n}{2}\right\rfloor+2},...,x_{n}\right\} \end{aligned}

构造多项式

g_{0}(x)=\prod_{x_{i}\in X_{0}}\left(x-x_{i}\right)

则有 \forall x\in X_{0}:g_{0}(x)=0

考虑将 f(x) 表示为 g_{0}(x)Q(x)+f_{0}(x) 的形式,即:

f_{0}(x)\equiv f(x)\pmod{g_{0}(x)}

则有 \forall x\in X_{0}:f(x)=g_{0}(x)Q(x)+f_{0}(x)=f_{0}(x) X_{1} 同理.

至此,问题的规模被减半,可以使用分治+多项式取模解决.

时间复杂度

T(n)=2T\left(\frac{n}{2}\right)+O(n \log{n})=O\left(n\log^{2}{n}\right)

Code

啥?你要代码?不存在的,调了一万年现在还是 WA

放一份 _rqy 的代码 / 逃

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
#include <cctype>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>

typedef long long LL;

const int mod = 998244353;
const int g = 3;
const int N = 100050 * 4;

LL pow_mod(LL a, LL b) {
b = (b % (mod - 1) + mod - 1) % (mod - 1);
LL ans = 1;
for (; b; b >>= 1, a = a * a % mod)
if (b & 1) ans = ans * a % mod;
return ans;
}

typedef std::vector<LL> VLL;

int len, rev[N];
LL omega[N], omega_i[N];

void InitNTT(int n) {
int p = 0;
while ((1 << p) <= n) ++p;
len = 1 << p;
for (int i = 1; i < len; ++i)
rev[i] = ((i & 1) << (p - 1)) | (rev[i >> 1] >> 1);
omega[0] = omega_i[0] = 1;
LL w = omega[1] = omega_i[len - 1] = pow_mod(g, (mod - 1) / len);
for (int i = 2; i < len; ++i)
omega[i] = omega_i[len - i] = omega[i - 1] * w % mod;
}

void NTT(LL *A, LL *omega) {
for (int i = 0; i < len; ++i) if (rev[i] > i) std::swap(A[i], A[rev[i]]);
for (int h = 2; h <= len; h <<= 1)
for (int j = 0, t = len / h; j < len; j += h) {
LL *w = omega;
for (LL *l = A + j, *r = l + h / 2, *p = r; l != p; ++l, ++r) {
LL _t1 = *l, _t2 = *r * *w % mod;
*l = (_t1 + _t2) % mod;
*r = (_t1 - _t2) % mod;
w += t;
}
}
if (omega == ::omega_i)
for (int i = 0, v = -(mod - 1) / len; i < len; ++i)
A[i] = A[i] * v % mod;
}

void Conv(const VLL &A, const VLL &B, VLL &C) {
static LL _t1[N], _t2[N];
int n = A.size(), m = B.size();
std::copy(A.begin(), A.end(), _t1);
std::copy(B.begin(), B.end(), _t2);
InitNTT(n + m);
std::fill(_t1 + n, _t1 + len, 0);
std::fill(_t2 + m, _t2 + len, 0);
NTT(_t1, omega); NTT(_t2, omega);
for (int i = 0; i < len; ++i) _t1[i] = _t1[i] * _t2[i] % mod;
NTT(_t1, omega_i);
C.resize(n + m - 1);
for (int i = 0; i < n + m - 1; ++i) C[i] = _t1[i];
}

void PolyInv(const LL *A, int n, LL *B) {
if (n == 1) {
B[0] = pow_mod(A[0], mod - 2);
return;
}
static LL _t1[N], _t2[N];
int m = (n + 1) / 2;
PolyInv(A, m, B);
InitNTT((n + 1) * 2);
for (int i = 0; i < n; ++i) _t1[i] = A[i];
for (int i = n; i < len; ++i) _t1[i] = 0;
for (int i = 0; i < m; ++i) _t2[i] = B[i];
for (int i = m; i < len; ++i) _t2[i] = 0;
NTT(_t1, omega); NTT(_t2, omega);
for (int i = 0; i < len; ++i)
_t2[i] = (2 - _t1[i] * _t2[i]) % mod * _t2[i] % mod;
NTT(_t2, omega_i);
for (int i = 0; i < n; ++i) B[i] = _t2[i];
}

void PolyMod(const LL *A, int n, const LL *B, int m, LL *D) {
if (n < m) {
for (int i = 0; i < n; ++i) D[i] = A[i];
return;
}
static LL _C[N], _t1[N];
int t = n - m + 1, k = m - 1;
for (int i = 0; i < t && i < m; ++i) _t1[i] = B[m - i - 1];
for (int i = m; i < t; ++i) _t1[i] = 0;
PolyInv(_t1, t, _C);
InitNTT(2 * t);
for (int i = 0; i < t; ++i) _t1[i] = A[n - i - 1];
for (int i = t; i < len; ++i) _C[i] = _t1[i] = 0;
NTT(_t1, omega); NTT(_C, omega);
for (int i = 0; i < len; ++i) _C[i] = _C[i] * _t1[i] % mod;
NTT(_C, omega_i);
InitNTT(m + t);
for (int i = 0; i < t - i - 1; ++i) std::swap(_C[i], _C[t - i - 1]);
for (int i = t; i < len; ++i) _C[i] = 0;
for (int i = 0; i < m; ++i) _t1[i] = B[i];
for (int i = m; i < len; ++i) _t1[i] = 0;
NTT(_C, omega); NTT(_t1, omega);
for (int i = 0; i < len; ++i) _t1[i] = _t1[i] * _C[i] % mod;
NTT(_t1, omega_i);
for (int i = 0; i < k; ++i) D[i] = (A[i] - _t1[i]) % mod;
}

void Mod(const VLL &A, const VLL &B, VLL &D) {
static LL _t1[N], _t2[N], _t3[N];
int n = A.size(), m = B.size();
for (int i = 0; i < n; ++i) _t1[i] = A[i];
for (int i = 0; i < m; ++i) _t2[i] = B[i];
PolyMod(_t1, n, _t2, m, _t3);
int k = std::min(m - 1, n);
D.resize(k);
for (int i = 0; i < k; ++i) D[i] = _t3[i];
while (D.size() > 1 && D.back() == 0) D.pop_back();
}

int cnt;
LL F[N], x[N], y[N];
VLL A[N], B[N];

void Solve1(int T, int l, int r) {
if (l == r) {
A[T].clear();
A[T].push_back(-x[l] % mod);
A[T].push_back(1);
} else {
int mid = (l + r) >> 1, L, R;
Solve1(L = ++cnt, l, mid);
Solve1(R = ++cnt, mid + 1, r);
Conv(A[L], A[R], A[T]);
}
}

void Solve2(int T, int l, int r) {
if (l == r) {
y[l] = B[T][0];
} else {
int mid = (l + r) >> 1;
int L = ++cnt;
Mod(B[T], A[L], B[L]);
Solve2(L, l, mid);
int R = ++cnt;
Mod(B[T], A[R], B[R]);
Solve2(R, mid + 1, r);
}
}

void Solve(int n, int m) {
Solve1(cnt = 1, 0, m - 1);
B[1].resize(n);
for (int i = 0; i < n; ++i) B[1][i] = F[i];
Mod(B[1], A[1], B[1]);
Solve2(cnt = 1, 0, m - 1);
}

int main() {
int n, k;
scanf("%d", &n);
for (int i = 0; i < n; ++i) scanf("%lld", &F[i]);
scanf("%d", &k);
for (int i = 0; i < k; ++i) scanf("%lld", &x[i]);
Solve(n, k);
for (int i = 0; i < k; ++i)
printf("%lld\n", (y[i] + mod) % mod);
return 0;
}

多项式的快速插值

Description

给出一个 n+1 个点的集合

X=\left\{\left(x_{0},y_{0}\right),\left(x_{1},y_{1}\right),...,\left(x_{n},y_{n}\right)\right\}

求一个 n 次多项式 f(x) 使得其满足 \forall(x, y)\in X:f(x)=y

Method

仍然考虑使用分治来将问题规模减半.

将给定的点分为两部分:

\begin{aligned} X_{0}&=\left\{x_{0},x_{1},...,x_{\left\lfloor\frac{n}{2}\right\rfloor}\right\}\\ X_{1}&=\left\{x_{\left\lfloor\frac{n}{2}\right\rfloor+1},x_{\left\lfloor\frac{n}{2}\right\rfloor+2},...,x_{n}\right\} \end{aligned}

假设已经求出了 X_{0} 中的点插值出的多项式 f_{0}(x) ,考虑如何使其变为所求的 f(x)

构造多项式

g_{0}(x)=\prod_{x_{i}\in X_{0}}\left(x-x_{i}\right)

则有 \forall(x, y)\in X_{0}:g_{0}(x)=0
考虑将 f(x) 表示为 g_{0}(x)f_{1}(x)+f_{0}(x) 的形式。

由于 \forall(x, y)\in X_{0}:f(x)=g_{0}(x)f_{1}(x)+f_{0}(x)=f_{0}(x)=y ,故 X_{0} 中的点都在 f(x) 上.

考虑构造 f_{1}(x) 使得 X_{1} 中的点也在 f(x) 上,即:

\forall(x, y)\in X_{1}:f_{1}(x)g_{0}(x)+f_{0}(x)=y

变形可得:

\forall(x, y)\in X_{1}:f_{1}(x)=\frac{y-f_{0}(x)}{g_{0}(x)}

这样就得到了新的待插值点集合:

X'_{1}=\left\{\left(x,\frac{y-f_{0}(x)}{g_{0}(x)}\right):(x, y)\in X_{1}\right\}

递归对 X'_{1} 插值出 f_{1}(x) 即可。

由于每次都需要多点求值求出新的待插值点集合 X'_{1} ,时间复杂度为:

T(n)=2T\left(\frac{n}{2}\right)+O\left(n\log^{2}{n}\right)=O\left(n\log^{3}{n}\right)

Code

啥?你要代码?我多点求值都没过你问我要插值代码?

这次 _rqy 的代码也没有了,就啥都不放了 / 逃