import os 
import numpy as np
import matplotlib.pyplot as plt
import scipy.io as spio
from scipy.stats import norm 
from scipy.stats import multivariate_normal
import time
import sys

flt_min = sys.float_info.min

%matplotlib notebook

mixGaussTrue = dict()
mixGaussTrue['k'] = 3
mixGaussTrue['d'] = 2
mixGaussTrue['weight'] = np.array([0.1309, 0.3966, 0.4725])
mixGaussTrue['mean'] = np.array([[ 4.0491 , 4.8597],[ 7.7578 , 1.6335],[ 11.9945, 8.9206]]).T
mixGaussTrue['cov'] = np.reshape([0.5, 0.25], newshape=(1,1,2))
mixGaussTrue['cov'] = np.zeros(shape=(mixGaussTrue['d'],mixGaussTrue['d'],mixGaussTrue['k']))
mixGaussTrue['cov'][:,:,0] = np.array([[  4.2534, 0.4791], [0.4791, 0.3522]])
mixGaussTrue['cov'][:,:,1] = np.array([[  0.9729, 0.8723],[  0.8723,  2.6317]])
mixGaussTrue['cov'][:,:,2] = np.array([[  0.9886, -1.2244],[ -1.2244, 3.0187]])

def sampleFromDiscrete(probDist):     
    """
    Draws a random sample from a discrete probability distribution using a rejection sampling method.  
    
    Keyword arguments:
    probDist -- discrete probability ditrubtion to sample from.
    
    Returns: 
    r -- sampled point.
    """
    nIndex = len(probDist)                                                                 
    while True:                                                                            
        #choose random index                                                               
        r=int(np.floor(np.random.uniform()*nIndex))                                        
        #choose random height                                                              
        randHeight = np.random.uniform()                                                   
        #if height is less than probability value at this point in the                     
        # histogram then select                                                            
        if randHeight<probDist[r]:                                                         
            break                                                                          
    return r                                                                               def getGaussian2SD(m,s,angle1):
    """
    Find position of in xy co-ordinates at 2SD out for a certain angle.
    """
  
    if s.shape[1] == 1:
        s = diag(s)

    vec = np.array([np.cos(angle1), np.sin(angle1)])
    
    factor = 4/(vec@np.linalg.inv(s)@vec.T)
    
    x = np.cos(angle1) * np.sqrt(factor)
    y = np.sin(angle1) * np.sqrt(factor)

    x = x + m[0]
    y = y + m[1]
                
    return x,y
    
    def drawGaussianOutline(m,s,w):
    """
    Draw 2DGaussian
    """
    angleInc = 0.1    
    c = (0.9*(1-w), 0, 0)

    for cAngle in np.arange(0,2*np.pi,angleInc):         
        angle1 = cAngle
        angle2 = cAngle+angleInc
        x1, y1 = getGaussian2SD(m,s,angle1)
        x2, y2 = getGaussian2SD(m,s,angle2)
        plt.plot([x1, x2],[y1, y2],'-',linewidth=2,color=c)                 
    return
    
    def drawEMData2d(data,mixGauss):
    """
    Plot the data plot the mixtures of Gaussian model, mixGauss, on top of it.
    
    Keyword arguments:
    data -- d by n matrix of data points.
    mixGauss -- dict with mixture of gaussian information.
    
    """
    #set(gcf,'Color',[1 1 1])
    plt.cla() 
    plt.plot(data[0,:],data[1,:],'k.')

    for cGauss in range(mixGauss['k']):
        drawGaussianOutline(mixGauss['mean'][:,cGauss],mixGauss['cov'][:,:,cGauss],mixGauss['weight'][cGauss])
    
    return     

def mixGaussGen(mixGauss, nData):                                                                              
    """
    Generates data from a d-dimensional mixture of Gaussians model. 
    
    Keyword arguments:
    mixGauss -- dict containing the mixture of gaussians arguments.
    nData -- number of data points to generate.
    
    Returns: 
    data -- d by nData, generated data points. 
    
    """
    # create space for output data                                                                               
    data = np.zeros(shape=(mixGauss['d'], nData))                                                                            
    # for each data point                                                                                        
    for cData in range(nData):                                                                                   
        # randomly choose Gaussian according to probability distributions                                        
        h = sampleFromDiscrete(mixGauss['weight'])                                                               
        # draw a sample from the appropriate Gaussian distribution  
        # first sample from the covariance matrix (google how to do this - it
        # will involve the numpy function np.linalg.cholesky().  Then add the mean vector
        # TO DO (f)- insert data generation code here                                                                               
        placeholder = -1 # remove this palceholder when you do. 
        #rr=
        #data[:,cData] = np.transpose(mixGauss['mean'][0][h])+np.random.normal(1,mixGauss['d'])* np.transpose(np.linalg.cholesky(mixGauss['cov'][:,:,h]))
        data[:,cData]=np.random.multivariate_normal(mixGauss['mean'][:,h], mixGauss['cov'][:,:,h])                                                                                                 
    return data
    
    
    #define number of samples to generate
nData = 400;

#generate data from the mixture of Gaussians
#make sure you've filled in the routine above.
data = mixGaussGen(mixGaussTrue,nData)

#draw data, true Gaussians
drawEMData2d(data,mixGaussTrue)

def getMixGaussLogLike(data, mixGaussEst): 
    """
    Calculate the log likelihood for the whole dataset under a mixture of Gaussians model.
    
    Keyword arguments:
    data -- d by n matrix containing data points.
    mixGaussEst -- dict containing the mixture of gaussians parameters.

    Returns: 
    logLike -- scalar containing the log likelihood.
    
    """
    
    data = np.atleast_2d(data)                                                                         
    # find total number of data items                                                                  
    nDims, nData = data.shape                                                                          
    
    # initialize log likelihoods                                                                       
    logLike = 0;                                                                                       
                                                                                                       
    # run through each data item                                                                       
    for cData in range(nData):                                                                         
        thisData = data[:, cData]                                                                      
        # TO DO - calculate likelihood of this data point under mixture of                         
        # Gaussians model. Replace this                                                                
        like = 0
        
        for m in range(mixGaussEst['k']):
            #like = like + mixGaussEst['weight'][m]+1/((2*np.pi)^(nDims)*numpy.linalg.det(mixGaussEst['cov'][:,:,h]))^(1/2)* exp(-0.5*(np.transpose(thisData-mixGaussEst['mean'][:,h])))*np.linalg.inv(mixGaussEst['cov'][:,:,m])*(thisData-mixGaussEst['mean'][:,h])
            #mu = mixGaussEst['mean'][:,m]
            #var = mixGaussEst['cov'][:,:,m]
            #norm = 1/((2*np.pi)**(nDims)*np.linalg.det(mixGaussEst['cov'][:,:,m]))**(1/2)*np.exp(-0.5*(np.transpose(thisData-mixGaussEst['mean'][:,m])))*np.linalg.inv(mixGaussEst['cov'][:,:,m])*(thisData-mixGaussEst['mean'][:,m])
            #lambdaK = mixGaussEst['weight'][m]
            #like = like + lambdaK*norm
            #print(1/((2*np.pi)**(nDims)*np.linalg.det(mixGaussEst['cov'][:,:,m]))**(1/2))
            #print(np.exp(-0.5*(np.transpose(thisData-mixGaussEst['mean'][:,m])))*np.linalg.inv(mixGaussEst['cov'][:,:,m])*(thisData-mixGaussEst['mean'][:,m]))
            like = like+(1/((2*np.pi)**(nDims)*np.linalg.det(mixGaussEst['cov'][:,:,m]))**(1/2))*np.exp(-0.5*(np.transpose(thisData-mixGaussEst['mean'][:,m])))@np.linalg.inv(mixGaussEst['cov'][:,:,m])@(thisData-mixGaussEst['mean'][:,m])
            #print(np.exp(-0.5*(np.transpose(thisData-mixGaussEst['mean'][:,m]))))
            #print(like)
        
        #print(like)
        # add to total log like                                                                        
        logLike = logLike + np.log(like)                                                             
        #print(logLike)
        #loglike= np.vectorize(loglike)                                                                                              
    #return  np.asscalar((logLike))                                                                    
    return np.asscalar(np.array(logLike))  
                                                                                                       
                                                                                                       
                                                                                                       
                                                                                                       
                                                                                                       
                                                                                                       def fitMixGauss(data, k):
    """
    Estimate a k MoG model that would fit the data. Incremently plots the outcome.
               
    
    Keyword arguments:
    data -- d by n matrix containing data points.
    k -- scalar representing the number of gaussians to use in the MoG model.
    
    Returns: 
    mixGaussEst -- dict containing the estimated MoG parameters.
    
    """
    
    #     MAIN E-M ROUTINE  
    #     In the E-M algorithm, we calculate a complete posterior distribution over                                  
    #     the (nData) hidden variables in the E-Step.  
    #     In the M-Step, we update the parameters of the Gaussians (mean, cov, w).   
    
    nDims, nData = data.shape


    postHidden = np.zeros(shape=(k, nData))

    # we will initialize the values to random values
    mixGaussEst = dict()
    mixGaussEst['d'] = nDims
    mixGaussEst['k'] = k
    mixGaussEst['weight'] = (1 / k) * np.ones(shape=(k))
    mixGaussEst['mean'] = 2 * np.random.randn(nDims, k)
    mixGaussEst['cov'] = np.zeros(shape=(nDims, nDims, k))
    for cGauss in range(k):
        mixGaussEst['cov'][:, :, cGauss] = 2.5 + 1.5 * np.random.uniform() * np.eye(nDims)
        

    # calculate current likelihood
    # TO DO - fill in this routine
    logLike = getMixGaussLogLike(data, mixGaussEst)
    print('Log Likelihood Iter 0 : {:4.3f}\n'.format(logLike))

    nIter = 30;

    logLikeVec = np.zeros(shape=(2 * nIter))
    boundVec = np.zeros(shape=(2 * nIter))

    fig, ax = plt.subplots(1, 1)

    for cIter in range(nIter):

        # ===================== =====================
        # Expectation step
        # ===================== =====================
        curCov = mixGaussEst['cov']                                                                                  
        curWeight = mixGaussEst['weight']                                                                            
        curMean = mixGaussEst['mean']
        num= np.zeros(shape=(k,nData))
        for cData in range(nData):
            # TO DO (g) : fill in column of 'hidden' - calculate posterior probability that
            # this data point came from each of the Gaussians
            # replace this:
            
            
            thisData = data[:,cData]
            #for c in range(k):
            #    num[c] = mixGaussEst['weight'][c] * (1/((2*np.pi)**(nDims)*np.linalg.det(mixGaussEst['cov'][:,:,c]))**(1/2))*np.exp(-0.5*(np.transpose(thisData-mixGaussEst['mean'][:,c])))@np.linalg.inv(mixGaussEst['cov'][:,:,c])@(thisData-mixGaussEst['mean'][:,c])
            
            
            thisdata = data[:,cData];
            denominatorExp = 0
            for j in range(k):
                mu = curMean[:,j]
                sigma = curCov[:,:,j]
                curNorm = (1/((2*np.pi)**(nDims)*np.linalg.det(sigma))**(1/2))*np.exp(-0.5*(np.transpose(thisData-mu)))@np.linalg.inv(sigma)@(mu)
                num[j,cData] = curWeight[j]*curNorm
                denominatorExp = denominatorExp + num[j,cData]

            postHidden[:, cData] = num[:,cData]/denominatorExp


        # ===================== =====================
        # Maximization Step
        # ===================== =====================
        # for each constituent Gaussian
        for cGauss in range(k):
            # TO DO (h):  Update weighting parameters mixGauss.weight based on the total
            # posterior probability associated with each Gaussian. Replace this:
            #mixGaussEst['weight'][cGauss] = mixGaussEst['weight'][cGauss]
            sum_Kth_Gauss_Resp = np.sum(postHidden[cGauss,:])

            mixGaussEst['weight'][cGauss] = sum_Kth_Gauss_Resp /np.sum(postHidden)
            
        
            
            
            #mixGaussEst['weight'][cGauss] = np.sum(postHidden[cGauss,:])/sum(sum(postHidden[:,:]));
            
            
            
            
            
            
            # TO DO (i):  Update mean parameters mixGauss.mean by weighted average
            # where weights are given by posterior probability associated with
            # Gaussian.  Replace this:
            #mixGaussEst['mean'][:,cGauss] = mixGaussEst['mean'][:,cGauss]
            numerator = 0
            for j in range(nData):
                numerator = numerator + postHidden[cGauss,j]*data[:,j]
            numerator = np.dot( postHidden[cGauss,:],data[0,:])
            mixGaussEst['mean'][:,cGauss] = numerator / sum_Kth_Gauss_Resp
            
            
            # TO DO (j):  Update covarance parameter based on weighted average of
            # square distance from update mean, where weights are given by
            # posterior probability associated with Gaussian
            #mixGaussEst['cov'][:,:,cGauss] = mixGaussEst['cov'][:,:,cGauss]
            muMatrix = mixGaussEst['mean'][:,cGauss]
            numerator = 0
            for j in range(nData):
                numerator_i = postHidden[cGauss,j]*(data[:,j]-muMatrix)@np.transpose(data[:,j]-muMatrix)
                numerator = numerator + numerator_i
            
            mixGaussEst['cov'][:,:,cGauss] = numerator /sum_Kth_Gauss_Resp
            
            # draw the new solution
        
        drawEMData2d(data, mixGaussEst)
        time.sleep(0.7)
        fig.canvas.draw()

        # calculate the log likelihood
        logLike = getMixGaussLogLike(data, mixGaussEst)
        print('Log Likelihood After Iter {} : {:4.3f}\n'.format(cIter, logLike))


    return mixGaussEst
    
    
    
    #define number of components to estimate
nGaussEst = 3

#fit mixture of Gaussians (Pretend someone handed you some data. Now what?)
#TO DO fill in this routine (below)


mixGaussEst = fitMixGauss(data,nGaussEst);