import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt 

x = np.linspace(-4.,4.,100)
y = x + np.random.normal(0,0.1,100)

X = tf.placeholder(tf.float32, shape=[None,1])
Y = tf.placeholder(tf.float32, shape=[None])

def linear(X, n_input, n_output,activation=None,scope=None):
    with tf.variable_scope(scope or 'Linear'):
        W = tf.get_variable(
                name='W',
                shape=[n_input,n_output],
                initializer=tf.random_normal_initializer(mean=0,stddev=0.1))
        b = tf.get_variable(
                name='b',
                shape=[n_output],
                initializer=tf.random_normal_initializer(mean=0,stddev=0.1))
        h = tf.matmul(X,W) + b
        if activation is not None:
            h = activation(h)
        return h

n_neurons = [1,20,1]

current_input = X

for layer in range(1,len(n_neurons)):
    current_input = linear(
            X=current_input,
            n_input = n_neurons[layer-1],
            n_output = n_neurons[layer],
            activation=tf.sigmoid if (layer+1) < len(n_neurons) else None,
            scope='layer_' + str(layer))

Y_pred = current_input

cost = tf.reduce_mean(tf.square(Y_pred - Y))

optimizer = tf.train.GradientDescentOptimizer(0.5).minimize(cost)

n_iterations = 1000
batch_size = 100

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    for iteration in range(n_iterations):
        idxs = np.random.permutation(range(len(x)))
        n_batches = len(idxs) // batch_size
        for batch in range(n_batches):
            idxs_i = idxs[batch*batch_size:(batch+1)*batch_size]
            sess.run(optimizer, feed_dict={X:np.reshape(x[idxs_i], [len(idxs_i),1]),Y:y[idxs_i]}) 
        training_cost = sess.run(cost,feed_dict={X:np.reshape(x, [len(x),1]),Y:y})

        training_xs = x
        training_ys = sess.run(Y_pred,feed_dict={X:np.reshape(training_xs, [len(training_xs),1])})
        if iteration % 20 == 0 : 
            plt.plot(x,y)
            plt.plot(training_xs,training_ys)
            plt.savefig('./graphs/graph'+str(iteration))
            plt.close()
            print(training_cost)







