fork download
  1.  
  2. torch.manual_seed(131)
  3. cuda = torch.device('cuda')
  4.  
  5. class MODEL(nn.Module):
  6.  
  7. def __init__(self):
  8. super(MODEL, self).__init__()
  9. self.conv1 = nn.Conv2d(1, 8, 3)
  10. self.conv2 = nn.Conv2d(8, 16, 3)
  11. self.conv3 = nn.Conv2d(16, 32, 3)
  12. self.conv4 = nn.Conv2d(32, 16, 3)
  13. self.conv5 = nn.Conv2d(16, 1, 3)
  14.  
  15.  
  16. self.activ = FF.relu
  17. self.padding = nn.ConstantPad2d(1, 255.0)
  18. self.pool = nn.AvgPool2d(3, stride=1)
  19.  
  20.  
  21. def forward(self, x):
  22. x = self.activ( self.conv1(self.padding(x) ) )
  23. x = self.activ( ( self.conv2( self.padding(x) ) ) )
  24. x = self.pool(self.padding(x))
  25. x = self.activ( ( self.conv3( self.padding(x) ) ) )
  26. x = self.activ( ( self.conv4( self.padding(x) ) ) )
  27. x = self.pool( self.padding(x) )
  28. x = self.activ( ( self.conv5( self.padding(x) ) ) )
  29. x = torch.tanh(x)
  30.  
  31. return x
  32.  
  33.  
  34. model_a = MODEL()
  35. model_b = MODEL()
  36. model_a.to(cuda)
  37. model_b.to(cuda)
  38.  
  39. learning_rate = 0.01
  40.  
  41. loss_algo = nn.MSELoss()
  42. optimizer_A = optim.RMSprop(model_a.parameters(), lr = learning_rate, momentum=0.9)
  43. optimizer_B = optim.RMSprop(model_b.parameters(), lr = learning_rate, momentum=0.9)
  44.  
  45. for epoch in range(1):
  46. i = 0
  47. for data in os.listdir(root_dir):
  48. image = io.imread(os.path.join(root_dir, data))
  49. orig = image.copy()
  50. image = skimage.color.rgb2lab(image)
  51. xg = image[:,:,0]
  52. xa = image[:,:,1]
  53. xb = image[:,:,2]
  54. # image.permute(2, 0, 1)
  55.  
  56. xg_tensor = torch.tensor(xg, device=cuda)
  57. xa_tensor = torch.tensor(xa, device=cuda)
  58. xb_tensor = torch.tensor(xb, device=cuda)
  59.  
  60. xg_tensor = xg_tensor.view(1,1,xg_tensor.size()[0], xg_tensor.size()[1])
  61. xa_tensor = xa_tensor.view(1,1,xa_tensor.size()[0],xa_tensor.size()[1])
  62. xb_tensor = xb_tensor.view(1,1,xb_tensor.size()[0],xb_tensor.size()[1])
  63.  
  64. xg_tensor = xg_tensor.float()
  65. xa_tensor = xa_tensor.float()
  66. xb_tensor = xb_tensor.float()
  67.  
  68. optimizer_A.zero_grad()
  69. output_1 = model_a(xg_tensor)
  70. loss_a = loss_algo(torch.mul(output_1,128.0), xa_tensor)
  71. loss_a.backward()
  72. optimizer_A.step()
  73.  
  74. optimizer_B.zero_grad()
  75. output_2 = model_b(xb_tensor)
  76. loss_b = loss_algo(torch.mul(output_2,128.0), xb_tensor)
  77. loss_b.backward()
  78. optimizer_B.step()
  79.  
  80. # print("current loss = " + str(loss_a))
  81. print("current loss = " + str(loss_a) + " " + str(loss_b))
  82. i += 1
  83. # if(i == 2 or i == 1000 or True):
  84. # # output1 = output_1.view(256,256)
  85. # # output2 = output_2.view(256,256)
  86. # # a = output1.detach().cpu().numpy()
  87. # # b = output2.detach().cpu().numpy()
  88. # # out_img = image.copy()
  89. # # out_img[:,:,1] = a.copy() * 128.0
  90. # # out_img[:,:,2] = b.copy() * 128.0
  91. # # print(output1)
  92. # # print(skimage.color.rgb2lab(orig)[:, :, 1] )
  93. # # plot.imshow(orig)
  94. # # plot.show()
  95. # # plot.imshow(skimage.color.lab2rgb(out_img))
  96. # # plot.show()
  97. # if i == 1000:
  98. # break
  99. if i == 600:
  100. break
Runtime error #stdin #stdout #stderr 0.02s 9280KB
stdin
Standard input is empty
stdout
Standard output is empty
stderr
Traceback (most recent call last):
  File "./prog.py", line 2, in <module>
NameError: name 'torch' is not defined