import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(-4.,4.,100)
y = x + np.random.normal(0,0.1,100)
X = tf.placeholder(tf.float32, shape=[None,1])
Y = tf.placeholder(tf.float32, shape=[None])
def linear(X, n_input, n_output,activation=None,scope=None):
with tf.variable_scope(scope or 'Linear'):
W = tf.get_variable(
name='W',
shape=[n_input,n_output],
initializer=tf.random_normal_initializer(mean=0,stddev=0.1))
b = tf.get_variable(
name='b',
shape=[n_output],
initializer=tf.random_normal_initializer(mean=0,stddev=0.1))
h = tf.matmul(X,W) + b
if activation is not None:
h = activation(h)
return h
n_neurons = [1,20,1]
current_input = X
for layer in range(1,len(n_neurons)):
current_input = linear(
X=current_input,
n_input = n_neurons[layer-1],
n_output = n_neurons[layer],
activation=tf.sigmoid if (layer+1) < len(n_neurons) else None,
scope='layer_' + str(layer))
Y_pred = current_input
cost = tf.reduce_mean(tf.square(Y_pred - Y))
optimizer = tf.train.GradientDescentOptimizer(0.5).minimize(cost)
n_iterations = 1000
batch_size = 100
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
for iteration in range(n_iterations):
idxs = np.random.permutation(range(len(x)))
n_batches = len(idxs) // batch_size
for batch in range(n_batches):
idxs_i = idxs[batch*batch_size:(batch+1)*batch_size]
sess.run(optimizer, feed_dict={X:np.reshape(x[idxs_i], [len(idxs_i),1]),Y:y[idxs_i]})
training_cost = sess.run(cost,feed_dict={X:np.reshape(x, [len(x),1]),Y:y})
training_xs = x
training_ys = sess.run(Y_pred,feed_dict={X:np.reshape(training_xs, [len(training_xs),1])})
if iteration % 20 == 0 :
plt.plot(x,y)
plt.plot(training_xs,training_ys)
plt.savefig('./graphs/graph'+str(iteration))
plt.close()
print(training_cost)
aW1wb3J0IHRlbnNvcmZsb3cgYXMgdGYKaW1wb3J0IG51bXB5IGFzIG5wCmltcG9ydCBtYXRwbG90bGliLnB5cGxvdCBhcyBwbHQgCgp4ID0gbnAubGluc3BhY2UoLTQuLDQuLDEwMCkKeSA9IHggKyBucC5yYW5kb20ubm9ybWFsKDAsMC4xLDEwMCkKClggPSB0Zi5wbGFjZWhvbGRlcih0Zi5mbG9hdDMyLCBzaGFwZT1bTm9uZSwxXSkKWSA9IHRmLnBsYWNlaG9sZGVyKHRmLmZsb2F0MzIsIHNoYXBlPVtOb25lXSkKCmRlZiBsaW5lYXIoWCwgbl9pbnB1dCwgbl9vdXRwdXQsYWN0aXZhdGlvbj1Ob25lLHNjb3BlPU5vbmUpOgogICAgd2l0aCB0Zi52YXJpYWJsZV9zY29wZShzY29wZSBvciAnTGluZWFyJyk6CiAgICAgICAgVyA9IHRmLmdldF92YXJpYWJsZSgKICAgICAgICAgICAgICAgIG5hbWU9J1cnLAogICAgICAgICAgICAgICAgc2hhcGU9W25faW5wdXQsbl9vdXRwdXRdLAogICAgICAgICAgICAgICAgaW5pdGlhbGl6ZXI9dGYucmFuZG9tX25vcm1hbF9pbml0aWFsaXplcihtZWFuPTAsc3RkZGV2PTAuMSkpCiAgICAgICAgYiA9IHRmLmdldF92YXJpYWJsZSgKICAgICAgICAgICAgICAgIG5hbWU9J2InLAogICAgICAgICAgICAgICAgc2hhcGU9W25fb3V0cHV0XSwKICAgICAgICAgICAgICAgIGluaXRpYWxpemVyPXRmLnJhbmRvbV9ub3JtYWxfaW5pdGlhbGl6ZXIobWVhbj0wLHN0ZGRldj0wLjEpKQogICAgICAgIGggPSB0Zi5tYXRtdWwoWCxXKSArIGIKICAgICAgICBpZiBhY3RpdmF0aW9uIGlzIG5vdCBOb25lOgogICAgICAgICAgICBoID0gYWN0aXZhdGlvbihoKQogICAgICAgIHJldHVybiBoCgpuX25ldXJvbnMgPSBbMSwyMCwxXQoKY3VycmVudF9pbnB1dCA9IFgKCmZvciBsYXllciBpbiByYW5nZSgxLGxlbihuX25ldXJvbnMpKToKICAgIGN1cnJlbnRfaW5wdXQgPSBsaW5lYXIoCiAgICAgICAgICAgIFg9Y3VycmVudF9pbnB1dCwKICAgICAgICAgICAgbl9pbnB1dCA9IG5fbmV1cm9uc1tsYXllci0xXSwKICAgICAgICAgICAgbl9vdXRwdXQgPSBuX25ldXJvbnNbbGF5ZXJdLAogICAgICAgICAgICBhY3RpdmF0aW9uPXRmLnNpZ21vaWQgaWYgKGxheWVyKzEpIDwgbGVuKG5fbmV1cm9ucykgZWxzZSBOb25lLAogICAgICAgICAgICBzY29wZT0nbGF5ZXJfJyArIHN0cihsYXllcikpCgpZX3ByZWQgPSBjdXJyZW50X2lucHV0Cgpjb3N0ID0gdGYucmVkdWNlX21lYW4odGYuc3F1YXJlKFlfcHJlZCAtIFkpKQoKb3B0aW1pemVyID0gdGYudHJhaW4uR3JhZGllbnREZXNjZW50T3B0aW1pemVyKDAuNSkubWluaW1pemUoY29zdCkKCm5faXRlcmF0aW9ucyA9IDEwMDAKYmF0Y2hfc2l6ZSA9IDEwMAoKd2l0aCB0Zi5TZXNzaW9uKCkgYXMgc2VzczoKICAgIHNlc3MucnVuKHRmLmluaXRpYWxpemVfYWxsX3ZhcmlhYmxlcygpKQogICAgZm9yIGl0ZXJhdGlvbiBpbiByYW5nZShuX2l0ZXJhdGlvbnMpOgogICAgICAgIGlkeHMgPSBucC5yYW5kb20ucGVybXV0YXRpb24ocmFuZ2UobGVuKHgpKSkKICAgICAgICBuX2JhdGNoZXMgPSBsZW4oaWR4cykgLy8gYmF0Y2hfc2l6ZQogICAgICAgIGZvciBiYXRjaCBpbiByYW5nZShuX2JhdGNoZXMpOgogICAgICAgICAgICBpZHhzX2kgPSBpZHhzW2JhdGNoKmJhdGNoX3NpemU6KGJhdGNoKzEpKmJhdGNoX3NpemVdCiAgICAgICAgICAgIHNlc3MucnVuKG9wdGltaXplciwgZmVlZF9kaWN0PXtYOm5wLnJlc2hhcGUoeFtpZHhzX2ldLCBbbGVuKGlkeHNfaSksMV0pLFk6eVtpZHhzX2ldfSkgCiAgICAgICAgdHJhaW5pbmdfY29zdCA9IHNlc3MucnVuKGNvc3QsZmVlZF9kaWN0PXtYOm5wLnJlc2hhcGUoeCwgW2xlbih4KSwxXSksWTp5fSkKCiAgICAgICAgdHJhaW5pbmdfeHMgPSB4CiAgICAgICAgdHJhaW5pbmdfeXMgPSBzZXNzLnJ1bihZX3ByZWQsZmVlZF9kaWN0PXtYOm5wLnJlc2hhcGUodHJhaW5pbmdfeHMsIFtsZW4odHJhaW5pbmdfeHMpLDFdKX0pCiAgICAgICAgaWYgaXRlcmF0aW9uICUgMjAgPT0gMCA6IAogICAgICAgICAgICBwbHQucGxvdCh4LHkpCiAgICAgICAgICAgIHBsdC5wbG90KHRyYWluaW5nX3hzLHRyYWluaW5nX3lzKQogICAgICAgICAgICBwbHQuc2F2ZWZpZygnLi9ncmFwaHMvZ3JhcGgnK3N0cihpdGVyYXRpb24pKQogICAgICAgICAgICBwbHQuY2xvc2UoKQogICAgICAgICAgICBwcmludCh0cmFpbmluZ19jb3N0KQoKCgoKCgoK