fork download
  1. import os, sys
  2. import time
  3. import math
  4. import glob
  5. import copy
  6. print("\nUsing torch-0.2.0_3-cp36")
  7. # Problem when using 0.3
  8. ''' Ignore Warnings '''
  9. import warnings
  10. warnings.filterwarnings("ignore")
  11.  
  12. import torch
  13. import torch.optim as O
  14. import torch.nn as nn
  15. from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
  16. from torch.autograd import Variable
  17. from torchtext import data
  18. from torchtext import datasets
  19.  
  20. import torch.optim as optim
  21.  
  22. import nltk
  23. import re
  24. from custom_snli_loader import CustomSNLI
  25. from enc_dec import EncDec
  26. from vnmt import VRAE_VNMT
  27.  
  28. import matplotlib
  29. # matplotlib.use('qt5agg')
  30. matplotlib.use('agg')
  31. import matplotlib.pyplot as plt
  32. import matplotlib.ticker as ticker
  33. import pickle
  34. from utils import get_args, makedirs, tokenize, load_dataset
  35. from gte import create_example, reverse_input, show_plot, plot_losses
  36.  
  37.  
  38. ##################################
  39. # Load the entailment only snli
  40. ##################################
  41. SOS_TOKEN = 2
  42. EOS_TOKEN = 1
  43. batch_size = 8#256#128
  44. max_seq_len = 110 #52#35
  45. vocab_size = 60000
  46. word_vectors = 'glove.6B.300d'
  47. vector_cache = os.path.join(os.getcwd(), '.vector_cache/input_vectors.pt')
  48. opt = get_args()
  49.  
  50. print("Batch Size : " + str(batch_size))
  51. inputs, train_iter, val_iter, test_iter = load_dataset(batch_size, max_seq_len, vocab_size, word_vectors, vector_cache)
  52. print("Dataset Loaded")
  53. config = opt
  54. d_embed = 300
  55. n_hid = 256 #512 # becuase we'll concat two hidden tensors later
  56. n_layers = 1 ## IMPORTANT
  57. dropout = 0.2 #todo: changed from 0.5
  58. model_name = 'vnmt'
  59. #rnn_type = 'LSTM'
  60. rnn_type = 'GRU'
  61. #dec_type = 'attn'
  62. dec_type = 'vanilla'
  63. config.n_embed = len(inputs.vocab)
  64. ntokens = len(inputs.vocab)
  65. gpu = 0
  66. gpu_ids = [2,0,3]
  67. torch.cuda.set_device(gpu)
  68. # torch.cuda.set_device(1)
  69. finetune = False
  70.  
  71.  
  72. ##################################
  73. # Load model
  74. ##################################
  75. print("Loading model now")
  76. model = VRAE_VNMT(rnn_type, d_embed, n_hid, config.n_embed, max_seq_len, n_layers=n_layers, dropout=dropout)#, word_dropout=0.5)
  77.  
  78. # model.cuda()
  79. model.encoder_prior.embeddings.weight.data = inputs.vocab.vectors
  80. model.encoder_post.embeddings.weight.data = inputs.vocab.vectors
  81. model.decoder.embeddings.weight.data = inputs.vocab.vectors
  82. model.encoder_prior.embeddings.weight.requires_grad = False
  83. model.encoder_post.embeddings.weight.requires_grad = False
  84. model.decoder.embeddings.weight.requires_grad = False
  85. print("model loaded")
  86.  
  87. if finetune:
  88. # Initialize enc/dec's weights with the pretrained model
  89. loaded_model = torch.load('vnmt_pretrain_gru_gte_best.pkl', map_location=lambda storage, locatoin: storage.cuda(gpu))
  90. print(loaded_model.encoder.hidden_dim)
  91.  
  92. #loaded_model = torch.load('vnmt_pretrain_3-gru_12162017/vnmt_pretrain_gru_gte_best.pkl', map_location=lambda storage, locatoin: storage.cuda(gpu))
  93. #model.encoder_prior = loaded_model.encoder
  94. model.decoder = loaded_model.decoder
  95. model.encoder_prior = copy.deepcopy(loaded_model.encoder)
  96. model.encoder_post = copy.deepcopy(loaded_model.encoder)
  97. print(model.encoder_prior.hidden_dim)#512
  98. print(model.encoder_post.hidden_dim)#512
  99. #loaded_reverse = torch.load('vnmt_pretrain_reverse_1-gru1e-3_12082017/vnmt_pretrain_reverse_gru_gte_e10.pkl', map_location=lambda storage, locatoin: storage.cuda(gpu))
  100. #model.encoder_post = loaded_reverse.encoder
  101.  
  102. model.encoder_prior.cuda()
  103. model.encoder_post.cuda()
  104. model.decoder.cuda()
  105. model.encoder_prior.embeddings.weight.requires_grad = False
  106. model.encoder_post.embeddings.weight.requires_grad = False
  107. model.decoder.embeddings.weight.requires_grad = False
  108.  
  109.  
  110. # setup optimizer
  111. lr = 1e-4#5e-5
  112. epochs = 50
  113. clip = 5.0
  114. log_interval = 50
  115. save_interval = 1
  116. model_parameters = filter(lambda p: p.requires_grad, model.parameters())
  117. optimizer = optim.Adam(model_parameters, lr=lr, betas=(0.9, 0.999))
  118.  
  119. ########################
  120. ## Multi GPU Support ##
  121. ########################
  122.  
  123. model = torch.nn.DataParallel(model,device_ids=gpu_ids) # added
  124. model.cuda()
  125. print("Model loaded int cuda")
  126.  
  127.  
  128. def evaluate(val_iter, model, n_tokens, eval_batch_size, kld_weight=1.0, wv=None):
  129. """
  130. Eval acc, bleu, etc.
  131. """
  132.  
  133. # Turn on evaluation mode which disables dropout.
  134. model.eval().cuda()
  135. model.module.encoder_prior.eval()
  136. model.module.encoder_post.eval()
  137. model.module.decoder.eval()
  138. total_loss = 0
  139. loss = 0
  140.  
  141. for batch_idx, batch in enumerate(val_iter):
  142. batch.premise.data = batch.premise.data.transpose(1,0)
  143. batch.hypothesis.data = batch.hypothesis.data.transpose(1,0)
  144. _loss, _kld = model.module.batchNLLLoss(batch.premise, batch.hypothesis, train=False)
  145. loss += _loss + kld_weight * _kld # full kld when evaluation?
  146.  
  147. return loss / float(len(val_iter))
  148.  
  149.  
  150.  
  151. def kld_coef(i, batch_size):
  152. #return (math.tanh((i - 17500)/1000) + 1)/2 # 700 minibatches * 25 epochs = 17500
  153. return (math.tanh( (i - int(3500/(batch_size/float(32))) ) / 1000) + 1)/2 # bs: 256 vs 32. 256/32=8. 3500/8 = 437.5
  154.  
  155.  
  156. def plot_vae_loss(nlls, klds, kld_weights, filename):
  157. plt.clf()
  158. plt.figure()
  159. fig, ax = plt.subplots()
  160. plt.plot(train_losses, label='train')
  161. plt.plot(val_losses, label='validation')
  162. ax.legend()
  163. plt.xlabel('Number of iterations')
  164. plt.ylabel('Negative log likelihood loss')
  165. plt.savefig(filename)
  166.  
  167. ''' For testing Purpose of the model'''
  168.  
  169. def test():
  170. loaded_model = torch.load('./vnmt_gru_gte_best.pkl',
  171. map_location=lambda storage, locatoin: storage.cuda(gpu))
  172. # print(loaded_model.encoder.hidden_dim)
  173. model.module.decoder = loaded_model.decoder
  174. model.module.encoder_prior = copy.deepcopy(loaded_model.encoder_prior)
  175. model.module.encoder_post = copy.deepcopy(loaded_model.encoder_post)
  176. ntokens = len(inputs.vocab)
  177. best_val_loss = float('inf')
  178.  
  179. sents = [
  180. 'People are celebrating a victory on the square.',
  181. 'Two women who just had lunch hugging and saying goodbye.',
  182. ]
  183. example0 = create_example(inputs, sents[0], max_seq_len)
  184. example1 = create_example(inputs, sents[1], max_seq_len)
  185. print(model.generate(inputs, ntokens, example0, max_seq_len))
  186. print(model.generate(inputs, ntokens, example1, max_seq_len))
  187.  
  188.  
  189. def train(pretrain=False, kld_annealing=True):
  190. DEBUG=False
  191. print('gte_vae.train')
  192. print('lr=%F'%lr)
  193.  
  194. # Turn on training mode which enables dropout.
  195. model.module.train()
  196. print("here")
  197. total_loss = 0
  198. total_acc = 0
  199. # for plotting
  200. train_losses = []
  201. val_losses = []
  202. kld_values = [] # unweighted values
  203. kld_weights = []
  204. nlls = []
  205.  
  206. ntokens = len(inputs.vocab)
  207. best_val_loss = float('inf')
  208.  
  209. sents = [
  210. 'People are celebrating a victory on the square.',
  211. 'Two women who just had lunch hugging and saying goodbye.',
  212. ]
  213.  
  214. iteration = 0
  215. if kld_annealing:
  216. kld_weight = kld_coef(iteration, batch_size)
  217. else:
  218. kld_weight = 1.0
  219. val_loss = evaluate(val_iter, model, ntokens, opt.batch_size, kld_weight=kld_weight)
  220. val_loss = val_loss.data[0]
  221.  
  222. print('kld_annealing:')
  223. print(kld_annealing)
  224. print('Eavluating...')
  225. print(val_loss)
  226. example0 = create_example(inputs, sents[0], max_seq_len)
  227. example1 = create_example(inputs, sents[1], max_seq_len)
  228. print(model.module.generate(inputs, ntokens, example0, max_seq_len))
  229. print(model.module.generate(inputs, ntokens, example1, max_seq_len))
  230.  
  231. start_time = time.time()
  232.  
  233.  
  234.  
  235. # plot / dump check before proceeding with training
  236. kld_stats = { 'nll': nlls, 'kld_values': kld_values, 'kld_weights': kld_weights }
  237. with open('kld_stats.pkl', 'wb') as f:
  238. pickle.dump(kld_stats, f, protocol=pickle.HIGHEST_PROTOCOL)
  239. plot_losses([0, 1, 2, 3, 4], 'train', 'train_loss.eps')
  240.  
  241.  
  242.  
  243. for epoch in range(epochs):
  244. train_iter.init_epoch()
  245. n_correct, n_total = 0, 0
  246. total_loss = 0
  247. train_loss = 0
  248.  
  249. print("Epoch : " + str(epoch +1 ))
  250. for batch_idx, batch in enumerate(train_iter):
  251. # Turn on training mode which enables dropout.
  252. model.module.train()
  253. model.module.encoder_prior.train()
  254. model.module.encoder_post.train()
  255. model.module.decoder.train()
  256. optimizer.zero_grad()
  257.  
  258. #print(batch.text.data.shape) # 35 x 64
  259. #batch.text.data = batch.text.data.view(-1, max_seq_len) # -1 instead of opt.batch_size to avoid reshaping err at the end of the epoch
  260. batch.premise.data = batch.premise.data.transpose(1,0) # should be 64x35 [batch_size x seq_len]
  261. batch.hypothesis.data = batch.hypothesis.data.transpose(1,0) # should be 64x35 [batch_size x seq_len]
  262. #nll, kld = model.batchNLLLoss(batch.premise, batch.hypothesis)
  263. nll, kld = model.module.batchNLLLoss(batch.premise, batch.hypothesis, train=True)
  264.  
  265. # KLD Cost Annealing
  266. # ref: https://a...content-available-to-author-only...v.org/pdf/1511.06349.pdf
  267. iteration += 1
  268. if kld_annealing:
  269. kld_weight = kld_coef(iteration, batch_size)
  270. else:
  271. kld_weight = 1.0
  272. loss = nll + kld_weight * kld
  273.  
  274. nlls.append(nll.data)
  275. kld_values.append(kld.data)
  276. kld_weights.append(kld_weight)
  277.  
  278. loss.backward()
  279. torch.nn.utils.clip_grad_norm(model.module.encoder_prior.parameters(), clip)
  280. torch.nn.utils.clip_grad_norm(model.module.encoder_post.parameters(), clip)
  281. torch.nn.utils.clip_grad_norm(model.module.decoder.parameters(), clip)
  282. #torch.nn.utils.clip_grad_norm(model.parameters(), clip)
  283. optimizer.step()
  284.  
  285.  
  286. batch_loss = loss.data
  287. total_loss += batch_loss
  288. train_loss += batch_loss
  289.  
  290. if batch_idx % log_interval == 0 and batch_idx > 0:
  291. print('iteration: %d' % iteration)
  292. print('kld_weight: %.16f' % kld_weight)
  293. print('nll: %.16f' % nll.data[0])
  294. print('kld_value: %.16f' % kld.data[0])
  295. cur_loss = total_loss[0] / log_interval
  296. elapsed = time.time() - start_time
  297. print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | '
  298. 'loss {:5.2f} | ppl {:8.2f}'.format(
  299. epoch, batch_idx, len(train_iter) // max_seq_len, lr,
  300. elapsed * 1000 / log_interval, cur_loss, 0))#math.exp(cur_loss)
  301. total_loss = 0
  302. start_time = time.time()
  303.  
  304. print('Evalating...')
  305. val_loss = evaluate(val_iter, model, ntokens, opt.batch_size, kld_weight=kld_weight)
  306. print(val_loss.data[0])
  307. print(model.module.generate(inputs, ntokens, example0, max_seq_len))
  308. print(model.module.generate(inputs, ntokens, example1, max_seq_len))
  309.  
  310. print(nlls[-1])
  311. print(kld_values[-1])
  312. print(kld_weights[-1])
  313. print('Epoch train loss:')
  314. print(train_loss[0])
  315. train_loss = train_loss / float(len(train_iter))
  316. print(train_loss[0])
  317. train_losses.append(train_loss[0])
  318.  
  319.  
  320. val_loss = evaluate(val_iter, model, ntokens, opt.batch_size, kld_weight=kld_weight)
  321. val_loss = val_loss.data[0]
  322. val_losses.append(val_loss)
  323. # Save the model if the validation loss is the best we've seen so far.
  324. if val_loss < best_val_loss:
  325. with open('%s_%s_gte_best.pkl'%(model_name, rnn_type.lower()), 'wb') as f:
  326. torch.save(model, f)
  327. best_val_loss = val_loss
  328. else:
  329. # Anneal the learning rate if no improvement has been seen in the validation dataset.
  330. #lr /= 4.0
  331. #print('lr annealed: %f'%lr)
  332. pass
  333. if epoch % save_interval == 0:
  334. with open('%s_%s_gte_e%d.pkl'%(model_name, rnn_type.lower(), epoch), 'wb') as f:
  335. torch.save(model, f)
  336.  
  337. # save train/val loss lists
  338. with open('train_losses.pkl', 'wb') as f:
  339. pickle.dump(train_losses, f, protocol=pickle.HIGHEST_PROTOCOL)
  340. with open('val_losses.pkl', 'wb') as f:
  341. pickle.dump(val_losses, f, protocol=pickle.HIGHEST_PROTOCOL)
  342. kld_stats = { 'nll': nlls, 'kld_values': kld_values, 'kld_weights': kld_weights }
  343. with open('kld_stats.pkl', 'wb') as f:
  344. pickle.dump(kld_stats, f, protocol=pickle.HIGHEST_PROTOCOL)
  345.  
  346.  
  347. plot_losses(train_losses, 'train', 'train_loss.eps')
  348. plot_losses(val_losses, 'validation', 'val_loss.eps')
  349. show_plot(train_losses, val_losses, 'train-val_loss.eps')
  350.  
  351.  
  352. print(train_losses)
  353. print(val_losses)
  354.  
  355. # save train/val loss lists
  356. with open('train_losses.pickle', 'wb') as f:
  357. pickle.dump(train_losses, f, protocol=pickle.HIGHEST_PROTOCOL)
  358. with open('val_losses.pickle', 'wb') as f:
  359. pickle.dump(val_losses, f, protocol=pickle.HIGHEST_PROTOCOL)
  360. show_plot(train_losses, val_losses, 'train-val_loss.eps')
  361.  
  362.  
  363.  
  364. if __name__ == "__main__":
  365. print('Training VRAE...')
  366. train(kld_annealing=False)
  367. #train(kld_annealing=True)
  368. #train()
  369.  
  370.  
Runtime error #stdin #stdout #stderr 0.05s 9716KB
stdin
Standard input is empty
stdout
Using torch-0.2.0_3-cp36
stderr
Traceback (most recent call last):
  File "./prog.py", line 12, in <module>
ImportError: No module named 'torch'