:- set_prolog_flag(verbose,silent).
:- set_prolog_flag(occurs_check,true).
:- op(500,yfx,$).
:- prompt(_, '').
:- use_module(library(readutil)).

%%%% IDEONE compatibility for mutually recursive predicates %%%%
eqty/2.
unify_oemap/2.
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

kind(KC,var(X),   K1) :- first(X:K,KC).
kind(KC,F $ G,    K2) :- K2\==row, kind(KC,F,K1->K2),
                         K1\==row, kind(KC,G,K1).
kind(KC,A -> B,    o) :- kind(KC,A,o), kind(KC,B,o).
kind(KC,{R},       o) :- kind(KC,R,row).
kind(KC,[],      row).
kind(KC,[X:T|R], row) :- kind(KC,T,o), kind(KC,R,row).

type(KC,C,var(X),     A) --> { first(X:S,C) }, inst_ty(KC,S,A).
type(KC,C,lam(X,E),A->B) --> type(KC,[X:mono(A)|C],E,B),
                             [ kind(KC,A->B,o) ].
type(KC,C,X $ Y,      B) --> type(KC,C,X,A->B), type(KC,C,Y,A1),
                             !, { eqty(A,A1) }. % note the cut !
type(KC,C,let(X=E0,E),T) --> type(KC,C,E0,A),
                             type(KC,[X:poly(C,A)|C],E,T).
type(KC,C,{XEs},    {R}) --> { zip_with('=',Xs,Es,XEs) },
                             type_many(KC,C,Es,Ts),
                             { zip_with(':',Xs,Ts,R) }.
type(KC,C,sel(L,X),   T) --> { first(X:T,R) }, type(KC,C,L,{R}).

first(K:V,[K1:V1|Xs]) :- K = K1, V=V1.
first(K:V,[K1:V1|Xs]) :- K\==K1, first(K:V, Xs).

inst_ty(KC,poly(C,T),T2) --> { copy_term(t(C,T),t(C,T1)), 
                               free_variables(T,Xs),
                               free_variables(T1,Xs1) },
                             samekinds(KC,Xs,Xs1), { T1=T2 }.
inst_ty(KC,mono(T),   T) --> [].

samekinds(KC,[],    []    ) --> [].
samekinds(KC,[X|Xs],[Y|Ys]) --> { X\==Y },
                                [ kind(KC,X,K), kind(KC,Y,K) ],
                                samekinds(KC,Xs,Ys).
samekinds(KC,[X|Xs],[X|Ys]) --> [], samekinds(KC,Xs,Ys).

zip_with(F,[],    [],    []      ).
zip_with(F,[X|Xs],[Y|Ys],[FXY|Ps]) :- FXY=..[F,X,Y],
                                      zip_with(F,Xs,Ys,Ps).

type_many(KC,C,[],    []    ) --> [].
type_many(KC,C,[E|Es],[T|Ts]) --> type(KC,C,E,T),
                                  type_many(KC,C,Es,Ts).

variablize(var(X)) :- gensym(t,X).

infer_type(KC,C,E,T) :-
  phrase( type(KC,C,E,T), Gs0 ),
  copy_term(Gs0,Gs), 
  (bagof(Ty,X^Y^member(kind(X,Ty,Y),Gs),Tys); Tys=[]),
  free_variables(Tys,Xs),
  maplist(variablize,Xs), % replace free tyvar to var(t)
  maplist(call,Gs). % run all goals in Gs

ctx0([ 'Nat':mono(o)
     , 'List':mono(o->o)
     , 'Pair':mono(o->o->o)
     | _
     ],
     [ 'Zero':mono(Nat)
     , 'Succ':mono(Nat -> Nat)
     , 'Nil' :poly([], List$A)
     , 'Cons':poly([], A->((List$A)->(List$A)))
     , 'Pair':poly([], A0->B0->Pair$A0$B0)
     ])
  :- Nat = var('Nat'), List = var('List'), Pair=var('Pair').

run(N,T) :- ctx0(KC,C),
  Zero = var('Zero'), Succ = var('Succ'),
  Cons = var('Cons'), Nil = var('Nil'),
  Pair = var('Pair'),
  E0 = let(id=lam(x,var(x)),var(id)$var(id)),     % A->A
  E1 = lam(y,let(x=lam(z,var(y)),var(x)$var(x))), % A->A
  E2 = {[z=lam(x,var(x))]},   % {[z:A->A]}
  E3 = lam(r,sel(var(r),x)),  % {[x:A |R]} -> A
  %E4: {[y:A,x:B |R]} -> Pair$A$B
  E4 = lam(r,Pair$sel(var(r),y)$sel(var(r),x)),
  %E5: {[x:(A->A),y:(B->B)]}
  E5 = {[x=lam(x,var(x)),y=lam(x,var(x))]},
  E6 = E4 $ E5, % Pair $ B->B $ A->A
  %E7: {[y:A |R]} -> Pair $ A $ {[y:A| R]}
  E7 = lam(r,Pair$sel(var(r),y)$var(r)),
  E8 = E7 $ {[y=lam(x,var(x))]}, % Pair $ B->B $ {[y:(B->B)| R]}
  E9 = E7 $ {[]},
  E10 = {[]},
  Es = [E0,E1,E2,E3,E4,E5,E6,E7,E8,E9,E10],
  nth0(N,Es,E), infer_type(KC,C,E,T).


% related paper:
%
% Membership-Constraints and Complexity in Logic Programming with Sets,
% Frieder Stolzenburg (1996).
% http://l...content-available-to-author-only...r.com/chapter/10.1007%2F978-94-009-0349-4_15
% http://c...content-available-to-author-only...u.edu/viewdoc/summary?doi=10.1.1.54.8356

% more advanced notion of type equality at work
eqty(A1,A2) :- (var(A1); var(A2)), !, A1=A2.
eqty({R1},{R2}) :- !, unify_oemap(R1,R2). % permutation(R1,R2), !.
eqty(A1->B1,A2->B2) :- !, eqty(A2,A1), !, eqty(B1,B2). % in case of subtyping
eqty(A,A).

% set/map membership with extra-logical builtin \==
% to cut down duplicate answers as sets
memb(X,[X|_]).
memb(X,[Y|L]) :- X \== Y, memb(X,L).

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% once inspired by the idea of the paper,
% finite map unification is just like this
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

% unify finite maps
unify_map(A,B) :- submap_of(A,B), submap_of(B,A).

submap_of([], _).
submap_of([X:V|R],M) :- first(X:V1,M), eqty(V,V1), submap_of(R,M).


% finite map minus
mapminus(A,[],A).
mapminus([],_,[]).
mapminus([X:V|Ps],B,C) :- first(X:V1,B), !, eqty(V,V1) -> mapminus(Ps,B,C).
mapminus([X:V|Ps],B,[X:V|C]) :- mapminus(Ps,B,C).



% unify open ended maps with possibly uninstantiated variable tail at the end
unify_oemap(A,B) :- ( var(A); var(B) ), !, A=B.
unify_oemap(A,B) :-
        split_heads(A,Xs-T1), make_map(Xs,M1),
        split_heads(B,Ys-T2), make_map(Ys,M2),
        unify_oe_map(M1-T1, M2-T2).

make_map(L,M) :- setof(X:V,first(X:V,L),M). % remove duplicates
make_map([],[]).

split_heads([],[]-[]).
split_heads([X:V|T],[X:V]-T) :- var(T), !, true.
split_heads([X:V|Ps],[X:V|Hs]-T) :- split_heads(Ps,Hs-T).

% helper function for unify_oemap
unify_oe_map(Xs-T1,Ys-T2) :- T1==[], T2==[], unify_map(Xs,Ys).
unify_oe_map(Xs-T1,Ys-T2) :- T1==[], submap_of(Ys,Xs), mapminus(Xs,Ys,T2).
unify_oe_map(Xs-T1,Ys-T2) :- T2==[], submap_of(Xs,Ys), mapminus(Ys,Xs,T1).
unify_oe_map(Xs-T1,Ys-T2) :- 
        mapminus(Ys,Xs,L1), append(L1,T,T1),
        mapminus(Xs,Ys,L2), append(L2,T,T2).

%% ?- unify_oemap([z:string,y:bool|M1],[y:T,x:int|M2]).
%% M1 = [x:int|_G1426],
%% T = bool,
%% M2 = [z:string|_G1426].

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

main:-
	process,
	halt.

process:-
    run(0,T0), print(T0), nl,
    run(1,T1), print(T1), nl,
    run(2,T2), print(T2), nl,
    run(3,T3), print(T3), nl,
    run(4,T4), print(T4), nl,
    run(5,T5), print(T5), nl,
    run(6,T6), print(T6), nl,
    run(7,T7), print(T7), nl,
    run(8,T8), print(T8), nl,
    (run(9,T9) -> write("must fail but "), print(T9); print(fail) ), nl,
    run(10,T10), print(T10), nl,
	true.

:- main.


