import numpy as np
from time import *
from numba import njit as njit
import numba.types

def gen(n,size):
    return np.array([np.concatenate((np.zeros(x,dtype=np.int32),np.random.randint(0,9,size-x))) for x in np.random.randint(0,size,n)])

def np_trim_zeros(arr):
    N=arr.shape[1]
    for i,x in enumerate(arr):
        m=np.trim_zeros(x)
        arr[i]=np.concatenate((m,np.zeros(N-len(m),dtype=np.int32)))
    return arr



def np_slice(arr):
    for i,x in enumerate(arr):
        t=0
        for c in x:
            if c==0:
                t+=1
            else:
                break
        arr[i]=np.concatenate((x[t:],x[:t])) #Вот эта фигня лишние переменные в памяти держит
    return arr

def np_slice_no_mem(arr):
    N=arr.shape[1]
    for i,x in enumerate(arr):
        t=0
        for c in x:
            if c==0:
                t+=1
            else:
                break
        arr[i][0:N-t]=x[t:N] #Вот эта возможно тоже, но меньше
        arr[i][N-t:]=0
    return arr

def np_musor(arr): #Я без комментариев это оставлю
    N=arr.shape[1]#Количество элементов в массиве
    #print(arr)
    masks=arr==0
    #print(masks)
    for i in range(N-1):
        masks[:,i+1]&=masks[:,i]
        if not masks[:,i+1].any():
            break
    i+=1
    sm=np.sum(masks,axis=1)
    #print(sm)
    for d in range(1,1+i):
        m=sm==d
        arr[m]=np.roll(arr[m],-d,axis=1)
    #print(arr)


    return arr


#np_slice_no_mem_numba=njit(numba.types.void(numba.types.int32[:,:]))(np_slice_no_mem)
@njit(numba.types.int32[:,:](numba.types.int32[:,:]))
def np_numba(arr):
    N=arr.shape[1]
    for i,x in enumerate(arr):
        t=0
        for c in x:
            if c==0:
                t+=1
            else:
                break
        arr[i][0:N-t]=x[t:N]
        arr[i][N-t:]=0
    return arr



#arr=gen(1000,200) #1000 строк по 200 элементов
#№arr=gen(107,10)
#print(arr.shape)
def prn(arr,n=3):
    for c in range(n):
        print(f'{arr[c][:20]} ...')

show=0 #False

def test(n,length,func):
    print(f'test {n}x{length}')
    arr = gen(n, length)
    if show:
        prn(arr)
    for fun in func:
        brr = arr.copy()
        t1 = time()
        brr = np.array(brr)
        brr = fun(brr)
        t2 = time()
        print(f"   {fun.__name__} [{(t2-t1)*1000:.3f} ms]")
        if show:
            prn(brr, 1)


test(1000,200,[np_trim_zeros,
            np_slice,
            np_slice_no_mem,
            np_musor,
            np_numba,]
        )

#exit()

test(10000,2000,[
            np_musor,
            np_numba,]
        )

