fork(6) download
  1. #pragma GCC optimize("Ofast,unroll-loops")
  2. //#pragma GCC target("avx,avx2,fma")
  3. #pragma GCC target("avx")
  4.  
  5. #include <bits/stdc++.h>
  6.  
  7. class Timer {
  8. std::chrono::time_point<std::chrono::steady_clock> timePoint;
  9. size_t value;
  10. public:
  11. void start() { timePoint = std::chrono::steady_clock::now(); }
  12. void finish() {
  13. auto curr = std::chrono::steady_clock::now();
  14. auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(curr - timePoint);
  15. value = elapsed.count();
  16. }
  17. size_t operator()() const { return value; }
  18. };
  19.  
  20. namespace {
  21. template<int n, typename T>
  22. void mult(const T *__restrict a, const T *__restrict b, T *__restrict res) {
  23. if (n <= 64) { // if length is small then naive multiplication if faster
  24. for (int i = 0; i < n; i++) {
  25. for (int j = 0; j < n; j++) {
  26. res[i + j] += a[i] * b[j];
  27. }
  28. }
  29. } else {
  30. const int mid = n / 2;
  31. alignas(64) T btmp[n], E[n] = {};
  32. auto atmp = btmp + mid;
  33. for (int i = 0; i < mid; i++) {
  34. atmp[i] = a[i] + a[i + mid]; // atmp(x) - sum of two halfs a(x)
  35. btmp[i] = b[i] + b[i + mid]; // btmp(x) - sum of two halfs b(x)
  36. }
  37. mult<mid>(atmp, btmp, E); // Calculate E(x) = (alow(x) + ahigh(x)) * (blow(x) + bhigh(x))
  38. mult<mid>(a + 0, b + 0, res); // Calculate rlow(x) = alow(x) * blow(x)
  39. mult<mid>(a + mid, b + mid, res + n); // Calculate rhigh(x) = ahigh(x) * bhigh(x)
  40. for (int i = 0; i < mid; i++) { // Then, calculate rmid(x) = E(x) - rlow(x) - rhigh(x) and write in memory
  41. const auto tmp = res[i + mid];
  42. res[i + mid] += E[i] - res[i] - res[i + 2 * mid];
  43. res[i + 2 * mid] += E[i + mid] - tmp - res[i + 3 * mid];
  44. }
  45. }
  46. }
  47. }
  48.  
  49. #define isz(x) (int)(x).size()
  50.  
  51. using namespace std;
  52. using cd = complex<double>;
  53. const double PI = acos(-1);
  54.  
  55. void fft(vector<cd> & a, bool invert) {
  56. int n = a.size();
  57.  
  58. for (int i = 1, j = 0; i < n; i++) {
  59. int bit = n >> 1;
  60. for (; j & bit; bit >>= 1)
  61. j ^= bit;
  62. j ^= bit;
  63.  
  64. if (i < j)
  65. swap(a[i], a[j]);
  66. }
  67.  
  68. for (int len = 2; len <= n; len <<= 1) {
  69. double ang = 2 * PI / len * (invert ? -1 : 1);
  70. cd wlen(cos(ang), sin(ang));
  71. for (int i = 0; i < n; i += len) {
  72. cd w(1);
  73. for (int j = 0; j < len / 2; j++) {
  74. cd u = a[i+j], v = a[i+j+len/2] * w;
  75. a[i+j] = u + v;
  76. a[i+j+len/2] = u - v;
  77. w *= wlen;
  78. }
  79. }
  80. }
  81.  
  82. if (invert) {
  83. for (cd & x : a)
  84. x /= n;
  85. }
  86. }
  87.  
  88. vector<int> multiply(vector<int> const& a, vector<int> const& b) {
  89. vector<cd> fa(a.begin(), a.end()), fb(b.begin(), b.end());
  90. int n = 1;
  91. while (n < a.size() + b.size())
  92. n <<= 1;
  93. fa.resize(n);
  94. fb.resize(n);
  95.  
  96. fft(fa, false);
  97. fft(fb, false);
  98. for (int i = 0; i < n; i++)
  99. fa[i] *= fb[i];
  100. fft(fa, true);
  101.  
  102. vector<int> result(n);
  103. for (int i = 0; i < n; i++)
  104. result[i] = round(fa[i].real());
  105. return result;
  106. }
  107.  
  108. template<int NMAX>
  109. void test() {
  110. Timer timer;
  111. timer.start();
  112. static int a[NMAX], b[NMAX], c[2*NMAX];
  113. for (int i = 0; i < NMAX; i++) a[i] = b[i] = 1;
  114. mult<NMAX>(a,b,c);
  115. timer.finish();
  116. std::cout << "NMAX = " << std::setw(10) << NMAX << ", karatsuba: " << std::setw(5) << timer() << "ms,";
  117. timer.start();
  118. vector<int> va(NMAX,1), vb(NMAX,1);
  119. auto vc = multiply(va, vb);
  120. timer.finish();
  121. std::cout << " FFT: " << std::setw(5) << timer() << "ms" << std::endl;
  122. for (int i = 0; i < isz(vc); i++) {
  123. if (vc[i] != c[i]) {
  124. std::cout << "i = " << i << ", vc[i] = " << vc[i] << ", c[i] " << c[i] << std::endl;
  125. std::exit(0);
  126. }
  127. assert(vc[i] == c[i]);
  128. }
  129. }
  130.  
  131. int main() {
  132. test<(1 << 19)>();
  133. test<(1 << 18)>();
  134. test<(1 << 17)>();
  135. test<(1 << 16)>();
  136. test<(1 << 15)>();
  137. test<(1 << 14)>();
  138. test<(1 << 13)>();
  139. test<(1 << 12)>();
  140. test<(1 << 11)>();
  141. test<(1 << 10)>();
  142. return 0;
  143. }
Success #stdin #stdout 2.39s 64908KB
stdin
Standard input is empty
stdout
NMAX =     524288, karatsuba:   985ms, FFT:   548ms
NMAX =     262144, karatsuba:   301ms, FFT:   198ms
NMAX =     131072, karatsuba:    90ms, FFT:    92ms
NMAX =      65536, karatsuba:    36ms, FFT:    68ms
NMAX =      32768, karatsuba:    15ms, FFT:    24ms
NMAX =      16384, karatsuba:     5ms, FFT:    13ms
NMAX =       8192, karatsuba:     1ms, FFT:     4ms
NMAX =       4096, karatsuba:     0ms, FFT:     1ms
NMAX =       2048, karatsuba:     0ms, FFT:     0ms
NMAX =       1024, karatsuba:     0ms, FFT:     0ms