import py2048
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

# define some parameters
epochs = 5000
input_size = 256
hidden_layer_1 = 450
hidden_layer_2 = 400
hidden_layer_3 = 200
output_size = 4
batch_size = 32
min_prob = 0.000001


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_layer_1)
        self.fc2 = nn.Linear(hidden_layer_1, hidden_layer_2)
        self.fc3 = nn.Linear(hidden_layer_2, hidden_layer_3)
        self.fc4 = nn.Linear(hidden_layer_3, output_size)


    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return F.softmax(x, dim=0)


class PNAgent:
    def __init__(self, learning_rate, gamma):
        self.learning_rate = learning_rate
        self.gamma = gamma
        self.rewards = np.array([])
        self.states = []
        self.actions = []
        self.probs = []
        self.model = Net().double()
        self.loss = torch.nn.NLLLoss()
        self.opt = torch.optim.RMSprop(self.model.parameters(), lr=learning_rate)

    def add_episode(self, states, actions, probs, rewards):
        self.states.append(states)
        self.actions.append(actions)
        self.probs.append(probs)
        self.rewards = np.append(self.rewards, rewards)

    def act(self, state):
        #print(state)
        t = torch.from_numpy(self.encode(state.flatten())).type(torch.DoubleTensor)
        state = torch.autograd.Variable(t, requires_grad=False)
        out = self.model(state)
        action = np.random.choice(4, 1, p=(out.detach().numpy()))[0]
        return action, out

    # taken from: https://g...content-available-to-author-only...b.com/karpathy/a4166c7fe253700972fcbc77e4ea32c5
    def discount_rewards(self, rewards, size):
        discounted_rewards = np.zeros_like(rewards)
        running_gamma = 1
        for t in reversed(range(0, rewards.size)):
            discounted_rewards[t] = rewards[t] * running_gamma
            running_gamma *= self.gamma
        return discounted_rewards

    def train(self):
        R = np.vstack(self.rewards)
        R -= np.mean(R)
        R = R / np.std(R)

        for i in range(len(self.states)):
            r = torch.from_numpy(R[i]).type(torch.DoubleTensor)
            for j in range(len(self.states[i])):
                x = torch.from_numpy(self.encode(self.states[i][j].flatten())).type(torch.DoubleTensor)
                state = torch.autograd.Variable(x, requires_grad=False)
                action = torch.from_numpy(np.array(self.actions[i][j])).type(torch.DoubleTensor)

                self.opt.zero_grad()
                loss = torch.sum(action - (self.model(state)) * r)
                #if j==0: print(loss)
                loss.backward()
                self.opt.step()

        self.states, self.probs, self.actions, self.rewards = [], [], [], np.array([])

    def load(self, name):
        self.model.load_weights(name)

    def save(self, name):
        self.model.save_weights(name)

    def encode(self, state):
        new_state = np.array([])
        for block in state:
            one_hot = [0] * 16
            if block != 0:
                one_hot[int(block)] = 1
            new_state = np.append(new_state, one_hot)
        return new_state


if __name__ == "__main__":
    agent = PNAgent(0.01, 0.95)
    env = py2048.GameBoard(4, 4)
    avg_scores = []

    for epoch in range(epochs):
        scores = []
        for game in range(batch_size):
            states = []
            actions = []
            probs = []

            while True:
                action, prob = agent.act(env.board)
                state, _ , done = env.step(action)
                states.append(state)
                one_hot = [0] * 4
                one_hot[action] = 1
                actions.append(one_hot)
                probs.append(prob)

                if done:
                    scores.append(env.score)
                    #print("game:", game+1, "score:", env.score, "highest:", np.max(env.exponentiate()))
                    agent.add_episode(states, actions, probs, np.sum(env.exponentiate()))
                    env.reset()
                    break
        #print(states, actions)
        agent.train()
        avg_scores.append(np.mean(scores))
        print("epoch:", epoch, "mean_score:", avg_scores[-1])

    plt.plot(avg_scores)
    plt.xlabel("Batch")
    plt.ylabel("Average terminal sum of tiles")
    plt.show()
    plt.savefig("avg_board.png")
