fork(2) download
  1. #include <iostream>
  2. #include <random>
  3. #include <ctime>
  4. #include <cassert>
  5. #include <chrono>
  6. #include <immintrin.h>
  7.  
  8. const int Q = 30000000;
  9. const int N = 1 << 24;
  10.  
  11. using T = uint32_t;
  12. T a[2 * N];
  13.  
  14. const T identity_element = 0;
  15. T reduce(T u, T v) { return u + v; }
  16. __attribute__((target("sse4.1"))) __m128i reduce(__m128i u, __m128i v) { return _mm_add_epi32(u, v); }
  17. __attribute__((target("avx2"))) __m256i reduce(__m256i u, __m256i v) { return _mm256_add_epi32(u, v); }
  18.  
  19. static_assert(sizeof(T) == 4, "Segment tree elements must be 32-bit");
  20.  
  21. T query_recursive_inner(int v, int vl, int vr, int l, int r)
  22. {
  23. if(l >= r) return identity_element;
  24. if(l <= vl && vr <= r) return a[v];
  25. int vm = (vl + vr) >> 1;
  26. return reduce(query_recursive_inner(v << 1, vl, vm, l, std::min(r, vm)), query_recursive_inner((v << 1) | 1, vm, vr, std::max(l, vm), r));
  27. }
  28.  
  29. T query_recursive_inner(int l, int r)
  30. {
  31. return query_recursive_inner(1, 0, N, l, r + 1);
  32. }
  33.  
  34. T query_recursive_outer(int v, int vl, int vr, int l, int r)
  35. {
  36. if(vl == l && vr == r) return a[v];
  37.  
  38. int vm = (vl + vr) >> 1;
  39. if(r <= vm) return query_recursive_outer(v << 1, vl, vm, l, r);
  40. if(l >= vm) return query_recursive_outer((v << 1) | 1, vm, vr, l, r);
  41. return reduce(query_recursive_outer(v << 1, vl, vm, l, vm), query_recursive_outer((v << 1) | 1, vm, vr, vm, r));
  42. }
  43.  
  44. T query_recursive_outer(int l, int r)
  45. {
  46. return query_recursive_outer(1, 0, N, l, r + 1);
  47. }
  48.  
  49. T query_bottom_up(int l, int r)
  50. {
  51. l += N;
  52. r += N;
  53. T ans = identity_element;
  54. while(l <= r)
  55. {
  56. if(l & 1)
  57. {
  58. ans = reduce(ans, a[l]);
  59. l++;
  60. }
  61. if(!(r & 1))
  62. {
  63. ans = reduce(ans, a[r]);
  64. r--;
  65. }
  66. l >>= 1;
  67. r >>= 1;
  68. }
  69. return ans;
  70. }
  71.  
  72.  
  73. int ffs(unsigned int x) { return sizeof(unsigned int) * 8 - 1 - __builtin_clz(x); }
  74.  
  75. __attribute__((target("avx2"))) T query_parallel(int l, int r)
  76. {
  77. if(l == r) return a[l + N];
  78.  
  79. int mbit = ffs(l ^ r);
  80. int reset = ((1 << mbit) - 1);
  81. int m = r & ~reset;
  82.  
  83. using vecint = T __attribute__((vector_size(32)));
  84. __m256i identity_vec = _mm256_set1_epi32(identity_element);
  85. vecint vec_ans = (vecint)identity_vec;
  86. __m256i indexes = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
  87.  
  88. if((l & reset) != 0) {
  89. int ll = l - 1 + N;
  90. int rr = m - 1 + N;
  91.  
  92. int modbit = 0;
  93. int maxmodbit = ffs(ll ^ rr) + 1;
  94.  
  95. vecint ll_vec = (vecint)_mm256_srav_epi32(_mm256_set1_epi32(ll), indexes);
  96.  
  97. #define LOOP(content) if(modbit + 8 <= maxmodbit) { \
  98. vec_ans = (vecint)reduce((__m256i)vec_ans, _mm256_i32gather_epi32((int*)a, (__m256i)(((ll_vec & 1) - 1) & (ll_vec | 1)), 4)); \
  99. ll_vec >>= 8; \
  100. modbit += 8; \
  101. content \
  102. }
  103. LOOP(LOOP(LOOP(LOOP())))
  104. #undef LOOP
  105.  
  106. __m256i tmp = _mm256_i32gather_epi32((int*)a, (__m256i)(((ll_vec & 1) - 1) & (ll_vec | 1)), 4);
  107. __m256i mask = _mm256_cmpgt_epi32(_mm256_set1_epi32(maxmodbit & 7), indexes);
  108. vec_ans = (vecint)reduce((__m256i)vec_ans, _mm256_blendv_epi8(identity_vec, tmp, mask));
  109. }
  110. else
  111. vec_ans[0] = reduce(vec_ans[0], a[(l + N) >> mbit]);
  112.  
  113.  
  114. if((r & reset) != reset)
  115. {
  116. int ll = m + N;
  117. int rr = r + 1 + N;
  118.  
  119. int modbit = 0;
  120. int maxmodbit = ffs(ll ^ rr) + 1;
  121.  
  122. vecint rr_vec = (vecint)_mm256_srav_epi32(_mm256_set1_epi32(rr), indexes);
  123.  
  124. #define LOOP(content) if(modbit + 8 <= maxmodbit) { \
  125. vec_ans = (vecint)reduce((__m256i)vec_ans, _mm256_i32gather_epi32((int*)a, (__m256i)(~((rr_vec & 1) - 1) & (rr_vec - 1)), 4)); \
  126. rr_vec >>= 8; \
  127. modbit += 8; \
  128. content \
  129. }
  130. LOOP(LOOP(LOOP(LOOP())))
  131. #undef LOOP
  132.  
  133. __m256i tmp = _mm256_i32gather_epi32((int*)a, (__m256i)(~((rr_vec & 1) - 1) & (rr_vec - 1)), 4);
  134. __m256i mask = _mm256_cmpgt_epi32(_mm256_set1_epi32(maxmodbit & 7), indexes);
  135. vec_ans = (vecint)reduce((__m256i)vec_ans, _mm256_blendv_epi8(identity_vec, tmp, mask));
  136. }
  137. else
  138. vec_ans[0] = reduce(vec_ans[0], a[(r + N) >> mbit]);
  139.  
  140. // vec_ans = 7 6 5 4 3 2 1 0
  141. __m128i low128 = _mm256_castsi256_si128((__m256i)vec_ans); // 3 2 1 0
  142. __m128i high128 = _mm256_extractf128_si256((__m256i)vec_ans, 1); // 7 6 5 4
  143. __m128i ans128 = reduce(low128, high128); // 7+3 6+2 5+1 4+0
  144. T ans = identity_element;
  145. for(int i = 0; i < 4; i++) ans = reduce(ans, ((T __attribute__((vector_size(16))))ans128)[i]);
  146. return ans;
  147. }
  148.  
  149. T pref[N];
  150. T query_prefix(int l, int r) { return pref[r] - (l == 0 ? 0 : pref[l - 1]); }
  151.  
  152. T fenwick[N];
  153.  
  154. T query_fenwick(int r)
  155. {
  156. T ans = 0;
  157. for (; r >= 0; r = (r & (r + 1)) - 1) ans += fenwick[r];
  158. return ans;
  159. }
  160.  
  161. T query_fenwick(int l, int r)
  162. {
  163. return query_fenwick(r) - (l == 0 ? 0 : query_fenwick(l - 1));
  164. }
  165.  
  166. int main()
  167. {
  168. //std::mt19937 rng(1447);
  169. std::mt19937 rng(std::chrono::steady_clock::now().time_since_epoch().count());
  170.  
  171. std::pair<int, int>* queries = new std::pair<int, int>[Q];
  172. for(int i = 0; i < Q; i++)
  173. {
  174. int l = rng() % N;
  175. int r = rng() % N;
  176. if(l > r) std::swap(l, r);
  177. queries[i] = {l, r};
  178. }
  179. for(int i = 0; i < N; i++) a[N + i] = rng() % (32768);
  180. for(int i = N - 1; i >= 1; i--) a[i] = reduce(a[i * 2], a[i * 2 + 1]);
  181. a[0] = identity_element;
  182.  
  183. for(int i = 0; i < N; i++) pref[i] = (i == 0 ? 0 : pref[i - 1]) + a[N + i];
  184.  
  185. for(int i = 0; i < N; i++)
  186. for (int j = i; j < N; j = (j | (j + 1)))
  187. fenwick[j] += a[N + i];
  188.  
  189. #define CHECK(func) { \
  190. auto clock_start = clock(); \
  191. T checksum = 0; \
  192. for(int i = 0; i < Q; i++) { \
  193. checksum += func(queries[i].first, queries[i].second); \
  194. } \
  195. std::cout << #func << ": " << (double)(clock() - clock_start) / CLOCKS_PER_SEC << " seconds (checksum: " << checksum << ")" << std::endl; \
  196. }
  197.  
  198. CHECK(query_recursive_inner)
  199. CHECK(query_recursive_outer)
  200. CHECK(query_bottom_up)
  201. CHECK(query_parallel)
  202. CHECK(query_fenwick)
  203. CHECK(query_prefix)
  204.  
  205. delete [] queries;
  206.  
  207. return 0;
  208. }
Time limit exceeded #stdin #stdout 5s 499700KB
stdin
Standard input is empty
stdout
Standard output is empty