/** Catamorphisms */

#include <algorithm>
#include <boost/algorithm/string/join.hpp>
#include <boost/range/algorithm.hpp>
#include <boost/range/adaptors.hpp>
#include <boost/range/numeric.hpp>
#include <boost/variant.hpp>
#include <cassert>
#include <functional>
#include <iostream>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>

using namespace boost::adaptors;


//--------------------------------------------------------
// Recursive operation
//--------------------------------------------------------

using nb = int;
using id = std::string;

struct add_tag {};
struct mul_tag {};

template<typename Tag, typename R>
struct op
{
	op() = default;

	template<typename Range>
	explicit op (Range const& rng) : m_rands(rng.begin(), rng.end()) {}
	
	std::vector<R> const& rands() const { return m_rands; }
	
private:
	std::vector<R> m_rands;
};

template<typename R> using add_op = op<add_tag, R>;
template<typename R> using mul_op = op<mul_tag, R>;

template<typename R>
using expression_r = boost::variant<int, id, add_op<R>, mul_op<R>>;

struct expression : boost::recursive_wrapper<expression_r<expression>>
{
	using boost::recursive_wrapper<expression_r<expression>>::recursive_wrapper;
};

//--------------------------------------------------------
// Smart constructors
//--------------------------------------------------------

expression cst(int i) { return expression(i); };

expression var(id id) { return expression(id); };

expression add(std::vector<expression> const& rands)
{
	return expression(add_op<expression>{ rands });
}

expression mul(std::vector<expression> const& rands)
{
	return expression(mul_op<expression>{ rands });
}

//--------------------------------------------------------
// Query
//--------------------------------------------------------

template <typename T>
int const* get_as_cst(expression_r<T> const& e)
{
	return boost::get<int>(&e);
}

template <typename T>
id const* get_as_var(expression_r<T> const& e)
{
	return boost::get<id>(&e);
}

template <typename T>
add_op<T> const* get_as_add(expression_r<T> const& e)
{
	return boost::get<add_op<T>>(&e);
}

template <typename T>
mul_op<T> const* get_as_mul(expression_r<T> const& e)
{
	return boost::get<mul_op<T>>(&e);
}

void throw_missing_pattern_matching_clause()
{
	throw std::logic_error("Missing case in pattern matching");
}


//--------------------------------------------------------
// FUNCTOR INSTANCE
//--------------------------------------------------------

template<typename A, typename M>
auto fmap(M map, expression_r<A> const& e)
{
	using B = decltype(map(std::declval<A>()));
	using Out = expression_r<B>;
	
	if (auto* o = get_as_add(e))
		return Out(add_op<B>(o->rands() | transformed(map)));
		
	if (auto* o = get_as_mul(e))
		return Out(mul_op<B>(o->rands() | transformed(map)));
	
	if (auto* i = get_as_cst(e)) return Out(*i);
	if (auto* v = get_as_var(e)) return Out(*v);
	throw_missing_pattern_matching_clause();
}


//--------------------------------------------------------
// CATAMORPHISM
//--------------------------------------------------------

template<typename Out, typename Algebra>
Out cata(Algebra f, expression const& ast)
{
	return f(
		fmap(
			[f](expression const& e) -> Out {
				return cata<Out>(f, e);
			},
			ast.get()));
}

//--------------------------------------------------------
// PARAMORPHISM
// The algebra f now takes an expression_r of (Out, expression)
// This allows to access the context of the evaluation
//--------------------------------------------------------

template<typename Out, typename Algebra>
Out para(Algebra f, expression const& ast)
{
	return f(
		fmap(
			[f](expression const& e) -> std::pair<Out, expression const*> {
				return { para<Out>(f, e), &e };
			},
			ast.get()));
}


//--------------------------------------------------------
// DISPLAY
//--------------------------------------------------------

template<typename Tag>
std::string print_op(op<Tag, std::string> const& e, std::string const& op_repr)
{
	return std::string("(") + op_repr + " " + boost::algorithm::join(e.rands(), " ") + ")";
}

std::string print_alg(expression_r<std::string> const& e)
{
	if (auto* o = get_as_add(e)) return print_op(*o, "+");
	if (auto* o = get_as_mul(e)) return print_op(*o, "*");
	if (auto* i = get_as_cst(e)) return std::to_string(*i);
	if (auto* v = get_as_var(e)) return *v;
	throw_missing_pattern_matching_clause();
}


//--------------------------------------------------------
// DISPLAY (INFIX)
//--------------------------------------------------------

std::string print_infix_op_bad(op<add_tag, std::string> const& e)
{
	return boost::algorithm::join(e.rands(), " + ");
}

std::string with_parens(std::string const& s)
{
	return std::string("(") + s + ")";
}

std::string print_infix_op_bad(op<mul_tag, std::string> const& e)
{
	return boost::algorithm::join(e.rands() | transformed(with_parens), " * ");
}

std::string print_infix_bad(expression_r<std::string> const& e)
{
	if (auto* o = get_as_add(e)) return print_infix_op_bad(*o);
	if (auto* o = get_as_mul(e)) return print_infix_op_bad(*o);
	if (auto* i = get_as_cst(e)) return std::to_string(*i);
	if (auto* v = get_as_var(e)) return *v;
	throw_missing_pattern_matching_clause();
}

//--------------------------------------------------------

std::string print_op_infix(op<add_tag, std::pair<std::string, expression const*>> const& e)
{
   auto fst = [](auto const& e) { return e.first; }; 
   return boost::algorithm::join(e.rands() | transformed(fst), " + ");
}

std::string print_op_infix(op<mul_tag, std::pair<std::string, expression const*>> const& e)
{
   auto wrap_addition = [](auto const& sub_expr) {
      if (get_as_add(sub_expr.second->get()))
         return with_parens(sub_expr.first);
      return sub_expr.first;
   };
   return boost::algorithm::join(e.rands() | transformed(wrap_addition), " * ");
}

std::string print_infix(expression_r<std::pair<std::string, expression const*>> const& e)
{
   if (auto* o = get_as_add(e)) return print_op_infix(*o);
   if (auto* o = get_as_mul(e)) return print_op_infix(*o);
   if (auto* i = get_as_cst(e)) return std::to_string(*i);
   if (auto* v = get_as_var(e)) return *v;
   throw_missing_pattern_matching_clause();
}


//--------------------------------------------------------
// EVALUATION
//--------------------------------------------------------

using env = std::map<id, nb>;

auto eval_alg(env const& env)
{
	return [&env] (expression_r<int> const& e)
	{
		if (auto* o = get_as_add(e))
			return boost::accumulate(o->rands(), 0, std::plus<int>());
			
		if (auto* o = get_as_mul(e))
			return boost::accumulate(o->rands(), 1, std::multiplies<int>());
		
		if (auto* v = get_as_var(e)) return env.find(*v)->second;
		if (auto* i = get_as_cst(e)) return *i;
		throw_missing_pattern_matching_clause();
	};
}

int eval(env const& env, expression const& expr)
{
	return cata<int>(eval_alg(env), expr);
}


//--------------------------------------------------------
// DEPENDENCIES
//--------------------------------------------------------

template<typename Tag>
std::set<id> join_sets(op<Tag, std::set<id>> const& op)
{
	std::set<id> out;
	for (auto r: op.rands())
		out.insert(r.begin(), r.end());
	return out;
}

std::set<id> dependencies_alg(expression_r<std::set<id>> const& e)
{
	if (auto* o = get_as_add(e)) return join_sets(*o);
	if (auto* o = get_as_mul(e)) return join_sets(*o);
	if (auto* v = get_as_var(e)) return {*v};
	return {};
}

std::set<id> dependencies(expression const& e)
{
	return cata<std::set<id>>(dependencies_alg, e);
}


//--------------------------------------------------------
// OPTIMIZATIONS
//--------------------------------------------------------

template<typename Tag, typename Step>
expression optimize_op(op<Tag, expression> const& e, int neutral, Step step)
{
	int res = neutral;
	std::vector<expression> subs;
	
	for (expression const& sub: e.rands())
	{
		if (auto* i = get_as_cst(sub.get()))
		{
			res = step(res, *i);
		}
		else
		{
			subs.push_back(sub);
		}
	}
	
	if (subs.empty()) return cst(res);
	if (res != neutral) subs.push_back(cst(res));
	if (subs.size() == 1) return subs.front();
	return expression(op<Tag, expression>(subs));
}

template<typename Range>
bool has_zero(Range const& subs)
{
	return end(subs) != boost::find_if(subs, [](expression const& sub) {
		auto* i = get_as_cst(sub.get());
		return i && *i == 0;
	});
}

expression opt_add_alg(expression_r<expression> const& e)
{
	if (auto* op = get_as_add(e))
		return optimize_op(*op, 0, std::plus<int>());
	return e;
}

expression opt_mul_alg(expression_r<expression> const& e)
{
	if (auto* op = get_as_mul(e))
	{
		if (has_zero(op->rands()))
			return cst(0);
		return optimize_op(*op, 1, std::multiplies<int>());
	}
	return e;
}

expression optimize_alg(expression_r<expression> const& e)
{
	return opt_mul_alg(opt_add_alg(e).get());
}


//--------------------------------------------------------
// PARTIAL EVAL
//--------------------------------------------------------

auto partial_eval_alg(env const& env)
{
	return [&env] (expression_r<expression> const& e) -> expression
	{
		if (auto* v = get_as_var(e))
		{
			auto it = env.find(*v);
			if (it != env.end()) return cst(it->second);
			return var(*v);
		}
		return e;
	};
}

expression partial_eval(env const& env, expression const& e)
{
	return cata<expression>(
		[&env](expression_r<expression> const& e) -> expression {
			return optimize_alg(partial_eval_alg(env)(e).get());
		},
		e);
}


//--------------------------------------------------------
// EVALUATION (Different implementations)
//--------------------------------------------------------

void throw_missing_variables(std::set<id> const& dependencies)
{
	std::ostringstream s;
	for (auto const& d: dependencies)
		s << d << " ";
	throw std::logic_error(s.str());
}

int eval_2(env const& env, expression const& e)
{
	auto reduced = partial_eval(env, e);
	if (auto* i = get_as_cst(reduced.get())) return *i;
	throw_missing_variables(dependencies(reduced));
}


//--------------------------------------------------------
// Tests
//--------------------------------------------------------

int main()
{
	expression e = add({
		cst(1),
		cst(2),
		mul({cst(0), var("x"), var("y")}),
		mul({cst(1), var("y"), add({cst(2), var("x")})}),
		add({cst(0), var("x")})
		});
		
	env full_env = {{"x", 1}, {"y", 2}};
	std::cout << cata<std::string>(print_alg, e) << std::endl;
	std::cout << cata<std::string>(print_infix_bad, e) << std::endl;
	std::cout << para<std::string>(print_infix, e) << std::endl;
	std::cout << eval(full_env, e) << std::endl;
	std::cout << eval_2(full_env, e) << std::endl;

	auto e2 = cata<expression>(partial_eval_alg(full_env), e);
	env empty_env;
	std::cout << cata<std::string>(print_alg, e2) << std::endl; //TODO - chain optimize and partial
	std::cout << eval(empty_env, e2) << std::endl;
	std::cout << eval_2(empty_env, e2) << std::endl;

	auto e3 = cata<expression>(optimize_alg, e);
	std::cout << cata<std::string>(print_alg, e3) << std::endl;
	std::cout << eval(full_env, e3) << std::endl;
	std::cout << eval_2(full_env, e3) << std::endl;

	try {
		eval_2(empty_env, e);
	} catch (std::logic_error const& e) {
		std::cout << e.what() << std::endl;
	}
	return 0;
}