fork download
  1. #include <iostream> //allow input output
  2. #define ll long long //ll now means long long, thus we don't type long long each time
  3. using namespace std; //we don't need to use std::
  4. //blank line
  5. //another blank line because why not
  6. ll fastpow(ll a, ll b, ll mod){//this is a fastpow method, computes a^b mod mod
  7. ll cur = 1;//If k is the # of itrs in the while loop, a^(b%2^k)
  8. ll p2p = a;//a^(2^k)
  9. while(b != 0){//we stop when b = 0
  10. if(b%2 == 1){//we need to multiply cur by p2p when this bit is 1
  11. cur *= p2p;//we multiply cur by a^(2^k)
  12. cur = cur%mod;//we mod cur to avoid overflow stuff
  13. }//end if statment
  14. p2p *= p2p;//we square p2p to ready the next itr of the loop
  15. p2p = p2p%mod;//we mod p2p to avoid overflow stuff
  16. b = b/2;//shift b 1 bit back for the next itr
  17. }//end while loop
  18. return cur;//return the answer
  19. }//end fastpow
  20. //blank line
  21. //another blank line because why not
  22. void ntt(ll* a, ll* x, ll* r, ll n, ll mod){//this finds a(x) for all primitive roots, which are stored in x.
  23. //r[i] is the reverse of the binary string of i, n is the size of a, mod is the mod we are taking everything
  24. for(int i = 0; i < (1 << n); i++){//a for loop going from 0 - 2^n-1, this takes care of base cases
  25. if(i < r[i]){//We don't want to swap twice
  26. swap(a[i],a[r[i]]);//we take care of all the base cases by setting ans[i] = a[r[i]], but we use the a as the ans array
  27. }//end if statement
  28. }//end for loop
  29. //blank line
  30. for(int i = 1; i <= n; i++){//a for loop going from 1 - n, this solves each "layer"
  31. for(int j = 0; j < (1 << n); j+=(1 << i)){//j enumerates through each subproblem block, for example, [0 2 4 6] is a [1 3 5 7]
  32. for(int k = 0; k < (1 << (i-1)); k++){//we enumerate through the 2 halves of each subproblem block and combine them
  33. ll x1 = a[j+k] + x[(1 << n)/(1 << i) * k] * a[j+k+(1 << (i-1))];//We use A(x) = Aeven(x^2) + x * Aodd(x^2), so we compute them for every (2^i)th root
  34. ll x2 = a[j+k] + x[(1 << n)/(1 << i) * (k + (1 << (i-1)))] * a[j+k+(1 << (i-1))];//we do a similar thing for the other root
  35. a[j+k] = x1%mod;//we take mod, and we can replace a with the value, since the old values won't the used again
  36. a[j+k+(1 << (i-1))] = x2%mod;//same as last line
  37. }//end for loop
  38. }//end for loop
  39. }//end for loop
  40. //blank line
  41. //another blank line because why not
  42. }//end ntt method.
  43. //blank line
  44. //another blank line because why not
  45. ll* mul(ll* a, ll* b, ll* c, ll n, ll g, ll mod){//This method multiplies a and b and puts the
  46. //result in c. n is a power of 2 larger than a and b's degree, g is a primitive root of mod, and
  47. // mod is the mod we are taking
  48. n++;//we increase n, since we actually need 1 more
  49. // cout << n << " " << endl;
  50. ll r [(1 << n)];//this stores the reverse binary of every number 4 -> 100 -> 001 -> 1
  51. //however, everything is padded so it has n digits
  52. r[0] = 0;//base case
  53. for(int i = 1; i < (1 << n); i++){//a for loop to compute r[x] for the rest of the nums
  54. if(i%2 != 0){//the last digit is a 1
  55. r[i] = (r[i/2]/2) + (1 << (n-1));//we add (1 << (n-1)) to the front
  56. }else{// the last digit is a 0
  57. r[i] = (r[i/2]/2);//we take r[i/2], and divide by 2 since there is an extra 0
  58. }//end if else
  59. }//end for loop
  60. ll x1 [(1 << n)];//the 2^n th roots of unity
  61. x1[0] = 1;//0th root of unity is always 0
  62. ll rt = fastpow(g, (mod-1)/(1 << n), mod);//we raise the primitive root to a power
  63. //to obtain the first 2^nth root of unity
  64. for(int i = 1; i < (1 << n); i++){//for loop
  65. x1[i] = (rt * x1[i-1])%mod;//we manually compute each root of unity
  66. }//end for loop
  67. ll x2 [(1 << n)];//the 2^n th roots of unity, but backwards (These are actually the -xth 2^nth root of unity)
  68. x2[0] = 1;//we set 0th root of unity to 0
  69. for(int i = 1; i < (1 << n); i++){//for loop
  70. x2[i] = x1[(1 << n) - i];//the -xth root of unity is equal to the n-xth root of unity
  71. }//end for loop
  72. ntt(a, x1, r, n, mod);// we do ntt on a to find a's point form
  73. ntt(b, x1, r, n, mod);// we do ntt on b to find b's point form
  74. for(int i = 0; i < (1 << n); i++){//for loop
  75. c[i] = (a[i] * b[i])%mod;//we find the point form of c
  76. }//end for loop
  77. ntt(c, x2, r, n, mod);//we do an inverse ntt on c to find c's coefficient form.
  78. //we use x2 for this since the inverse is basically the same except we use x^-n as coeffs
  79. ll inverse = fastpow((1 << n), mod-2, mod);//we need to divide by 2^n at the end
  80. //blank line
  81. for(int i = 0; i < (1 << n); i++){//for loop
  82. c[i] = (c[i] * inverse)%mod;//we divide the coeffs by 2^n since we are doing ifft
  83. }//end for loop
  84. return c;//the coefficients of the multiplied poly is returned
  85. }//end mul function
  86. //blank line
  87. int main() {//main function
  88. int n;//degree of n
  89. int m;//degree of m
  90. cin >> n >> m;//n and m
  91. int x = max(n,m)+1;//the highest number of terms in either n and m
  92. int p2 = 1;//this finds 1 + power of 2 higher than n and m
  93. while(x != 0){//while loop, pushes x until it gets to 0
  94. x = x/2;//we divide x by 2
  95. p2++;//we increase p2
  96. }//end while loop
  97. ll a [(1 << p2)];//polynomial 1
  98. ll b [(1 << p2)];//polynomial 2
  99. ll c [(1 << p2)];//answer
  100. for(int i = 0; i < (1 << p2); i++){//for loop
  101. a[i] = 0;//init poly1 to 0
  102. b[i] = 0;//init poly2 to 0s
  103. c[i] = 0;//init ans to 0s (probably not needed)
  104. }//end for loop
  105. for(int i = 0; i <= n; i++){//for loop
  106. cin >> a[i];//read in each elem of a
  107. }//end for loop
  108. for(int i = 0; i <= m; i++){//for loop
  109. cin >> b[i];//read in each elem of b
  110. }//end for loop
  111. mul(a,b,c,p2-1,3,998244353);//we multiply a and b
  112. for(int i = 0; i <= (n+m); i++){//for loop
  113. cout << c[i] << " ";//print the result
  114. }//end for loop
  115. cout << "\n";//an extra endl since why not
  116. }//end main
Success #stdin #stdout 0.01s 5300KB
stdin
Standard input is empty
stdout