fork download
  1. import numpy as np
  2.  
  3. def advindex_allaxes(ind, axis):
  4. # Error checks for axis for out-of-bounds
  5. n = ind.ndim
  6. if axis < 0:
  7. if axis >= -n:
  8. axis += n
  9. else:
  10. raise IndexError('axis out of range')
  11. elif axis >= n:
  12. raise IndexError('axis out of range')
  13.  
  14. idx = np.ogrid[tuple(map(slice, ind.shape))]
  15. idx[axis] = ind
  16. return tuple(idx)
  17.  
  18. def take_along_axis(arr, ind, axis):
  19. return arr[advindex_allaxes(ind, axis)]
  20.  
  21. np.random.seed(0)
  22. A = np.random.randint(0,9,(3,4,5,6,7))
  23.  
  24. n = A.ndim
  25. for i in range(-n,n):
  26. print(np.allclose(take_along_axis(A,A.argsort(axis=i),axis=i), np.sort(A,i)))
  27.  
Success #stdin #stdout 0.07s 92224KB
stdin
Standard input is empty
stdout
True
True
True
True
True
True
True
True
True
True