#include <algorithm>
#include <iomanip>
#include <iostream>
#include <memory>
#include <vector>
using namespace std;
class Counter {
public:
Counter() { ++count_; }
~Counter() { --count_; }
static int count() { return count_; }
private:
static int count_;
};
int Counter::count_ = 0;
/// Modified code starts here
class MyGraph {
public:
~MyGraph() {
for(auto n: nodes)
n->children.clear();
}
class Node : public Counter, public enable_shared_from_this<Node> {
vector<shared_ptr<Node>> children;
MyGraph* owner;
bool visited;
friend class MyGraph;
void setOwner(MyGraph* graph) {
if (owner && owner != graph)
throw "Can't share nodes between graphs";
if (owner)
return;
owner = graph;
graph->registerNode(shared_from_this());
}
public:
Node() : visited(false), owner(nullptr) {}
void AddChild(const shared_ptr<Node>& node) {
children.push_back(node);
node->setOwner(owner);
}
void RemoveChild(const shared_ptr<Node>& node) {
children.erase(std::remove(children.begin(), children.end(), node), children.end());
}
};
void SetRoot(const shared_ptr<Node>& node) {
root = node;
node->setOwner(this);
}
void registerNode(const shared_ptr<Node>& node) {
nodes.push_back(node);
}
void ShrinkToFit() {
for(auto n: nodes)
n->visited = false;
visit(root);
nodes.erase(std::remove_if(nodes.begin(), nodes.end(), [](shared_ptr<Node>& n) {
if (n->visited)
return false;
n->children.clear();
return true;
}), nodes.end());
}
void visit(const shared_ptr<Node>& n) {
if (n->visited)
return;
n->visited = true;
for(auto c: n->children)
visit(c);
}
static shared_ptr<MyGraph::Node> MakeNode() {
return make_shared<MyGraph::Node>();
}
private:
shared_ptr<Node> root;
vector<shared_ptr<Node>> nodes;
};
/// End modified code
bool TestCase1() {
MyGraph g;
{
auto a = MyGraph::MakeNode();
g.SetRoot(a);
auto b = MyGraph::MakeNode();
a->AddChild(b);
auto c = MyGraph::MakeNode();
b->AddChild(c);
a->RemoveChild(b);
}
g.ShrinkToFit();
return Counter::count() == 1;
}
bool TestCase2() {
MyGraph g;
{
auto a = MyGraph::MakeNode();
g.SetRoot(a);
auto b = MyGraph::MakeNode();
a->AddChild(b);
auto c = MyGraph::MakeNode();
b->AddChild(c);
auto d = MyGraph::MakeNode();
b->AddChild(d);
d->AddChild(b);
a->RemoveChild(b);
}
g.ShrinkToFit();
return Counter::count() == 1;
}
bool TestCase3() {
MyGraph g;
{
auto a = MyGraph::MakeNode();
g.SetRoot(a);
auto b = MyGraph::MakeNode();
a->AddChild(b);
auto c = MyGraph::MakeNode();
b->AddChild(c);
auto d = MyGraph::MakeNode();
b->AddChild(d);
d->AddChild(b);
}
g.ShrinkToFit();
return Counter::count() == 4;
}
bool TestCase4() { // New test case
MyGraph g;
{
auto a = MyGraph::MakeNode();
g.SetRoot(a);
auto b = MyGraph::MakeNode();
a->AddChild(b);
auto c = MyGraph::MakeNode();
b->AddChild(c);
auto d = MyGraph::MakeNode();
b->AddChild(d);
d->AddChild(b);
d->RemoveChild(b);
}
g.ShrinkToFit();
return Counter::count() == 4;
}
int main() {
cout.setf(ios::boolalpha);
bool passed1 = TestCase1();
cout << passed1 << endl;
bool passed2 = TestCase2();
cout << passed2 << endl;
bool passed3 = TestCase3();
cout << passed3 << endl;
bool passed4 = TestCase4();
cout << passed4 << endl;
return 0;
}