import numpy as np

def advindex_allaxes(ind, axis):    
    # Error checks for axis for out-of-bounds
    n = ind.ndim
    if axis < 0:
        if axis >= -n:
            axis += n
        else:
            raise IndexError('axis out of range')
    elif axis >= n:
            raise IndexError('axis out of range')
            
    idx = np.ogrid[tuple(map(slice, ind.shape))]
    idx[axis] = ind
    return tuple(idx)

def take_along_axis(arr, ind, axis):
    return arr[advindex_allaxes(ind, axis)]

np.random.seed(0)
A = np.random.randint(0,9,(3,4,5,6,7))

n = A.ndim
for i in range(-n,n):
    print(np.allclose(take_along_axis(A,A.argsort(axis=i),axis=i), np.sort(A,i)))
