import random
import time


def easy_solve(num1, num2, ans):
    nums = [1, 2, 3, 4, 5, 6, 7, 8, 9]
    input = (num1 + num2 + ans)
    nums_used = input.replace('x', '')
    [nums.remove(int(i)) for i in nums_used]
    no_answer = True
    while no_answer:
        temp_input = input[:]
        temp_nums = nums[:]
        while 'x' in temp_input:
            for i in range(0, len(temp_input)):
                if temp_input[i] == 'x':
                    rand_num = random.randrange(0, len(temp_nums))
                    temp_input = temp_input[:i] + str(temp_nums[rand_num]) + temp_input[i+1:]
                    temp_nums.remove(temp_nums[rand_num])
        num1 = int(temp_input[0] + temp_input[1] + temp_input[2])
        num2 = int(temp_input[3] + temp_input[4] + temp_input[5])
        ans = int(temp_input[6] + temp_input[7] + temp_input[8])
        if num1 + num2 == ans:
            no_answer = False
    print(str(num1) + ' + ' + str(num2) + ' = ' + str(ans))


def hard_solve(problem):
    split_spot = find_split(problem)
    sets = int(len(problem.replace(' ', '').replace('+', '').replace('=', ''))/9)
    nums = []
    for i in range(1, 10):
        for x in range(0, sets):
            nums.append(i)
    nums_used = problem.replace(' ', '').replace('+', '').replace('=', '').replace('x', '')
    [nums.remove(int(i)) for i in nums_used]
    no_answer = True
    problem = problem.replace(' ', '').replace('+', '').replace('=', '')
    while no_answer:
        temp_problem = problem[:]
        temp_nums = nums[:]
        while 'x' in temp_problem:
            for i in range(0, len(temp_problem)):
                if temp_problem[i] == 'x':
                    rand_num = random.randrange(0, len(temp_nums))
                    temp_problem = temp_problem[:i] + str(temp_nums[rand_num]) + temp_problem[i+1:]
                    temp_nums.remove(temp_nums[rand_num])
        sums = [int(temp_problem[i:i+3]) for i in range(0, split_spot, 3)]
        temp_problem = temp_problem[split_spot:]
        ans = [int(temp_problem[i:i+3]) for i in range(0, len(temp_problem), 3)]
        total_sum = 0
        for i in sums:
            total_sum += i
        total_ans = 0
        for i in ans:
            total_ans += i
        if total_ans == total_sum:
            no_answer = False
    sums_string = ' + '.join(str(i) for i in sums)
    ans_string = ' + '.join(str(i) for i in ans)
    print(sums_string + " = " + ans_string)


def find_split(problem):
    problem = problem.replace(' ', '').replace('+', '')
    for i in range(0, len(problem)):
        if problem[i] == '=':
            return i

start = time.time()

easy_solve('1xx', 'xxx', '468')
easy_solve('xxx', 'x81', '9x4')
easy_solve('xxx', '39x', 'x75')
easy_solve('xxx', '5x1', '86x')
hard_solve('xxx + xxx + 5x3 + 123 = xxx + 795')
hard_solve('xxx + xxx + 23x + 571 = xxx + x82')
hard_solve('xxx + xxx + xx7 + 212 = xxx + 889')
hard_solve('xxx + xxx + 1x6 + 142 = xxx + 553')
hard_solve('xxx + xxx + xxx + x29 + 821 = xxx + xxx + 8xx + 867')
hard_solve('xxx + xxx + xxx + 4x1 + 689 = xxx + xxx + x5x + 957')
hard_solve('xxx + xxx + xxx + 64x + 581 = xxx + xxx + xx2 + 623')
hard_solve('xxx + xxx + xxx + x81 + 759 = xxx + xxx + 8xx + 462')
hard_solve('xxx + xxx + xxx + 6x3 + 299 = xxx + xxx + x8x + 423')
hard_solve('xxx + xxx + xxx + 58x + 561 = xxx + xxx + xx7 + 993')

print("--- %s seconds ---" % (time.time() - start))