fork download
  1. import tensorflow as tf
  2. import numpy as np
  3.  
  4. # Create a session
  5. sess = tf.compat.v1.Session()
  6.  
  7. # Initialize all variables
  8. sess.run(tf.compat.v1.global_variables_initializer())
  9.  
  10. tensor = [[0, 1, 2, 10],
  11. [3, 4, 5, 11],
  12. [6, 7, 8, 12]]
  13.  
  14. print(tensor)
  15.  
  16. shift, scale = tf.keras.layers.Lambda(
  17. lambda x: tf.split(x, num_or_size_splits=2, axis=-1)
  18. )(tensor)
  19.  
  20. # Run the session to get the value of the tensor
  21. tensor_shift = sess.run(shift)
  22. tensor_scale = sess.run(scale)
  23.  
  24. # Print the tensor value
  25. print("shift = ", tensor_shift)
  26. print("scale = ", tensor_scale)
  27.  
  28. # Close the session
  29. sess.close()
  30.  
Success #stdin #stdout 1.44s 205524KB
stdin
Standard input is empty
stdout
[[0, 1, 2, 10], [3, 4, 5, 11], [6, 7, 8, 12]]
('shift = ', array([[0, 1],
       [3, 4],
       [6, 7]], dtype=int32))
('scale = ', array([[ 2, 10],
       [ 5, 11],
       [ 8, 12]], dtype=int32))