fork download
  1. /** Catamorphisms */
  2.  
  3. #include <algorithm>
  4. #include <boost/algorithm/string/join.hpp>
  5. #include <boost/range/algorithm.hpp>
  6. #include <boost/range/adaptors.hpp>
  7. #include <boost/range/numeric.hpp>
  8. #include <boost/variant.hpp>
  9. #include <cassert>
  10. #include <functional>
  11. #include <iostream>
  12. #include <map>
  13. #include <memory>
  14. #include <set>
  15. #include <string>
  16. #include <vector>
  17.  
  18. using namespace boost::adaptors;
  19.  
  20.  
  21. //--------------------------------------------------------
  22. // Recursive operation
  23. //--------------------------------------------------------
  24.  
  25. using nb = int;
  26. using id = std::string;
  27.  
  28. struct add_tag {};
  29. struct mul_tag {};
  30.  
  31. template<typename Tag, typename R>
  32. struct op
  33. {
  34. op() = default;
  35.  
  36. template<typename Range>
  37. explicit op (Range const& rng) : m_rands(rng.begin(), rng.end()) {}
  38.  
  39. std::vector<R> const& rands() const { return m_rands; }
  40.  
  41. private:
  42. std::vector<R> m_rands;
  43. };
  44.  
  45. template<typename R> using add_op = op<add_tag, R>;
  46. template<typename R> using mul_op = op<mul_tag, R>;
  47.  
  48. template<typename R>
  49. using expression_r = boost::variant<int, id, add_op<R>, mul_op<R>>;
  50.  
  51. struct expression : boost::recursive_wrapper<expression_r<expression>>
  52. {
  53. using boost::recursive_wrapper<expression_r<expression>>::recursive_wrapper;
  54. };
  55.  
  56. //--------------------------------------------------------
  57. // Smart constructors
  58. //--------------------------------------------------------
  59.  
  60. expression cst(int i) { return expression(i); };
  61.  
  62. expression var(id id) { return expression(id); };
  63.  
  64. expression add(std::vector<expression> const& rands)
  65. {
  66. return expression(add_op<expression>{ rands });
  67. }
  68.  
  69. expression mul(std::vector<expression> const& rands)
  70. {
  71. return expression(mul_op<expression>{ rands });
  72. }
  73.  
  74. //--------------------------------------------------------
  75. // Query
  76. //--------------------------------------------------------
  77.  
  78. template <typename T>
  79. int const* get_as_cst(expression_r<T> const& e)
  80. {
  81. return boost::get<int>(&e);
  82. }
  83.  
  84. template <typename T>
  85. id const* get_as_var(expression_r<T> const& e)
  86. {
  87. return boost::get<id>(&e);
  88. }
  89.  
  90. template <typename T>
  91. add_op<T> const* get_as_add(expression_r<T> const& e)
  92. {
  93. return boost::get<add_op<T>>(&e);
  94. }
  95.  
  96. template <typename T>
  97. mul_op<T> const* get_as_mul(expression_r<T> const& e)
  98. {
  99. return boost::get<mul_op<T>>(&e);
  100. }
  101.  
  102. void throw_missing_pattern_matching_clause()
  103. {
  104. throw std::logic_error("Missing case in pattern matching");
  105. }
  106.  
  107.  
  108. //--------------------------------------------------------
  109. // FUNCTOR INSTANCE
  110. //--------------------------------------------------------
  111.  
  112. template<typename A, typename M>
  113. auto fmap(M map, expression_r<A> const& e)
  114. {
  115. using B = decltype(map(std::declval<A>()));
  116. using Out = expression_r<B>;
  117.  
  118. if (auto* o = get_as_add(e))
  119. return Out(add_op<B>(o->rands() | transformed(map)));
  120.  
  121. if (auto* o = get_as_mul(e))
  122. return Out(mul_op<B>(o->rands() | transformed(map)));
  123.  
  124. if (auto* i = get_as_cst(e)) return Out(*i);
  125. if (auto* v = get_as_var(e)) return Out(*v);
  126. throw_missing_pattern_matching_clause();
  127. }
  128.  
  129.  
  130. //--------------------------------------------------------
  131. // CATAMORPHISM
  132. //--------------------------------------------------------
  133.  
  134. template<typename Out, typename Algebra>
  135. Out cata(Algebra f, expression const& ast)
  136. {
  137. return f(
  138. fmap(
  139. [f](expression const& e) -> Out {
  140. return cata<Out>(f, e);
  141. },
  142. ast.get()));
  143. }
  144.  
  145. //--------------------------------------------------------
  146. // PARAMORPHISM
  147. // The algebra f now takes an expression_r of (Out, expression)
  148. // This allows to access the context of the evaluation
  149. //--------------------------------------------------------
  150.  
  151. template<typename Out, typename Algebra>
  152. Out para(Algebra f, expression const& ast)
  153. {
  154. return f(
  155. fmap(
  156. [f](expression const& e) -> std::pair<Out, expression const*> {
  157. return { para<Out>(f, e), &e };
  158. },
  159. ast.get()));
  160. }
  161.  
  162.  
  163. //--------------------------------------------------------
  164. // DISPLAY
  165. //--------------------------------------------------------
  166.  
  167. template<typename Tag>
  168. std::string print_op(op<Tag, std::string> const& e, std::string const& op_repr)
  169. {
  170. return std::string("(") + op_repr + " " + boost::algorithm::join(e.rands(), " ") + ")";
  171. }
  172.  
  173. std::string print_alg(expression_r<std::string> const& e)
  174. {
  175. if (auto* o = get_as_add(e)) return print_op(*o, "+");
  176. if (auto* o = get_as_mul(e)) return print_op(*o, "*");
  177. if (auto* i = get_as_cst(e)) return std::to_string(*i);
  178. if (auto* v = get_as_var(e)) return *v;
  179. throw_missing_pattern_matching_clause();
  180. }
  181.  
  182.  
  183. //--------------------------------------------------------
  184. // DISPLAY (INFIX)
  185. //--------------------------------------------------------
  186.  
  187. std::string print_infix_op_bad(op<add_tag, std::string> const& e)
  188. {
  189. return boost::algorithm::join(e.rands(), " + ");
  190. }
  191.  
  192. std::string with_parens(std::string const& s)
  193. {
  194. return std::string("(") + s + ")";
  195. }
  196.  
  197. std::string print_infix_op_bad(op<mul_tag, std::string> const& e)
  198. {
  199. return boost::algorithm::join(e.rands() | transformed(with_parens), " * ");
  200. }
  201.  
  202. std::string print_infix_bad(expression_r<std::string> const& e)
  203. {
  204. if (auto* o = get_as_add(e)) return print_infix_op_bad(*o);
  205. if (auto* o = get_as_mul(e)) return print_infix_op_bad(*o);
  206. if (auto* i = get_as_cst(e)) return std::to_string(*i);
  207. if (auto* v = get_as_var(e)) return *v;
  208. throw_missing_pattern_matching_clause();
  209. }
  210.  
  211. //--------------------------------------------------------
  212.  
  213. std::string print_op_infix(op<add_tag, std::pair<std::string, expression const*>> const& e)
  214. {
  215. auto fst = [](auto const& e) { return e.first; };
  216. return boost::algorithm::join(e.rands() | transformed(fst), " + ");
  217. }
  218.  
  219. std::string print_op_infix(op<mul_tag, std::pair<std::string, expression const*>> const& e)
  220. {
  221. auto wrap_addition = [](auto const& sub_expr) {
  222. if (get_as_add(sub_expr.second->get()))
  223. return with_parens(sub_expr.first);
  224. return sub_expr.first;
  225. };
  226. return boost::algorithm::join(e.rands() | transformed(wrap_addition), " * ");
  227. }
  228.  
  229. std::string print_infix(expression_r<std::pair<std::string, expression const*>> const& e)
  230. {
  231. if (auto* o = get_as_add(e)) return print_op_infix(*o);
  232. if (auto* o = get_as_mul(e)) return print_op_infix(*o);
  233. if (auto* i = get_as_cst(e)) return std::to_string(*i);
  234. if (auto* v = get_as_var(e)) return *v;
  235. throw_missing_pattern_matching_clause();
  236. }
  237.  
  238.  
  239. //--------------------------------------------------------
  240. // EVALUATION
  241. //--------------------------------------------------------
  242.  
  243. using env = std::map<id, nb>;
  244.  
  245. auto eval_alg(env const& env)
  246. {
  247. return [&env] (expression_r<int> const& e)
  248. {
  249. if (auto* o = get_as_add(e))
  250. return boost::accumulate(o->rands(), 0, std::plus<int>());
  251.  
  252. if (auto* o = get_as_mul(e))
  253. return boost::accumulate(o->rands(), 1, std::multiplies<int>());
  254.  
  255. if (auto* v = get_as_var(e)) return env.find(*v)->second;
  256. if (auto* i = get_as_cst(e)) return *i;
  257. throw_missing_pattern_matching_clause();
  258. };
  259. }
  260.  
  261. int eval(env const& env, expression const& expr)
  262. {
  263. return cata<int>(eval_alg(env), expr);
  264. }
  265.  
  266.  
  267. //--------------------------------------------------------
  268. // DEPENDENCIES
  269. //--------------------------------------------------------
  270.  
  271. template<typename Tag>
  272. std::set<id> join_sets(op<Tag, std::set<id>> const& op)
  273. {
  274. std::set<id> out;
  275. for (auto r: op.rands())
  276. out.insert(r.begin(), r.end());
  277. return out;
  278. }
  279.  
  280. std::set<id> dependencies_alg(expression_r<std::set<id>> const& e)
  281. {
  282. if (auto* o = get_as_add(e)) return join_sets(*o);
  283. if (auto* o = get_as_mul(e)) return join_sets(*o);
  284. if (auto* v = get_as_var(e)) return {*v};
  285. return {};
  286. }
  287.  
  288. std::set<id> dependencies(expression const& e)
  289. {
  290. return cata<std::set<id>>(dependencies_alg, e);
  291. }
  292.  
  293.  
  294. //--------------------------------------------------------
  295. // OPTIMIZATIONS
  296. //--------------------------------------------------------
  297.  
  298. template<typename Tag, typename Step>
  299. expression optimize_op(op<Tag, expression> const& e, int neutral, Step step)
  300. {
  301. int res = neutral;
  302. std::vector<expression> subs;
  303.  
  304. for (expression const& sub: e.rands())
  305. {
  306. if (auto* i = get_as_cst(sub.get()))
  307. {
  308. res = step(res, *i);
  309. }
  310. else
  311. {
  312. subs.push_back(sub);
  313. }
  314. }
  315.  
  316. if (subs.empty()) return cst(res);
  317. if (res != neutral) subs.push_back(cst(res));
  318. if (subs.size() == 1) return subs.front();
  319. return expression(op<Tag, expression>(subs));
  320. }
  321.  
  322. template<typename Range>
  323. bool has_zero(Range const& subs)
  324. {
  325. return end(subs) != boost::find_if(subs, [](expression const& sub) {
  326. auto* i = get_as_cst(sub.get());
  327. return i && *i == 0;
  328. });
  329. }
  330.  
  331. expression opt_add_alg(expression_r<expression> const& e)
  332. {
  333. if (auto* op = get_as_add(e))
  334. return optimize_op(*op, 0, std::plus<int>());
  335. return e;
  336. }
  337.  
  338. expression opt_mul_alg(expression_r<expression> const& e)
  339. {
  340. if (auto* op = get_as_mul(e))
  341. {
  342. if (has_zero(op->rands()))
  343. return cst(0);
  344. return optimize_op(*op, 1, std::multiplies<int>());
  345. }
  346. return e;
  347. }
  348.  
  349. expression optimize_alg(expression_r<expression> const& e)
  350. {
  351. return opt_mul_alg(opt_add_alg(e).get());
  352. }
  353.  
  354.  
  355. //--------------------------------------------------------
  356. // PARTIAL EVAL
  357. //--------------------------------------------------------
  358.  
  359. auto partial_eval_alg(env const& env)
  360. {
  361. return [&env] (expression_r<expression> const& e) -> expression
  362. {
  363. if (auto* v = get_as_var(e))
  364. {
  365. auto it = env.find(*v);
  366. if (it != env.end()) return cst(it->second);
  367. return var(*v);
  368. }
  369. return e;
  370. };
  371. }
  372.  
  373. expression partial_eval(env const& env, expression const& e)
  374. {
  375. return cata<expression>(
  376. [&env](expression_r<expression> const& e) -> expression {
  377. return optimize_alg(partial_eval_alg(env)(e).get());
  378. },
  379. e);
  380. }
  381.  
  382.  
  383. //--------------------------------------------------------
  384. // EVALUATION (Different implementations)
  385. //--------------------------------------------------------
  386.  
  387. void throw_missing_variables(std::set<id> const& dependencies)
  388. {
  389. std::ostringstream s;
  390. for (auto const& d: dependencies)
  391. s << d << " ";
  392. throw std::logic_error(s.str());
  393. }
  394.  
  395. int eval_2(env const& env, expression const& e)
  396. {
  397. auto reduced = partial_eval(env, e);
  398. if (auto* i = get_as_cst(reduced.get())) return *i;
  399. throw_missing_variables(dependencies(reduced));
  400. }
  401.  
  402.  
  403. //--------------------------------------------------------
  404. // Tests
  405. //--------------------------------------------------------
  406.  
  407. int main()
  408. {
  409. expression e = add({
  410. cst(1),
  411. cst(2),
  412. mul({cst(0), var("x"), var("y")}),
  413. mul({cst(1), var("y"), add({cst(2), var("x")})}),
  414. add({cst(0), var("x")})
  415. });
  416.  
  417. env full_env = {{"x", 1}, {"y", 2}};
  418. std::cout << cata<std::string>(print_alg, e) << std::endl;
  419. std::cout << cata<std::string>(print_infix_bad, e) << std::endl;
  420. std::cout << para<std::string>(print_infix, e) << std::endl;
  421. std::cout << eval(full_env, e) << std::endl;
  422. std::cout << eval_2(full_env, e) << std::endl;
  423.  
  424. auto e2 = cata<expression>(partial_eval_alg(full_env), e);
  425. env empty_env;
  426. std::cout << cata<std::string>(print_alg, e2) << std::endl; //TODO - chain optimize and partial
  427. std::cout << eval(empty_env, e2) << std::endl;
  428. std::cout << eval_2(empty_env, e2) << std::endl;
  429.  
  430. auto e3 = cata<expression>(optimize_alg, e);
  431. std::cout << cata<std::string>(print_alg, e3) << std::endl;
  432. std::cout << eval(full_env, e3) << std::endl;
  433. std::cout << eval_2(full_env, e3) << std::endl;
  434.  
  435. try {
  436. eval_2(empty_env, e);
  437. } catch (std::logic_error const& e) {
  438. std::cout << e.what() << std::endl;
  439. }
  440. return 0;
  441. }
Success #stdin #stdout 0s 15288KB
stdin
Standard input is empty
stdout
(+ 1 2 (* 0 x y) (* 1 y (+ 2 x)) (+ 0 x))
1 + 2 + (0) * (x) * (y) + (1) * (y) * (2 + x) + 0 + x
1 + 2 + 0 * x * y + 1 * y * (2 + x) + 0 + x
10
10
(+ 1 2 (* 0 1 2) (* 1 2 (+ 2 1)) (+ 0 1))
10
10
(+ (* y (+ x 2)) x 3)
10
10
x y