def map_fn(index, flags, train_dataset, dev_dataset):
## Setup
# Sets a common random seed - both for initialization and ensuring graph is the same
torch.manual_seed(flags['seed'])
# Acquires the (unique) Cloud TPU core corresponding to this process's index
device = xm.xla_device()
print('DEIVCE: ', device) # <----- I ADD THIS
###### I REMOVE DOWNLOAD PART #####
'''
# Downloads train and test datasets
# Note: master goes first and downloads the dataset only once (xm.rendezvous)
# all the other workers wait for the master to be done downloading.
if not xm.is_master_ordinal():
xm.rendezvous('download_only_once')
train_dataset = datasets.FashionMNIST(
"/tmp/fashionmnist",
train=True,
download=True,
transform=my_transform)
test_dataset = datasets.FashionMNIST(
"/tmp/fashionmnist",
train=False,
download=True,
transform=my_transform)
if xm.is_master_ordinal():
xm.rendezvous('download_only_once')
'''
### THIS PART I DONT CHANGE ANYTHING except test_... to dev_...
## Dataloader construction
# Creates the (distributed) train sampler, which let this process only access
# its portion of the training dataset.
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset,
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal(),
shuffle=True)
dev_sampler = torch.utils.data.distributed.DistributedSampler(
dev_dataset,
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal(),
shuffle=False)
# Creates dataloaders, which load data in batches
# Note: dev loader is not shuffled or sampled
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=flags['batch_size'],
sampler=train_sampler,
num_workers=flags['num_workers'],
drop_last=True)
dev_loader = torch.utils.data.DataLoader(
dev_dataset,
batch_size=flags['batch_size'],
sampler=dev_sampler,
shuffle=False,
num_workers=flags['num_workers'],
drop_last=True)
##### I CHANGED TO MY CODE
## Network, optimizer, and loss function creation
model = MAINModel(flags).to(device) <----- CHANGED TO MY MODEL
model.train() <----- CHANGED TO MY MODEL
loss_fn = nn.CrossEntropyLoss()
if True:
bert_model_1_params = model.bert_model_1.parameters()
bert_model_1_optimizer = AdamW(bert_model_1_params,...)
if True:
bert_model_2_params = model.bert_model_2.parameters()
bert_model_2_optimizer = AdamW(bert_model_2_params,...)
## linear
linear_params = model.linear.parameters()
linear_optimizer = torch.optim.SGD(linear_params,...)
## Trains
for epoch in range(flags['total_epochs']):
para_train_loader = pl.ParallelLoader(train_loader, [device]).per_device_loader(device)
for batch_num, batch in enumerate(para_train_loader):
data, targets = batch[:-1], batch[-1] <----- MY CODE
# Acquires the network's best guesses at each class
output = model(data)
# Computes loss
loss = loss_fn(output, targets)
# Updates model
if True:
bert_model_1.zero_grad()
if True:
bert_model_2.zero_grad()
linear_optimizer.zero_grad()
loss.backward()
# I HAVE THREE OPTIMIZER INSTEAD OF 1
if True:
xm.optimizer_step(bert_model_1_optimizer)
if True:
xm.optimizer_step(bert_model_2_optimizer)
xm.optimizer_step(linear_optimizer)
# Below code has no import change
elapsed_train_time
= time.
time() - train_start
print("Process", index, "finished training. Train time was:", elapsed_train_time)
## Evaluation
# Sets net to eval and no grad context
model.eval()
with torch.no_grad():
num_correct = 0
total_guesses = 0
para_train_loader = pl.ParallelLoader(dev_loader, [device]).per_device_loader(device)
for batch_num, batch in enumerate(para_train_loader):
data, targets = batch[:-1], batch[-1]
# Acquires the network's best guesses at each class
output = rec_model(data)
best_guesses = torch.argmax(output, 1)
# Updates running statistics
num_correct += torch.eq(targets, best_guesses).sum().item()
total_guesses += flags['batch_size']
elapsed_eval_time
= time.
time() - eval_start
print("Process", index, "finished evaluation. Evaluation time was:", elapsed_eval_time)
print("Process", index, "guessed", num_correct, "of", total_guesses, "correctly for", num_correct/total_guesses * 100, "% accuracy.")
ZGVmIG1hcF9mbihpbmRleCwgZmxhZ3MsIHRyYWluX2RhdGFzZXQsIGRldl9kYXRhc2V0KToKICAgICMjIFNldHVwIAoKICAgICMgU2V0cyBhIGNvbW1vbiByYW5kb20gc2VlZCAtIGJvdGggZm9yIGluaXRpYWxpemF0aW9uIGFuZCBlbnN1cmluZyBncmFwaCBpcyB0aGUgc2FtZQogICAgdG9yY2gubWFudWFsX3NlZWQoZmxhZ3NbJ3NlZWQnXSkKCiAgICAjIEFjcXVpcmVzIHRoZSAodW5pcXVlKSBDbG91ZCBUUFUgY29yZSBjb3JyZXNwb25kaW5nIHRvIHRoaXMgcHJvY2VzcydzIGluZGV4CiAgICBkZXZpY2UgPSB4bS54bGFfZGV2aWNlKCkKICAgIHByaW50KCdERUlWQ0U6ICcsIGRldmljZSkgICMgPC0tLS0tIEkgQUREIFRISVMKCiAgICAKICAgICMjIyMjIyBJIFJFTU9WRSBET1dOTE9BRCBQQVJUICAjIyMjIwogICAgJycnCiAgICAjIERvd25sb2FkcyB0cmFpbiBhbmQgdGVzdCBkYXRhc2V0cwogICAgIyBOb3RlOiBtYXN0ZXIgZ29lcyBmaXJzdCBhbmQgZG93bmxvYWRzIHRoZSBkYXRhc2V0IG9ubHkgb25jZSAoeG0ucmVuZGV6dm91cykKICAgICMgICBhbGwgdGhlIG90aGVyIHdvcmtlcnMgd2FpdCBmb3IgdGhlIG1hc3RlciB0byBiZSBkb25lIGRvd25sb2FkaW5nLgoKICAgIGlmIG5vdCB4bS5pc19tYXN0ZXJfb3JkaW5hbCgpOgoJCQl4bS5yZW5kZXp2b3VzKCdkb3dubG9hZF9vbmx5X29uY2UnKQoKICAgIHRyYWluX2RhdGFzZXQgPSBkYXRhc2V0cy5GYXNoaW9uTU5JU1QoCiAgICAgICIvdG1wL2Zhc2hpb25tbmlzdCIsCiAgICAgIHRyYWluPVRydWUsCiAgICAgIGRvd25sb2FkPVRydWUsCiAgICAgIHRyYW5zZm9ybT1teV90cmFuc2Zvcm0pCgogICAgdGVzdF9kYXRhc2V0ID0gZGF0YXNldHMuRmFzaGlvbk1OSVNUKAogICAgICAiL3RtcC9mYXNoaW9ubW5pc3QiLAogICAgICB0cmFpbj1GYWxzZSwKICAgICAgZG93bmxvYWQ9VHJ1ZSwKICAgICAgdHJhbnNmb3JtPW15X3RyYW5zZm9ybSkKICAKICAgIGlmIHhtLmlzX21hc3Rlcl9vcmRpbmFsKCk6CiAgICAgIHhtLnJlbmRlenZvdXMoJ2Rvd25sb2FkX29ubHlfb25jZScpCiAgICAnJycKICAgIAogICAgCiAgICAjIyMgVEhJUyBQQVJUIEkgRE9OVCBDSEFOR0UgQU5ZVEhJTkcgZXhjZXB0IHRlc3RfLi4uIHRvIGRldl8uLi4KICAgICMjIERhdGFsb2FkZXIgY29uc3RydWN0aW9uCiAgCiAgICAjIENyZWF0ZXMgdGhlIChkaXN0cmlidXRlZCkgdHJhaW4gc2FtcGxlciwgd2hpY2ggbGV0IHRoaXMgcHJvY2VzcyBvbmx5IGFjY2VzcwogICAgIyBpdHMgcG9ydGlvbiBvZiB0aGUgdHJhaW5pbmcgZGF0YXNldC4KICAgIHRyYWluX3NhbXBsZXIgPSB0b3JjaC51dGlscy5kYXRhLmRpc3RyaWJ1dGVkLkRpc3RyaWJ1dGVkU2FtcGxlcigKICAgICAgICB0cmFpbl9kYXRhc2V0LAogICAgICAgIG51bV9yZXBsaWNhcz14bS54cnRfd29ybGRfc2l6ZSgpLAogICAgICAgIHJhbms9eG0uZ2V0X29yZGluYWwoKSwKICAgICAgICBzaHVmZmxlPVRydWUpCiAgCiAgICBkZXZfc2FtcGxlciA9IHRvcmNoLnV0aWxzLmRhdGEuZGlzdHJpYnV0ZWQuRGlzdHJpYnV0ZWRTYW1wbGVyKAogICAgICAgIGRldl9kYXRhc2V0LAogICAgICAgIG51bV9yZXBsaWNhcz14bS54cnRfd29ybGRfc2l6ZSgpLAogICAgICAgIHJhbms9eG0uZ2V0X29yZGluYWwoKSwKICAgICAgICBzaHVmZmxlPUZhbHNlKQogIAogICMgQ3JlYXRlcyBkYXRhbG9hZGVycywgd2hpY2ggbG9hZCBkYXRhIGluIGJhdGNoZXMKICAjIE5vdGU6IGRldiBsb2FkZXIgaXMgbm90IHNodWZmbGVkIG9yIHNhbXBsZWQKICAgIHRyYWluX2xvYWRlciA9IHRvcmNoLnV0aWxzLmRhdGEuRGF0YUxvYWRlcigKICAgICAgICB0cmFpbl9kYXRhc2V0LAogICAgICAgIGJhdGNoX3NpemU9ZmxhZ3NbJ2JhdGNoX3NpemUnXSwKICAgICAgICBzYW1wbGVyPXRyYWluX3NhbXBsZXIsCiAgICAgICAgbnVtX3dvcmtlcnM9ZmxhZ3NbJ251bV93b3JrZXJzJ10sCiAgICAgICAgZHJvcF9sYXN0PVRydWUpCgogICAgZGV2X2xvYWRlciA9IHRvcmNoLnV0aWxzLmRhdGEuRGF0YUxvYWRlcigKICAgICAgICBkZXZfZGF0YXNldCwKICAgICAgICBiYXRjaF9zaXplPWZsYWdzWydiYXRjaF9zaXplJ10sCiAgICAgICAgc2FtcGxlcj1kZXZfc2FtcGxlciwKICAgICAgICBzaHVmZmxlPUZhbHNlLAogICAgICAgIG51bV93b3JrZXJzPWZsYWdzWydudW1fd29ya2VycyddLAogICAgICAgIGRyb3BfbGFzdD1UcnVlKQogIAoJCQogICAgIyMjIyMgSSBDSEFOR0VEIFRPIE1ZIENPREUKICAgICMjIE5ldHdvcmssIG9wdGltaXplciwgYW5kIGxvc3MgZnVuY3Rpb24gY3JlYXRpb24KCiAgICBtb2RlbCA9IE1BSU5Nb2RlbChmbGFncykudG8oZGV2aWNlKSAgIDwtLS0tLSBDSEFOR0VEIFRPIE1ZIE1PREVMCiAgICBtb2RlbC50cmFpbigpICA8LS0tLS0gQ0hBTkdFRCBUTyBNWSBNT0RFTAoKICAgIGxvc3NfZm4gPSBubi5Dcm9zc0VudHJvcHlMb3NzKCkKCiAgICBpZiBUcnVlOgogICAgICAgIGJlcnRfbW9kZWxfMV9wYXJhbXMgPSBtb2RlbC5iZXJ0X21vZGVsXzEucGFyYW1ldGVycygpCiAgICAgICAgYmVydF9tb2RlbF8xX29wdGltaXplciA9IEFkYW1XKGJlcnRfbW9kZWxfMV9wYXJhbXMsLi4uKQogICAgICAgIAogICAgaWYgVHJ1ZToKICAgICAgICBiZXJ0X21vZGVsXzJfcGFyYW1zID0gbW9kZWwuYmVydF9tb2RlbF8yLnBhcmFtZXRlcnMoKQogICAgICAgIGJlcnRfbW9kZWxfMl9vcHRpbWl6ZXIgPSBBZGFtVyhiZXJ0X21vZGVsXzJfcGFyYW1zLC4uLikKICAgIAogICAgIyMgbGluZWFyCiAgICBsaW5lYXJfcGFyYW1zID0gbW9kZWwubGluZWFyLnBhcmFtZXRlcnMoKQogICAgbGluZWFyX29wdGltaXplciA9IHRvcmNoLm9wdGltLlNHRChsaW5lYXJfcGFyYW1zLC4uLikKCiAgCgoKICAgICMjIFRyYWlucwogICAgdHJhaW5fc3RhcnQgPSB0aW1lLnRpbWUoKQogICAgZm9yIGVwb2NoIGluIHJhbmdlKGZsYWdzWyd0b3RhbF9lcG9jaHMnXSk6CiAgICAgICAgcGFyYV90cmFpbl9sb2FkZXIgPSBwbC5QYXJhbGxlbExvYWRlcih0cmFpbl9sb2FkZXIsIFtkZXZpY2VdKS5wZXJfZGV2aWNlX2xvYWRlcihkZXZpY2UpCiAgICAgICAgZm9yIGJhdGNoX251bSwgYmF0Y2ggaW4gZW51bWVyYXRlKHBhcmFfdHJhaW5fbG9hZGVyKToKCiAgICAgICAgICAgIGRhdGEsIHRhcmdldHMgPSBiYXRjaFs6LTFdLCBiYXRjaFstMV0gICAgPC0tLS0tIE1ZIENPREUKCiAgICAgICAgICAgICMgQWNxdWlyZXMgdGhlIG5ldHdvcmsncyBiZXN0IGd1ZXNzZXMgYXQgZWFjaCBjbGFzcwogICAgICAgICAgICBvdXRwdXQgPSBtb2RlbChkYXRhKQoKICAgICAgICAgICAgIyBDb21wdXRlcyBsb3NzCiAgICAgICAgICAgIGxvc3MgPSBsb3NzX2ZuKG91dHB1dCwgdGFyZ2V0cykKCiAgICAgICAgICAgICMgVXBkYXRlcyBtb2RlbAogICAgICAgICAgICBpZiBUcnVlOgogICAgICAgICAgICAgICAgYmVydF9tb2RlbF8xLnplcm9fZ3JhZCgpCiAgICAgICAgCiAgICAgICAgICAgIGlmIFRydWU6CiAgICAgICAgICAgICAgICBiZXJ0X21vZGVsXzIuemVyb19ncmFkKCkKCiAgICAgICAgICAgIGxpbmVhcl9vcHRpbWl6ZXIuemVyb19ncmFkKCkKCiAgICAgICAgICAgIGxvc3MuYmFja3dhcmQoKQoKCiAgICAgICAgICAgICMgSSBIQVZFIFRIUkVFIE9QVElNSVpFUiBJTlNURUFEIE9GIDEKICAgICAgICAgICAgaWYgVHJ1ZToKICAgICAgICAgICAgICAgIHhtLm9wdGltaXplcl9zdGVwKGJlcnRfbW9kZWxfMV9vcHRpbWl6ZXIpCiAgICAgICAgICAgICAgICAgICAKICAgICAgICAgICAgaWYgVHJ1ZToKICAgICAgICAgICAgICAgIHhtLm9wdGltaXplcl9zdGVwKGJlcnRfbW9kZWxfMl9vcHRpbWl6ZXIpCiAgICAgICAgICAgIAogICAgICAgICAgICB4bS5vcHRpbWl6ZXJfc3RlcChsaW5lYXJfb3B0aW1pemVyKQogICAgICAgICAgICAKCQkKICAgICMgQmVsb3cgY29kZSBoYXMgbm8gaW1wb3J0IGNoYW5nZQoKICAgIGVsYXBzZWRfdHJhaW5fdGltZSA9IHRpbWUudGltZSgpIC0gdHJhaW5fc3RhcnQKICAgIHByaW50KCJQcm9jZXNzIiwgaW5kZXgsICJmaW5pc2hlZCB0cmFpbmluZy4gVHJhaW4gdGltZSB3YXM6IiwgZWxhcHNlZF90cmFpbl90aW1lKSAKCgkJCiAgICAjIyBFdmFsdWF0aW9uCiAgICAjIFNldHMgbmV0IHRvIGV2YWwgYW5kIG5vIGdyYWQgY29udGV4dCAKICAgIG1vZGVsLmV2YWwoKQogICAgZXZhbF9zdGFydCA9IHRpbWUudGltZSgpCiAgICB3aXRoIHRvcmNoLm5vX2dyYWQoKToKICAgICAgICBudW1fY29ycmVjdCA9IDAKICAgICAgICB0b3RhbF9ndWVzc2VzID0gMAoKICAgICAgICBwYXJhX3RyYWluX2xvYWRlciA9IHBsLlBhcmFsbGVsTG9hZGVyKGRldl9sb2FkZXIsIFtkZXZpY2VdKS5wZXJfZGV2aWNlX2xvYWRlcihkZXZpY2UpCiAgICAgICAgZm9yIGJhdGNoX251bSwgYmF0Y2ggaW4gZW51bWVyYXRlKHBhcmFfdHJhaW5fbG9hZGVyKToKICAgICAgICAgICAgZGF0YSwgdGFyZ2V0cyA9IGJhdGNoWzotMV0sIGJhdGNoWy0xXQoKICAgICAgICAgICAgIyBBY3F1aXJlcyB0aGUgbmV0d29yaydzIGJlc3QgZ3Vlc3NlcyBhdCBlYWNoIGNsYXNzCiAgICAgICAgICAgIG91dHB1dCA9IHJlY19tb2RlbChkYXRhKQogICAgICAgICAgICBiZXN0X2d1ZXNzZXMgPSB0b3JjaC5hcmdtYXgob3V0cHV0LCAxKQoKICAgICAgICAgICAgIyBVcGRhdGVzIHJ1bm5pbmcgc3RhdGlzdGljcwogICAgICAgICAgICBudW1fY29ycmVjdCArPSB0b3JjaC5lcSh0YXJnZXRzLCBiZXN0X2d1ZXNzZXMpLnN1bSgpLml0ZW0oKQogICAgICAgICAgICB0b3RhbF9ndWVzc2VzICs9IGZsYWdzWydiYXRjaF9zaXplJ10KICAKICAgIGVsYXBzZWRfZXZhbF90aW1lID0gdGltZS50aW1lKCkgLSBldmFsX3N0YXJ0CiAgICBwcmludCgiUHJvY2VzcyIsIGluZGV4LCAiZmluaXNoZWQgZXZhbHVhdGlvbi4gRXZhbHVhdGlvbiB0aW1lIHdhczoiLCBlbGFwc2VkX2V2YWxfdGltZSkKICAgIHByaW50KCJQcm9jZXNzIiwgaW5kZXgsICJndWVzc2VkIiwgbnVtX2NvcnJlY3QsICJvZiIsIHRvdGFsX2d1ZXNzZXMsICJjb3JyZWN0bHkgZm9yIiwgbnVtX2NvcnJlY3QvdG90YWxfZ3Vlc3NlcyAqIDEwMCwgIiUgYWNjdXJhY3kuIik=