fork(4) download
  1. import numpy as np
  2. np.random.seed(0)
  3. error = np.random.rand(2,5)
  4. label = np.zeros_like(error)
  5. label[0,3] = 1
  6. label[1,1] = 1
  7. print('error: \n' ,error)
  8. print('label: \n', label)
  9.  
  10. per_ts_loss = 0
  11. indexes = np.argmax(label,axis=1)
  12. for i, idx in enumerate(indexes):
  13. error[i,idx] -=1
  14. per_ts_loss += error[i,idx]
  15.  
  16. print('\nerror(indexes {} and {} are changed): \n{}'.format(indexes[0], indexes[1], error))
  17. print('per_ts_loss: \n', per_ts_loss)
Success #stdin #stdout 0.08s 92224KB
stdin
Standard input is empty
stdout
error: 
 [[ 0.5488135   0.71518937  0.60276338  0.54488318  0.4236548 ]
 [ 0.64589411  0.43758721  0.891773    0.96366276  0.38344152]]
label: 
 [[ 0.  0.  0.  1.  0.]
 [ 0.  1.  0.  0.  0.]]

error(indexes 3 and 1 are changed): 
[[ 0.5488135   0.71518937  0.60276338 -0.45511682  0.4236548 ]
 [ 0.64589411 -0.56241279  0.891773    0.96366276  0.38344152]]
per_ts_loss: 
 -1.01752960574