#include <array>
#include <string>
#include <numeric>
#include <fstream>
#include <iostream>
#include <algorithm>
#include <functional>
#include <cstdlib>
#include <cmath>

constexpr int symbol_count = '~' - ' ' + 1;

typedef float table_f[ symbol_count ][ symbol_count ];
typedef int table_i[ symbol_count ][ symbol_count ];

table_i prediction; // How many times is group A followed by symbol B?
int total[ symbol_count ];

table_f combo_cost; // How many nats are lost by grouping two symbols?

std::string group[ symbol_count ];

void print_column_headings( std::ostream & stats ) {
	stats << '\t';
	for ( int c = 0; c != symbol_count; ++ c ) {
		stats << char( ' ' + c ) << '\t';
	};
	stats << '\n';
}

float logz( int x ) { return x? log( x ) : 0; }

constexpr float bytes_per_nat = 1. / ( log(2) * 8 );

float symbol_combo_cost( int a, int b )
	{ return (a+b) * logz(a+b) - a * logz(a) - b * logz(b); }

void update_table_combo_cost( int a, int b ) {
	combo_cost[ a ][ b ] = symbol_combo_cost( total[ a ], total[ b ] ) - std::inner_product(
		std::begin( prediction[ a ] ), std::end( prediction[ a ] ),
		std::begin( prediction[ b ] ), 0.,
		std::plus< float >(),
		symbol_combo_cost
	);
}

void collect_samples( std::istream & corpus ) {
	std::string item;
	while ( getline( corpus, item ) ) {
		if ( item.empty() ) continue;
		
		for ( int cx = 1; cx < item.size(); ++ cx ) {
			if ( item[ cx - 1 ] < ' ' || item[ cx - 1 ] > '~' ) {
				++ cx;
				std::cerr << "bad line: " << item << '\n';
				continue;
			}
			
			++ prediction[ item[ cx - 1 ] - ' ' ][ item[ cx ] - ' ' ];
		}
	}
	if ( ! corpus.eof() ) std::cerr << "Input failure\n";
}

float coded_size() {
	float acc = 0;
	for ( int t = 0; t != symbol_count; ++ t ) {
		for ( int p = 0; p != symbol_count; ++ p ) {
			acc -= prediction[ t ][ p ] * logz( prediction[ t ][ p ] );
		}
		acc += total[ t ] * logz( total[ t ] );
	}
	return acc;
}

int main( int argc, char ** argv ) {
	if ( argc > 2 ) {
		std::cerr << "Usage: " << argv[0] << " <filename>\n";
		return EXIT_FAILURE;
	}
	
	bool use_pipes = argc != 2;
	
	std::ifstream corpus_file;
	if ( ! use_pipes ) corpus_file.open( argv[1] );
	
	std::istream & corpus = use_pipes? std::cin : corpus_file;
	
	if ( ! corpus ) {
		std::cerr << "Could not open " << argv[1] << "\n";
		std::perror( nullptr );
		return EXIT_FAILURE;
	}
	
	collect_samples( corpus );
	
	for ( int sx = 0; sx != symbol_count; ++ sx ) {
		total[ sx ] = std::accumulate( std::begin( prediction[ sx ] ), std::end( prediction[ sx ] ), 0 );
	}
	
	for ( int lo = 0; lo != symbol_count; ++ lo ) {
		for ( int hi = 0; hi != lo + 1; ++ hi ) {
			combo_cost[ lo ][ hi ] = std::numeric_limits< float >::infinity();
		}
		for ( int hi = lo + 1; hi != symbol_count; ++ hi ) {
			update_table_combo_cost( lo, hi );
		}
	}
	
	std::ofstream stats_file;
	if ( ! use_pipes ) stats_file.open( "stats.txt" );
	std::ostream & stats = use_pipes? std::cout : stats_file;
	{
		print_column_headings( stats );
		
		for ( int row = 0; row != symbol_count; ++ row ) {
			stats << char( ' ' + row ) << ':' << '\t';
			
			for ( int col = 0; col != symbol_count; ++ col ) {
				stats << prediction[ row ][ col ] << '\t';
			}
			stats << '\n';
		}
		stats << '\n';
		
		stats.precision( 0 );
		stats.setf( std::ios::fixed );
		
		print_column_headings( stats );
		for ( int lo = 0; lo != symbol_count; ++ lo ) {
			stats << char( ' ' + lo ) << ':' << std::string( lo + 2, '\t' );
			
			for ( int hi = lo + 1; hi != symbol_count; ++ hi ) {
				stats << combo_cost[ lo ][ hi ] << '\t';
			}
			stats << '\n';
		}
	}
	
	float size_check = coded_size();
	stats << "ideal size = " << size_check * bytes_per_nat << '\n';
	
	for ( int n = 0; n != symbol_count; ++ n ) group[ n ] += char( ' ' + n );
	
	for ( float * max_corr; * ( max_corr = std::min_element( & combo_cost[0][0], & combo_cost[symbol_count-1][symbol_count] ) ) != std::numeric_limits< float >::infinity(); ) {
		int maxx = max_corr - & combo_cost[0][0],
			lo = maxx / symbol_count,
			hi = maxx % symbol_count;
		
		size_check += * max_corr;
		
		stats << * max_corr * bytes_per_nat << " x " << group[ lo ] << " += " << group[ hi ];
		group[ lo ] += group[ hi ];
		std::sort( group[ lo ].begin(), group[ lo ].end() );
		group[ hi ].clear();
		
		for ( int nx = 0; nx != symbol_count; ++ nx ) {
			prediction[ lo ][ nx ] += prediction[ hi ][ nx ];
			prediction[ hi ][ nx ] = 0;
		}
		total[ lo ] += total[ hi ];
		total[ hi ] = 0;
		
		std::fill( std::begin( combo_cost[ hi ] ) + hi + 1, std::end( combo_cost[ hi ] ), std::numeric_limits< float >::infinity() );
		for ( int lo = 0; lo != hi; ++ lo ) combo_cost[ lo ][ hi ] = std::numeric_limits< float >::infinity();
		
		for ( int o = 0; o != lo; ++ o ) {
			if ( group[ o ].empty() ) continue;
			update_table_combo_cost( o, lo );
		}
		for ( int hi = lo + 1; hi != symbol_count; ++ hi ) {
			if ( group[ hi ].empty() ) continue;
			update_table_combo_cost( lo, hi );
		}
		stats << '\n';
	}
	
	stats << "single-table size = " << size_check * bytes_per_nat << " = " << coded_size() * bytes_per_nat << " vs " << total[0] << '\n';
} 
