fork download
  1. def map_fn(index, flags, train_dataset, dev_dataset):
  2. ## Setup
  3.  
  4. # Sets a common random seed - both for initialization and ensuring graph is the same
  5. torch.manual_seed(flags['seed'])
  6.  
  7. # Acquires the (unique) Cloud TPU core corresponding to this process's index
  8. device = xm.xla_device()
  9. print('DEIVCE: ', device) # <----- I ADD THIS
  10.  
  11.  
  12. ###### I REMOVE DOWNLOAD PART #####
  13. '''
  14. # Downloads train and test datasets
  15. # Note: master goes first and downloads the dataset only once (xm.rendezvous)
  16. # all the other workers wait for the master to be done downloading.
  17.  
  18. if not xm.is_master_ordinal():
  19. xm.rendezvous('download_only_once')
  20.  
  21. train_dataset = datasets.FashionMNIST(
  22. "/tmp/fashionmnist",
  23. train=True,
  24. download=True,
  25. transform=my_transform)
  26.  
  27. test_dataset = datasets.FashionMNIST(
  28. "/tmp/fashionmnist",
  29. train=False,
  30. download=True,
  31. transform=my_transform)
  32.  
  33. if xm.is_master_ordinal():
  34. xm.rendezvous('download_only_once')
  35. '''
  36.  
  37.  
  38. ### THIS PART I DONT CHANGE ANYTHING except test_... to dev_...
  39. ## Dataloader construction
  40.  
  41. # Creates the (distributed) train sampler, which let this process only access
  42. # its portion of the training dataset.
  43. train_sampler = torch.utils.data.distributed.DistributedSampler(
  44. train_dataset,
  45. num_replicas=xm.xrt_world_size(),
  46. rank=xm.get_ordinal(),
  47. shuffle=True)
  48.  
  49. dev_sampler = torch.utils.data.distributed.DistributedSampler(
  50. dev_dataset,
  51. num_replicas=xm.xrt_world_size(),
  52. rank=xm.get_ordinal(),
  53. shuffle=False)
  54.  
  55. # Creates dataloaders, which load data in batches
  56. # Note: dev loader is not shuffled or sampled
  57. train_loader = torch.utils.data.DataLoader(
  58. train_dataset,
  59. batch_size=flags['batch_size'],
  60. sampler=train_sampler,
  61. num_workers=flags['num_workers'],
  62. drop_last=True)
  63.  
  64. dev_loader = torch.utils.data.DataLoader(
  65. dev_dataset,
  66. batch_size=flags['batch_size'],
  67. sampler=dev_sampler,
  68. shuffle=False,
  69. num_workers=flags['num_workers'],
  70. drop_last=True)
  71.  
  72.  
  73. ##### I CHANGED TO MY CODE
  74. ## Network, optimizer, and loss function creation
  75.  
  76. model = MAINModel(flags).to(device) <----- CHANGED TO MY MODEL
  77. model.train() <----- CHANGED TO MY MODEL
  78.  
  79. loss_fn = nn.CrossEntropyLoss()
  80.  
  81. if True:
  82. bert_model_1_params = model.bert_model_1.parameters()
  83. bert_model_1_optimizer = AdamW(bert_model_1_params,...)
  84.  
  85. if True:
  86. bert_model_2_params = model.bert_model_2.parameters()
  87. bert_model_2_optimizer = AdamW(bert_model_2_params,...)
  88.  
  89. ## linear
  90. linear_params = model.linear.parameters()
  91. linear_optimizer = torch.optim.SGD(linear_params,...)
  92.  
  93.  
  94.  
  95.  
  96. ## Trains
  97. train_start = time.time()
  98. for epoch in range(flags['total_epochs']):
  99. para_train_loader = pl.ParallelLoader(train_loader, [device]).per_device_loader(device)
  100. for batch_num, batch in enumerate(para_train_loader):
  101.  
  102. data, targets = batch[:-1], batch[-1] <----- MY CODE
  103.  
  104. # Acquires the network's best guesses at each class
  105. output = model(data)
  106.  
  107. # Computes loss
  108. loss = loss_fn(output, targets)
  109.  
  110. # Updates model
  111. if True:
  112. bert_model_1.zero_grad()
  113.  
  114. if True:
  115. bert_model_2.zero_grad()
  116.  
  117. linear_optimizer.zero_grad()
  118.  
  119. loss.backward()
  120.  
  121.  
  122. # I HAVE THREE OPTIMIZER INSTEAD OF 1
  123. if True:
  124. xm.optimizer_step(bert_model_1_optimizer)
  125.  
  126. if True:
  127. xm.optimizer_step(bert_model_2_optimizer)
  128.  
  129. xm.optimizer_step(linear_optimizer)
  130.  
  131.  
  132. # Below code has no import change
  133.  
  134. elapsed_train_time = time.time() - train_start
  135. print("Process", index, "finished training. Train time was:", elapsed_train_time)
  136.  
  137.  
  138. ## Evaluation
  139. # Sets net to eval and no grad context
  140. model.eval()
  141. eval_start = time.time()
  142. with torch.no_grad():
  143. num_correct = 0
  144. total_guesses = 0
  145.  
  146. para_train_loader = pl.ParallelLoader(dev_loader, [device]).per_device_loader(device)
  147. for batch_num, batch in enumerate(para_train_loader):
  148. data, targets = batch[:-1], batch[-1]
  149.  
  150. # Acquires the network's best guesses at each class
  151. output = rec_model(data)
  152. best_guesses = torch.argmax(output, 1)
  153.  
  154. # Updates running statistics
  155. num_correct += torch.eq(targets, best_guesses).sum().item()
  156. total_guesses += flags['batch_size']
  157.  
  158. elapsed_eval_time = time.time() - eval_start
  159. print("Process", index, "finished evaluation. Evaluation time was:", elapsed_eval_time)
  160. print("Process", index, "guessed", num_correct, "of", total_guesses, "correctly for", num_correct/total_guesses * 100, "% accuracy.")
Runtime error #stdin #stdout #stderr 0.12s 23564KB
stdin
Standard input is empty
stdout
Standard output is empty
stderr
  File "./prog.py", line 76
    model = MAINModel(flags).to(device)   <----- CHANGED TO MY MODEL
                                                          ^
SyntaxError: invalid syntax