//
//  main.cpp
//  Prim's Algorithm
//
//  Created by Himanshu on 28/05/23.
//

#include <iostream>
#include <vector>
#include <queue>
#include <limits>

using namespace std;

#define INF numeric_limits<int>::max()
#define N 5

typedef pair<int, int> graphNode; // pair (node, weight)
vector<vector<graphNode>> graph(N);

void printVector(vector<int> vec) {
    
    for (int i=0; i<N; i++) {
        if (vec[i] == INF) {
            cout<<"INF, ";
        } else {
            cout<<vec[i]<<", ";
        }
    }
    
    cout<<endl;
}

void printMST(vector<int>& key, vector<int>& parent) {
    
    cout<<"Minimum Spanning Tree Edges:"<<endl<<endl;
    
    cout<<"Edge \tWeight"<<endl;
    
    for (int i = 1; i < N; ++i) {
        if (parent[i] != -1) {
            cout<<parent[i]<<" - "<<i<<"\t"<<key[i]<<endl;
        }
    }
}

void primMST(int source, vector<int> &key, vector<int> &parent, vector<bool> &inMST) {

    int i = 1;

    // Custom comparator for the priority queue
    auto cmp = [](const graphNode& a, const graphNode& b) {
        return a.second > b.second;
    };

    priority_queue<graphNode, vector<graphNode>, decltype(cmp)> pq(cmp);

    key[source] = 0; // Start with the source node
    pq.push({source, key[source]});

    while (!pq.empty()) {
        int u = pq.top().first;
        pq.pop();

        inMST[u] = true;

        for (auto edge : graph[u]) {
            int v = edge.first;
            int weight = edge.second;

            if (!inMST[v] && weight < key[v]) {
                parent[v] = u;
                key[v] = weight;
                pq.push({v, key[v]});
                
                cout<<"Iteration "<<i++<<":"<<endl;
 
                cout<<"weight: "<<weight<<endl;
                cout<<"edge: "<<u<<" - "<<v<<endl;
 
                cout<<"Key values: ";
                printVector(key);
 
                cout<<"Parent values: ";
                printVector(parent);
                cout<<endl;

                
            }
        }
    }

}

void initializeGraph() {
    
    // Example graph
    // pair {node, weight}
    graph[0].push_back({1, 4});
    graph[0].push_back({2, 3});

    graph[1].push_back({2, 1});
    graph[1].push_back({3, 2});

    graph[2].push_back({3, 4});

    graph[3].push_back({4, 2});
    
}

int main() {
    
    vector<int> parent(N, -1); // stores the parent of each node in the MST
    vector<int> key(N, INF); // stores the minimum weight to reach each node
    vector<bool> inMST(N, false); // indicates whether a node is already in the MST
    int source = 0; // source node
    
    initializeGraph();

    primMST(source, key, parent, inMST);
    
    printMST(key, parent);

    return 0;
}
