fork(1) download
  1. #include <algorithm>
  2. #include <iomanip>
  3. #include <iostream>
  4. #include <memory>
  5. #include <vector>
  6. using namespace std;
  7.  
  8. class Counter {
  9. public:
  10. Counter() { ++count_; }
  11. ~Counter() { --count_; }
  12. static int count() { return count_; }
  13. private:
  14. static int count_;
  15. };
  16.  
  17. int Counter::count_ = 0;
  18.  
  19. /// Modified code starts here
  20.  
  21. class MyGraph {
  22. public:
  23. ~MyGraph() {
  24. for(auto n: nodes)
  25. n->children.clear();
  26. }
  27.  
  28. class Node : public Counter, public enable_shared_from_this<Node> {
  29. vector<shared_ptr<Node>> children;
  30. MyGraph* owner;
  31. bool visited;
  32. friend class MyGraph;
  33.  
  34. void setOwner(MyGraph* graph) {
  35. if (owner && owner != graph)
  36. throw "Can't share nodes between graphs";
  37. if (owner)
  38. return;
  39.  
  40. owner = graph;
  41. graph->registerNode(shared_from_this());
  42. }
  43.  
  44. public:
  45. Node() : visited(false), owner(nullptr) {}
  46. void AddChild(const shared_ptr<Node>& node) {
  47. children.push_back(node);
  48. node->setOwner(owner);
  49. }
  50.  
  51.  
  52. void RemoveChild(const shared_ptr<Node>& node) {
  53. children.erase(std::remove(children.begin(), children.end(), node), children.end());
  54. }
  55. };
  56.  
  57. void SetRoot(const shared_ptr<Node>& node) {
  58. root = node;
  59. node->setOwner(this);
  60. }
  61.  
  62. void registerNode(const shared_ptr<Node>& node) {
  63. nodes.push_back(node);
  64. }
  65.  
  66. void ShrinkToFit() {
  67. for(auto n: nodes)
  68. n->visited = false;
  69.  
  70. visit(root);
  71. nodes.erase(std::remove_if(nodes.begin(), nodes.end(), [](shared_ptr<Node>& n) {
  72. if (n->visited)
  73. return false;
  74. n->children.clear();
  75. return true;
  76. }), nodes.end());
  77. }
  78.  
  79. void visit(const shared_ptr<Node>& n) {
  80. if (n->visited)
  81. return;
  82.  
  83. n->visited = true;
  84. for(auto c: n->children)
  85. visit(c);
  86. }
  87.  
  88. static shared_ptr<MyGraph::Node> MakeNode() {
  89. return make_shared<MyGraph::Node>();
  90. }
  91.  
  92. private:
  93. shared_ptr<Node> root;
  94. vector<shared_ptr<Node>> nodes;
  95. };
  96.  
  97. /// End modified code
  98. bool TestCase1() {
  99. MyGraph g;
  100. {
  101. auto a = MyGraph::MakeNode();
  102. g.SetRoot(a);
  103. auto b = MyGraph::MakeNode();
  104. a->AddChild(b);
  105. auto c = MyGraph::MakeNode();
  106. b->AddChild(c);
  107. a->RemoveChild(b);
  108. }
  109. g.ShrinkToFit();
  110. return Counter::count() == 1;
  111. }
  112.  
  113. bool TestCase2() {
  114. MyGraph g;
  115. {
  116. auto a = MyGraph::MakeNode();
  117. g.SetRoot(a);
  118. auto b = MyGraph::MakeNode();
  119. a->AddChild(b);
  120. auto c = MyGraph::MakeNode();
  121. b->AddChild(c);
  122. auto d = MyGraph::MakeNode();
  123. b->AddChild(d);
  124. d->AddChild(b);
  125. a->RemoveChild(b);
  126. }
  127. g.ShrinkToFit();
  128. return Counter::count() == 1;
  129. }
  130.  
  131. bool TestCase3() {
  132. MyGraph g;
  133. {
  134. auto a = MyGraph::MakeNode();
  135. g.SetRoot(a);
  136. auto b = MyGraph::MakeNode();
  137. a->AddChild(b);
  138. auto c = MyGraph::MakeNode();
  139. b->AddChild(c);
  140. auto d = MyGraph::MakeNode();
  141. b->AddChild(d);
  142. d->AddChild(b);
  143. }
  144. g.ShrinkToFit();
  145. return Counter::count() == 4;
  146. }
  147.  
  148.  
  149. bool TestCase4() { // New test case
  150. MyGraph g;
  151. {
  152. auto a = MyGraph::MakeNode();
  153. g.SetRoot(a);
  154. auto b = MyGraph::MakeNode();
  155. a->AddChild(b);
  156. auto c = MyGraph::MakeNode();
  157. b->AddChild(c);
  158. auto d = MyGraph::MakeNode();
  159. b->AddChild(d);
  160. d->AddChild(b);
  161. d->RemoveChild(b);
  162. }
  163. g.ShrinkToFit();
  164. return Counter::count() == 4;
  165. }
  166.  
  167. int main() {
  168. cout.setf(ios::boolalpha);
  169. bool passed1 = TestCase1();
  170. cout << passed1 << endl;
  171. bool passed2 = TestCase2();
  172. cout << passed2 << endl;
  173. bool passed3 = TestCase3();
  174. cout << passed3 << endl;
  175.  
  176. bool passed4 = TestCase4();
  177. cout << passed4 << endl;
  178. return 0;
  179. }
Success #stdin #stdout 0s 3500KB
stdin
Standard input is empty
stdout
true
true
true
true