菜单 学习猿地 - LMONKEY

VIP

开通学习猿地VIP

尊享10项VIP特权 持续新增

知识通关挑战

打卡带练!告别无效练习

接私单赚外块

VIP优先接,累计金额超百万

学习猿地私房课免费学

大厂实战课仅对VIP开放

你的一对一导师

每月可免费咨询大牛30次

领取更多软件工程师实用特权

入驻
25
0

Fast Fourier Transform

原创
05/13 14:22
阅读数 3962

FFT&NTT及其应用&多项式操作

之前写的博客

船新模板:

 1 #include <cstdio>
 2 #include <algorithm>
 3 #include <cstring>
 4 #include <cmath>
 5 
 6 const int N = 1000010;
 7 const double pi = 3.1415926535897932384626;
 8 
 9 struct CP {
10     double x, y;
11     CP(double X = 0, double Y = 0) {
12         x = X;
13         y = Y;
14     }
15     inline CP operator +(const CP &w) const {
16         return CP(x + w.x, y + w.y);
17     }
18     inline CP operator -(const CP &w) const {
19         return CP(x - w.x, y - w.y);
20     }
21     inline CP operator *(const CP &w) const {
22         return CP(x * w.x - y * w.y, x * w.y + y * w.x); // 注意第一个是 - 号,i² = -1 
23     }
24 }a[N << 2], b[N << 2]; // 数组开最高次数的两倍。
25 
26 int r[N << 2];
27 
28 inline void FFT(int n, CP *a, int f) {
29     for(int i = 0; i < n; i++) {
30         if(i < r[i]) {
31             std::swap(a[i], a[r[i]]); // 先交换
32         }
33     }
34 
35     for(int len = 1; len < n; len <<= 1) { // 这里三个 < 没有 <=
36         CP Wn(cos(pi / len), f * sin(pi / len)); // cos + i * sin
37         for(int i = 0; i < n; i += (len << 1)) { // 每次处理长度为 len << 1 的区间,len倍增,i += len << 1
38             CP w(1, 0);
39             for(int j = 0; j < len; j++) {
40                 CP t = a[i + len + j] * w; // 背诵这4句
41                 a[i + len + j] = a[i + j] - t;
42                 a[i + j] = a[i + j] + t;
43                 w = w * Wn;
44             }
45         }
46     }
47 
48     if(f == -1) {
49         for(int i = 0; i <= n; i++) {
50             a[i].x /= n; // 除以 n
51         }
52     }
53     return;
54 }
55 
56 int main() {
57     int n, m;
58     scanf("%d%d", &n, &m);
59     for(int i = 0, x; i <= n; i++) {
60         scanf("%d", &x);
61         a[i].x = x;
62     }
63     for(int i = 0, x; i <= m; i++) {
64         scanf("%d", &x);
65         b[i].x = x;
66     }
67 
68     int lm = 1, len = 2;
69     while(len <= n + m) { // error : <
70         lm++;
71         len <<= 1;
72     }
73     for(int i = 1; i <= len; i++) { // error : n
74         r[i] = (r[i >> 1] >> 1) | ((i & 1) << (lm - 1));
75     }
76 
77     FFT(len, a, 1);
78     FFT(len, b, 1);
79     for(int i = 0; i <= len; i++) { // error : n
80         a[i] = a[i] * b[i]; // 直接使用复数乘法
81     }
82     FFT(len, a, -1);
83 
84     for(int i = 0; i <= n + m; i++) {
85         printf("%d ", (int)(a[i].x + 0.5));
86     }
87 
88     return 0;
89 }
FFT

 NTT主要记得随时取模,然后单位根变一下。

 1 #include <cstdio>
 2 #include <algorithm>
 3 
 4 typedef long long LL;
 5 const int N = 1000010;
 6 const LL MO = 998244353, g = 3;
 7 
 8 LL a[N << 2], b[N << 2];
 9 int r[N << 2];
10 
11 inline LL qpow(LL a, LL b) {
12     LL ans = 1;
13     a %= MO;
14     while(b) {
15         if(b & 1) {
16             ans = ans * a % MO;
17         }
18         a = a * a % MO;
19         b = b >> 1;
20     }
21     return ans;
22 }
23 
24 inline void NTT(int n, LL *a, int f) {
25     for(int i = 0; i < n; i++) {
26         if(i < r[i]) {
27             std::swap(a[i], a[r[i]]);
28         }
29     }
30 
31     for(int len = 1; len < n; len <<= 1) {
32         LL Wn = qpow(g, (MO - 1) / (len << 1)); // here
33         if(f == -1) {
34             Wn = qpow(Wn, MO - 2); // here 
35         }
36         for(int i = 0; i < n; i += (len << 1)) {
37             LL w = 1;
38             for(int j = 0; j < len; j++) {
39                 LL t = a[i + len + j] * w % MO;
40                 a[i + len + j] = (a[i + j] - t + MO) % MO;
41                 a[i + j] = (a[i + j] + t) % MO;
42                 w = w * Wn % MO;
43             }
44         }
45     }
46 
47     if(f == -1) {
48         LL inv = qpow(n, MO - 2);
49         for(int i = 0; i <= n; i++) {
50             a[i] = a[i] * inv % MO;
51         }
52     }
53 
54     return;
55 }
56 
57 int main() {
58     int n, m;
59     scanf("%d%d", &n, &m);
60     for(int i = 0; i <= n; i++) {
61         scanf("%lld", &a[i]);
62     }
63     for(int i = 0; i <= m; i++) {
64         scanf("%lld", &b[i]);
65     }
66 
67     int lm = 1, len = 2;
68     while(len <=  n + m) {
69         len <<= 1;
70         lm++;
71     }
72     for(int i = 1; i <= len; i++) {
73         r[i] = (r[i >> 1] >> 1) | ((i & 1) << (lm - 1));
74     }
75 
76     NTT(len, a, 1);
77     NTT(len, b, 1);
78     for(int i = 0; i <= len; i++) {
79         a[i] = a[i] * b[i] % MO;
80     }
81     NTT(len, a, -1);
82 
83     for(int i = 0; i <= n + m; i++) {
84         printf("%lld ", (a[i] + MO) % MO);
85     }
86 
87     return 0;
88 }
NTT

接下来搞点例题来。

两个串  万径人踪灭    Fuzzy Search  快速傅里叶之二  序列统计  Triple  礼物 

分治法法塔:

给定g,f0,求fi = ∑fj *gi-j

解:直接暴力是n³,暴力法法塔是n²logn,暴力分治也是n²logn。现在要介绍的算法是采用cdq分治思想来进行法法塔,时间复杂度nlog²n。

考虑一段区间[l, r],其中[l, mid]的f已经求出来了。现在要求[l, mid]对(mid, r]的贡献。

发现fl的贡献要乘上gmid-l+1~r-l,fmid的贡献要乘上g1~r-mid。那么我们把g的[1, r-l]提出来卷积即可。

卷出来的对应位置加到f的(mid, r]上即可。

注意最长的一次卷积是f[0, mid]和g[1, n],所以乘起来长度是1.5n,数组要开三倍。

  1 #include <cstdio>
  2 #include <algorithm>
  3 
  4 typedef long long LL;
  5 const int N = 100010;
  6 const LL MO = 998244353, G = 3;
  7 
  8 int r[N * 3];
  9 LL a[N * 3], b[N * 3], g[N], f[N];
 10 
 11 inline LL qpow(LL a, LL b) {
 12     LL ans = 1;
 13     while(b) {
 14         if(b & 1) {
 15             ans = ans * a % MO;
 16         }
 17         a = a * a % MO;
 18         b = b >> 1;
 19     }
 20     return ans;
 21 }
 22 
 23 inline void NTT(int n, LL *a, int f) {
 24     for(int i = 0; i < n; i++) {
 25         if(i < r[i]) {
 26             std::swap(a[i], a[r[i]]);
 27         }
 28     }
 29     for(int len = 1; len < n; len <<= 1) {
 30         LL Wn = qpow(G, (MO - 1) / (len << 1));
 31         if(f == -1) {
 32             Wn = qpow(Wn, MO - 2);
 33         }
 34         for(int i = 0; i < n; i += (len << 1)) {
 35             LL w = 1;
 36             for(int j = 0; j < len; j++) {
 37                 LL t = a[i + len + j] * w % MO;
 38                 a[i + len + j] = (a[i + j] - t + MO) % MO;
 39                 a[i + j] = (a[i + j] + t) % MO;
 40                 w = w * Wn % MO;
 41             }
 42         }
 43     }
 44     if(f == -1) {
 45         LL inv = qpow(n, MO - 2);
 46         for(int i = 0; i <= n; i++) {
 47             a[i] = a[i] * inv % MO;
 48         }
 49     }
 50     return;
 51 }
 52 
 53 inline int prework(int n) {
 54     int len = 2, lm = 1;
 55     while(len <= n) {
 56         len <<= 1;
 57         lm++;
 58     }
 59     for(int i = 1; i <= len; i++) {
 60         r[i] = (r[i >> 1] >> 1) | ((i & 1) << (lm - 1));
 61     }
 62     return len;
 63 }
 64 
 65 inline void solve(int l, int r) {
 66     if(l == r) {
 67         return;
 68     }
 69     int mid = (l + r) >> 1;
 70     solve(l, mid);
 71 
 72     int len = prework(r - l + mid - l);
 73     for(int i = l; i <= mid; i++) {
 74         a[i - l] = f[i];
 75     }
 76     for(int i = mid - l + 1; i <= len; i++) {
 77         a[i] = 0;
 78     }
 79     for(int i = 1; i <= r - l; i++) {
 80         b[i] = g[i];
 81     }
 82     b[0] = 0;
 83     for(int i = r - l + 1; i <= len; i++) {
 84         b[i] = 0;
 85     }
 86     NTT(len, a, 1);
 87     NTT(len, b, 1);
 88     for(int i = 0; i <= len; i++) {
 89         a[i] = a[i] * b[i] % MO;
 90     }
 91     NTT(len, a, -1);
 92     for(int i = mid + 1; i <= r; i++) {
 93         f[i] = (f[i] + a[i - l]) % MO;
 94     }
 95 
 96     solve(mid + 1, r);
 97     return;
 98 }
 99 
100 int main() {
101 
102     int n;
103     scanf("%d", &n);
104     n--;
105     for(int i = 1; i <= n; i++) {
106         scanf("%lld", &g[i]);
107     }
108     f[0] = 1;
109 
110     solve(0, n);
111 
112     for(int i = 0; i <= n; i++) {
113         printf("%lld ", f[i]);
114     }
115     return 0;
116 }
AC代码

本题也可用多项式求逆解决。

多项式求逆:

参考此博客。利用倍增实现。注意求逆不能一直用点值,因为不同的长度,同一系数会化出不同的点值。

求逆中间的乘法,长度为当前长度的2倍。

  1 #include <cstdio>
  2 #include <algorithm>
  3 #include <cstring>
  4 
  5 typedef long long LL;
  6 const int N = 100010;
  7 const LL MO = 998244353, G = 3;
  8 
  9 int r[N << 2];
 10 LL a[N << 2], b[N << 2], A[N << 2], B[N << 2], c[N << 2];
 11 
 12 inline LL qpow(LL a, LL b) {
 13     LL ans = 1; a %= MO;
 14     while(b) {
 15         if(b & 1) ans = ans * a % MO;
 16         a = a * a % MO;
 17         b = b >> 1;
 18     }
 19     return ans;
 20 }
 21 
 22 inline void prework(int n) {
 23     static int R = 0;
 24     if(R == n) return;
 25     int lm = 1;
 26     while((1 << lm) < n) lm++;
 27     for(int i = 1; i < n; i++) {
 28         r[i] = (r[i >> 1] >> 1) | ((i & 1) << (lm - 1));
 29     }
 30     R = n;
 31     return;
 32 }
 33 
 34 inline void NTT(LL *a, int n, int f) {
 35     prework(n);
 36     for(int i = 0; i < n; i++) {
 37         if(i < r[i]) {
 38             std::swap(a[i], a[r[i]]);
 39         }
 40     }
 41     for(int len = 1; len < n; len <<= 1) {
 42         LL Wn = qpow(G, (MO - 1) / (len << 1));
 43         if(f == -1) Wn = qpow(Wn, MO - 2);
 44         for(int i = 0; i < n; i += (len << 1)) {
 45             LL w = 1;
 46             for(int j = 0; j < len; j++) {
 47                 LL t = a[i + len + j] * w % MO;
 48                 a[i + len + j] = (a[i + j] - t) % MO;
 49                 a[i + j] = (a[i + j] + t) % MO;
 50                 w = w * Wn % MO;
 51             }
 52         }
 53     }
 54     if(f == -1) {
 55         LL inv = qpow(n, MO - 2);
 56         for(int i = 0; i < n; i++) {
 57             a[i] = a[i] * inv % MO;
 58         }
 59     }
 60     return;
 61 }
 62 
 63 inline void mul(LL *a, LL *b, LL *c, int n) {
 64     memcpy(A, a, n * sizeof(LL));
 65     memcpy(B, b, n * sizeof(LL));
 66     NTT(A, n, 1); NTT(B, n, 1);
 67     for(int i = 0; i < n; i++) c[i] = A[i] * B[i] % MO;
 68     NTT(c, n, -1);
 69     return;
 70 }
 71 
 72 void inv(LL *a, LL *ans, int n) {
 73     if(n == 1) {
 74         ans[0] = qpow(a[0], MO - 2);
 75         ans[1] = 0;
 76         return;
 77     }
 78     inv(a, ans, n >> 1);
 79     /// temp = 2 ans - a ans ans
 80     memcpy(A, a, n * sizeof(LL)); memset(A + n, 0, n * sizeof(LL));
 81     memcpy(B, ans, n * sizeof(LL)); memset(B + n, 0, n * sizeof(LL));
 82     NTT(A, n * 2, 1); NTT(B, n * 2, 1);
 83     for(int i = 0; i < n * 2; i++) ans[i] = (2 - A[i] * B[i] % MO) * B[i] % MO;
 84     NTT(ans, n * 2, -1);
 85     memset(ans + n, 0, n * sizeof(LL));
 86     return;
 87 }
 88 
 89 int main() {
 90     int n;
 91     scanf("%d", &n);
 92     for(int i = 0; i < n; i++) {
 93         scanf("%lld", &a[i]);
 94     }
 95     int len = 1;
 96     while(len < n) {
 97         len <<= 1;
 98     }
 99     inv(a, b, len);
100     for(int i = 0; i < n; i++) {
101         printf("%lld ", (b[i] + MO) % MO);
102     }
103     puts("");
104     return 0;
105 }
多项式求逆

多项式除法&取模:

参考资料同上。转置之后求逆然后乘回去。

注意B的逆元求出来长度是n - m + 1.....不是m

memset的时候注意长度要为正才能赋值。

  1 #include <cstdio>
  2 #include <algorithm>
  3 #include <cstring>
  4 
  5 typedef long long LL;
  6 const int N = 100010;
  7 const LL MO = 998244353, G = 3;
  8 typedef LL arr[N << 2];
  9 
 10 int r[N << 2];
 11 arr a, b, c, d, temp, A, B;
 12 
 13 inline LL qpow(LL a, LL b) {
 14     LL ans = 1; a %= MO;
 15     while(b) {
 16         if(b & 1) ans = ans * a % MO;
 17         a = a * a % MO;
 18         b = b >> 1;
 19     }
 20     return ans;
 21 }
 22 
 23 inline void prework(int n) {
 24     static int R = 0;
 25     if(R == n) return;
 26     int lm = 1;
 27     while((1 << lm) < n) lm++;
 28     for(int i = 1; i < n; i++) {
 29         r[i] = (r[i >> 1] >> 1) | ((i & 1) << (lm - 1));
 30     }
 31     R = n;
 32     return;
 33 }
 34 
 35 inline void NTT(LL *a, int n, int f) {
 36     prework(n);
 37     for(int i = 0; i < n; i++) {
 38         if(i < r[i]) {
 39             std::swap(a[i], a[r[i]]);
 40         }
 41     }
 42     for(int len = 1; len < n; len <<= 1) {
 43         LL Wn = qpow(G, (MO - 1) / (len << 1));
 44         if(f == -1) Wn = qpow(Wn, MO - 2);
 45         for(int i = 0; i < n; i += (len << 1)) {
 46             LL w = 1;
 47             for(int j = 0; j < len; j++) {
 48                 LL t = a[i + len + j] * w % MO;
 49                 a[i + len + j] = (a[i + j] - t) % MO;
 50                 a[i + j] = (a[i + j] + t) % MO;
 51                 w = w * Wn % MO;
 52             }
 53         }
 54     }
 55     if(f == -1) {
 56         LL inv = qpow(n, MO - 2);
 57         for(int i = 0; i < n; i++) {
 58             a[i] = a[i] * inv % MO;
 59         }
 60     }
 61     return;
 62 }
 63 
 64 inline void mul(const LL *a, const LL *b, LL *c, int n) {
 65     memcpy(A, a, n * sizeof(LL));
 66     memcpy(B, b, n * sizeof(LL));
 67     NTT(A, n, 1); NTT(B, n, 1);
 68     for(int i = 0; i < n; i++) c[i] = A[i] * B[i] % MO;
 69     NTT(c, n, -1);
 70     return;
 71 }
 72 
 73 void inv(const LL *a, LL *ans, int n) {
 74     if(n == 1) {
 75         ans[0] = qpow(a[0], MO - 2);
 76         ans[1] = 0;
 77         return;
 78     }
 79     inv(a, ans, n >> 1);
 80     /// temp = 2 ans - a ans ans
 81     memcpy(A, a, n * sizeof(LL)); memset(A + n, 0, n * sizeof(LL));
 82     memcpy(B, ans, n * sizeof(LL)); memset(B + n, 0, n * sizeof(LL));
 83     NTT(A, n * 2, 1); NTT(B, n * 2, 1);
 84     for(int i = 0; i < n * 2; i++) ans[i] = (2 - A[i] * B[i] % MO) * B[i] % MO;
 85     NTT(ans, n * 2, -1);
 86     memset(ans + n, 0, n * sizeof(LL));
 87     return;
 88 }
 89 
 90 inline void div(LL *a, LL *b, LL *c, int na, int nb) {
 91     int nc = na - nb;
 92     na++; nb++; nc++;
 93     std::reverse(a, a + na);
 94     std::reverse(b, b + nb);
 95     int len = 1;
 96     while(len < nc) len <<= 1;
 97     if(nb < len) memset(b + nb, 0, (len - nb) * sizeof(LL));
 98     inv(b, c, len);
 99     memset(c + nc, 0, (len - nc) * sizeof(LL));
100     while(len < na + nc) len <<= 1;
101     memset(a + na, 0, (len - na) * sizeof(LL));
102     memset(c + nc, 0, (len - nc) * sizeof(LL));
103     mul(a, c, c, len);
104     std::reverse(a, a + na);
105     std::reverse(b, b + nb);
106     std::reverse(c, c + nc); memset(c + nc, 0, (len - nc) * sizeof(LL));
107     return;
108 }
109 
110 inline void mod(LL *a, LL *b, LL *c, LL *d, int na, int nb) {
111     int nc = na - nb;
112     na++; nb++; nc++;
113     int len = 1;
114     while(len < na) len <<= 1;
115     memset(b + nb, 0, (len - nb) * sizeof(LL));
116     memset(c + nc, 0, (len - nc) * sizeof(LL));
117     mul(b, c, temp, len);
118     for(int i = 0; i < nb - 1; i++) {
119         d[i] = (a[i] - temp[i]) % MO;
120     }
121     return;
122 }
123 
124 int main() {
125     for(int i = 0; i < N * 4; i++) { // test
126         temp[i] = rand();
127         a[i] = rand();
128         b[i] = rand();
129         c[i] = rand();
130         d[i] = rand();
131         A[i] = rand();
132         B[i] = rand();
133     }
134     int n, m;
135     scanf("%d%d", &n, &m);
136     for(int i = 0; i <= n; i++) {
137         scanf("%lld", &a[i]);
138     }
139     for(int i = 0; i <= m; i++) {
140         scanf("%lld", &b[i]);
141     }
142     div(a, b, c, n, m);
143     mod(a, b, c, d, n, m);
144     for(int i = 0; i <= n - m; i++) {
145         printf("%lld ", (c[i] + MO) % MO);
146     }
147     printf("\n");
148     for(int i = 0; i < m; i++) {
149         printf("%lld ", (d[i] + MO) % MO);
150     }
151     printf("\n");
152     return 0;
153 }
多项式除法&取模

多项式取对数&求导&积分:

这个比指数友善些......参见洛咕题解

求导和积分就很友善。积分可以预处理逆元来做到O(n),但是我由于太懒选择了nlogn...

求ln,就是把这个ln外面套积分,里面套求导。然后lnx的导数是1/x就很友善了,求逆即可。

  1 #include <cstdio>
  2 #include <algorithm>
  3 #include <cstring>
  4 
  5 typedef long long LL;
  6 const int N = 100010;
  7 const LL MO = 998244353, G = 3;
  8 typedef LL arr[N << 2];
  9 
 10 int r[N << 2];
 11 arr a, b, c, d, temp, A, B;
 12 
 13 inline LL qpow(LL a, LL b) {
 14     LL ans = 1; a %= MO;
 15     while(b) {
 16         if(b & 1) ans = ans * a % MO;
 17         a = a * a % MO;
 18         b = b >> 1;
 19     }
 20     return ans;
 21 }
 22 
 23 inline void prework(int n) {
 24     static int R = 0;
 25     if(R == n) return;
 26     int lm = 1;
 27     while((1 << lm) < n) lm++;
 28     for(int i = 1; i < n; i++) {
 29         r[i] = (r[i >> 1] >> 1) | ((i & 1) << (lm - 1));
 30     }
 31     R = n;
 32     return;
 33 }
 34 
 35 inline void NTT(LL *a, int n, int f) {
 36     prework(n);
 37     for(int i = 0; i < n; i++) {
 38         if(i < r[i]) {
 39             std::swap(a[i], a[r[i]]);
 40         }
 41     }
 42     for(int len = 1; len < n; len <<= 1) {
 43         LL Wn = qpow(G, (MO - 1) / (len << 1));
 44         if(f == -1) Wn = qpow(Wn, MO - 2);
 45         for(int i = 0; i < n; i += (len << 1)) {
 46             LL w = 1;
 47             for(int j = 0; j < len; j++) {
 48                 LL t = a[i + len + j] * w % MO;
 49                 a[i + len + j] = (a[i + j] - t) % MO;
 50                 a[i + j] = (a[i + j] + t) % MO;
 51                 w = w * Wn % MO;
 52             }
 53         }
 54     }
 55     if(f == -1) {
 56         LL inv = qpow(n, MO - 2);
 57         for(int i = 0; i < n; i++) {
 58             a[i] = a[i] * inv % MO;
 59         }
 60     }
 61     return;
 62 }
 63 
 64 inline void mul(const LL *a, const LL *b, LL *c, int n) {
 65     memcpy(A, a, n * sizeof(LL));
 66     memcpy(B, b, n * sizeof(LL));
 67     NTT(A, n, 1); NTT(B, n, 1);
 68     for(int i = 0; i < n; i++) c[i] = A[i] * B[i] % MO;
 69     NTT(c, n, -1);
 70     return;
 71 }
 72 
 73 void Inv(const LL *a, LL *ans, int n) {
 74     if(n == 1) {
 75         ans[0] = qpow(a[0], MO - 2);
 76         ans[1] = 0;
 77         return;
 78     }
 79     Inv(a, ans, n >> 1);
 80     /// temp = 2 ans - a ans ans
 81     memcpy(A, a, n * sizeof(LL)); memset(A + n, 0, n * sizeof(LL));
 82     memcpy(B, ans, n * sizeof(LL)); memset(B + n, 0, n * sizeof(LL));
 83     NTT(A, n * 2, 1); NTT(B, n * 2, 1);
 84     for(int i = 0; i < n * 2; i++) ans[i] = (2 - A[i] * B[i] % MO) * B[i] % MO;
 85     NTT(ans, n * 2, -1);
 86     memset(ans + n, 0, n * sizeof(LL));
 87     return;
 88 }
 89 
 90 inline void getInv(const LL *a, LL *ans, int n) {
 91     memcpy(temp, a, n * sizeof(LL));
 92     int len = 1;
 93     while(len < n) len <<= 1;
 94     memset(temp + n, 0, (len - n) * sizeof(LL));
 95     Inv(temp, ans, len);
 96     memset(ans + n, 0, sizeof(LL));
 97     return;
 98 }
 99 
100 inline void div(LL *a, LL *b, LL *c, int na, int nb) {
101     int nc = na - nb;
102     na++; nb++; nc++;
103     std::reverse(a, a + na);
104     std::reverse(b, b + nb);
105     getInv(b, c, nc);
106     int len = 1;
107     while(len < na + nc) len <<= 1;
108     memset(a + na, 0, (len - na) * sizeof(LL));
109     memset(c + nc, 0, (len - nc) * sizeof(LL));
110     mul(a, c, c, len);
111     std::reverse(a, a + na);
112     std::reverse(b, b + nb);
113     std::reverse(c, c + nc); memset(c + nc, 0, (len - nc) * sizeof(LL));
114     return;
115 }
116 
117 inline void mod(LL *a, LL *b, LL *c, LL *d, int na, int nb) {
118     int nc = na - nb;
119     na++; nb++; nc++;
120     int len = 1;
121     while(len < na) len <<= 1;
122     memset(b + nb, 0, (len - nb) * sizeof(LL));
123     memset(c + nc, 0, (len - nc) * sizeof(LL));
124     mul(b, c, temp, len);
125     for(int i = 0; i < nb - 1; i++) {
126         d[i] = (a[i] - temp[i]) % MO;
127     }
128     return;
129 }
130 
131 inline void der(const LL *a, LL *b, int n) { /// derivation   qiu dao
132     for(int i = 0; i < n - 1; i++) {
133         b[i] = a[i + 1] * (i + 1) % MO;
134     }
135     b[n - 1] = 0;
136     return;
137 }
138 
139 inline void ter(const LL *a, LL *b, int n) { /// quadrature   ji fen
140     for(int i = n - 1; i >= 1; i--) {
141         b[i] = a[i - 1] * qpow(i, MO - 2) % MO;
142     }
143     b[0] = 0;
144     return;
145 }
146 
147 inline void getLn(const LL *a, LL *b, int n) {
148     getInv(a, c, n);
149     der(a, d, n);
150     int len = 1;
151     while(len < n * 2) len <<= 1;
152     memset(c + n, 0, (len - n) * sizeof(len));
153     memset(d + n, 0, (len - n) * sizeof(len));
154     mul(c, d, b, len);
155     memset(b + n, 0, (len - n) * sizeof(LL));
156     ter(b, b, n);
157     return;
158 }
159 
160 int main() {
161     int n;
162     scanf("%d", &n);
163     for(int i = 0; i < n; i++) {
164         scanf("%lld", &a[i]);
165     }
166     getLn(a, b, n);
167     for(int i = 0; i < n; i++) {
168         printf("%lld ", (b[i] + MO) % MO);
169     }
170     printf("\n");
171     return 0;
172 }
多项式ln

多项式取指数&牛顿迭代:

还是洛咕题解

首先背诵牛顿迭代:f(x) = f(x0) + f'(x0)(x - x0)

由f(x) = 0得x = x0 - f(x0) / f'(x0)

然后设G(f(x)) = ln(f(x)) - A(x),求其零点。

G'(f(x)) = 1 / f(x)

然后倍增,常数项是ea[0],注意倍增式系数运算时那个 + 1只加在常数项上。

多项式exp居然要用到ln...ln又要求逆...我死了。

  1 #include <cstdio>
  2 #include <algorithm>
  3 #include <cstring>
  4 
  5 typedef long long LL;
  6 const int N = 100010;
  7 const LL MO = 998244353, G = 3;
  8 typedef LL arr[N << 2];
  9 
 10 int r[N << 2];
 11 arr a, b, inv_temp, ln_t1, ln_t2, exp_t1, exp_t2, mod_temp, A, B;
 12 
 13 inline LL qpow(LL a, LL b) {
 14     LL ans = 1; a %= MO;
 15     while(b) {
 16         if(b & 1) ans = ans * a % MO;
 17         a = a * a % MO;
 18         b = b >> 1;
 19     }
 20     return ans;
 21 }
 22 
 23 inline void prework(int n) {
 24     static int R = 0;
 25     if(R == n) return;
 26     int lm = 1;
 27     while((1 << lm) < n) lm++;
 28     for(int i = 1; i < n; i++) {
 29         r[i] = (r[i >> 1] >> 1) | ((i & 1) << (lm - 1));
 30     }
 31     R = n;
 32     return;
 33 }
 34 
 35 inline void NTT(LL *a, int n, int f) {
 36     prework(n);
 37     for(int i = 0; i < n; i++) {
 38         if(i < r[i]) {
 39             std::swap(a[i], a[r[i]]);
 40         }
 41     }
 42     for(int len = 1; len < n; len <<= 1) {
 43         LL Wn = qpow(G, (MO - 1) / (len << 1));
 44         if(f == -1) Wn = qpow(Wn, MO - 2);
 45         for(int i = 0; i < n; i += (len << 1)) {
 46             LL w = 1;
 47             for(int j = 0; j < len; j++) {
 48                 LL t = a[i + len + j] * w % MO;
 49                 a[i + len + j] = (a[i + j] - t) % MO;
 50                 a[i + j] = (a[i + j] + t) % MO;
 51                 w = w * Wn % MO;
 52             }
 53         }
 54     }
 55     if(f == -1) {
 56         LL inv = qpow(n, MO - 2);
 57         for(int i = 0; i < n; i++) {
 58             a[i] = a[i] * inv % MO;
 59         }
 60     }
 61     return;
 62 }
 63 
 64 inline void mul(const LL *a, const LL *b, LL *c, int n) {
 65     memcpy(A, a, n * sizeof(LL));
 66     memcpy(B, b, n * sizeof(LL));
 67     NTT(A, n, 1); NTT(B, n, 1);
 68     for(int i = 0; i < n; i++) c[i] = A[i] * B[i] % MO;
 69     NTT(c, n, -1);
 70     return;
 71 }
 72 
 73 void Inv(const LL *a, LL *ans, int n) {
 74     if(n == 1) {
 75         ans[0] = qpow(a[0], MO - 2);
 76         ans[1] = 0;
 77         return;
 78     }
 79     Inv(a, ans, n >> 1);
 80     /// temp = 2 ans - a ans ans
 81     memcpy(A, a, n * sizeof(LL)); memset(A + n, 0, n * sizeof(LL));
 82     memcpy(B, ans, n * sizeof(LL)); memset(B + n, 0, n * sizeof(LL));
 83     NTT(A, n * 2, 1); NTT(B, n * 2, 1);
 84     for(int i = 0; i < n * 2; i++) ans[i] = (2 - A[i] * B[i] % MO) * B[i] % MO;
 85     NTT(ans, n * 2, -1);
 86     memset(ans + n, 0, n * sizeof(LL));
 87     return;
 88 }
 89 
 90 inline void getInv(const LL *a, LL *ans, int n) {
 91     memcpy(inv_temp, a, n * sizeof(LL));
 92     int len = 1;
 93     while(len < n) len <<= 1;
 94     memset(inv_temp + n, 0, (len - n) * sizeof(LL));
 95     Inv(inv_temp, ans, len);
 96     memset(ans + n, 0, (len - n) * sizeof(LL));
 97     return;
 98 }
 99 
100 inline void div(LL *a, LL *b, LL *c, int na, int nb) {
101     int nc = na - nb;
102     na++; nb++; nc++;
103     std::reverse(a, a + na);
104     std::reverse(b, b + nb);
105     getInv(b, c, nc);
106     int len = 1;
107     while(len < na + nc) len <<= 1;
108     memset(a + na, 0, (len - na) * sizeof(LL));
109     memset(c + nc, 0, (len - nc) * sizeof(LL));
110     mul(a, c, c, len);
111     std::reverse(a, a + na);
112     std::reverse(b, b + nb);
113     std::reverse(c, c + nc); memset(c + nc, 0, (len - nc) * sizeof(LL));
114     return;
115 }
116 
117 inline void mod(LL *a, LL *b, LL *c, LL *d, int na, int nb) {
118     int nc = na - nb;
119     na++; nb++; nc++;
120     int len = 1;
121     while(len < na) len <<= 1;
122     memset(b + nb, 0, (len - nb) * sizeof(LL));
123     memset(c + nc, 0, (len - nc) * sizeof(LL));
124     mul(b, c, mod_temp, len);
125     for(int i = 0; i < nb - 1; i++) {
126         d[i] = (a[i] - mod_temp[i]) % MO;
127     }
128     return;
129 }
130 
131 inline void der(const LL *a, LL *b, int n) { /// derivation   qiu dao
132     for(int i = 0; i < n - 1; i++) {
133         b[i] = a[i + 1] * (i + 1) % MO;
134     }
135     b[n - 1] = 0;
136     return;
137 }
138 
139 inline void ter(const LL *a, LL *b, int n) { /// quadrature   ji fen
140     for(int i = n - 1; i >= 1; i--) {
141         b[i] = a[i - 1] * qpow(i, MO - 2) % MO;
142     }
143     b[0] = 0;
144     return;
145 }
146 
147 inline void getLn(const LL *a, LL *b, int n) {
148     getInv(a, ln_t1, n);
149     der(a, ln_t2, n);
150     int len = 1;
151     while(len < n * 2) len <<= 1;
152     memset(ln_t1 + n, 0, (len - n) * sizeof(len));
153     memset(ln_t2 + n, 0, (len - n) * sizeof(len));
154     mul(ln_t1, ln_t2, b, len);
155     memset(b + n, 0, (len - n) * sizeof(LL));
156     ter(b, b, n);
157     return;
158 }
159 
160 void Exp(const LL *a, LL *ans, int n) {
161     if(n == 1) {
162         ans[0] = 1; ans[1] = 0;
163         return;
164     }
165     Exp(a, ans, n >> 1);
166     getLn(ans, exp_t1, n);
167     for(int i = 0; i < n; i++) {
168         exp_t1[i] = (a[i] - exp_t1[i]) % MO;
169     }
170     exp_t1[0] = (exp_t1[0] + 1) % MO;
171     memcpy(exp_t2, ans, n * sizeof(LL));
172     memset(exp_t2 + n, 0, n * sizeof(LL));
173     mul(exp_t1, exp_t2, ans, n * 2);
174     memset(ans + n, 0, n * sizeof(LL));
175     return;
176 }
177 
178 inline void getExp(const LL *a, LL *b, int n) {
179     int len = 1;
180     while(len < n) len <<= 1;
181     Exp(a, b, len);
182     memset(b + n, 0, (len - n) * sizeof(LL));
183     return;
184 }
185 
186 int main() {
187     int n;
188     scanf("%d", &n);
189     for(int i = 0; i < n; i++) {
190         scanf("%lld", &a[i]);
191     }
192     getExp(a, b, n);
193     for(int i = 0; i < n; i++) {
194         printf("%lld ", (b[i] + MO) % MO);
195     }
196     printf("\n");
197     return 0;
198 }
多项式exp

多项式开根:

牛顿迭代的简单应用......注意内层乘的长度是2n,还有自己跟自己卷积的时候不能调用mul,因为会做两次DFT...

  1 #include <cstdio>
  2 #include <algorithm>
  3 #include <cstring>
  4 
  5 typedef long long LL;
  6 const int N = 100010;
  7 const LL MO = 998244353, G = 3;
  8 typedef LL arr[N << 2];
  9 
 10 int r[N << 2];
 11 arr a, b, inv_t, ln_t1, ln_t2, exp_t1, exp_t2, mod_temp, sqrt_t, A, B;
 12 LL inv2;
 13 
 14 inline LL qpow(LL a, LL b) {
 15     LL ans = 1; a %= MO;
 16     while(b) {
 17         if(b & 1) ans = ans * a % MO;
 18         a = a * a % MO;
 19         b = b >> 1;
 20     }
 21     return ans;
 22 }
 23 
 24 inline void prework(int n) {
 25     static int R = 0;
 26     if(R == n) return;
 27     int lm = 1;
 28     while((1 << lm) < n) lm++;
 29     for(int i = 1; i < n; i++) {
 30         r[i] = (r[i >> 1] >> 1) | ((i & 1) << (lm - 1));
 31     }
 32     R = n;
 33     return;
 34 }
 35 
 36 inline void NTT(LL *a, int n, int f) {
 37     prework(n);
 38     for(int i = 0; i < n; i++) {
 39         if(i < r[i]) {
 40             std::swap(a[i], a[r[i]]);
 41         }
 42     }
 43     for(int len = 1; len < n; len <<= 1) {
 44         LL Wn = qpow(G, (MO - 1) / (len << 1));
 45         if(f == -1) Wn = qpow(Wn, MO - 2);
 46         for(int i = 0; i < n; i += (len << 1)) {
 47             LL w = 1;
 48             for(int j = 0; j < len; j++) {
 49                 LL t = a[i + len + j] * w % MO;
 50                 a[i + len + j] = (a[i + j] - t) % MO;
 51                 a[i + j] = (a[i + j] + t) % MO;
 52                 w = w * Wn % MO;
 53             }
 54         }
 55     }
 56     if(f == -1) {
 57         LL inv = qpow(n, MO - 2);
 58         for(int i = 0; i < n; i++) {
 59             a[i] = a[i] * inv % MO;
 60         }
 61     }
 62     return;
 63 }
 64 
 65 inline void mul(const LL *a, const LL *b, LL *c, int n) {
 66     memcpy(A, a, n * sizeof(LL));
 67     memcpy(B, b, n * sizeof(LL));
 68     NTT(A, n, 1); NTT(B, n, 1);
 69     for(int i = 0; i < n; i++) c[i] = A[i] * B[i] % MO;
 70     NTT(c, n, -1);
 71     return;
 72 }
 73 
 74 void Inv(const LL *a, LL *ans, int n) {
 75     if(n == 1) {
 76         ans[0] = qpow(a[0], MO - 2);
 77         ans[1] = 0;
 78         return;
 79     }
 80     Inv(a, ans, n >> 1);
 81     /// temp = 2 ans - a ans ans
 82     memcpy(A, a, n * sizeof(LL)); memset(A + n, 0, n * sizeof(LL));
 83     memcpy(B, ans, n * sizeof(LL)); memset(B + n, 0, n * sizeof(LL));
 84     NTT(A, n * 2, 1); NTT(B, n * 2, 1);
 85     for(int i = 0; i < n * 2; i++) ans[i] = (2 - A[i] * B[i] % MO) * B[i] % MO;
 86     NTT(ans, n * 2, -1);
 87     memset(ans + n, 0, n * sizeof(LL));
 88     return;
 89 }
 90 
 91 inline void getInv(const LL *a, LL *ans, int n) {
 92     memcpy(inv_t, a, n * sizeof(LL));
 93     int len = 1;
 94     while(len < n) len <<= 1;
 95     memset(inv_t + n, 0, (len - n) * sizeof(LL));
 96     Inv(inv_t, ans, len);
 97     memset(ans + n, 0, (len - n) * sizeof(LL));
 98     return;
 99 }
100 
101 inline void div(LL *a, LL *b, LL *c, int na, int nb) {
102     int nc = na - nb;
103     na++; nb++; nc++;
104     std::reverse(a, a + na);
105     std::reverse(b, b + nb);
106     getInv(b, c, nc);
107     int len = 1;
108     while(len < na + nc) len <<= 1;
109     memset(a + na, 0, (len - na) * sizeof(LL));
110     memset(c + nc, 0, (len - nc) * sizeof(LL));
111     mul(a, c, c, len);
112     std::reverse(a, a + na);
113     std::reverse(b, b + nb);
114     std::reverse(c, c + nc); memset(c + nc, 0, (len - nc) * sizeof(LL));
115     return;
116 }
117 
118 inline void mod(LL *a, LL *b, LL *c, LL *d, int na, int nb) {
119     int nc = na - nb;
120     na++; nb++; nc++;
121     int len = 1;
122     while(len < na) len <<= 1;
123     memset(b + nb, 0, (len - nb) * sizeof(LL));
124     memset(c + nc, 0, (len - nc) * sizeof(LL));
125     mul(b, c, mod_temp, len);
126     for(int i = 0; i < nb - 1; i++) {
127         d[i] = (a[i] - mod_temp[i]) % MO;
128     }
129     return;
130 }
131 
132 inline void der(const LL *a, LL *b, int n) { /// derivation   qiu dao
133     for(int i = 0; i < n - 1; i++) {
134         b[i] = a[i + 1] * (i + 1) % MO;
135     }
136     b[n - 1] = 0;
137     return;
138 }
139 
140 inline void ter(const LL *a, LL *b, int n) { /// quadrature   ji fen
141     for(int i = n - 1; i >= 1; i--) {
142         b[i] = a[i - 1] * qpow(i, MO - 2) % MO;
143     }
144     b[0] = 0;
145     return;
146 }
147 
148 inline void getLn(const LL *a, LL *b, int n) {
149     getInv(a, ln_t1, n);
150     der(a, ln_t2, n);
151     int len = 1;
152     while(len < n * 2) len <<= 1;
153     memset(ln_t1 + n, 0, (len - n) * sizeof(len));
154     memset(ln_t2 + n, 0, (len - n) * sizeof(len));
155     mul(ln_t1, ln_t2, b, len);
156     memset(b + n, 0, (len - n) * sizeof(LL));
157     ter(b, b, n);
158     return;
159 }
160 
161 void Exp(const LL *a, LL *ans, int n) {
162     if(n == 1) {
163         ans[0] = 1; ans[1] = 0;
164         return;
165     }
166     Exp(a, ans, n >> 1);
167     getLn(ans, exp_t1, n);
168     for(int i = 0; i < n; i++) {
169         exp_t1[i] = (a[i] - exp_t1[i]) % MO;
170     }
171     exp_t1[0] = (exp_t1[0] + 1) % MO;
172     memcpy(exp_t2, ans, n * sizeof(LL));
173     memset(exp_t2 + n, 0, n * sizeof(LL));
174     mul(exp_t1, exp_t2, ans, n * 2);
175     memset(ans + n, 0, n * sizeof(LL));
176     return;
177 }
178 
179 inline void getExp(const LL *a, LL *b, int n) {
180     int len = 1;
181     while(len < n) len <<= 1;
182     Exp(a, b, len);
183     memset(b + n, 0, (len - n) * sizeof(LL));
184     return;
185 }
186 
187 void Sqrt(const LL *a, LL *b, int n) {
188     if(n == 1) {
189         b[0] = 1;
190         b[1] = 0;
191         return;
192     }
193     Sqrt(a, b, n >> 1);
194     getInv(b, B, n);
195     memset(B + n, 0, n * sizeof(LL));
196     memcpy(A, a, n * sizeof(LL));
197     memset(A + n, 0, n * sizeof(LL));
198     NTT(A, n << 1, 1); NTT(B, n << 1, 1);
199     for(int i = 0; i < (n << 1); i++) B[i] = B[i] * A[i] % MO;
200     NTT(B, n << 1, -1);
201     for(int i = 0; i < n; i++) b[i] = (B[i] + b[i]) % MO * inv2 % MO;
202     memset(b + n, 0, n * sizeof(LL));
203     return;
204 }
205 
206 inline void getSqrt(const LL *a, LL *b, int n) {
207     int len = 1;
208     while(len < n) len <<= 1;
209     memcpy(sqrt_t, a, n * sizeof(LL));
210     memset(sqrt_t + n, 0, (len - n) * sizeof(LL));
211     Sqrt(sqrt_t, b, len);
212     memset(b + n, 0, (len - n) * sizeof(LL));
213     return;
214 }
215 
216 int main() {
217     inv2 = (MO + 1) / 2;
218     int n;
219     scanf("%d", &n);
220     for(int i = 0; i < n; i++) {
221         scanf("%lld", &a[i]);
222     }
223     getSqrt(a, b, n);
224     for(int i = 0; i < n; i++) {
225         printf("%lld ", (b[i] + MO) % MO);
226     }
227     printf("\n");
228     return 0;
229 }
AC代码

三模数NTT

任意模数NTT,使用三个1e9级别的模数 + exCRT实现。

首先背诵模数:

469762049, 998244353, 1004535809

CRT的时候注意到最后一次合并爆long long了,于是直接对给定的模数取模。参考资料

  1 #include <bits/stdc++.h>
  2 
  3 #define MO MOD[turn]
  4 
  5 typedef long long LL;
  6 const int N = 400010;
  7 const LL MOD[] = {469762049, 998244353, 1004535809};
  8 
  9 int turn, r[N];
 10 LL a[N], b[N], A[N], B[N], c[3][N], mod;
 11 
 12 LL exgcd(LL a, LL &x, LL b, LL &y) {
 13     if(!b) {
 14         x = 1;
 15         y = 0;
 16         return a;
 17     }
 18     LL g = exgcd(b, x, a % b, y);
 19     std::swap(x, y);
 20     y -= x * (a / b);
 21     return g;
 22 }
 23 
 24 LL gcd(LL a, LL b) {
 25     if(!b) return a;
 26     return gcd(b, a % b);
 27 }
 28 
 29 inline LL lcm(LL a, LL b) {
 30     return a / gcd(a, b) * b;
 31 }
 32 
 33 inline LL mul(LL a, LL b, LL c) {
 34     LL ans = 0;
 35     a = (a % c + c) % c;
 36     b = (b % c + c) % c;
 37     while(b) {
 38         if(b & 1) ans = (ans + a) % c;
 39         a = (a + a) % c;
 40         b = b >> 1;
 41     }
 42     return ans;
 43 }
 44 
 45 inline LL qpow(LL a, LL b) {
 46     LL ans = 1;
 47     a %= MO;
 48     while(b) {
 49         if(b & 1) ans = ans * a % MO;
 50         a = a * a % MO;
 51         b = b >> 1;
 52     }
 53     return ans;
 54 }
 55 
 56 inline void prework(int n) {
 57     static int R = 0;
 58     if(R == n) return;
 59     R = n;
 60     int lm = 1;
 61     while((1 << lm) < n) lm++;
 62     for(int i = 0; i < n; i++) {
 63         r[i] = (r[i >> 1] >> 1) | ((i & 1) << (lm - 1));
 64     }
 65     return;
 66 }
 67 
 68 inline void NTT(LL *a, int n, int f) {
 69     prework(n);
 70     for(int i = 0; i < n; i++) {
 71         if(i < r[i]) std::swap(a[i], a[r[i]]);
 72     }
 73     for(int len = 1; len < n; len <<= 1) {
 74         LL Wn = qpow(3, (MO - 1) / (len << 1));
 75         if(f == -1) Wn = qpow(Wn, MO - 2);
 76         for(int i = 0; i < n; i += (len << 1)) {
 77             LL w = 1;
 78             for(int j = 0; j < len; j++) {
 79                 LL t = w * a[i + len + j] % MO;
 80                 a[i + len + j] = (a[i + j] - t + MO) % MO;
 81                 a[i + j] = (a[i + j] + t) % MO;
 82                 w = w * Wn % MO;
 83             }
 84         }
 85     }
 86     if(f == -1) {
 87         LL inv = qpow(n, MO - 2);
 88         for(int i = 0; i < n; i++) {
 89             a[i] = a[i] * inv % MO;
 90         }
 91     }
 92     return;
 93 }
 94 
 95 inline void merge(int n) {
 96     LL p = lcm(MOD[0], MOD[1]), x, y;
 97     for(int i = 0; i < n; i++) {
 98         LL C = (c[1][i] - c[0][i] + MOD[1]) % MOD[1];
 99         LL g = exgcd(MOD[0], x, MOD[1], y);
100         x = mul(x, C / g, p);
101         LL a = (c[0][i] + mul(x, MOD[0], p)) % p;
102 
103         C = ((c[2][i] - a) % MOD[2] + MOD[2]) % MOD[2];
104         g = exgcd(p, x, MOD[2], y);
105         /// ERROR    x = mul(x, C / g, mod);
106         x = mul(x, C / g, MOD[2]);
107         c[2][i] = (a + mul(x, p, mod)) % mod;
108     }
109     return;
110 }
111 
112 int main() {
113     int n, m;
114     scanf("%d%d%lld", &n, &m, &mod);
115     for(int i = 0; i <= n; i++) {
116         scanf("%lld", &a[i]);
117     }
118     for(int i = 0; i <= m; i++) {
119         scanf("%lld", &b[i]);
120     }
121 
122     int len = 1;
123     while(len <= n + m) {
124         len <<= 1;
125     }
126     for(turn = 0; turn < 3; turn++) {
127         for(int i = 0; i < len; i++) {
128             A[i] = a[i] % MO;
129             B[i] = b[i] % MO;
130         }
131         NTT(A, len, 1); NTT(B, len, 1);
132         for(int i = 0; i < len; i++) A[i] = A[i] * B[i] % MO;
133         NTT(A, len, -1);
134         for(int i = 0; i < len; i++) {
135             c[turn][i] = A[i];
136         }
137     }
138 
139     merge(len);
140 
141     for(int i = 0; i <= n + m; i++) {
142         printf("%lld ", (c[2][i] % mod + mod) % mod);
143     }
144     return 0;
145 }
AC代码

 

发表评论

0/200
25 点赞
0 评论
收藏