fork(2) download
  1. // Adapted from https://w...content-available-to-author-only...s.org/wavelet-trees-introduction
  2.  
  3. #include <iostream>
  4. #include <vector>
  5. #include <map>
  6. #include <algorithm>
  7. #include <climits>
  8. using namespace std;
  9.  
  10. // wavelet tree class
  11. class wavelet_tree {
  12. public:
  13. // Range to elements
  14. int low, high;
  15.  
  16. // Left and Right child
  17. wavelet_tree* l, *r;
  18.  
  19. std::vector<int> freq;
  20.  
  21. // Default constructor
  22. // Array is in range [x, y]
  23. // Indices are in range [from, to]
  24. wavelet_tree(int* from, int* to, int x, int y)
  25. {
  26. // Initialising low and high
  27. low = x, high = y;
  28.  
  29. // Array is of 0 length
  30. if (from >= to)
  31. return;
  32.  
  33. // Array is homogenous
  34. // Example : 1 1 1 1 1
  35. if (high == low) {
  36. // Assigning storage to freq array
  37. freq.reserve(to - from + 1);
  38.  
  39. // Initialising the Freq array
  40. freq.push_back(0);
  41.  
  42. // Assigning values
  43. for (auto it = from; it != to; it++)
  44.  
  45. // freq will be increasing as there'll
  46. // be no further sub-tree
  47. freq.push_back(freq.back() + 1);
  48.  
  49. return;
  50. }
  51.  
  52. // Computing mid
  53. int mid = (low + high) / 2;
  54.  
  55. // Lambda function to check if a number
  56. // is less than or equal to mid
  57. auto lessThanMid = [mid](int x) {
  58. return x <= mid;
  59. };
  60.  
  61. // Assigning storage to freq array
  62. freq.reserve(to - from + 1);
  63.  
  64. // Initialising the freq array
  65. freq.push_back(0);
  66.  
  67. // Assigning value to freq array
  68. for (auto it = from; it != to; it++)
  69.  
  70. // If lessThanMid returns 1(true), we add
  71. // 1 to previous entry. Otherwise, we add 0
  72. // (element goes to right sub-tree)
  73. freq.push_back(freq.back() + lessThanMid(*it));
  74.  
  75. // std::stable_partition partitions the array w.r.t Mid
  76. auto pivot = std::stable_partition(from, to, lessThanMid);
  77.  
  78. // Left sub-tree's object
  79. l = new wavelet_tree(from, pivot, low, mid);
  80.  
  81. // Right sub-tree's object
  82. r = new wavelet_tree(pivot, to, mid + 1, high);
  83. }
  84.  
  85. // Count of numbers in range[L..R] less than
  86. // or equal to k
  87. int kOrLess(int l, int r, int k)
  88. {
  89. // No elements int range is less than k
  90. if (l > r or k < low)
  91. return 0;
  92.  
  93. // All elements in the range are less than k
  94. if (high <= k)
  95. return r - l + 1;
  96.  
  97. // Computing LtCount and RtCount
  98. int LtCount = freq[l - 1];
  99. int RtCount = freq[r];
  100.  
  101. // Answer is (no. of element <= k) in
  102. // left + (those <= k) in right
  103. return (this->l->kOrLess(LtCount + 1, RtCount, k) +
  104. this->r->kOrLess(l - LtCount, r - RtCount, k));
  105. }
  106.  
  107. // Count of numbers in range[L..R] less than
  108. // or equal to k
  109. int kOrMore(int l, int r, int k)
  110. {
  111. // No elements int range are greater than k
  112. if (l > r or k > high)
  113. return 0;
  114.  
  115. // All elements in the range are greater than k
  116. if (low >= k)
  117. return r - l + 1;
  118.  
  119. // Computing LtCount and RtCount
  120. int LtCount = freq[l - 1];
  121. int RtCount = freq[r];
  122.  
  123. // Answer is (no. of element <= k) in
  124. // left + (those <= k) in right
  125. return (this->l->kOrMore(LtCount + 1, RtCount, k) +
  126. this->r->kOrMore(l - LtCount, r - RtCount, k));
  127. }
  128.  
  129. };
  130.  
  131. // Driver code
  132. int main()
  133. {
  134. int size = 7, high = INT_MIN;
  135. // 1 2 3 4 5 6 7
  136. int arr[] = {1, 2, 3, 2, 4, 3, 1};
  137. int next[size];
  138. std::map<int, int> next_idx;
  139.  
  140. for (int i=size-1; i>=0; i--){
  141. if (next_idx.find(arr[i]) == next_idx.end())
  142. next[i] = size + 1;
  143. else
  144. next[i] = next_idx[arr[i]];
  145. next_idx[arr[i]] = i + 1;
  146. high = max(high, next[i]);
  147. }
  148.  
  149. // Object of class wavelet tree
  150. wavelet_tree obj(next, next + size, 1, high);
  151.  
  152. // Queries are NON-zero-based
  153. //
  154. // 1 2 3 4 5 6 7
  155. // {1, 2, 3, 2, 4, 3, 1};
  156. // query([3, 6]) = 3;
  157. cout << obj.kOrMore(3, 6, 7) << '\n';
  158. // query([1, 4]) = 3;
  159. cout << obj.kOrMore(1, 4, 5) << '\n';
  160. // query([1, 7]) = 4;
  161. cout << obj.kOrMore(1, 7, 8) << '\n';
  162.  
  163. return 0;
  164. }
  165.  
Success #stdin #stdout 0s 4332KB
stdin
Standard input is empty
stdout
3
3
4