import py2048
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
# define some parameters
epochs = 5000
input_size = 256
hidden_layer_1 = 450
hidden_layer_2 = 400
hidden_layer_3 = 200
output_size = 4
batch_size = 32
min_prob = 0.000001
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_layer_1)
self.fc2 = nn.Linear(hidden_layer_1, hidden_layer_2)
self.fc3 = nn.Linear(hidden_layer_2, hidden_layer_3)
self.fc4 = nn.Linear(hidden_layer_3, output_size)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
x = self.fc4(x)
return F.softmax(x, dim=0)
class PNAgent:
def __init__(self, learning_rate, gamma):
self.learning_rate = learning_rate
self.gamma = gamma
self.rewards = np.array([])
self.states = []
self.actions = []
self.probs = []
self.model = Net().double()
self.loss = torch.nn.NLLLoss()
self.opt = torch.optim.RMSprop(self.model.parameters(), lr=learning_rate)
def add_episode(self, states, actions, probs, rewards):
self.states.append(states)
self.actions.append(actions)
self.probs.append(probs)
self.rewards = np.append(self.rewards, rewards)
def act(self, state):
#print(state)
t = torch.from_numpy(self.encode(state.flatten())).type(torch.DoubleTensor)
state = torch.autograd.Variable(t, requires_grad=False)
out = self.model(state)
action = np.random.choice(4, 1, p=(out.detach().numpy()))[0]
return action, out
# taken from: https://g...content-available-to-author-only...b.com/karpathy/a4166c7fe253700972fcbc77e4ea32c5
def discount_rewards(self, rewards, size):
discounted_rewards = np.zeros_like(rewards)
running_gamma = 1
for t in reversed(range(0, rewards.size)):
discounted_rewards[t] = rewards[t] * running_gamma
running_gamma *= self.gamma
return discounted_rewards
def train(self):
R = np.vstack(self.rewards)
R -= np.mean(R)
R = R / np.std(R)
for i in range(len(self.states)):
r = torch.from_numpy(R[i]).type(torch.DoubleTensor)
for j in range(len(self.states[i])):
x = torch.from_numpy(self.encode(self.states[i][j].flatten())).type(torch.DoubleTensor)
state = torch.autograd.Variable(x, requires_grad=False)
action = torch.from_numpy(np.array(self.actions[i][j])).type(torch.DoubleTensor)
self.opt.zero_grad()
loss = torch.sum(action - (self.model(state)) * r)
#if j==0: print(loss)
loss.backward()
self.opt.step()
self.states, self.probs, self.actions, self.rewards = [], [], [], np.array([])
def load(self, name):
self.model.load_weights(name)
def save(self, name):
self.model.save_weights(name)
def encode(self, state):
new_state = np.array([])
for block in state:
one_hot = [0] * 16
if block != 0:
one_hot[int(block)] = 1
new_state = np.append(new_state, one_hot)
return new_state
if __name__ == "__main__":
agent = PNAgent(0.01, 0.95)
env = py2048.GameBoard(4, 4)
avg_scores = []
for epoch in range(epochs):
scores = []
for game in range(batch_size):
states = []
actions = []
probs = []
while True:
action, prob = agent.act(env.board)
state, _ , done = env.step(action)
states.append(state)
one_hot = [0] * 4
one_hot[action] = 1
actions.append(one_hot)
probs.append(prob)
if done:
scores.append(env.score)
#print("game:", game+1, "score:", env.score, "highest:", np.max(env.exponentiate()))
agent.add_episode(states, actions, probs, np.sum(env.exponentiate()))
env.reset()
break
#print(states, actions)
agent.train()
avg_scores.append(np.mean(scores))
print("epoch:", epoch, "mean_score:", avg_scores[-1])
plt.plot(avg_scores)
plt.xlabel("Batch")
plt.ylabel("Average terminal sum of tiles")
plt.show()
plt.savefig("avg_board.png")
aW1wb3J0IHB5MjA0OAppbXBvcnQgbnVtcHkgYXMgbnAKaW1wb3J0IG1hdGgKaW1wb3J0IHRvcmNoCmltcG9ydCB0b3JjaC5ubiBhcyBubgppbXBvcnQgdG9yY2gubm4uZnVuY3Rpb25hbCBhcyBGCmltcG9ydCBtYXRwbG90bGliLnB5cGxvdCBhcyBwbHQKCiMgZGVmaW5lIHNvbWUgcGFyYW1ldGVycwplcG9jaHMgPSA1MDAwCmlucHV0X3NpemUgPSAyNTYKaGlkZGVuX2xheWVyXzEgPSA0NTAKaGlkZGVuX2xheWVyXzIgPSA0MDAKaGlkZGVuX2xheWVyXzMgPSAyMDAKb3V0cHV0X3NpemUgPSA0CmJhdGNoX3NpemUgPSAzMgptaW5fcHJvYiA9IDAuMDAwMDAxCgoKY2xhc3MgTmV0KG5uLk1vZHVsZSk6CiAgICBkZWYgX19pbml0X18oc2VsZik6CiAgICAgICAgc3VwZXIoTmV0LCBzZWxmKS5fX2luaXRfXygpCiAgICAgICAgc2VsZi5mYzEgPSBubi5MaW5lYXIoaW5wdXRfc2l6ZSwgaGlkZGVuX2xheWVyXzEpCiAgICAgICAgc2VsZi5mYzIgPSBubi5MaW5lYXIoaGlkZGVuX2xheWVyXzEsIGhpZGRlbl9sYXllcl8yKQogICAgICAgIHNlbGYuZmMzID0gbm4uTGluZWFyKGhpZGRlbl9sYXllcl8yLCBoaWRkZW5fbGF5ZXJfMykKICAgICAgICBzZWxmLmZjNCA9IG5uLkxpbmVhcihoaWRkZW5fbGF5ZXJfMywgb3V0cHV0X3NpemUpCgoKICAgIGRlZiBmb3J3YXJkKHNlbGYsIHgpOgogICAgICAgIHggPSBGLnJlbHUoc2VsZi5mYzEoeCkpCiAgICAgICAgeCA9IEYucmVsdShzZWxmLmZjMih4KSkKICAgICAgICB4ID0gRi5yZWx1KHNlbGYuZmMzKHgpKQogICAgICAgIHggPSBzZWxmLmZjNCh4KQogICAgICAgIHJldHVybiBGLnNvZnRtYXgoeCwgZGltPTApCgoKY2xhc3MgUE5BZ2VudDoKICAgIGRlZiBfX2luaXRfXyhzZWxmLCBsZWFybmluZ19yYXRlLCBnYW1tYSk6CiAgICAgICAgc2VsZi5sZWFybmluZ19yYXRlID0gbGVhcm5pbmdfcmF0ZQogICAgICAgIHNlbGYuZ2FtbWEgPSBnYW1tYQogICAgICAgIHNlbGYucmV3YXJkcyA9IG5wLmFycmF5KFtdKQogICAgICAgIHNlbGYuc3RhdGVzID0gW10KICAgICAgICBzZWxmLmFjdGlvbnMgPSBbXQogICAgICAgIHNlbGYucHJvYnMgPSBbXQogICAgICAgIHNlbGYubW9kZWwgPSBOZXQoKS5kb3VibGUoKQogICAgICAgIHNlbGYubG9zcyA9IHRvcmNoLm5uLk5MTExvc3MoKQogICAgICAgIHNlbGYub3B0ID0gdG9yY2gub3B0aW0uUk1TcHJvcChzZWxmLm1vZGVsLnBhcmFtZXRlcnMoKSwgbHI9bGVhcm5pbmdfcmF0ZSkKCiAgICBkZWYgYWRkX2VwaXNvZGUoc2VsZiwgc3RhdGVzLCBhY3Rpb25zLCBwcm9icywgcmV3YXJkcyk6CiAgICAgICAgc2VsZi5zdGF0ZXMuYXBwZW5kKHN0YXRlcykKICAgICAgICBzZWxmLmFjdGlvbnMuYXBwZW5kKGFjdGlvbnMpCiAgICAgICAgc2VsZi5wcm9icy5hcHBlbmQocHJvYnMpCiAgICAgICAgc2VsZi5yZXdhcmRzID0gbnAuYXBwZW5kKHNlbGYucmV3YXJkcywgcmV3YXJkcykKCiAgICBkZWYgYWN0KHNlbGYsIHN0YXRlKToKICAgICAgICAjcHJpbnQoc3RhdGUpCiAgICAgICAgdCA9IHRvcmNoLmZyb21fbnVtcHkoc2VsZi5lbmNvZGUoc3RhdGUuZmxhdHRlbigpKSkudHlwZSh0b3JjaC5Eb3VibGVUZW5zb3IpCiAgICAgICAgc3RhdGUgPSB0b3JjaC5hdXRvZ3JhZC5WYXJpYWJsZSh0LCByZXF1aXJlc19ncmFkPUZhbHNlKQogICAgICAgIG91dCA9IHNlbGYubW9kZWwoc3RhdGUpCiAgICAgICAgYWN0aW9uID0gbnAucmFuZG9tLmNob2ljZSg0LCAxLCBwPShvdXQuZGV0YWNoKCkubnVtcHkoKSkpWzBdCiAgICAgICAgcmV0dXJuIGFjdGlvbiwgb3V0CgogICAgIyB0YWtlbiBmcm9tOiBodHRwczovL2cuLi5jb250ZW50LWF2YWlsYWJsZS10by1hdXRob3Itb25seS4uLmIuY29tL2thcnBhdGh5L2E0MTY2YzdmZTI1MzcwMDk3MmZjYmM3N2U0ZWEzMmM1CiAgICBkZWYgZGlzY291bnRfcmV3YXJkcyhzZWxmLCByZXdhcmRzLCBzaXplKToKICAgICAgICBkaXNjb3VudGVkX3Jld2FyZHMgPSBucC56ZXJvc19saWtlKHJld2FyZHMpCiAgICAgICAgcnVubmluZ19nYW1tYSA9IDEKICAgICAgICBmb3IgdCBpbiByZXZlcnNlZChyYW5nZSgwLCByZXdhcmRzLnNpemUpKToKICAgICAgICAgICAgZGlzY291bnRlZF9yZXdhcmRzW3RdID0gcmV3YXJkc1t0XSAqIHJ1bm5pbmdfZ2FtbWEKICAgICAgICAgICAgcnVubmluZ19nYW1tYSAqPSBzZWxmLmdhbW1hCiAgICAgICAgcmV0dXJuIGRpc2NvdW50ZWRfcmV3YXJkcwoKICAgIGRlZiB0cmFpbihzZWxmKToKICAgICAgICBSID0gbnAudnN0YWNrKHNlbGYucmV3YXJkcykKICAgICAgICBSIC09IG5wLm1lYW4oUikKICAgICAgICBSID0gUiAvIG5wLnN0ZChSKQoKICAgICAgICBmb3IgaSBpbiByYW5nZShsZW4oc2VsZi5zdGF0ZXMpKToKICAgICAgICAgICAgciA9IHRvcmNoLmZyb21fbnVtcHkoUltpXSkudHlwZSh0b3JjaC5Eb3VibGVUZW5zb3IpCiAgICAgICAgICAgIGZvciBqIGluIHJhbmdlKGxlbihzZWxmLnN0YXRlc1tpXSkpOgogICAgICAgICAgICAgICAgeCA9IHRvcmNoLmZyb21fbnVtcHkoc2VsZi5lbmNvZGUoc2VsZi5zdGF0ZXNbaV1bal0uZmxhdHRlbigpKSkudHlwZSh0b3JjaC5Eb3VibGVUZW5zb3IpCiAgICAgICAgICAgICAgICBzdGF0ZSA9IHRvcmNoLmF1dG9ncmFkLlZhcmlhYmxlKHgsIHJlcXVpcmVzX2dyYWQ9RmFsc2UpCiAgICAgICAgICAgICAgICBhY3Rpb24gPSB0b3JjaC5mcm9tX251bXB5KG5wLmFycmF5KHNlbGYuYWN0aW9uc1tpXVtqXSkpLnR5cGUodG9yY2guRG91YmxlVGVuc29yKQoKICAgICAgICAgICAgICAgIHNlbGYub3B0Lnplcm9fZ3JhZCgpCiAgICAgICAgICAgICAgICBsb3NzID0gdG9yY2guc3VtKGFjdGlvbiAtIChzZWxmLm1vZGVsKHN0YXRlKSkgKiByKQogICAgICAgICAgICAgICAgI2lmIGo9PTA6IHByaW50KGxvc3MpCiAgICAgICAgICAgICAgICBsb3NzLmJhY2t3YXJkKCkKICAgICAgICAgICAgICAgIHNlbGYub3B0LnN0ZXAoKQoKICAgICAgICBzZWxmLnN0YXRlcywgc2VsZi5wcm9icywgc2VsZi5hY3Rpb25zLCBzZWxmLnJld2FyZHMgPSBbXSwgW10sIFtdLCBucC5hcnJheShbXSkKCiAgICBkZWYgbG9hZChzZWxmLCBuYW1lKToKICAgICAgICBzZWxmLm1vZGVsLmxvYWRfd2VpZ2h0cyhuYW1lKQoKICAgIGRlZiBzYXZlKHNlbGYsIG5hbWUpOgogICAgICAgIHNlbGYubW9kZWwuc2F2ZV93ZWlnaHRzKG5hbWUpCgogICAgZGVmIGVuY29kZShzZWxmLCBzdGF0ZSk6CiAgICAgICAgbmV3X3N0YXRlID0gbnAuYXJyYXkoW10pCiAgICAgICAgZm9yIGJsb2NrIGluIHN0YXRlOgogICAgICAgICAgICBvbmVfaG90ID0gWzBdICogMTYKICAgICAgICAgICAgaWYgYmxvY2sgIT0gMDoKICAgICAgICAgICAgICAgIG9uZV9ob3RbaW50KGJsb2NrKV0gPSAxCiAgICAgICAgICAgIG5ld19zdGF0ZSA9IG5wLmFwcGVuZChuZXdfc3RhdGUsIG9uZV9ob3QpCiAgICAgICAgcmV0dXJuIG5ld19zdGF0ZQoKCmlmIF9fbmFtZV9fID09ICJfX21haW5fXyI6CiAgICBhZ2VudCA9IFBOQWdlbnQoMC4wMSwgMC45NSkKICAgIGVudiA9IHB5MjA0OC5HYW1lQm9hcmQoNCwgNCkKICAgIGF2Z19zY29yZXMgPSBbXQoKICAgIGZvciBlcG9jaCBpbiByYW5nZShlcG9jaHMpOgogICAgICAgIHNjb3JlcyA9IFtdCiAgICAgICAgZm9yIGdhbWUgaW4gcmFuZ2UoYmF0Y2hfc2l6ZSk6CiAgICAgICAgICAgIHN0YXRlcyA9IFtdCiAgICAgICAgICAgIGFjdGlvbnMgPSBbXQogICAgICAgICAgICBwcm9icyA9IFtdCgogICAgICAgICAgICB3aGlsZSBUcnVlOgogICAgICAgICAgICAgICAgYWN0aW9uLCBwcm9iID0gYWdlbnQuYWN0KGVudi5ib2FyZCkKICAgICAgICAgICAgICAgIHN0YXRlLCBfICwgZG9uZSA9IGVudi5zdGVwKGFjdGlvbikKICAgICAgICAgICAgICAgIHN0YXRlcy5hcHBlbmQoc3RhdGUpCiAgICAgICAgICAgICAgICBvbmVfaG90ID0gWzBdICogNAogICAgICAgICAgICAgICAgb25lX2hvdFthY3Rpb25dID0gMQogICAgICAgICAgICAgICAgYWN0aW9ucy5hcHBlbmQob25lX2hvdCkKICAgICAgICAgICAgICAgIHByb2JzLmFwcGVuZChwcm9iKQoKICAgICAgICAgICAgICAgIGlmIGRvbmU6CiAgICAgICAgICAgICAgICAgICAgc2NvcmVzLmFwcGVuZChlbnYuc2NvcmUpCiAgICAgICAgICAgICAgICAgICAgI3ByaW50KCJnYW1lOiIsIGdhbWUrMSwgInNjb3JlOiIsIGVudi5zY29yZSwgImhpZ2hlc3Q6IiwgbnAubWF4KGVudi5leHBvbmVudGlhdGUoKSkpCiAgICAgICAgICAgICAgICAgICAgYWdlbnQuYWRkX2VwaXNvZGUoc3RhdGVzLCBhY3Rpb25zLCBwcm9icywgbnAuc3VtKGVudi5leHBvbmVudGlhdGUoKSkpCiAgICAgICAgICAgICAgICAgICAgZW52LnJlc2V0KCkKICAgICAgICAgICAgICAgICAgICBicmVhawogICAgICAgICNwcmludChzdGF0ZXMsIGFjdGlvbnMpCiAgICAgICAgYWdlbnQudHJhaW4oKQogICAgICAgIGF2Z19zY29yZXMuYXBwZW5kKG5wLm1lYW4oc2NvcmVzKSkKICAgICAgICBwcmludCgiZXBvY2g6IiwgZXBvY2gsICJtZWFuX3Njb3JlOiIsIGF2Z19zY29yZXNbLTFdKQoKICAgIHBsdC5wbG90KGF2Z19zY29yZXMpCiAgICBwbHQueGxhYmVsKCJCYXRjaCIpCiAgICBwbHQueWxhYmVsKCJBdmVyYWdlIHRlcm1pbmFsIHN1bSBvZiB0aWxlcyIpCiAgICBwbHQuc2hvdygpCiAgICBwbHQuc2F2ZWZpZygiYXZnX2JvYXJkLnBuZyIpCg==