#include <iostream>
#include <vector>
#include <chrono>
using namespace std;
using namespace chrono;
 
// Threshold for switching to the naive method (selected experimentally)
const int THRESHOLD = 1;
 
// Function to print the matrix
void printMatrix(const vector<vector<int>>& matrix) {
    int n = (int) matrix.size();
    int m = (int) matrix[0].size();
 
    // Set the limits for rows and columns to be displayed
    int maxRows = min(n, 4); // Display up to 4 rows
    int maxCols = min(m, 4); // Display up to 4 columns
 
    for (int i = 0; i < maxRows; i++) {
        for (int j = 0; j < maxCols; j++) {
            cout << matrix[i][j] << " ";
        }
 
        // If there are more columns, indicate there are more
        if (m > maxCols) {
            cout << "...";
        }
 
        cout << endl;
    }
 
    // If there are more rows, indicate that there are more
    if (n > maxRows) {
        cout << "..." << endl;
    }
}
 
// Naive matrix multiplication for small matrices
vector<vector<int>> naiveMultiply(const vector<vector<int>>& A, const vector<vector<int>>& B) {
    int n = (int) A.size();
    vector<vector<int>> C(n, vector<int>(n, 0));
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < n; j++) {
            for (int k = 0; k < n; k++) {
                C[i][j] += A[i][k] * B[k][j];
            }
        }
    }
     
    return C;
}
 
// Function to add two matrices
vector<vector<int>> add(const vector<vector<int>>& A, const vector<vector<int>>& B) {
    int n = (int) A.size();
    vector<vector<int>> result(n, vector<int>(n));
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < n; j++) {
            result[i][j] = A[i][j] + B[i][j];
        }
    }
     
    return result;
}
 
// Function to subtract two matrices
vector<vector<int>> subtract(const vector<vector<int>>& A, const vector<vector<int>>& B) {
    int n = (int) A.size();
    vector<vector<int>> result(n, vector<int>(n));
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < n; j++) {
            result[i][j] = A[i][j] - B[i][j];
        }
    }
     
    return result;
}
 
// Function to split a matrix into 4 submatrices
void splitMatrix(const vector<vector<int>>& original, vector<vector<int>>& A11, vector<vector<int>>& A12,
                 vector<vector<int>>& A21, vector<vector<int>>& A22) {
    int newSize = (int) (original.size() / 2);
    for (int i = 0; i < newSize; i++) {
        for (int j = 0; j < newSize; j++) {
            A11[i][j] = original[i][j];
            A12[i][j] = original[i][j + newSize];
            A21[i][j] = original[i + newSize][j];
            A22[i][j] = original[i + newSize][j + newSize];
        }
    }
}
 
// Function to join 4 submatrices into a matrix
void joinMatrices(const vector<vector<int>>& A11, const vector<vector<int>>& A12,
                  const vector<vector<int>>& A21, const vector<vector<int>>& A22,
                  vector<vector<int>>& result) {
    int newSize = (int) A11.size();
    for (int i = 0; i < newSize; i++) {
        for (int j = 0; j < newSize; j++) {
            result[i][j] = A11[i][j];
            result[i][j + newSize] = A12[i][j];
            result[i + newSize][j] = A21[i][j];
            result[i + newSize][j + newSize] = A22[i][j];
        }
    }
}
 
// Hybrid Strassen's matrix multiplication function
vector<vector<int>> strassenMultiply(const vector<vector<int>>& A, const vector<vector<int>>& B) {
    int n = (int) A.size();
     
    // Base case: Use the naive method for small matrices
    if (n <= THRESHOLD) {
        return naiveMultiply(A, B);
    }
 
    // Splitting matrices into submatrices
    int newSize = n / 2;
    vector<vector<int>> A11(newSize, vector<int>(newSize));
    vector<vector<int>> A12(newSize, vector<int>(newSize));
    vector<vector<int>> A21(newSize, vector<int>(newSize));
    vector<vector<int>> A22(newSize, vector<int>(newSize));
    vector<vector<int>> B11(newSize, vector<int>(newSize));
    vector<vector<int>> B12(newSize, vector<int>(newSize));
    vector<vector<int>> B21(newSize, vector<int>(newSize));
    vector<vector<int>> B22(newSize, vector<int>(newSize));
 
    splitMatrix(A, A11, A12, A21, A22);
    splitMatrix(B, B11, B12, B21, B22);
 
    // Computing the 7 products using Strassen's method
    vector<vector<int>> M1 = strassenMultiply(add(A11, A22), add(B11, B22));
    vector<vector<int>> M2 = strassenMultiply(add(A21, A22), B11);
    vector<vector<int>> M3 = strassenMultiply(A11, subtract(B12, B22));
    vector<vector<int>> M4 = strassenMultiply(A22, subtract(B21, B11));
    vector<vector<int>> M5 = strassenMultiply(add(A11, A12), B22);
    vector<vector<int>> M6 = strassenMultiply(subtract(A21, A11), add(B11, B12));
    vector<vector<int>> M7 = strassenMultiply(subtract(A12, A22), add(B21, B22));
 
    // Calculating C submatrices
    vector<vector<int>> C11 = add(subtract(add(M1, M4), M5), M7);
    vector<vector<int>> C12 = add(M3, M5);
    vector<vector<int>> C21 = add(M2, M4);
    vector<vector<int>> C22 = add(subtract(add(M1, M3), M2), M6);
 
    // Joining the 4 submatrices into the result matrix
    vector<vector<int>> C(n, vector<int>(n));
    joinMatrices(C11, C12, C21, C22, C);
 
    return C;
}
 
// Function to measure execution time
void measureExecutionTime(const string& name, vector<vector<int>> (*multiplyFunc)(const vector<vector<int>>&, const vector<vector<int>>&), const vector<vector<int>>& A, const vector<vector<int>>& B) {
    auto start = high_resolution_clock::now();
    vector<vector<int>> C = multiplyFunc(A, B);
    auto end = high_resolution_clock::now();
    auto duration = duration_cast<milliseconds>(end - start).count();
    cout << endl << name << " took " << (duration/1000.0) << " seconds.\n";
  
    cout << "Result Matrix C:" << endl;
    printMatrix(C);
  
}
 
int main() {
    int n = 4; // Example size, should be a power of 2 for simplicity when using Strassen's method
    vector<vector<int>> A(n, vector<int>(n, 1));
    vector<vector<int>> B(n, vector<int>(n, 1));
    
    A[3][0] = 0;
    A[3][1] = 0;
    A[3][2] = 0;
    A[3][3] = 0;
    
    A[0][3] = 0;
    A[1][3] = 0;
    A[2][3] = 0;
    A[3][3] = 0;
    
    B[3][0] = 0;
    B[3][1] = 0;
    B[3][2] = 0;
    B[3][3] = 0;
    
    B[0][3] = 0;
    B[1][3] = 0;
    B[2][3] = 0;
    B[3][3] = 0;
      
    cout << "Matrix A:" << endl;
    printMatrix(A);
      
    cout << endl << "Matrix B:" << endl;
    printMatrix(B);
  
    measureExecutionTime("Naive Multiplication", naiveMultiply, A, B);
    measureExecutionTime("Hybrid Multiplication", strassenMultiply, A, B);
 
    return 0;
}