#include <cassert>
using namespace std;
class Box;
class Rabbit {
public:
~Rabbit() { disconnect(); }
void connect(Box& box);
void disconnect();
Box* getMyBox() const { return box; }
private:
Box* box = nullptr;
};
class Box {
public:
~Box() { disconnect(); }
void connect(Rabbit& otherRabbit);
void disconnect();
Rabbit* getMyRabbit() const { return rabbit; }
private:
Rabbit* rabbit = nullptr;
};
inline void Rabbit::connect(Box& otherBox)
{
if (box != &otherBox)
{
disconnect();
box = &otherBox;
box->connect(*this);
}
}
inline void Rabbit::disconnect()
{
if (box != nullptr)
{
Box* oldBox = box;
box = nullptr; // do this before calling disconnect on box - otherwise you get infinite recursion...
oldBox->disconnect();
}
}
inline void Box::connect(Rabbit& otherRabbit)
{
if (rabbit != &otherRabbit)
{
disconnect();
rabbit = &otherRabbit;
rabbit->connect(*this);
}
}
inline void Box::disconnect()
{
if (rabbit != nullptr)
{
Rabbit* oldRabbit = rabbit;
rabbit = nullptr; // do this before calling disconnect on box - otherwise you get infinite recursion...
oldRabbit->disconnect();
}
}
int main() {
Box b1, b2;
Rabbit r1;
b1.connect(r1);
assert(b1.getMyRabbit() == &r1);
assert(r1.getMyBox() == &b1);
r1.connect(b2);
assert(b1.getMyRabbit() == nullptr);
assert(r1.getMyBox() == &b2);
assert(b2.getMyRabbit() == &r1);
b2.disconnect();
assert(b1.getMyRabbit() == nullptr);
assert(r1.getMyBox() == nullptr);
assert(b2.getMyRabbit() == nullptr);
{
Rabbit r2;
r2.connect(b1);
assert(b1.getMyRabbit() == &r2);
assert(r2.getMyBox() == &b1);
}
assert(b1.getMyRabbit() == nullptr);
}