from sympy import *
import numpy as np
import string


# this code implement Gauss_elimination to solve
# set of equation

# edit in line 119

def isZero(_vect):
    n = _vect.shape[0]

    result = True
    for i in range(n):
        if _vect[i] != 0:
            result = False
            break

    return result


def swap(a, b):
    temp = a[0:].copy()
    a[0:] = b[0:]
    b[0:] = temp


def rank(A):
    rows = A.shape[0]

    result = 0
    for i in range(rows):
        if (isZero(A[i]) == False):
            result += 1
    return result


def getMatrix(A):
    return (A.T)[0:-1].copy().T


def isAllBelowZero(A, index):
    rows = A.shape[0]

    result = True
    for i in range(index + 1, rows):
        if A[i][index] != 0:
            result = False
            break

    return result


def Gauss_elimination(A):
    rows, cols = A.shape
    loop_epoch = cols if cols < rows else rows
    i = 0

    while i < loop_epoch:
        if A[i][i] != 0:
            for j in range(i + 1, rows):
                scaling_value = A[j][i] / A[i][i]
                A[j] = A[j] - scaling_value * A[i]
        else:
            for index in range(i + 1, rows):
                if A[index][i] != 0:
                    swap(A[i], A[index])
                    i -= 1
                    break
        i += 1

    for i in range(rows - 1):
        if (isZero(A[i])):
            for j in range(i + 1, rows):
                if (isZero(A[j]) == False):
                    swap(A[i], A[j])
                    break


def back_substitution(A):
    Gauss_elimination(A)
    rows,  cols = A.shape
    cols -= 1
    rA = rank(getMatrix(A))
    rAB = rank(A)

    x = np.zeros(cols)

    if (rA + 1 == rAB):
        print(f"Setequation have no solution")
    elif (rA == rAB and rA == cols):
        print(f"Setequation have one solution: ")
        # back substitution
        i = cols - 1
        while i >= 0:
            x[i] = (A[i][cols] - x.dot(getMatrix(A)[i])) / A[i][i]
            i -= 1
    else:
        alphabet = symbols(list(string.ascii_lowercase))
        print(f"Setequation have infinity solution with {cols - rA} ")
        n_iter = 0
        pivot_row = 0
        y = []
        for i in range(cols):
            if A[i + pivot_row][i] == 0:
                pivot_row -= 1
                y.append(alphabet[n_iter])
                n_iter += 1
            else:
                zero = symbols('0')
                y.append(zero)

        i = cols - 1
        while i >= 0:
            s = (getMatrix(A)[i]).dot(y)
            if A[i][i] != 0:
            	# simplify right this expression
                y[i] = (((A[i][cols] - s) / A[i][i]))

            i -= 1
        x = y

    return x


def main():
    # matrix = np.array([[1, 2, 2, 9], [3, 4, 4, 1], [5, 6, 7, 0]])
    matrix = np.array([[1, 2, 3, 0],
                       [3, 2, 3, 2],
                       [2, 2, 3, 1]])
    print(matrix)
    Gauss_elimination(matrix)
    print(matrix)
    print((back_substitution(matrix)))


main()
