import os, math, random, time

random.seed(42)

# Data
if not os.path.exists("input.txt"):
  import urllib.request
  urllib.request.urlretrieve("https://r...content-available-to-author-only...t.com/karpathy/makemore/refs/heads/master/names.txt", "input.txt")
docs = [l.strip() for l in open("input.txt").read().strip().split("\n") if l.strip()]
random.shuffle(docs)
uchars = sorted(set("".join(docs)))
BOS = len(uchars)
vocab_size = len(uchars) + 1

# Vector helpers
vadd = lambda a, b: [ai + bi for ai, bi in zip(a, b)]
vdot = lambda a, b: sum(ai * bi for ai, bi in zip(a, b))
vscale = lambda a, s: [ai * s for ai in a]
vaccum = lambda dst, src: [dst.__setitem__(j, dst[j] + src[j]) for j in range(len(src))]

# Config & params
n_embd, n_head, n_layer, block_size = 16, 4, 1, 8
head_dim = n_embd // n_head
mat = lambda r, c, s=0.02: [[random.gauss(0, s) for _ in range(c)] for _ in range(r)]
dlike = lambda d: {k: [[0.0]*len(r) for r in m] for k, m in d.items()}
state_dict = {"wte": mat(vocab_size, n_embd), "wpe": mat(block_size, n_embd), "lm_head": mat(vocab_size, n_embd)}
for i in range(n_layer):
  state_dict |= {f"layer{i}.attn_wq": mat(n_embd, n_embd), f"layer{i}.attn_wk": mat(n_embd, n_embd),
                 f"layer{i}.attn_wv": mat(n_embd, n_embd), f"layer{i}.attn_wo": mat(n_embd, n_embd, 0),
                 f"layer{i}.mlp_fc1": mat(4 * n_embd, n_embd), f"layer{i}.mlp_fc2": mat(n_embd, 4 * n_embd, 0)}

# Ops
def linear(x, w): return [vdot(wr, x) for wr in w]

def softmax(z):
  mx = max(z); e = [math.exp(v - mx) for v in z]; s = sum(e)
  return vscale(e, 1 / s)

def rmsnorm(x):
  inv = (sum(v*v for v in x) / len(x) + 1e-5) ** -0.5
  return vscale(x, inv), inv

def rmsnorm_b(dy, y, inv):
  s = sum(dy[j] * y[j] for j in range(len(y)))
  return [inv * (dy[j] - y[j] * s / len(y)) for j in range(len(y))]

def linear_b(dy, w, x, dw):
  for j in range(len(dy)):
    for k in range(len(x)):
      dw[j][k] += dy[j] * x[k]
  return [sum(dy[j] * w[j][k] for j in range(len(dy))) for k in range(len(x))]

def softmax_b(dp, p):
  s = sum(di * pi for di, pi in zip(dp, p))
  return [p[i] * (dp[i] - s) for i in range(len(p))]

# Forward pass (single position, shared by training and inference)
def forward_pos(tok, pos, keys, values, save=False):
  x0, inv0 = rmsnorm(vadd(state_dict['wte'][tok], state_dict['wpe'][pos]))
  x, layers = x0, []
  for li in range(n_layer):
    xn, ainv = rmsnorm(x)
    q = linear(xn, state_dict[f'layer{li}.attn_wq'])
    keys[li].append(linear(xn, state_dict[f'layer{li}.attn_wk']))
    values[li].append(linear(xn, state_dict[f'layer{li}.attn_wv']))
    x_attn, attn_weights, n_ctx = [0.0] * n_embd, [], len(keys[li])
    for h in range(n_head):
      hs = h * head_dim
      q_h = q[hs:hs + head_dim]
      k_h = [keys[li][t][hs:hs + head_dim] for t in range(n_ctx)]
      v_h = [values[li][t][hs:hs + head_dim] for t in range(n_ctx)]
      attn_logits = [sum(q_h[j] * k_h[t][j] for j in range(head_dim)) / head_dim**0.5 for t in range(n_ctx)]
      attn_w = softmax(attn_logits)
      attn_weights.append(attn_w)
      for j in range(head_dim):
        x_attn[hs + j] = sum(attn_w[t] * v_h[t][j] for t in range(n_ctx))
    x = vadd(linear(x_attn, state_dict[f'layer{li}.attn_wo']), x)
    x_norm, norm_inv = rmsnorm(x)
    mlp_h1 = linear(x_norm, state_dict[f'layer{li}.mlp_fc1'])
    mlp_h1_act = [max(0.0, v) ** 2 for v in mlp_h1]
    x = vadd(linear(mlp_h1_act, state_dict[f'layer{li}.mlp_fc2']), x)
    if save:
      layers.append((xn, ainv, q, x_attn, attn_weights, x_norm, norm_inv, mlp_h1, mlp_h1_act))
  return (linear(x, state_dict['lm_head']), (tok, x0, inv0, layers, x)) if save else linear(x, state_dict['lm_head'])

# Train (Adam)
learning_rate, beta1, beta2, eps_adam = 1e-2, 0.9, 0.95, 1e-8
num_steps = 5000
m_state, v_state = dlike(state_dict), dlike(state_dict)

t0 = time.time()
for step in range(num_steps):
  doc = docs[step % len(docs)]
  tokens = [BOS] + [uchars.index(c) for c in doc] + [BOS]
  n = min(block_size, len(tokens) - 1)

  # Forward
  keys, values, saved = [[] for _ in range(n_layer)], [[] for _ in range(n_layer)], []
  loss = 0.0
  for pos in range(n):
    logits, state = forward_pos(tokens[pos], pos, keys, values, save=True)
    probs = softmax(logits)
    loss -= math.log(probs[tokens[pos + 1]]) / n
    saved.append((*state, probs))

  # Backward
  dstate = dlike(state_dict)
  dkeys = [[[0.0] * n_embd for _ in range(n)] for _ in range(n_layer)]
  dvalues = [[[0.0] * n_embd for _ in range(n)] for _ in range(n_layer)]
  for pos in range(n - 1, -1, -1):
    tok, x0, inv0, layers, xf, probs = saved[pos]
    target_id = tokens[pos + 1]
    dx = linear_b([(probs[j] - (j == target_id)) / n for j in range(vocab_size)], state_dict['lm_head'], xf, dstate['lm_head'])
    for li in range(n_layer - 1, -1, -1):
      xn, ainv, q, x_attn, attn_weights, x_norm, norm_inv, mlp_h1, mlp_h1_act = layers[li]
      dmlp_h1_act = linear_b(dx, state_dict[f'layer{li}.mlp_fc2'], mlp_h1_act, dstate[f'layer{li}.mlp_fc2'])
      dxn = linear_b([dmlp_h1_act[j] * 2 * mlp_h1[j] if mlp_h1[j] > 0 else 0 for j in range(4 * n_embd)],
                     state_dict[f'layer{li}.mlp_fc1'], x_norm, dstate[f'layer{li}.mlp_fc1'])
      dx = vadd(rmsnorm_b(dxn, x_norm, norm_inv), dx)
      dx_attn = linear_b(dx, state_dict[f'layer{li}.attn_wo'], x_attn, dstate[f'layer{li}.attn_wo'])
      dq = [0.0] * n_embd
      for h in range(n_head):
        hs = h * head_dim
        q_h, attn_w = q[hs:hs + head_dim], attn_weights[h]
        k_h = [keys[li][t][hs:hs + head_dim] for t in range(pos + 1)]
        v_h = [values[li][t][hs:hs + head_dim] for t in range(pos + 1)]
        dattn = softmax_b([sum(dx_attn[hs + j] * v_h[t][j] for j in range(head_dim)) for t in range(pos + 1)], attn_w)
        for t in range(pos + 1):
          c = dattn[t] / head_dim**0.5
          for j in range(head_dim):
            dvalues[li][t][hs + j] += dx_attn[hs + j] * attn_w[t]
            dq[hs + j] += c * k_h[t][j]
            dkeys[li][t][hs + j] += c * q_h[j]
      dxn = linear_b(dq, state_dict[f'layer{li}.attn_wq'], xn, dstate[f'layer{li}.attn_wq'])
      dxn = vadd(dxn, linear_b(dkeys[li][pos], state_dict[f'layer{li}.attn_wk'], xn, dstate[f'layer{li}.attn_wk']))
      dxn = vadd(dxn, linear_b(dvalues[li][pos], state_dict[f'layer{li}.attn_wv'], xn, dstate[f'layer{li}.attn_wv']))
      dx = vadd(rmsnorm_b(dxn, xn, ainv), dx)
    demb = rmsnorm_b(dx, x0, inv0)
    vaccum(dstate['wte'][tok], demb)
    vaccum(dstate['wpe'][pos], demb)

  # Adam
  lr_t = learning_rate * 0.5 * (1 + math.cos(math.pi * step / num_steps))
  m_hat_corr, v_hat_corr = 1 - beta1**(step+1), 1 - beta2**(step+1)
  for k in state_dict:
    for i, row in enumerate(state_dict[k]):
      for j in range(len(row)):
        g = dstate[k][i][j]
        m_state[k][i][j] = beta1 * m_state[k][i][j] + (1 - beta1) * g
        v_state[k][i][j] = beta2 * v_state[k][i][j] + (1 - beta2) * g**2
        state_dict[k][i][j] -= lr_t * (m_state[k][i][j] / m_hat_corr) / ((v_state[k][i][j] / v_hat_corr)**0.5 + eps_adam)
  print(f"step {step+1}/{num_steps}  loss: {loss:.4f}")

print(f"\nTotal training time: {time.time() - t0:.2f}s")

# Inference
temperature = 2.0
for sample_idx in range(20):
  keys, values = [[] for _ in range(n_layer)], [[] for _ in range(n_layer)]
  token_id, sample = BOS, []
  for pos in range(block_size):
    probs = softmax(vscale(forward_pos(token_id, pos, keys, values), temperature))
    token_id = random.choices(range(vocab_size), weights=probs)[0]
    if token_id == BOS:
      break
    sample.append(uchars[token_id])
  print(f"{sample_idx+1}: {''.join(sample)}")
# your code goes here