from collections import Counter
from random import randint
from matplotlib import pyplot


def call_counter(func):
    def helper(*args, **kwargs):
        helper.calls += 1
        return func(*args, **kwargs)
    helper.calls = 0
    return helper


def random_guess(options):
    res = []

    for num_options in options:
        res.append(randint(0, num_options))

    return res


@call_counter
def correct_answers(guess, solution):
    correct = 0
    for g, s in zip(guess, solution):
        if g == s:
            correct += 1
    return correct


def play_game(possible_responses, solution, strategy):
    correct_answers.calls = 0
    strategy(possible_responses, solution)
    return correct_answers.calls


def xurdones_strategy(possible_responses, solution):
    guess = random_guess(possible_responses)
    change_idx = 0
    old_correct = correct_answers(guess, solution)
    while old_correct != len(solution):
        old_guess = guess[change_idx]
        guess[change_idx] = (guess[change_idx] + 1) % (possible_responses[change_idx]+1)
        new_correct = correct_answers(guess, solution)

        if new_correct < old_correct:
            guess[change_idx] = old_guess
        elif new_correct > old_correct:
            old_correct = new_correct
        if new_correct != old_correct:
            change_idx += 1


def random_strategy(possible_responses, solution):
    while correct_answers(random_guess(possible_responses), solution) != len(solution):
        pass


def run_tests(n):
    possible_responses = [1, 2, 1, 2, 1, 2]
    solution = [0, 2, 0, 1, 1, 0]

    figure = pyplot.figure()

    # My strategy
    num_rounds_xur = []
    num_rounds_rand = []
    for _ in range(n):
        num_rounds_xur.append(play_game(possible_responses, solution, xurdones_strategy))
        num_rounds_rand.append(play_game(possible_responses, solution, random_strategy))
    print(Counter(num_rounds_xur))
    print(Counter(num_rounds_rand))

    ax1 = figure.add_subplot(211)
    ax1.hist(num_rounds_xur)

    ax2 = figure.add_subplot(212)
    ax2.hist(num_rounds_rand)
    pyplot.show()


if __name__ == '__main__':
    run_tests(1000000)
