import numpy as np
np.random.seed(0)
error = np.random.rand(2,5)
label = np.zeros_like(error)
label[0,3] = 1
label[1,1] = 1
print('error: \n' ,error)
print('label: \n', label)
per_ts_loss = 0
indexes = np.argmax(label,axis=1)
for i, idx in enumerate(indexes):
error[i,idx] -=1
per_ts_loss += error[i,idx]
print('\nerror(indexes {} and {} are changed): \n{}'.format(indexes[0], indexes[1], error))
print('per_ts_loss: \n', per_ts_loss)
aW1wb3J0IG51bXB5IGFzIG5wIApucC5yYW5kb20uc2VlZCgwKQplcnJvciA9IG5wLnJhbmRvbS5yYW5kKDIsNSkKbGFiZWwgPSBucC56ZXJvc19saWtlKGVycm9yKQpsYWJlbFswLDNdID0gMQpsYWJlbFsxLDFdID0gMSAKcHJpbnQoJ2Vycm9yOiBcbicgLGVycm9yKQpwcmludCgnbGFiZWw6IFxuJywgbGFiZWwpCgpwZXJfdHNfbG9zcyA9IDAKaW5kZXhlcyA9IG5wLmFyZ21heChsYWJlbCxheGlzPTEpCmZvciBpLCBpZHggaW4gZW51bWVyYXRlKGluZGV4ZXMpOgogICAgZXJyb3JbaSxpZHhdIC09MQogICAgcGVyX3RzX2xvc3MgKz0gZXJyb3JbaSxpZHhdCgpwcmludCgnXG5lcnJvcihpbmRleGVzIHt9IGFuZCB7fSBhcmUgY2hhbmdlZCk6IFxue30nLmZvcm1hdChpbmRleGVzWzBdLCBpbmRleGVzWzFdLCBlcnJvcikpCnByaW50KCdwZXJfdHNfbG9zczogXG4nLCBwZXJfdHNfbG9zcyk=