import tensorflow as tf
import numpy as np
# Create a session
sess = tf.compat.v1.Session()
# Initialize all variables
sess.run(tf.compat.v1.global_variables_initializer())
tensor = [[0, 1, 2, 10],
[3, 4, 5, 11],
[6, 7, 8, 12]]
print(tensor)
shift, scale = tf.keras.layers.Lambda(
lambda x: tf.split(x, num_or_size_splits=2, axis=-1)
)(tensor)
# Run the session to get the value of the tensor
tensor_shift = sess.run(shift)
tensor_scale = sess.run(scale)
# Print the tensor value
print("shift = ", tensor_shift)
print("scale = ", tensor_scale)
# Close the session
sess.close()
aW1wb3J0IHRlbnNvcmZsb3cgYXMgdGYKaW1wb3J0IG51bXB5IGFzIG5wCgojIENyZWF0ZSBhIHNlc3Npb24Kc2VzcyA9IHRmLmNvbXBhdC52MS5TZXNzaW9uKCkKCiMgSW5pdGlhbGl6ZSBhbGwgdmFyaWFibGVzCnNlc3MucnVuKHRmLmNvbXBhdC52MS5nbG9iYWxfdmFyaWFibGVzX2luaXRpYWxpemVyKCkpCgp0ZW5zb3IgPSBbWzAsIDEsIDIsIDEwXSwKICAgICAgICAgIFszLCA0LCA1LCAxMV0sCiAgICAgICAgICBbNiwgNywgOCwgMTJdXQoKcHJpbnQodGVuc29yKQoKc2hpZnQsIHNjYWxlID0gdGYua2VyYXMubGF5ZXJzLkxhbWJkYSgKICAgIGxhbWJkYSB4OiB0Zi5zcGxpdCh4LCBudW1fb3Jfc2l6ZV9zcGxpdHM9MiwgYXhpcz0tMSkKKSh0ZW5zb3IpCgojIFJ1biB0aGUgc2Vzc2lvbiB0byBnZXQgdGhlIHZhbHVlIG9mIHRoZSB0ZW5zb3IKdGVuc29yX3NoaWZ0ID0gc2Vzcy5ydW4oc2hpZnQpCnRlbnNvcl9zY2FsZSA9IHNlc3MucnVuKHNjYWxlKQoKIyBQcmludCB0aGUgdGVuc29yIHZhbHVlCnByaW50KCJzaGlmdCA9ICIsIHRlbnNvcl9zaGlmdCkKcHJpbnQoInNjYWxlID0gIiwgdGVuc29yX3NjYWxlKQoKIyBDbG9zZSB0aGUgc2Vzc2lvbgpzZXNzLmNsb3NlKCkK
[[0, 1, 2, 10], [3, 4, 5, 11], [6, 7, 8, 12]]
('shift = ', array([[0, 1],
[3, 4],
[6, 7]], dtype=int32))
('scale = ', array([[ 2, 10],
[ 5, 11],
[ 8, 12]], dtype=int32))