import numpy as np
from scipy.ndimage import median

dat = np.arange(12).reshape(2, 2, 3)
idx = np.array([[0, 0], [1, 2]])

def summarize(dat, idx):
    idx = np.unique(idx, return_inverse=True)[1].reshape(idx.shape)
    chan = dat.shape[-1]
    offset = idx.max() + 1
    index = np.stack([idx + i * offset for i in range(chan)], axis=-1)
    return median(dat, index, index=range(offset * chan)).reshape(chan, offset).T

print(summarize(dat, idx))