#include <iostream>
#include <vector>
#include <numeric>
using namespace std;

vector<int> parent, sz;

int find(int i) {
    if(parent[i]==i) return i;
    return parent[i]=find(parent[i]);
}

void merge(int i, int j) {
    int p1=find(i);
    int p2=find(j);
    
    if(p1==p2) return;
    if(sz[p1]<sz[p2]) {
        parent[p1]=p2;
        sz[p2]+=sz[p1];
    } else {
        parent[p2]=p1;
        sz[p1]+=sz[p2];
    }
}

int main() {
	vector<vector<int>> allowedSwaps={{0,4},{4,2},{1,3},{1,4}};

    int n=5;	//hard-code for now
    sz.resize(n,1);
    parent.resize(n);
    iota(begin(parent),end(parent),0);

	cout<<"Parents before: \n";
    for(auto e: parent)
        cout<<e<<" ";
    cout<<"\n";

    for(vector<int>& currswap: allowedSwaps) {
        merge(currswap[0],currswap[1]);
    }
    
    cout<<"Parents after: \n";
    for(auto e: parent)
        cout<<e<<" ";
    cout<<"\n";

	return 0;
}