#include <algorithm>
#include <iomanip>
#include <iostream>
#include <memory>
#include <vector>
using namespace std;
class Counter {
public:
Counter() { ++count_; }
~Counter() { --count_; }
static int count() { return count_; }
private:
static int count_;
};
int Counter::count_ = 0;
/// Modified code starts here
//
class MyGraph {
public:
class Node : public Counter {
vector<shared_ptr<Node>> children;
vector<weak_ptr<Node>> others;
MyGraph * graph;
bool hasParent;
friend class MyGraph;
void setGraph(MyGraph * g) {
if (graph)
return;
graph = g;
}
public:
Node() : graph(nullptr), hasParent(false) {}
void AddChild(const shared_ptr<Node>& node) {
if(node->hasParent) {
others.push_back(node);
} else {
node->hasParent = true;
children.push_back(node);
}
node->setGraph(graph);
}
void RemoveChild(const shared_ptr<Node>& node) {
auto it = std::find(children.begin(), children.end(), node);
if(it == children.end()) {
others.erase(std::remove_if(others.begin(), others.end(), [&node](weak_ptr<Node> & w) {
if(auto n = w.lock()) {
return n == node;
}
return true;
}));
} else {
children.erase(it);
graph->Remove(node);
node->hasParent = false;
}
}
};
void SetRoot(const shared_ptr<Node>& node) {
root = node;
node->setGraph(this);
}
void Remove(const shared_ptr<Node>& node) {
garbage.push_back(node);
}
void ShrinkToFit() {
// prevent stack overflow by queing the nodes
for(auto i = 0; i < garbage.size(); i++) {
garbage.insert(garbage.end(), garbage[i]->children.begin(), garbage[i]->children.end());
garbage[i]->children.clear();
}
garbage.clear();
}
~MyGraph() {
ShrinkToFit();
}
static shared_ptr<MyGraph::Node> MakeNode() {
return make_shared<MyGraph::Node>();
}
private:
shared_ptr<Node> root;
vector<shared_ptr<Node>> garbage;
};
/// End modified code
bool TestCase1() {
MyGraph g;
{
auto a = MyGraph::MakeNode();
g.SetRoot(a);
auto b = MyGraph::MakeNode();
a->AddChild(b);
auto c = MyGraph::MakeNode();
b->AddChild(c);
a->RemoveChild(b);
}
g.ShrinkToFit();
return Counter::count() == 1;
}
bool TestCase2() {
MyGraph g;
{
auto a = MyGraph::MakeNode();
g.SetRoot(a);
auto b = MyGraph::MakeNode();
a->AddChild(b);
auto c = MyGraph::MakeNode();
b->AddChild(c);
auto d = MyGraph::MakeNode();
b->AddChild(d);
d->AddChild(b);
a->RemoveChild(b);
}
g.ShrinkToFit();
return Counter::count() == 1;
}
bool TestCase3() {
MyGraph g;
{
auto a = MyGraph::MakeNode();
g.SetRoot(a);
auto b = MyGraph::MakeNode();
a->AddChild(b);
auto c = MyGraph::MakeNode();
b->AddChild(c);
auto d = MyGraph::MakeNode();
b->AddChild(d);
d->AddChild(b);
}
g.ShrinkToFit();
return Counter::count() == 4;
}
bool TestCase4() { // New test case
MyGraph g;
{
auto a = MyGraph::MakeNode();
g.SetRoot(a);
auto b = MyGraph::MakeNode();
a->AddChild(b);
auto c = MyGraph::MakeNode();
b->AddChild(c);
auto d = MyGraph::MakeNode();
b->AddChild(d);
d->AddChild(b);
d->RemoveChild(b);
}
g.ShrinkToFit();
return Counter::count() == 4;
}
bool TestCase5() {
MyGraph g;
{
auto a = MyGraph::MakeNode();
g.SetRoot(a);
auto b = MyGraph::MakeNode();
a->AddChild(b);
auto c = MyGraph::MakeNode();
a->AddChild(c);
b->AddChild(c);
c->AddChild(b);
b->RemoveChild(c);
a->RemoveChild(b);
}
g.ShrinkToFit();
return Counter::count() == 2;
}
int main() {
cout.setf(ios::boolalpha);
bool passed1 = TestCase1();
cout << passed1 << endl;
bool passed2 = TestCase2();
cout << passed2 << endl;
bool passed3 = TestCase3();
cout << passed3 << endl;
bool passed4 = TestCase4();
cout << passed4 << endl;
bool passed5 = TestCase5();
cout << passed5 << endl;
return 0;
}