#include <cassert>
using namespace std;

class Box;
class Rabbit {
public:
    ~Rabbit() { disconnect(); }
	void connect(Box& box);
    void disconnect();
	Box* getMyBox() const { return box; }
private:
	Box* box = nullptr;
};
class Box {
public:
    ~Box() { disconnect(); }
	void connect(Rabbit& otherRabbit);
    void disconnect();
	Rabbit* getMyRabbit() const { return rabbit; }
private:
	Rabbit* rabbit = nullptr;
};

inline void Rabbit::connect(Box& otherBox)
{
	if (box != &otherBox)
	{
		disconnect();
		box = &otherBox;
	    box->connect(*this);
	}
}

inline void Rabbit::disconnect()
{
	if (box != nullptr)
	{
		Box* oldBox = box;
		box = nullptr; // do this before calling disconnect on box - otherwise you get infinite recursion...
		oldBox->disconnect();
	}
}

inline void Box::connect(Rabbit& otherRabbit)
{
	if (rabbit != &otherRabbit)
	{
		disconnect();
		rabbit = &otherRabbit;
	    rabbit->connect(*this);
	}
}

inline void Box::disconnect()
{
	if (rabbit != nullptr)
	{
		Rabbit* oldRabbit = rabbit;
		rabbit = nullptr; // do this before calling disconnect on box - otherwise you get infinite recursion...
		oldRabbit->disconnect();
	}
}

int main() {
	Box b1, b2;
	Rabbit r1;

	b1.connect(r1);	
	assert(b1.getMyRabbit() == &r1);
	assert(r1.getMyBox() == &b1);
	
	r1.connect(b2);
	assert(b1.getMyRabbit() == nullptr);
	assert(r1.getMyBox() == &b2);
	assert(b2.getMyRabbit() == &r1);
	
	b2.disconnect();
	assert(b1.getMyRabbit() == nullptr);
	assert(r1.getMyBox() == nullptr);
	assert(b2.getMyRabbit() == nullptr);
	
	{ 
		Rabbit r2;
		r2.connect(b1);
		assert(b1.getMyRabbit() == &r2);
		assert(r2.getMyBox() == &b1);
	}
	assert(b1.getMyRabbit() == nullptr);
	
		
}