#include <stdio.h>
#include <stdlib.h>
#include <assert.h>
#include <string.h>

//#include "mmat.h"
#include <stdbool.h>

// Matrix.h
    typedef struct
    {
        size_t y_dim, x_dim;
        float **cells;
    } Matrix;

    Matrix *matrix_construct(size_t y_dim, size_t x_dim);	// constructs zeroed
    void matrix_set_values(Matrix *dst, float *row_major_floats);
    void matrix_free(Matrix *m);     // frees m as well as cells

    void matrix_assign(Matrix *dst, Matrix const *src);   // may re-size dst
    bool matrix_is_equal(Matrix const *m1, Matrix const *m2);   

    typedef Matrix *MatrixBinaryOperator(Matrix const *m1, Matrix const *m2);
    MatrixBinaryOperator matrix_multiply;
    bool matrix_operate_self(MatrixBinaryOperator *op, Matrix *m1, Matrix const *m2);

//----- end of mmat.h


Matrix *matrix_construct(size_t y_dim, size_t x_dim)
{
	Matrix *m = malloc( sizeof *m );

	m->x_dim = x_dim;
	m->y_dim = y_dim;

// assume IEEE754 floats (all bits zero = 0)
	m->cells = malloc( y_dim * sizeof *m->cells );
	for ( size_t ii = 0; ii < y_dim; ++ii )
		m->cells[ii] = calloc( x_dim, sizeof *m->cells[ii] );

	return m;
}

void matrix_set_values(Matrix *dst, float *row_major_floats)
{
	for (size_t y = 0; y < dst->y_dim; ++y)
		for (size_t x = 0; x < dst->x_dim; ++x)
			dst->cells[y][x] = *row_major_floats++;
}

bool matrix_is_equal(Matrix const *m1, Matrix const *m2)
{
	if ( m1->y_dim != m2->y_dim ) return false;
	if ( m1->x_dim != m2->x_dim ) return false;

	for (size_t y = 0; y < m1->y_dim; ++y)
		if ( memcmp(m1->cells[y], m2->cells[y], m1->x_dim * sizeof m1->cells[0][0]) )
				return false;

	return true;
}

bool matrix_operate_self(MatrixBinaryOperator *op, Matrix *m1, Matrix const *m2)
{
	Matrix *result = op(m1, m2);
	if ( !result ) return true;
	matrix_assign(m1, result);
	matrix_free(result);
	return false;
}

static void matrix_deallocate(Matrix *m)
{
	for ( size_t y = 0; y < m->y_dim; ++y )
		free(m->cells[y]);
	free(m->cells);

	m->x_dim = 0;
	m->y_dim = 0;
	m->cells = 0;
}

void matrix_free(Matrix *m)
{
	matrix_deallocate(m);
	free(m);
}

void matrix_assign(Matrix *dst, Matrix const *src)
{
	// note: could optimize this to not allocate if dims the same
	matrix_deallocate(dst);
	Matrix *new = matrix_construct(src->y_dim, src->x_dim);
	for (size_t y = 0; y < src->y_dim; ++y)
		memcpy(new->cells[y], src->cells[y], src->x_dim * sizeof new->cells[0][0]);

	if ( !new ) exit(EXIT_FAILURE);
	*dst = *new;
	free(new);
}

Matrix *matrix_multiply(Matrix const *m1, Matrix const *m2)
{
	if ( m1->x_dim != m2->y_dim )
		return NULL;

	Matrix *new = matrix_construct(m1->y_dim, m2->x_dim);

	for (size_t col = 0; col < m2->x_dim; ++col)	// Each column of m2
		for (size_t row = 0; row < m1->y_dim; ++row)	// Each row of m1
		{
		// do dot-product of those two
			float sum = 0;

			for (size_t ii = 0; ii < m2->y_dim; ++ii)
				sum += m2->cells[ii][col] * m1->cells[row][ii];

			new->cells[row][col] = sum;
		}

	return new;
}

static void matrix_printf(char const *prompt, Matrix const *mat)
{
	printf("%20.20s:", prompt);
	for (size_t y = 0; y < mat->y_dim; ++y)
	{
		if (y > 0 ) printf("%20.20s ", "");
		printf("( ");
		for (size_t x = 0; x < mat->x_dim; ++x)
			printf("%3.3f ", mat->cells[y][x]);
		printf(")\n");
	}

}
int main()
{

    float mat1_array[3][2] = {{0, 1}, {3, 4}, {6, 7}};
    float mat2_array[2][3] = {{5, 1, 2}, {3, 4, 5}}; 

    Matrix *mat1 = matrix_construct(3, 2);
    Matrix *mat2 = matrix_construct(2, 3);
    matrix_set_values(mat1, (float *)&mat1_array);
    matrix_set_values(mat2, (float *)&mat2_array);

    Matrix *mat3 = matrix_multiply(mat1, mat2);

	matrix_printf("mat1", mat1);
	matrix_printf("mat2", mat2);
	matrix_printf("mat3", mat3);

    matrix_operate_self(matrix_multiply, mat1, mat2);
	assert( matrix_is_equal(mat1, mat3) );

    matrix_free(mat1);    
    matrix_free(mat2);    
    matrix_free(mat3);    
}
