%%
% This script is for adding two points on an elliptic curve over a finite
% field of size of a prime p that is given in the short weierstrass form.
% note that p should not be excessively large, as many computations here
% rely on bruteforcing.


clc;clear;
% Field characteristic p ~= 2,3
p = 17;
% Parameters of the elliptic curve y^2 = x^3 + Ax + B
A = 2;
B = 2; 

% Points you want to add

P = [5,1]; % two coordinates [x,y] or just Inf for the neutral element
Q = [7,6]; % write P=Inf; Q=Inf; to find all points on the curve

% DO NOT CHANGE ANYTHING BELOW THIS LINE
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%


F=@(x,y) mod(x.^3 + A*x + B - y.^2,p);


%% perform checks
assert(mod(4*A^3 + 27*B^2, p) ~=0,'A,B do not define an elliptic curve');
assert(p~=2 && p~=3 && isprime(p),'p=2,3 or p is not prime');
assert(P(1) == Inf || F(P(1),P(2)) == 0 , 'P is not on the curve');
assert(Q(1) == Inf || F(Q(1),Q(2)) == 0 , 'Q is not on the curve');

%% find all points
all_points = [];
for x=0:p-1;
    for y=0:p-1;
        if F(x,y)==0
            all_points = [all_points; x,y];
        end
    end
end
disp('all points on the curve are the following:');
disp('     inf');
disp(all_points)


%% tree

if P == Inf     % P is neutral element
    disp('Result is:')
    disp(Q)
elseif Q == Inf % Q is neutral element
    disp('Result is:')
    disp(P)
else
    disp('apparently none of the summands is the neutral element...');
    if P(1) == Q(1) && mod(P(2)+Q(2),p) == 0 % P+Q == 0
        disp('Result is:')
        disp(Inf)
    else  % P ~= -Q -> we have a nonvertical line
        disp('P and Q are apparently not additive inverses to each other...');
        x0 = P(1);
        y0 = P(2);
        x1 = Q(1);
        y1 = Q(2);
        if x0 == x1 && y0 == y1; % P == Q -> calcualte tangent
            disp('we have to double P == Q ...');
            % m = (3* x0^2 + A)/ (2 * y0);
            % find modular inverse of (2*y0)
            k = find( mod( (1:p) * 2*y0, p) == 1);
            m = mod((3*x0^2 + A) * k, p);
            q = mod( (-x0^3+A*x0 + 2*B) * k, p);
            

        else  % calculate line (nonvertical) throu those two points
            disp('we can calculate a nonvertical line throu P and Q...');
            % m = (y1-y0) / (x1-x0)
            % find modular inverse of (x1-x0)
            k = find( mod( (1:p) * (x1-x0), p) == 1);
            m = mod( (y1-y0) * k, p);
            % q = y0 - m*x0
            q = mod( y0 - m*x0, p);
        end
        % calculate the zeros of this polynomial:
            % P = x^3+Ax+B - (m*x+q)^2 
            %   = x^3+Ax+B - (m^2 * x^2 + 2*m*q*x+q^2)
            %   = x^3 + (-m^2) * x^2 + (A-2*m*q) * x + (B-q^2)
        pl = mod([1, -m^2, A-2*m*q, B-q^2], p);
        pl_roots = mod(find(mod(polyval(pl,1:p),p)==0,3),p); 
        
        if numel(pl_roots) == 3
            disp('the polynomial has 3 distinct roots...');
            X = pl_roots( pl_roots ~= x0 & pl_roots ~= x1);
            Y = mod(m*X+q, p);
            disp('Result is:');
            disp(mod([X,-Y],p));
        elseif numel(pl_roots) == 2;
            disp('the polynomial has 2 distinct roots...');
            if all( mod(conv(conv([1,-x0],[1,-x1]),[1,-pl_roots(1)]),p) == pl)
                disp('Result is:')
                X = pl_roots(1);
                Y = mod(m*X+q, p);
                disp(mod([X,-Y],p));
            elseif  all( mod(conv(conv([1,-x0],[1,-x1]),[1,-pl_roots(2)]),p) == pl)
                disp('Result is:');
                X = pl_roots(2);
                Y = mod(m*X+q, p);
                disp(mod([X,-Y],p));
            else
               disp('ERROR: something went horribly wrong: could not reconstruct polynomial'); 
            end
        elseif numel(pl_roots) == 1;
            disp('the polynomial has 1 distinct root...');
            if all( mod(conv(conv([1,-x0],[1,-x1]),[1,-pl_roots]),p) == pl)
                disp('Result is:')
                X = pl_roots;
                Y = mod(m*X+q, p);
                disp(mod([X,-Y],p))  
            else
               disp('ERROR: somethign went horribly wrong: could not reconstruct polynomial'); 
            end
            
        else
            disp('ERROR: something went horribly wrong (not 1 or 2 or 3 roots)');
        end
   
    end
end
    

