import math
import matplotlib.pyplot
import matplotlib.pyplot as plt
import matplotlib as mpl
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.collections import PolyCollection
from matplotlib.colors import colorConverter
import numpy as np
import matplotlib.pyplot as plt
from numpy import *
from pylab import *
import pylab


class vec(object):
    def __init__(self, x, y, z):
        self.vec=(x, y, z)        
        
    def unit(self, vec):
        dist=math.sqrt(vec[0]**2+vec[1]**2+vec[2]**2)
        u=(vec[0]/dist,vec[1]/dist,vec[2]/dist)
        return u
    
    def add(self, vector1, vector2):
        sum=(vector1[0]+vector2[0], vector1[1]+vector2[1], vector1[2]+vector2[2])
        return sum
    def addfour(self, vec1, vec2, vec3, vec4):
        comebine1=vector.add(vec1,vec2)
        comebine2=vector.add(vec3,vec4)
        fouradded=vector.add(comebine1, comebine2)
        return fouradded 
    def smultiply(self,scalar, vec):
        product=(scalar*vec[0],scalar*vec[1],scalar*vec[2])
        return product
    def innerproduct(self, vec1, vec2):
        result=(vec1[0]*vec2[0]+vec1[1]*vec2[1]+vec1[2]*vec2[2])
        return result
    def crossproduct(self, vec1, vec2):
        res=(vec1[1]*vec2[2]-vec1[2]*vec2[1], vec1[2]*vec2[0]-vec1[0]*vec2[2], vec1[0]*vec2[1]-vec1[1]*vec2[0])
        return res
    def function(self, vec1, vec2):
        coef1=c1*inner(vec2, vec2)
        term1=self.smultiply(coef1, vec2)
        term2=self.smultiply(c2,self.crossproduct(vec2,vec1))
        term3=self.smultiply(c3,(0,0,1))
        term4=self.smultiply(c4*vec1[2],vec1)
        term5=self.smultiply(c5*vec1[2],(0,0,1))
        firstfouradded=self.addfour(term1,term2,term3,term4)
        allfiveadded=self.add(firstfouradded, term5)
        out=self.smultiply(h, allfiveadded)
        return  out
    def rungekutta1(self,vec1,vec):
        k11=self.smultiply(h,vec)
        k11_half=self.smultiply(.5, k11)
        y12=self.add(vec, k11_half)
        k12=self.smultiply(h,y12)
        k12_half=self.smultiply(.5, k12)
        y13=self.add(vec, k12_half)
        k13=self.smultiply(h,y13)
        y14=self.add(vec, k13)
        k14=self.smultiply(h,y14)
        k12_double=self.smultiply(2.0, k12)
        k13_double=self.smultiply(2.0, k13)
        k1s=self.addfour(k11, k12_double, k13_double, k14)
        n_next=self.add(vec1, self.smultiply(1.0/6, k1s))
        return n_next
    def rungekutta2(self, vec1, vec2):
        k21=self.function(vec1, vec2)
        n22=self.add(vec1,self.smultiply(.5*h,(1,1,1)))
        d_n22=self.add(vec2,self.smultiply(.5,k21))
        k22=self.function(n22,d_n22)
        n23=self.add(vec1,self.smultiply(.5*h,(1,1,1)))
        d_n23=self.add(vec2,self.smultiply(.5,k22))
        k23=self.function(n23,d_n23)
        n24=self.add(vec1,self.smultiply(h,(1,1,1)))
        d_n24=self.add(vec2,k23)
        k24=self.function(n24,d_n24)
        k22_double=self.smultiply(2.0,k22)
        k23_double=self.smultiply(2.0,k23)
        k2s=self.addfour(k21, k22_double, k23_double, k24)
        d_n_next=self.add(vec2, self.smultiply(1.0/6, k2s))
        return d_n_next
    
def plotlines(x ,y ,z):
        line = ax.plot([x,0],[y,0],[z,0],color='#0000A0',marker='.')
        ax.set_xlabel('X-axis')
        ax.set_ylabel('Y-axis')
        ax.set_zlabel('Z-axis')
        #del ax.lines[0]
        matplotlib.pyplot.show()
        



def plotdots(x,y,z):
    l = ax.scatter(x, y, z, c='#387C44',marker='.')
    del ax.lines[0]
    ax.set_xlabel('X-axis')
    ax.set_ylabel('Y-axis')
    ax.set_zlabel('Z-axis')
    draw()
    #matplotlib.pyplot.show()
        

def plotvel(dx, dy, dz):
        print dx,dy,dz
        ax.scatter(dx, dy, dz, c='m')
        ax.set_xlabel('X-axis')
        ax.set_ylabel('Y-axis')
        ax.set_zlabel('Z-axis')
        matplotlib.pyplot.show()

def xyproj(xs,ys):
    print xs,ys
    #plot(xs,ys,'o')
    plotpos(xs ,ys ,0)
    show()

def drawstaticplot(m,n, d_n):    
    for i in range(0,m):
        n=vector.rungekutta1(n, d_n)
        d_n=vector.rungekutta2(n, d_n)
        x1 = n[0]    
        y1 = n[1]
        z1 = n[2]
        #print x1,y1,z1
        xarray.append(x1)
        yarray.append(y1)
        zarray.append(z1)
    for j in range(0,m-20):
        if j%20 == 0:
            ax.plot([xarray[j],xarray[j+20]],[yarray[j],yarray[j+20]],[zarray[j],zarray[j+20]],color='#817339',marker='.')
            matplotlib.pyplot.show()

        
def drawdynamicplot(m,n,d_n):
    for i in range(0,m):
            n=vector.rungekutta1(n, d_n)
            d_n=vector.rungekutta2(n, d_n)
            x1 = n[0]
            y1 = n[1]
            z1 = n[2]
            print x1,y1,z1
            if i%20 == 0:
                print "-----------------------------"
                plotlines(x1,y1,z1)
                #plotdots(x1,y1,z1)
                
def realtimeplotline(m,n,d_n):
    for i in range(0,m):
            n=vector.rungekutta1(n, d_n)
            d_n=vector.rungekutta2(n, d_n)
            x1 = n[0]
            y1 = n[1]
            z1 = n[2]
            print x1,y1,z1
            if i%20 == 0: 
                print "--------------------mark------------------"
                index = ((i/20))%2
                #print index
                if index ==0:
                    x2 = n[0]
                    y2 = n[1]
                    z2 = n[2]
                if (index) == 1:
                    x3 = n[0]
                    y3 = n[1]
                    z3 = n[2]
                    ax.plot([x2,x3],[y2,y3],[z2,z3],color='#817339',marker='.')
                    matplotlib.pyplot.show()
 


fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

theta = np.linspace(-4 * np.pi, 4 * np.pi, 100)

c1=-1.0
c2=-1.0
c3=1.0
c4=-1.0
c5=0
h=.004
vector=vec(1,2,3)
allone=vec(1.0,1.0,1.0)
n=(1.0,1.0,1.0)
d_n=(1.0,-1.0,0.0)
ion()
x = n[0]    
y = n[1]
z = n[2]
n=(1.0,1.0,1.0)
d_n=(1.0,-1.0,0.0)


m=10000
xarray=[]
yarray=[]
zarray=[]

#drawstaticplot(m,n, d_n)
drawdynamicplot(m,n, d_n)
#realtimeplotline(m,n, d_n)
