fork(2) download
  1. #include <iostream>
  2. #include <algorithm>
  3. #include <vector>
  4. #include <cmath>
  5. #include <string>
  6. #include <set>
  7. #include <assert.h>
  8.  
  9. using namespace std;
  10.  
  11. typedef long long ll;
  12.  
  13. struct node {
  14. int key;
  15. ll sum = 0;
  16.  
  17. node *left = NULL;
  18. node *right = NULL;
  19. node *parent = NULL;
  20. node(int key, ll sum, node *left, node *right, node *parent):
  21. key(key), sum(sum), left(left), right(right), parent(parent) {}
  22. };
  23. class ST{
  24. private:
  25. /**
  26.   * Node related methods
  27.   */
  28.  
  29. ll get_sum(node *n){
  30. if (n == NULL) return 0ll;
  31. return n->sum;
  32. }
  33. void set_parent(node *n, node *p){
  34. if (n != NULL) n->parent = p;
  35. }
  36. void update(node *n){
  37. if (n == NULL) return;
  38. n->sum = n->key + get_sum(n->left) + get_sum(n->right);
  39. set_parent(n->left, n);
  40. set_parent(n->right, n);
  41. }
  42. bool has_parent(node *n){
  43. if (n == NULL) return false;
  44. if (n->parent == NULL) return false;
  45. return true;
  46. }
  47. bool has_grand_parent(node *n){
  48. if (n == NULL) return false;
  49. if (n->parent == NULL) return false;
  50. if (n->parent->parent == NULL) return false;
  51. return true;
  52. }
  53. bool is_left_child(node *n){
  54. if (!has_parent(n)) return false;
  55. node *p = n->parent;
  56. return p->left == n;
  57. }
  58. bool is_right_child(node *n){
  59. if (!has_parent(n)) return false;
  60. node *p = n->parent;
  61. return p->right == n;
  62. }
  63. /* end of Node related methods */
  64. /* Splay Tree related methods*/
  65.  
  66. node *root = NULL;
  67. void small_rotation(node *n){
  68. /* just pushes node one level up */
  69. if (!has_parent(n)) return;
  70. node *parent = n->parent;
  71. node *grand_parent = parent->parent;
  72. if (is_left_child(n)){
  73. node *nr = n->right;
  74. n->right = parent;
  75. parent->left = nr;
  76. } else {
  77. node *nl = n->left;
  78. n->left = parent;
  79. parent->right = nl;
  80. }
  81. update(parent);
  82. update(n);
  83. n->parent = grand_parent;
  84. if (grand_parent != NULL){
  85. if (grand_parent->left == parent) grand_parent->left = n;
  86. else grand_parent->right = n;
  87. }
  88. }
  89. void big_rotation(node *n){
  90. if (!has_grand_parent(n)) return;
  91. if (is_left_child(n) && is_left_child(n->parent)){
  92. small_rotation(n->parent);
  93. small_rotation(n);
  94. }
  95. else if (is_right_child(n) && is_right_child(n->parent)){
  96. small_rotation(n->parent);
  97. small_rotation(n);
  98. }
  99. else {
  100. small_rotation(n);
  101. small_rotation(n);
  102. }
  103. }
  104. void splay(node *n){
  105. if (n == NULL) return;
  106. while(has_parent(n)){
  107. if (has_grand_parent(n)){
  108. big_rotation(n);
  109. }
  110. else {
  111. small_rotation(n);
  112. break;
  113. }
  114. }
  115. root = n;
  116. }
  117. /* end of splay tree related methods */
  118. public:
  119. /** public splay tree methods */
  120. ST(){}
  121. ST(node *root): root(root) {}
  122. ~ST(){}
  123. /** returns the smallest node bigger or equal to a given key
  124.   and splays deepest node reached
  125.   returns NULL if key is bigger than all elements
  126.   */
  127. node* find(int key){
  128. node *v = root;
  129. node *last = root;
  130. node *next = NULL;
  131. while(v != NULL){
  132. if (v->key >= key && (next == NULL || v->key < next->key)){
  133. next = v;
  134. }
  135. last = v;
  136. if (v->key == key) break;
  137. else if (v->key > key) v = v->left;
  138. else v=v->right;
  139. }
  140. splay(last);
  141. return next;
  142. }
  143. /** check if key exists */
  144. bool exists(int key){
  145. node *c = find(key);
  146. return (c != NULL && c->key == key);
  147. }
  148. /* splits the tree based on given key, right side will have the key if it exists */
  149. void split(int key, node* &left, node* &right){
  150. if (root == NULL){
  151. left = NULL;
  152. right = NULL;
  153. return;
  154. }
  155. right = find(key);
  156. splay(right); //make right the root of the tree
  157. if (right == NULL){
  158. /* if biggest element is passed */
  159. left = root;
  160. return;
  161. }
  162. left = right->left;
  163. right->left = NULL;
  164. set_parent(left, NULL);
  165. update(left);
  166. update(right);
  167. }
  168. /**
  169.   * merge two trees
  170.   */
  171. static node* merge(node* left, node* right){
  172. if (left == NULL) return right;
  173. if (right == NULL) return left;
  174.  
  175. ST rtree(right);
  176. node *minr = right;
  177. while(minr->left != NULL){
  178. minr = minr->left;
  179. }
  180. rtree.splay(minr);
  181. minr->left = left;
  182. rtree.update(minr);
  183. return minr;
  184. }
  185. /** insert a new key */
  186. void insert(int key){
  187. node *left, *right;
  188. split(key, left, right);
  189. node *nn = NULL;
  190. if (right == NULL || right->key != key){
  191. nn = new node(key, (ll)key, NULL, NULL, NULL);
  192. }
  193. root = ST::merge(ST::merge(left, nn), right);
  194. }
  195. /** erase the key from the tree */
  196. void erase(int key){
  197. node *n = find(key);
  198. if (n != NULL && n->key == key){
  199. node *next = find(key+1);
  200. if (next != NULL){
  201. splay(next);
  202. splay(n);
  203. node *nl = n->left;
  204. next->left = nl;
  205. set_parent(next, NULL);
  206. set_parent(nl, next);
  207. root = next;
  208. update(root);
  209. }
  210. else {
  211. splay(n);
  212. node *nl = n->left;
  213. root = nl;
  214. set_parent(nl, NULL);
  215. splay(root);
  216. }
  217. }
  218. }
  219. /** sum of elements in range [l, r] */
  220. ll sum(int l, int r){
  221. if (l > r) return 0;
  222. node *left, *middle, *right;
  223. split(l, left, middle);
  224. ST mt(middle);
  225. mt.split(r+1, middle, right);
  226. ll ans = 0;
  227. if (middle != NULL) ans += (ll)middle->sum;
  228. node *nmiddle = ST::merge(left, middle);
  229. root = ST::merge(nmiddle, right);
  230. return ans;
  231. }
  232. void print(){
  233. inorder(root);
  234. }
  235. void inorder(node *cur){
  236. if (cur == NULL) return;
  237. inorder(cur->left);
  238. inorder(cur->right);
  239. }
  240. };
  241.  
  242. const int MODULO = 1000000001;
  243.  
  244. int main(){
  245. ios::sync_with_stdio(false);
  246. ST t;
  247. t.insert(0);
  248. t.insert(5);
  249. t.insert(10);
  250. t.insert(15);
  251. t.insert(20);
  252. cout << t.sum(0, 14) << endl; // 0 + 5 + 10
  253. cout << t.sum(5, 20) << endl; // 5 + 10 + 15 + 20
  254. return 0;
  255. }
Success #stdin #stdout 0s 3468KB
stdin
Standard input is empty
stdout
15
50