fork(1) download
  1. #include <bits/stdc++.h>
  2. using namespace std;
  3.  
  4.  
  5. struct sum_kth_smallest {
  6.  
  7. struct Node {
  8. long long sum;
  9. int cnt;
  10. int lCh, rCh;//children, indexes into `tree`
  11. };
  12.  
  13. int mn, mx;
  14. vector<int> roots;
  15. deque<Node> tree;
  16.  
  17. sum_kth_smallest(const vector<int>& arr) : mn(INT_MAX), mx(INT_MIN), roots(arr.size() + 1, 0) {
  18. tree.push_back({0, 0, 0}); //acts as null
  19. for (int val : arr) mn = min(mn, val), mx = max(mx, val);
  20. for (int i = 0; i < (int)arr.size(); i++)
  21. roots[i + 1] = update(roots[i], -mx, mx, arr[i]);
  22. }
  23. int update(int v, int tl, int tr, int idx) {
  24. if (tl == tr) {
  25. tree.push_back({tree[v].sum + tl, tree[v].cnt + 1, 0, 0});
  26. return tree.size() - 1;
  27. }
  28. int tm = tl + (tr - tl) / 2;
  29. int lCh = tree[v].lCh;
  30. int rCh = tree[v].rCh;
  31. if (idx <= tm)
  32. lCh = update(lCh, tl, tm, idx);
  33. else
  34. rCh = update(rCh, tm + 1, tr, idx);
  35. tree.push_back({tree[lCh].sum + tree[rCh].sum, tree[lCh].cnt + tree[rCh].cnt, lCh, rCh});
  36. return tree.size() - 1;
  37. }
  38.  
  39.  
  40. /* find kth smallest number among arr[l], arr[l+1], ..., arr[r]
  41. * k is 1-based, so find_kth(l,r,1) returns the min
  42. */
  43. int query(int l, int r, int k) const {
  44. assert(1 <= k && k <= r - l + 1); //note this condition implies L <= R
  45. assert(0 <= l && r + 1 < (int)roots.size());
  46. return query(roots[l], roots[r + 1], -mx, mx, k);
  47. }
  48. int query(int vl, int vr, int tl, int tr, int k) const {
  49. if (tl == tr)
  50. return tl;
  51. int tm = tl + (tr - tl) / 2;
  52. int left_count = tree[tree[vr].lCh].cnt - tree[tree[vl].lCh].cnt;
  53. if (left_count >= k) return query(tree[vl].lCh, tree[vr].lCh, tl, tm, k);
  54. return query(tree[vl].rCh, tree[vr].rCh, tm + 1, tr, k - left_count);
  55. }
  56.  
  57. /* find **sum** of k smallest numbers among arr[l], arr[l+1], ..., arr[r]
  58. * k is 1-based, so find_kth(l,r,1) returns the min
  59. */
  60. long long query_sum(int l, int r, int k) const {
  61. assert(1 <= k && k <= r - l + 1); //note this condition implies L <= R
  62. assert(0 <= l && r + 1 < (int)roots.size());
  63. return query_sum(roots[l], roots[r + 1], -mx, mx, k);
  64. }
  65. long long query_sum(int vl, int vr, int tl, int tr, int k) const {
  66. if (tl == tr)
  67. return 1LL * tl * k;
  68. int tm = tl + (tr - tl) / 2;
  69. int left_count = tree[tree[vr].lCh].cnt - tree[tree[vl].lCh].cnt;
  70. long long left_sum = tree[tree[vr].lCh].sum - tree[tree[vl].lCh].sum;
  71. if (left_count >= k) return query_sum(tree[vl].lCh, tree[vr].lCh, tl, tm, k);
  72. return left_sum + query_sum(tree[vl].rCh, tree[vr].rCh, tm + 1, tr, k - left_count);
  73. }
  74. };
  75.  
  76.  
  77. //MUCH RANDOM!!!
  78. seed_seq seed{
  79. (uint32_t)chrono::duration_cast<chrono::nanoseconds>(chrono::high_resolution_clock::now().time_since_epoch()).count(),
  80. (uint32_t)random_device()(),
  81. (uint32_t)(uintptr_t)make_unique<char>().get(),
  82. (uint32_t)__builtin_ia32_rdtsc()
  83. };
  84. mt19937 rng(seed);
  85.  
  86. template<class T>
  87. inline T getRand(T l, T r) {
  88. assert(l <= r);
  89. return uniform_int_distribution<T>(l, r)(rng);
  90. }
  91.  
  92. int main() {
  93. while(true) {
  94. int n = getRand(1, 1000);
  95. cout << "start of new test. n = " << n << endl;
  96. vector<int> arr(n);
  97. for(int i = 0; i < n; i++) {
  98. arr[i] = getRand<int>(0, 1e5);
  99. }
  100. sum_kth_smallest st(arr);
  101. for(int queries = 1000; queries--;) {
  102. int L = getRand(0,n-1), R = getRand(0,n-1);
  103. if(L > R) swap(L,R);
  104. vector<int> subarr(R-L+1);
  105. copy(arr.begin()+L, arr.begin()+R+1, subarr.begin());
  106. sort(subarr.begin(), subarr.end());
  107. int numLess = 0;
  108. long long prefixSum = 0;
  109. for(int k = 1; k <= R-L+1; k++) {
  110. prefixSum += subarr[k-1];
  111. assert(st.query(L,R,k) == subarr[k-1]);
  112. assert(st.query_sum(L,R,k) == prefixSum);
  113. }
  114. }
  115. }
  116. return 0;
  117. }
  118.  
Time limit exceeded #stdin #stdout 5s 5444KB
stdin
Standard input is empty
stdout
start of new test. n = 345
start of new test. n = 143
start of new test. n = 310
start of new test. n = 385
start of new test. n = 838
start of new test. n = 972
start of new test. n = 723
start of new test. n = 586
start of new test. n = 845
start of new test. n = 337
start of new test. n = 771
start of new test. n = 30
start of new test. n = 928
start of new test. n = 209
start of new test. n = 100
start of new test. n = 453
start of new test. n = 49
start of new test. n = 550
start of new test. n = 983
start of new test. n = 185
start of new test. n = 908
start of new test. n = 832
start of new test. n = 536
start of new test. n = 196
start of new test. n = 938
start of new test. n = 422
start of new test. n = 815
start of new test. n = 235
start of new test. n = 182
start of new test. n = 675
start of new test. n = 341
start of new test. n = 164
start of new test. n = 525