fork download
  1. #include <bits/stdc++.h>
  2. using namespace std;
  3.  
  4. #define N 100001
  5. #define L 18 /* L = ceil(log2(N * 2 - 1)) */
  6. #define N_ (1 << L)
  7. #define MD 469762049 /* MD = 56 * 2^23 + 1 */
  8.  
  9. int *wu[L + 1], *wv[L + 1];
  10.  
  11. int power(int a, int k) {
  12. long long b = a, p = 1;
  13.  
  14. while (k) {
  15. if (k & 1)
  16. p = p * b % MD;
  17. b = b * b % MD;
  18. k >>= 1;
  19. }
  20. return p;
  21. }
  22.  
  23. void init() {
  24. int l, i, u, v;
  25.  
  26. /* u ^ n must be equal to 1 for FFT */
  27. /* since x ^ (MD - 1) = 1 (mod MD) by fermat's little theorem */
  28. /* so MD - 1 must be power of 2 as n is a power of 2 */
  29. u = power(3, (MD - 1) >> L);
  30.  
  31. /* v is u ^ -1 = u ^ (MD - 2) by fermat's little theorem */
  32. /* this is for interpolation (inverse FFT) */
  33. v = power(u, MD - 2);
  34.  
  35. /* find the powers of u and v for each possible n */
  36. for (l = L; l > 0; l--) {
  37. int n = 1 << (l - 1);
  38.  
  39. wu[l] = (int *) malloc(n * sizeof *wu[l]);
  40. wv[l] = (int *) malloc(n * sizeof *wv[l]);
  41.  
  42. wu[l][0] = wv[l][0] = 1;
  43. for (i = 1; i < n; i++) {
  44. wu[l][i] = (long long) wu[l][i - 1] * u % MD;
  45. wv[l][i] = (long long) wv[l][i - 1] * v % MD;
  46. }
  47.  
  48. /* change u and for next iteration of n, which becomes n / 2 */
  49. /* u ^ n = 1 => (u ^ (n / 2)) ^ 2 = 1 */
  50. /* same goes for v */
  51. u = (long long) u * u % MD, v = (long long) v * v % MD;
  52. }
  53.  
  54. }
  55.  
  56. void ntt_(int *aa, int l, int inverse) {
  57. if (l > 0) {
  58. int n = 1 << l;
  59. int m = n >> 1;
  60. int *ww = inverse ? wv[l] : wu[l];
  61. int i, j;
  62.  
  63. /* solve for even and odd degrees */
  64. /* P(x) = P_even(x) + x * P_odd(x) */
  65. /* where P_even(x) is the terms with even degree and */
  66. /* where P_odd(x) is the terms with odd degree and */
  67. ntt_(aa, l - 1, inverse);
  68. ntt_(aa + m, l - 1, inverse);
  69.  
  70. /* now we make the point value form */
  71. for (i = 0; (j = i + m) < n; i++) {
  72.  
  73. /* a is even degree */
  74. /* b is odd degree, multiply by root^i */
  75. /* because we need to multiply back the part from */
  76. /* splitting P(x) into P_even(x) + x * P_odd(x) with x = root^i */
  77. int a = aa[i];
  78. int b = (long long) aa[j] * ww[i] % MD;
  79.  
  80. /* even part is all positive for all x of x ^ y */
  81. /* odd part is positive for x ^ y such that x is positive (first case, even index) */
  82. /* otherwise negative for negative x (second case, odd index) */
  83. if ((aa[i] = a + b) >= MD)
  84. aa[i] -= MD;
  85. if ((aa[j] = a - b) < 0)
  86. aa[j] += MD;
  87. }
  88. }
  89. }
  90.  
  91. void ntt(int *aa, int l, int inverse) {
  92. int n_ = 1 << l, i, j;
  93.  
  94. /* reverse the bits for each element so that */
  95. /* in the FFT we don't have to split the even/odd indexes */
  96. /* as it becomes into two ranges: [l, m) and [m, r) */
  97. for (i = 0, j = 1; j < n_; j++) {
  98. int b;
  99. int tmp;
  100.  
  101. for (b = n_ >> 1; (i ^= b) < b; b >>= 1)
  102. ;
  103. if (i < j)
  104. tmp = aa[i], aa[i] = aa[j], aa[j] = tmp;
  105. }
  106. ntt_(aa, l, inverse);
  107. }
  108.  
  109. void mult(int *aa, int n, int *bb, int m, int *out) {
  110. static int aa_[N_], bb_[N_];
  111. int l, n_, i, v;
  112.  
  113. /* enlarge size so that it's a power of 2 */
  114. l = 0;
  115. while (1 << l <= n - 1 + m - 1)
  116. l++;
  117. n_ = 1 << l;
  118. memcpy(aa_, aa, n * sizeof *aa), memset(aa_ + n, 0, (n_ - n) * sizeof *aa_);
  119. memcpy(bb_, bb, m * sizeof *bb), memset(bb_ + m, 0, (n_ - m) * sizeof *bb_);
  120.  
  121. /* FFT: convert coefficient form to point value form */
  122. ntt(aa_, l, 0), ntt(bb_, l, 0);
  123.  
  124. /* combine point value forms */
  125. for (i = 0; i < n_; i++)
  126. out[i] = (long long) aa_[i] * bb_[i] % MD;
  127.  
  128. /* Inverse FFT: convert point value form to coefficient form */
  129. ntt(out, l, 1);
  130. v = power(n_, MD - 2);
  131. for (i = 0; i < n_; i++)
  132. out[i] = (long long) out[i] * v % MD;
  133. }
  134.  
  135. int main() {
  136. static int aa[N], bb[N], out[N_];
  137. int n, m, i;
  138.  
  139. init();
  140. scanf("%d%d", &n, &m), n++, m++;
  141. for (i = 0; i < n; i++)
  142. scanf("%d", &aa[i]);
  143. for (i = 0; i < m; i++)
  144. scanf("%d", &bb[i]);
  145. mult(aa, n, bb, m, out);
  146. for (i = 0; i < n + m - 1; i++)
  147. printf("%d ", out[i]);
  148. printf("\n");
  149. return 0;
  150. }
Runtime error #stdin #stdout #stderr 0.01s 5548KB
stdin
Standard input is empty
stdout
Standard output is empty
stderr
*** buffer overflow detected ***: ./prog terminated