#include <benchmark/benchmark.h>
#include <vector>
#include <random>

using namespace std;

const int MOD = 998'244'353;
int add (int a, int b) { return a + b - (a + b < MOD ? 0 : MOD); }
int sub (int a, int b) { return a - b + (a - b >= 0 ? 0 : MOD); }
int mul (int a, int b) { return 1LL * a * b % MOD; }

struct Matrix : vector<int> {
    // initialization
    int n, m;
    Matrix (int n, int m) :
        vector<int>(n * m), n(n), m(m) {}
    Matrix (initializer_list<int> init, int row) :
        n(row), m(init.size() / n), vector<int>(init.begin(), init.end()) {}

    // access operators for different scenarios
    int* operator[] (int i) { return data() + i * m; }
    const int* operator[] (int i) const { return const_cast<int*>(data()) + i * m; }
};

static void matMulOriginal (benchmark::State &state) {
    int n = state.range(0);
    Matrix a(n, n), b(n, n);

    mt19937 rng(21);
    for (int i = 0; i < n; i++)
        for (int j = 0; j < n; j++)
            a[i][j] = rng() % MOD, b[i][j] = rng() % MOD;
    
    for (auto _ : state) {
        Matrix c(n, n);
        for (int i = 0; i < n; i++)
            for (int j = 0; j < n; j++)
                for (int k = 0; k < n; k++)
                    c[i][j] = add(c[i][j], mul(a[i][k], b[k][j]));
        
        benchmark::DoNotOptimize(c.data());
        benchmark::ClobberMemory();
    }
}
BENCHMARK(matMulOriginal)
    ->RangeMultiplier(2)
    ->Range(1 << 1, 1 << 10);

static void matMulTranspose (benchmark::State &state) {
    int n = state.range(0);
    Matrix a(n, n), b(n, n);

    mt19937 rng(21);
    for (int i = 0; i < n; i++)
        for (int j = 0; j < n; j++)
            a[i][j] = rng() % MOD, b[i][j] = rng() % MOD;
    
    for (auto _ : state) {
        Matrix bT(n, n), c(n, n);
        for (int i = 0; i < n; i++)
            for (int j = 0; j < n; j++) bT[i][j] = b[j][i];
        for (int i = 0; i < n; i++)
            for (int j = 0; j < n; j++)
                for (int k = 0; k < n; k++)
                    c[i][j] = add(c[i][j], mul(a[i][k], bT[j][k]));
                
        benchmark::DoNotOptimize(c.data());
        benchmark::ClobberMemory();
    }
}
BENCHMARK(matMulTranspose)
    ->RangeMultiplier(2)
    ->Range(1 << 1, 1 << 10);

const int TILESIZE = 16;
int bCached[TILESIZE][TILESIZE];

static void matMulTiling (benchmark::State &state) {
    int n = state.range(0);
    Matrix a(n, n), b(n, n);

    mt19937 rng(21);
    for (int i = 0; i < n; i++)
        for (int j = 0; j < n; j++)
            a[i][j] = rng() % MOD, b[i][j] = rng() % MOD;
    
    for (auto _ : state) {
        Matrix c(a.n, b.m);
        for (int iTile = 0; iTile < a.n; iTile += TILESIZE) {
            int iSize = min(TILESIZE, a.n - iTile);
            for (int jTile = 0; jTile < b.m; jTile += TILESIZE) {
                int jSize = min(TILESIZE, b.m - jTile);
                for (int kTile = 0; kTile < a.m; kTile += TILESIZE) {
                    int kSize = min(TILESIZE, a.m - kTile);
                    // transfer data to be cached for b + in-place transpose
                    for (int k = 0; k < kSize; k++)
                        for (int j = 0; j < jSize; j++)
                            bCached[j][k] = b[k + kTile][j + jTile];
                    
                    // perform matrix multiplication for current block
                    for (int i = 0; i < iSize; i++) {
                        // dot product between 2 cached rows
                        for (int j = 0; j < jSize; j++) {
                            unsigned long long hold = c[i + iTile][j + jTile];
                            for (int k = 0; k < kSize; k++)
                                hold += 1ULL * a[i + iTile][k + kTile] * bCached[j][k];
                            hold %= MOD, c[i + iTile][j + jTile] = hold;
                        }
                    }
                }
            }
        }
        
        benchmark::DoNotOptimize(c.data());
        benchmark::ClobberMemory();
    }
}
BENCHMARK(matMulTiling)
    ->RangeMultiplier(2)
    ->Range(1 << 1, 1 << 10);

BENCHMARK_MAIN();