import numpy as np
def advindex_allaxes(ind, axis):
# Error checks for axis for out-of-bounds
n = ind.ndim
if axis < 0:
if axis >= -n:
axis += n
else:
raise IndexError('axis out of range')
elif axis >= n:
raise IndexError('axis out of range')
idx = np.ogrid[tuple(map(slice, ind.shape))]
idx[axis] = ind
return tuple(idx)
def take_along_axis(arr, ind, axis):
return arr[advindex_allaxes(ind, axis)]
np.random.seed(0)
A = np.random.randint(0,9,(3,4,5,6,7))
n = A.ndim
for i in range(-n,n):
print(np.allclose(take_along_axis(A,A.argsort(axis=i),axis=i), np.sort(A,i)))
aW1wb3J0IG51bXB5IGFzIG5wCgpkZWYgYWR2aW5kZXhfYWxsYXhlcyhpbmQsIGF4aXMpOiAgICAKICAgICMgRXJyb3IgY2hlY2tzIGZvciBheGlzIGZvciBvdXQtb2YtYm91bmRzCiAgICBuID0gaW5kLm5kaW0KICAgIGlmIGF4aXMgPCAwOgogICAgICAgIGlmIGF4aXMgPj0gLW46CiAgICAgICAgICAgIGF4aXMgKz0gbgogICAgICAgIGVsc2U6CiAgICAgICAgICAgIHJhaXNlIEluZGV4RXJyb3IoJ2F4aXMgb3V0IG9mIHJhbmdlJykKICAgIGVsaWYgYXhpcyA+PSBuOgogICAgICAgICAgICByYWlzZSBJbmRleEVycm9yKCdheGlzIG91dCBvZiByYW5nZScpCiAgICAgICAgICAgIAogICAgaWR4ID0gbnAub2dyaWRbdHVwbGUobWFwKHNsaWNlLCBpbmQuc2hhcGUpKV0KICAgIGlkeFtheGlzXSA9IGluZAogICAgcmV0dXJuIHR1cGxlKGlkeCkKCmRlZiB0YWtlX2Fsb25nX2F4aXMoYXJyLCBpbmQsIGF4aXMpOgogICAgcmV0dXJuIGFyclthZHZpbmRleF9hbGxheGVzKGluZCwgYXhpcyldCgpucC5yYW5kb20uc2VlZCgwKQpBID0gbnAucmFuZG9tLnJhbmRpbnQoMCw5LCgzLDQsNSw2LDcpKQoKbiA9IEEubmRpbQpmb3IgaSBpbiByYW5nZSgtbixuKToKICAgIHByaW50KG5wLmFsbGNsb3NlKHRha2VfYWxvbmdfYXhpcyhBLEEuYXJnc29ydChheGlzPWkpLGF4aXM9aSksIG5wLnNvcnQoQSxpKSkpCg==