def get_sorted_data_from_matrix(matrix):
    data = []

    for x in matrix:
        for y in x:
            data.append(y)
    data.sort()

    return data


def get_matrix_size(matrix):
    height = len(matrix)
    width = len(matrix[0])

    return height, width


def create_sorted_matrix(data, height, width):
    new_matrix = []
    for x in range(height):
        x_list = []
        for y in range(width):
            x_list.append(data.pop(0))
        new_matrix.append(x_list)

    return new_matrix


def matrix_processing(matrix):
    data = get_sorted_data_from_matrix(matrix)
    height, width = get_matrix_size(matrix)
    return create_sorted_matrix(data, height, width)


def print_matrix(matrix):
    for x in matrix:
        print(x)


a = [[3,2,1],
     [4,1,2]]

print_matrix(matrix_processing(a))