fork download
  1. #include <bits/stdc++.h>
  2. using namespace std;
  3.  
  4. #ifdef LOCAL
  5. #define DEBUG(...) debug(#__VA_ARGS__, __VA_ARGS__)
  6. #else
  7. #define DEBUG(...) 6
  8. #endif
  9.  
  10. template<typename T, typename S> ostream& operator << (ostream &os, const pair<T, S> &p) {return os << "(" << p.first << ", " << p.second << ")";}
  11. template<typename C, typename T = decay<decltype(*begin(declval<C>()))>, typename enable_if<!is_same<C, string>::value>::type* = nullptr>
  12. ostream& operator << (ostream &os, const C &c) {bool f = true; os << "["; for (const auto &x : c) {if (!f) os << ", "; f = false; os << x;} return os << "]";}
  13. template<typename T> void debug(string s, T x) {cerr << "\033[1;35m" << s << "\033[0;32m = \033[33m" << x << "\033[0m\n";}
  14. template<typename T, typename... Args> void debug(string s, T x, Args... args) {for (int i=0, b=0; i<(int)s.size(); i++) if (s[i] == '(' || s[i] == '{') b++; else
  15. if (s[i] == ')' || s[i] == '}') b--; else if (s[i] == ',' && b == 0) {cerr << "\033[1;35m" << s.substr(0, i) << "\033[0;32m = \033[33m" << x << "\033[31m | "; debug(s.substr(s.find_first_not_of(' ', i + 1)), args...); break;}}
  16.  
  17. #include <ext/pb_ds/assoc_container.hpp>
  18. using namespace __gnu_pbds;
  19.  
  20. using ordered_set = tree<int, null_type, less<int>, rb_tree_tag, tree_order_statistics_node_update>;
  21.  
  22. // https://c...content-available-to-author-only...s.com/blog/entry/95575
  23.  
  24. #define len(a) (ll)a.size()
  25. typedef long long ll;
  26. typedef long double ld;
  27.  
  28. bool test=false;
  29. ll mod1=1e9+7;
  30. ll mod2=998244353;
  31. ll inf=1e10+5;
  32.  
  33. enum colour {red, black};
  34.  
  35. struct Node {
  36.  
  37. ll data, size;
  38. bool colour;
  39. Node *parent, *left, *right;
  40. };
  41.  
  42. class RBTree {
  43. public: // added this to gain access to tNil
  44.  
  45. Node *root, *tNil;
  46.  
  47. Node* findNode(Node *z) {
  48.  
  49. Node *x=this->root;
  50.  
  51. while(x!=tNil) {
  52.  
  53. if(x->data==z->data) {
  54.  
  55. return x;
  56. } else if(z->data<x->data) {
  57.  
  58. x=x->left;
  59. } else {
  60.  
  61. x=x->right;
  62. }
  63. }
  64.  
  65. return tNil;
  66. }
  67.  
  68. Node* findMin(Node *z) {
  69.  
  70. while(z->left!=tNil) {
  71.  
  72. z=z->left;
  73. }
  74.  
  75. return z;
  76. }
  77.  
  78. void leftRotate(Node *x) {
  79.  
  80. if(x->right==tNil) {
  81.  
  82. return;
  83. }
  84.  
  85. Node *y=x->right;
  86. x->right=y->left;
  87.  
  88. if(y->left!=tNil) {
  89.  
  90. y->left->parent=x;
  91. }
  92.  
  93. y->parent=x->parent;
  94.  
  95. if(x->parent==tNil) {
  96.  
  97. this->root=y;
  98. } else if(x==x->parent->left) {
  99.  
  100. x->parent->left=y;
  101. } else {
  102.  
  103. x->parent->right=y;
  104. }
  105.  
  106. y->left=x;
  107. x->parent=y;
  108. y->size=x->size;
  109. x->size=x->left->size+x->right->size+1;
  110. }
  111.  
  112. void rightRotate(Node *y) {
  113.  
  114. if(y->left==tNil) {
  115.  
  116. return;
  117. }
  118.  
  119. Node *x=y->left;
  120. y->left=x->right;
  121.  
  122. if(x->right!=tNil) {
  123.  
  124. x->right->parent=y;
  125. }
  126.  
  127. x->parent=y->parent;
  128.  
  129. if(y->parent==tNil) {
  130.  
  131. this->root=x;
  132. } else if(y==y->parent->right) {
  133.  
  134. y->parent->right=x;
  135. } else {
  136.  
  137. y->parent->left=x;
  138. }
  139.  
  140. x->right=y;
  141. y->parent=x;
  142. x->size=y->size;
  143. y->size=y->left->size+y->right->size+1;
  144. }
  145.  
  146. void insertFixUp(Node *z) {
  147.  
  148. while(z->parent->colour==red) {
  149.  
  150. if(z->parent==z->parent->parent->left) {
  151.  
  152. Node *y=z->parent->parent->right;
  153.  
  154. if(y->colour==red) {
  155.  
  156. z->parent->colour=black;
  157. y->colour=black;
  158.  
  159. if(z->parent->parent!=this->root) {
  160.  
  161. z->parent->parent->colour=red;
  162. }
  163. z=z->parent->parent;
  164. } else {
  165.  
  166. if(z==z->parent->right) {
  167.  
  168. z=z->parent;
  169. leftRotate(z);
  170. }
  171.  
  172. z->parent->colour=black;
  173. z->parent->parent->colour=red;
  174. rightRotate(z->parent->parent);
  175. }
  176. } else {
  177.  
  178. Node *y=z->parent->parent->left;
  179.  
  180. if(y->colour==red) {
  181.  
  182. z->parent->colour=black;
  183. y->colour=black;
  184.  
  185. if(z->parent->parent!=this->root) {
  186.  
  187. z->parent->parent->colour=red;
  188. }
  189. z=z->parent->parent;
  190. } else {
  191.  
  192. if(z==z->parent->left) {
  193.  
  194. z=z->parent;
  195. rightRotate(z);
  196. }
  197.  
  198. z->parent->colour=black;
  199. z->parent->parent->colour=red;
  200. leftRotate(z->parent->parent);
  201. }
  202. }
  203. }
  204. }
  205.  
  206. void insertNode(Node *z) {
  207.  
  208. if(this->root==tNil) {
  209.  
  210. z->colour=black;
  211. this->root=z;
  212. this->root->size=1;
  213. return;
  214. }
  215.  
  216. Node *y=tNil, *x=this->root;
  217.  
  218. while(x!=tNil) {
  219.  
  220. y=x;
  221. x->size++;
  222.  
  223. if(z->data<x->data) {
  224.  
  225. x=x->left;
  226. } else {
  227.  
  228. x=x->right;
  229. }
  230. }
  231.  
  232. z->parent=y;
  233.  
  234. if(z->data<y->data) {
  235.  
  236. y->left=z;
  237. } else {
  238.  
  239. y->right=z;
  240. }
  241.  
  242. insertFixUp(z);
  243. }
  244.  
  245. void transplant(Node *u, Node *v) {
  246.  
  247. if(u->parent==tNil) {
  248.  
  249. this->root=v;
  250. } else if(u==u->parent->left) {
  251.  
  252. u->parent->left=v;
  253. } else {
  254.  
  255. u->parent->right=v;
  256. }
  257. v->parent=u->parent;
  258. }
  259.  
  260. void deleteFixUp(Node *x) {
  261.  
  262. while(x!=this->root && x->colour==black) {
  263.  
  264. if(x==x->parent->left) {
  265.  
  266. Node *w=x->parent->right;
  267.  
  268. if(w->colour==red) {
  269.  
  270. w->colour=black;
  271. x->parent->colour=red;
  272. leftRotate(x->parent);
  273.  
  274. w=x->parent->right;
  275. }
  276.  
  277. if(w->left->colour==black && w->right->colour==black) {
  278.  
  279. w->colour=red;
  280. x=x->parent;
  281. } else {
  282.  
  283. if(w->right->colour==black) {
  284.  
  285. w->left->colour=black;
  286. w->colour=red;
  287. rightRotate(w);
  288.  
  289. w=x->parent->right;
  290. }
  291.  
  292. w->colour=x->parent->colour;
  293. x->parent->colour=red;
  294. w->right->colour=black;
  295. leftRotate(x->parent);
  296. x=this->root;
  297. }
  298. } else {
  299.  
  300. Node *w=x->parent->left;
  301.  
  302. if(w->colour==red) {
  303.  
  304. w->colour=black;
  305. x->parent->colour=red;
  306. rightRotate(x->parent);
  307.  
  308. w=x->parent->left;
  309. }
  310.  
  311. if(w->right->colour==black && w->left->colour==black) {
  312.  
  313. w->colour=red;
  314. x=x->parent;
  315. } else {
  316.  
  317. if(w->left->colour==black) {
  318.  
  319. w->right->colour=black;
  320. w->colour=red;
  321. leftRotate(w);
  322.  
  323. w=x->parent->left;
  324. }
  325.  
  326. w->colour=x->parent->colour;
  327. x->parent->colour=red;
  328. w->left->colour=black;
  329. rightRotate(x->parent);
  330. x=this->root;
  331. }
  332. }
  333. }
  334. x->colour=black;
  335. }
  336.  
  337. void deleteNode(Node *z) {
  338.  
  339. Node *y=z, *x;
  340. bool originalColour=y->colour;
  341.  
  342. if(z->left==tNil) {
  343.  
  344. x=z->right;
  345. transplant(z, z->right);
  346. } else if(z->right==tNil) {
  347.  
  348. x=z->left;
  349. transplant(z, z->left);
  350. } else {
  351.  
  352. y=findMin(z->right);
  353. originalColour=y->colour;
  354.  
  355. x=y->right;
  356.  
  357. if(y->parent==z) {
  358.  
  359. x->parent=y;
  360. } else {
  361.  
  362. transplant(y, y->right);
  363. y->right=z->right;
  364. y->right->parent=y;
  365.  
  366. Node *s=x->parent;
  367.  
  368. while(s!=tNil && s!=y) {
  369.  
  370. s->size--;
  371. s=s->parent;
  372. }
  373. }
  374.  
  375. transplant(z, y);
  376.  
  377. y->left=z->left;
  378. y->left->parent=y;
  379. y->colour=z->colour;
  380.  
  381. y->size=y->left->size+y->right->size+1;
  382. }
  383.  
  384. if(originalColour==black) {
  385.  
  386. deleteFixUp(x);
  387. }
  388. }
  389.  
  390. void inOrderHelper(Node *node) {
  391.  
  392. if(node==tNil) {
  393.  
  394. return;
  395. }
  396.  
  397. inOrderHelper(node->left);
  398.  
  399. std::cout<<node->data<<" ";
  400.  
  401. inOrderHelper(node->right);
  402. }
  403.  
  404. public:
  405.  
  406. RBTree() {
  407.  
  408. tNil=new Node();
  409. tNil->colour=black;
  410. tNil->size=0;
  411.  
  412. tNil->left=tNil;
  413. tNil->right=tNil;
  414. tNil->parent=tNil;
  415.  
  416. this->root=tNil;
  417. }
  418.  
  419. Node* getRoot() {
  420.  
  421. return this->root;
  422. }
  423.  
  424. Node* find(ll key) {
  425.  
  426. Node *z=new Node();
  427.  
  428. z->data=key;
  429.  
  430. return findNode(z);
  431. }
  432.  
  433. void insert(ll key) {
  434.  
  435. Node *z=new Node();
  436. z->data=key;
  437. z->colour=red;
  438. z->size=1;
  439.  
  440. z->left=tNil;
  441. z->right=tNil;
  442. z->parent=tNil;
  443.  
  444. insertNode(z);
  445. }
  446.  
  447. void erase(ll key) {
  448.  
  449. Node *z=find(key);
  450.  
  451. if(z==tNil) {
  452.  
  453. return;
  454. }
  455.  
  456. Node *s=z->parent;
  457.  
  458. while(s!=tNil) {
  459.  
  460. s->size--;
  461. s=s->parent;
  462. }
  463.  
  464. deleteNode(z);
  465. }
  466.  
  467. ll osSelect(Node *x, ll i) {
  468.  
  469. ll r=x->left->size+1;
  470.  
  471. if(i==r) {
  472.  
  473. return x->data;
  474. } else if(i<r) {
  475.  
  476. return osSelect(x->left, i);
  477. } else {
  478.  
  479. return osSelect(x->right, i-r);
  480. }
  481. }
  482.  
  483. ll osRank(Node *x) {
  484.  
  485. ll r=x->left->size+1;
  486.  
  487. Node *y=x;
  488.  
  489. while(y!=this->root) {
  490.  
  491. if(y==y->parent->right) {
  492.  
  493. r+=y->parent->left->size+1;
  494. }
  495. y=y->parent;
  496. }
  497.  
  498. return r;
  499. }
  500.  
  501. void inOrder() {
  502.  
  503. inOrderHelper(this->root);
  504. }
  505. };
  506.  
  507. const int N = 1e6;
  508. mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
  509. vector<int> input;
  510. clock_t start;
  511. ordered_set st;
  512. RBTree rbt;
  513.  
  514. void testInsert() {
  515. uniform_int_distribution<int> uid(1, 1e9);
  516. set<int> uniq;
  517. while (uniq.size() < N)
  518. uniq.insert(uid(rng));
  519. input.insert(input.end(), uniq.begin(), uniq.end());
  520. shuffle(input.begin(), input.end(), rng);
  521.  
  522. cout << "INSERT" << endl;
  523.  
  524. start = clock();
  525. for (int x : input)
  526. st.insert(x);
  527. cout << "PBDS: " << (double) (clock() - start) / CLOCKS_PER_SEC << endl;
  528.  
  529. start = clock();
  530. for (int x : input)
  531. rbt.insert(x);
  532. cout << "Red Black: " << (double) (clock() - start) / CLOCKS_PER_SEC << endl;
  533. }
  534.  
  535. void testFind() {
  536. uniform_int_distribution<int> uid(1, 1e9);
  537. vector<int> vals(input.begin(), input.end());
  538. vector<bool> ans1, ans2;
  539. for (int i=0; i<N; i+=2)
  540. vals[i] = uid(rng);
  541. shuffle(vals.begin(), vals.end(), rng);
  542. ans1.reserve(N);
  543. ans2.reserve(N);
  544.  
  545. cout << "FIND" << endl;
  546.  
  547. start = clock();
  548. for (int x : vals)
  549. ans1.push_back(st.find(x) != st.end());
  550. cout << "PBDS: " << (double) (clock() - start) / CLOCKS_PER_SEC << endl;
  551.  
  552. start = clock();
  553. for (int x : vals)
  554. ans2.push_back(rbt.find(x) != rbt.tNil);
  555. cout << "Red Black: " << (double) (clock() - start) / CLOCKS_PER_SEC << endl;
  556.  
  557. assert(ans1 == ans2);
  558. }
  559.  
  560. void testRank() {
  561. vector<int> vals(input.begin(), input.end()), ans1, ans2;
  562. shuffle(vals.begin(), vals.end(), rng);
  563. ans1.reserve(N);
  564. ans2.reserve(N);
  565.  
  566. cout << "RANK" << endl;
  567.  
  568. start = clock();
  569. for (int x : vals)
  570. ans1.push_back(st.order_of_key(x));
  571. cout << "PBDS: " << (double) (clock() - start) / CLOCKS_PER_SEC << endl;
  572.  
  573. start = clock();
  574. for (int x : vals)
  575. ans2.push_back(rbt.osRank(rbt.find(x)));
  576. cout << "Red Black: " << (double) (clock() - start) / CLOCKS_PER_SEC << endl;
  577.  
  578. // order_of_key returns strictly smaller instead of less than or equal, so I decrement 1 from ans2 to compensate
  579. for (int &x : ans2)
  580. x--;
  581. assert(ans1 == ans2);
  582. }
  583.  
  584. void testKth() {
  585. uniform_int_distribution<int> uid(0, (int) st.size() - 1);
  586. vector<int> vals(N), vals2, ans1, ans2;
  587. for (int &x : vals)
  588. x = uid(rng);
  589. vals2 = vals;
  590. for (int &x : vals2)
  591. x++; // increment 1 cause RBTree is one-indexed
  592. ans1.reserve(N);
  593. ans2.reserve(N);
  594.  
  595. cout << "KTH" << endl;
  596.  
  597. start = clock();
  598. for (int x : vals)
  599. ans1.push_back(*st.find_by_order(x));
  600. cout << "PBDS: " << (double) (clock() - start) / CLOCKS_PER_SEC << endl;
  601.  
  602. start = clock();
  603. for (int x : vals2)
  604. ans2.push_back(rbt.osSelect(rbt.getRoot(), x));
  605. cout << "Red Black: " << (double) (clock() - start) / CLOCKS_PER_SEC << endl;
  606.  
  607. assert(ans1 == ans2);
  608. }
  609.  
  610. void testErase() {
  611. cout << "ERASE" << endl;
  612.  
  613. start = clock();
  614. for (int x : input)
  615. st.erase(x);
  616. cout << "PBDS: " << (double) (clock() - start) / CLOCKS_PER_SEC << endl;
  617.  
  618. start = clock();
  619. for (int x : input)
  620. rbt.erase(x);
  621. cout << "Red Black: " << (double) (clock() - start) / CLOCKS_PER_SEC << endl;
  622. }
  623.  
  624. int main() {
  625. ios_base::sync_with_stdio(false);
  626. cin.tie(NULL);
  627.  
  628. testInsert();
  629. testFind();
  630. testRank();
  631. testKth();
  632. testErase();
  633.  
  634. return 0;
  635. }
  636.  
Time limit exceeded #stdin #stdout 5s 239392KB
stdin
Standard input is empty
stdout
INSERT
PBDS: 0.711953
Red Black: 0.577574
FIND
PBDS: 0.667535
Red Black: 0.621313
RANK
PBDS: 0.627829