题目描述
给定$2$个多项式$F(x), G(x)$,请求出$F(x) * G(x)$。
系数对$p$取模,且不保证$p$可以分解成$p = a \cdot 2^k + 1$之形式。
输入/输出
不关心
$1 \leq n \leq 10^5, 0 \leq a_i, b_i \leq 10^9, 2 \leq p \leq 10^9 + 9$
题解
主要是记录一下一次FFT同时对两个多项式进行DFT或IDFT这个常数技巧是怎么实现的。。。推式子什么的我也不会

现在这里有两个多项式$A(x),B(x)$,这两个多项式我们都需要做一次DFT
那么我们定义$P(x)=A(x)+i*B(x), Q(x)=A(x)-i*B(x)$
我们把$P(x)$经过DFT后的第$k$项记为$DFT(P_k)$ (其他的也同理),经过一些证明我们可以得到这样一个结论:$DFT(P_k)和DFT(Q_{n-k})$互为共轭复数,特殊的,$DFT(P_0)$和$DFT(Q_0)$为共轭复数
所以我们一次正向FFT之后求出所有的$DFT(P_k)$,就可以直接算出$DFT(Q_k)$
而此时$DFT(A_k)=\dfrac{DFT(P_k)+DFT(Q_k)}{2}, DFT(B_k)=-i*\dfrac{DFT(P_k)-DFT(Q_k)}{2}$也是直接算出来就好了
综上 我们只用了一次FFT就对$A(x)$和$B(x)$都进行了DFT 而不用两次
代码实现:
1 2 3 4 5 6 7 8 9 10 11 12 13
| void DFT(Complex *A, Complex *B) { for (int i = 0; i < lim; i++) { p[i] = A[i] + I * B[i]; } FFT(p, 1); for (int i = 0; i < lim; i++) { q[i] = conj(i ? p[lim-i] : p[0]); } for (int i = 0; i < lim; i++) { A[i] = (p[i] + q[i]) / 2; B[i] = (p[i] - q[i]) * I / -2; } }
|
IDFT也可以一次FFT处理两个多项式
对于两个点值表达的多项式$C(x), D(x)$
我们让$R(x)=C(x)+i*D(x)$
然后对$R$进行反向FFT
最后$IDFT(R_k)$的实部就是$IDFT(C_k)$,$IDFT(R_k)$的虚部就是$IDFT(D_k)$
这个就更好实现了
1 2 3 4 5 6 7 8 9 10
| void IDFT(Complex *C, Complex *D) { for (int i = 0; i < lim; i++) { r[i] = C[i] + I * D[i]; } FFT(r, -1); for (int i = 0; i < lim; i++) { C[i].x = r[i].x; D[i].x = r[i].y; } }
|
所以我们又只用了一次FFT就算出两个点值表达式IDFT后的结果
好像有人给这个取了个名字叫MTT
这个技巧的使用没有什么限制,只要有两个多项式都需要进行DFT/IDFT就可以用
不卡时限的题目还是不怎么需要用到这个技巧的。。。但是此题就要用
回到此题
当然你可以用三模数NTT 但是蒟蒻我并不会用
所以用FFT
如果直接FFT爆乘的话 肯定会爆出double的范围
但是 如果我们把F(x)拆成两个多项式$F(x)=A(x)*2^{15}+A_2(x)$,把$G(x)$拆成$G(x)=B(x)*2^{15}+B_2(x)$
然后计算$(A(x)*2^{15}+A_2(x))(B(x)*2^{15}+B_2(x))=A(x)*B(x)*2^{30}+(A(x)*B_2(x)+A_2(x)*B(x))*2^{15}+A_2(x)*B_2(x)$
这四个式子两两相乘是不会乘爆的
但是这样做需要做8次FFT(为了不乘爆最少也要7次) 时间上接受不了
用上面的那个技巧 我们可以把FFT的次数优化到4次
具体来说 就是$A(x)和A_2(x)$的DFT一起做$B(x)和B_2(x)$的DFT一起做
然后分别算出$A(x)*B(x),A(x)*B_2(x),A_2(x)*B(x),A_2(x)*B_2(x)$,同样两个一组做IDFT
一共4次 完美 时间上还过得去 只比NTT慢一倍左右吧
注意一定要用long double
。。。可能是乘积太大了导致精度不够用
代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
| #include <bits/stdc++.h> using namespace std; typedef long long ll;
template<typename T> inline void read(T &num) { T x = 0, f = 1; char ch = getchar(); for (; ch > '9' || ch < '0'; ch = getchar()) if (ch == '-') f = -1; for (; ch <= '9' && ch >= '0'; ch = getchar()) x = (x << 3) + (x << 1) + (ch ^ '0'); num = x * f; }
struct Complex { long double x, y; Complex(long double xx = 0, long double yy = 0): x(xx), y(yy) {} };
inline Complex operator + (Complex p, Complex q) { return Complex(p.x+q.x , p.y+q.y); } inline Complex operator - (Complex p, Complex q) { return Complex(p.x-q.x , p.y-q.y); } inline Complex operator * (Complex p, Complex q) { return Complex(p.x*q.x-p.y*q.y , p.x*q.y+p.y*q.x); } inline Complex operator / (Complex p, long double q) { return Complex(p.x/q, p.y/q); } inline Complex conj(Complex p) { return Complex(p.x, -p.y); }
Complex I = Complex(0, 1), p[500005], q[500005], a[500005], a2[500005], b[500005], b2[500005]; ll n, m, mod, lim, l, rev[500005], ans[500005]; const long double pi = acos(-1.0);
void FFT(Complex *c, int tp) { for (int i = 0; i < lim; i++) { if (i < rev[i]) swap(c[i], c[rev[i]]); } for (int mid = 1; mid < lim; mid <<= 1) { Complex wn = Complex(cos(pi / mid), sin(pi / mid) * tp); for (int r = mid<<1, j = 0; j < lim; j += r) { Complex w = Complex(1, 0); for (int k = 0; k < mid; k++, w = w * wn) { Complex x = c[j+k], y = w * c[j+k+mid]; c[j+k] = x + y; c[j+k+mid] = x - y; } } } }
void DFT(Complex *A, Complex *B) { for (int i = 0; i < lim; i++) { p[i] = A[i] + I * B[i]; } FFT(p, 1); for (int i = 0; i < lim; i++) { q[i] = conj(i ? p[lim-i] : p[0]); } for (int i = 0; i < lim; i++) { A[i] = (p[i] + q[i]) / 2; B[i] = (p[i] - q[i]) * I / -2; } }
int main() { read(n), read(m), read(mod); lim = 1, l = 0; while (lim < n + m) { lim <<= 1; l++; } for (int i = 0; i < lim; i++) { rev[i] = (rev[i>>1]>>1)|((i&1)<<(l-1)); } for (int i = 0; i <= n; i++) { ll x; read(x); a[i].x = x >> 15; a2[i].x = x & 0x7fff; } for (int i = 0; i <= m; i++) { ll x; read(x); b[i].x = x >> 15; b2[i].x = x & 0x7fff; } DFT(a, a2); DFT(b, b2); for (int i = 0; i < lim; i++) { p[i] = a[i] * b[i] + a[i] * b2[i] * I; q[i] = a2[i] * b[i] + a2[i] * b2[i] * I; } FFT(p, -1); FFT(q, -1); for (int i = 0; i < lim; i++) { ll ab, ab2, a2b, a2b2; ab = (ll)(p[i].x / lim + 0.5) % mod; a2b = (ll)(q[i].x / lim + 0.5) % mod; ab2 = (ll)(p[i].y / lim + 0.5) % mod; a2b2 = (ll)(q[i].y / lim + 0.5) % mod; ans[i] = (((ab << 30) % mod) + ((((ab2 + a2b) % mod) << 15) % mod) + a2b2) % mod; } for (int i = 0; i <= n + m; i++) { printf("%lld ", ans[i]); } return 0; }
|