fork download
  1. import tensorflow as tf
  2. import numpy as np
  3. import time
  4. import os
  5. from enum import Enum
  6.  
  7. class MappingType(Enum):
  8. Identity = 0
  9. Linear = 1
  10. Affine = 2
  11.  
  12. class ODESolver(Enum):
  13. SemiImplicit = 0
  14. Explicit = 1
  15. RungeKutta = 2
  16.  
  17. class LTCCell(tf.nn.rnn_cell.RNNCell):
  18.  
  19. def __init__(self, num_units):
  20.  
  21. self._input_size = -1
  22. self._num_units = num_units
  23. self._is_built = False
  24.  
  25. # Number of ODE solver steps in one RNN step
  26. self._ode_solver_unfolds = 6
  27. self._solver = ODESolver.SemiImplicit
  28.  
  29. self._input_mapping = MappingType.Affine
  30.  
  31. self._erev_init_factor = 1
  32.  
  33. self._w_init_max = 1.0
  34. self._w_init_min = 0.01
  35. self._cm_init_min = 0.5
  36. self._cm_init_max = 0.5
  37. self._gleak_init_min = 1
  38. self._gleak_init_max = 1
  39.  
  40. self._w_min_value = 0.00001
  41. self._w_max_value = 1000
  42. self._gleak_min_value = 0.00001
  43. self._gleak_max_value = 1000
  44. self._cm_t_min_value = 0.000001
  45. self._cm_t_max_value = 1000
  46.  
  47. self._fix_cm = None
  48. self._fix_gleak = None
  49. self._fix_vleak = None
  50.  
  51. @property
  52. def state_size(self):
  53. return self._num_units
  54.  
  55. @property
  56. def output_size(self):
  57. return self._num_units
  58.  
  59. def _map_inputs(self,inputs,resuse_scope=False):
  60. varscope = "sensory_mapping"
  61. reuse = tf.AUTO_REUSE
  62. if(resuse_scope):
  63. varscope = self._sensory_varscope
  64. reuse = True
  65.  
  66. with tf.variable_scope(varscope,reuse=reuse) as scope:
  67. self._sensory_varscope = scope
  68. if(self._input_mapping == MappingType.Affine or self._input_mapping == MappingType.Linear):
  69. w = tf.get_variable(name='input_w',shape=[self._input_size],trainable=True,initializer=tf.initializers.constant(1))
  70. inputs = inputs * w
  71. if(self._input_mapping == MappingType.Affine):
  72. b = tf.get_variable(name='input_b',shape=[self._input_size],trainable=True,initializer=tf.initializers.constant(0))
  73. inputs = inputs + b
  74. return inputs
  75.  
  76. # TODO: Implement RNNLayer properly,i.e, allocate variables here
  77. def build(self,input_shape):
  78. pass
  79.  
  80. def __call__(self, inputs, state, scope=None):
  81. with tf.variable_scope("ltc"):
  82. if(not self._is_built):
  83. # TODO: Move this part into the build method inherited form tf.Layers
  84. self._is_built = True
  85. self._input_size = int(inputs.shape[-1])
  86.  
  87. self._get_variables()
  88.  
  89. elif(self._input_size != int(inputs.shape[-1])):
  90. raise ValueError("You first feed an input with {} features and now one with {} features, that is not possible".format(
  91. self._input_size,
  92. int(inputs[-1])
  93. ))
  94.  
  95. inputs = self._map_inputs(inputs)
  96.  
  97. if(self._solver == ODESolver.Explicit):
  98. next_state = self._ode_step_explicit(inputs,state,_ode_solver_unfolds=self._ode_solver_unfolds)
  99. elif(self._solver == ODESolver.SemiImplicit):
  100. next_state = self._ode_step(inputs,state)
  101. elif(self._solver == ODESolver.RungeKutta):
  102. next_state = self._ode_step_runge_kutta(inputs,state)
  103. else:
  104. raise ValueError("Unknown ODE solver '{}'".format(str(self._solver)))
  105.  
  106. outputs = next_state
  107.  
  108. return outputs, next_state
  109.  
  110. # Create tf variables
  111. def _get_variables(self):
  112. self.sensory_mu = tf.get_variable(name='sensory_mu',shape=[self._input_size,self._num_units],trainable=True,initializer=tf.initializers.random_uniform(minval=0.3,maxval=0.8))
  113. self.sensory_sigma = tf.get_variable(name='sensory_sigma',shape=[self._input_size,self._num_units],trainable=True,initializer=tf.initializers.random_uniform(minval=3.0,maxval=8.0))
  114. self.sensory_W = tf.get_variable(name='sensory_W',shape=[self._input_size,self._num_units],trainable=True,initializer=tf.initializers.constant(np.random.uniform(low=self._w_init_min,high=self._w_init_max,size=[self._input_size,self._num_units])))
  115. sensory_erev_init = 2*np.random.randint(low=0,high=2,size=[self._input_size,self._num_units])-1
  116. self.sensory_erev = tf.get_variable(name='sensory_erev',shape=[self._input_size,self._num_units],trainable=True,initializer=tf.initializers.constant(sensory_erev_init*self._erev_init_factor))
  117.  
  118. self.mu = tf.get_variable(name='mu',shape=[self._num_units,self._num_units],trainable=True,initializer=tf.initializers.random_uniform(minval=0.3,maxval=0.8))
  119. self.sigma = tf.get_variable(name='sigma',shape=[self._num_units,self._num_units],trainable=True,initializer=tf.initializers.random_uniform(minval=3.0,maxval=8.0))
  120. self.W = tf.get_variable(name='W',shape=[self._num_units,self._num_units],trainable=True,initializer=tf.initializers.constant(np.random.uniform(low=self._w_init_min,high=self._w_init_max,size=[self._num_units,self._num_units])))
  121.  
  122. erev_init = 2*np.random.randint(low=0,high=2,size=[self._num_units,self._num_units])-1
  123. self.erev = tf.get_variable(name='erev',shape=[self._num_units,self._num_units],trainable=True,initializer=tf.initializers.constant(erev_init*self._erev_init_factor))
  124.  
  125. if(self._fix_vleak is None):
  126. self.vleak = tf.get_variable(name='vleak',shape=[self._num_units],trainable=True,initializer=tf.initializers.random_uniform(minval=-0.2,maxval=0.2))
  127. else:
  128. self.vleak = tf.get_variable(name='vleak',shape=[self._num_units],trainable=False,initializer=tf.initializers.constant(self._fix_vleak))
  129.  
  130. if(self._fix_gleak is None):
  131. initializer=tf.initializers.constant(self._gleak_init_min)
  132. if(self._gleak_init_max > self._gleak_init_min):
  133. initializer = tf.initializers.random_uniform(minval= self._gleak_init_min,maxval = self._gleak_init_max)
  134. self.gleak = tf.get_variable(name='gleak',shape=[self._num_units],trainable=True,initializer=initializer)
  135. else:
  136. self.gleak = tf.get_variable(name='gleak',shape=[self._num_units],trainable=False,initializer=tf.initializers.constant(self._fix_gleak))
  137.  
  138. if(self._fix_cm is None):
  139. initializer=tf.initializers.constant(self._cm_init_min)
  140. if(self._cm_init_max > self._cm_init_min):
  141. initializer = tf.initializers.random_uniform(minval= self._cm_init_min,maxval = self._cm_init_max)
  142. self.cm_t = tf.get_variable(name='cm_t',shape=[self._num_units],trainable=True,initializer=initializer)
  143. else:
  144. self.cm_t = tf.get_variable(name='cm_t',shape=[self._num_units],trainable=False,initializer=tf.initializers.constant(self._fix_cm))
  145.  
  146. # Hybrid euler method
  147. def _ode_step(self,inputs,state):
  148. v_pre = state
  149.  
  150. sensory_w_activation = self.sensory_W*self._sigmoid(inputs,self.sensory_mu,self.sensory_sigma)
  151. sensory_rev_activation = sensory_w_activation*self.sensory_erev
  152.  
  153. w_numerator_sensory = tf.reduce_sum(sensory_rev_activation,axis=1)
  154. w_denominator_sensory = tf.reduce_sum(sensory_w_activation,axis=1)
  155.  
  156. for t in range(self._ode_solver_unfolds):
  157. w_activation = self.W*self._sigmoid(v_pre,self.mu,self.sigma)
  158.  
  159. rev_activation = w_activation*self.erev
  160.  
  161. w_numerator = tf.reduce_sum(rev_activation,axis=1) + w_numerator_sensory
  162. w_denominator = tf.reduce_sum(w_activation,axis=1) + w_denominator_sensory
  163.  
  164. numerator = self.cm_t * v_pre + self.gleak*self.vleak + w_numerator
  165. denominator = self.cm_t + self.gleak + w_denominator
  166.  
  167. v_pre = numerator/denominator
  168.  
  169. return v_pre
  170.  
  171. def _f_prime(self,inputs,state):
  172. v_pre = state
  173.  
  174. # We can pre-compute the effects of the sensory neurons here
  175. sensory_w_activation = self.sensory_W*self._sigmoid(inputs,self.sensory_mu,self.sensory_sigma)
  176. w_reduced_sensory = tf.reduce_sum(sensory_w_activation,axis=1)
  177.  
  178. # Unfold the mutliply ODE multiple times into one RNN step
  179. w_activation = self.W*self._sigmoid(v_pre,self.mu,self.sigma)
  180.  
  181. w_reduced_synapse = tf.reduce_sum(w_activation,axis=1)
  182.  
  183. sensory_in = self.sensory_erev * sensory_w_activation
  184. synapse_in = self.erev * w_activation
  185.  
  186. sum_in = tf.reduce_sum(sensory_in,axis=1) - v_pre*w_reduced_synapse + tf.reduce_sum(synapse_in,axis=1) - v_pre * w_reduced_sensory
  187.  
  188. f_prime = 1/self.cm_t * (self.gleak * (self.vleak-v_pre) + sum_in)
  189.  
  190. return f_prime
  191.  
  192. def _ode_step_runge_kutta(self,inputs,state):
  193.  
  194. h = 0.1
  195. for i in range(self._ode_solver_unfolds):
  196. k1 = h*self._f_prime(inputs,state)
  197. k2 = h*self._f_prime(inputs,state+k1*0.5)
  198. k3 = h*self._f_prime(inputs,state+k2*0.5)
  199. k4 = h*self._f_prime(inputs,state+k3)
  200.  
  201. state = state + 1.0/6*(k1+2*k2+2*k3+k4)
  202.  
  203. return state
  204.  
  205. def _ode_step_explicit(self,inputs,state,_ode_solver_unfolds):
  206. v_pre = state
  207.  
  208. # We can pre-compute the effects of the sensory neurons here
  209. sensory_w_activation = self.sensory_W*self._sigmoid(inputs,self.sensory_mu,self.sensory_sigma)
  210. w_reduced_sensory = tf.reduce_sum(sensory_w_activation,axis=1)
  211.  
  212.  
  213. # Unfold the mutliply ODE multiple times into one RNN step
  214. for t in range(_ode_solver_unfolds):
  215. w_activation = self.W*self._sigmoid(v_pre,self.mu,self.sigma)
  216.  
  217. w_reduced_synapse = tf.reduce_sum(w_activation,axis=1)
  218.  
  219. sensory_in = self.sensory_erev * sensory_w_activation
  220. synapse_in = self.erev * w_activation
  221.  
  222. sum_in = tf.reduce_sum(sensory_in,axis=1) - v_pre*w_reduced_synapse + tf.reduce_sum(synapse_in,axis=1) - v_pre * w_reduced_sensory
  223.  
  224. f_prime = 1/self.cm_t * (self.gleak * (self.vleak-v_pre) + sum_in)
  225.  
  226. v_pre = v_pre + 0.1 * f_prime
  227.  
  228. return v_pre
  229.  
  230. def _sigmoid(self,v_pre,mu,sigma):
  231. v_pre = tf.reshape(v_pre,[-1,v_pre.shape[-1],1])
  232. mues = v_pre - mu
  233. x = sigma*mues
  234. return tf.nn.sigmoid(x)
  235.  
  236. def get_param_constrain_op(self):
  237.  
  238. cm_clipping_op = tf.assign(self.cm_t,tf.clip_by_value(self.cm_t, self._cm_t_min_value, self._cm_t_max_value))
  239. gleak_clipping_op = tf.assign(self.gleak,tf.clip_by_value(self.gleak, self._gleak_min_value, self._gleak_max_value))
  240. w_clipping_op = tf.assign(self.W,tf.clip_by_value(self.W, self._w_min_value, self._w_max_value))
  241. sensory_w_clipping_op = tf.assign(self.sensory_W ,tf.clip_by_value(self.sensory_W, self._w_min_value, self._w_max_value))
  242.  
  243. return [cm_clipping_op,gleak_clipping_op,w_clipping_op,sensory_w_clipping_op]
  244.  
  245. def export_weights(self,dirname,sess,output_weights=None):
  246. os.makedirs(dirname,exist_ok=True)
  247. w,erev,mu,sigma = sess.run([self.W,self.erev,self.mu,self.sigma])
  248. sensory_w,sensory_erev,sensory_mu,sensory_sigma = sess.run([self.sensory_W,self.sensory_erev,self.sensory_mu,self.sensory_sigma])
  249. vleak,gleak,cm = sess.run([self.vleak,self.gleak,self.cm_t])
  250.  
  251. if(not output_weights is None):
  252. output_w,output_b = sess.run(output_weights)
  253. np.savetxt(os.path.join(dirname,"output_w.csv"),output_w)
  254. np.savetxt(os.path.join(dirname,"output_b.csv"),output_b)
  255. np.savetxt(os.path.join(dirname,"w.csv"),w)
  256. np.savetxt(os.path.join(dirname,"erev.csv"),erev)
  257. np.savetxt(os.path.join(dirname,"mu.csv"),mu)
  258. np.savetxt(os.path.join(dirname,"sigma.csv"),sigma)
  259. np.savetxt(os.path.join(dirname,"sensory_w.csv"),sensory_w)
  260. np.savetxt(os.path.join(dirname,"sensory_erev.csv"),sensory_erev)
  261. np.savetxt(os.path.join(dirname,"sensory_mu.csv"),sensory_mu)
  262. np.savetxt(os.path.join(dirname,"sensory_sigma.csv"),sensory_sigma)
  263. np.savetxt(os.path.join(dirname,"vleak.csv"),vleak)
  264. np.savetxt(os.path.join(dirname,"gleak.csv"),gleak)
  265. np.savetxt(os.path.join(dirname,"cm.csv"),cm)
Success #stdin #stdout 1.03s 193516KB
stdin
Standard input is empty
stdout
Standard output is empty