torch.manual_seed(131)
cuda = torch.device('cuda')
class MODEL(nn.Module):
def __init__(self):
super(MODEL, self).__init__()
self.conv1 = nn.Conv2d(1, 8, 3)
self.conv2 = nn.Conv2d(8, 16, 3)
self.conv3 = nn.Conv2d(16, 32, 3)
self.conv4 = nn.Conv2d(32, 16, 3)
self.conv5 = nn.Conv2d(16, 1, 3)
self.activ = FF.relu
self.padding = nn.ConstantPad2d(1, 255.0)
self.pool = nn.AvgPool2d(3, stride=1)
def forward(self, x):
x = self.activ( self.conv1(self.padding(x) ) )
x = self.activ( ( self.conv2( self.padding(x) ) ) )
x = self.pool(self.padding(x))
x = self.activ( ( self.conv3( self.padding(x) ) ) )
x = self.activ( ( self.conv4( self.padding(x) ) ) )
x = self.pool( self.padding(x) )
x = self.activ( ( self.conv5( self.padding(x) ) ) )
x = torch.tanh(x)
return x
model_a = MODEL()
model_b = MODEL()
model_a.to(cuda)
model_b.to(cuda)
learning_rate = 0.01
loss_algo = nn.MSELoss()
optimizer_A = optim.RMSprop(model_a.parameters(), lr = learning_rate, momentum=0.9)
optimizer_B = optim.RMSprop(model_b.parameters(), lr = learning_rate, momentum=0.9)
for epoch in range(1):
i = 0
for data in os.listdir(root_dir):
image = io.imread(os.path.join(root_dir, data))
orig = image.copy()
image = skimage.color.rgb2lab(image)
xg = image[:,:,0]
xa = image[:,:,1]
xb = image[:,:,2]
# image.permute(2, 0, 1)
xg_tensor = torch.tensor(xg, device=cuda)
xa_tensor = torch.tensor(xa, device=cuda)
xb_tensor = torch.tensor(xb, device=cuda)
xg_tensor = xg_tensor.view(1,1,xg_tensor.size()[0], xg_tensor.size()[1])
xa_tensor = xa_tensor.view(1,1,xa_tensor.size()[0],xa_tensor.size()[1])
xb_tensor = xb_tensor.view(1,1,xb_tensor.size()[0],xb_tensor.size()[1])
xg_tensor = xg_tensor.float()
xa_tensor = xa_tensor.float()
xb_tensor = xb_tensor.float()
optimizer_A.zero_grad()
output_1 = model_a(xg_tensor)
loss_a = loss_algo(torch.mul(output_1,128.0), xa_tensor)
loss_a.backward()
optimizer_A.step()
optimizer_B.zero_grad()
output_2 = model_b(xb_tensor)
loss_b = loss_algo(torch.mul(output_2,128.0), xb_tensor)
loss_b.backward()
optimizer_B.step()
# print("current loss = " + str(loss_a))
print("current loss = " + str(loss_a) + " " + str(loss_b))
i += 1
# if(i == 2 or i == 1000 or True):
# # output1 = output_1.view(256,256)
# # output2 = output_2.view(256,256)
# # a = output1.detach().cpu().numpy()
# # b = output2.detach().cpu().numpy()
# # out_img = image.copy()
# # out_img[:,:,1] = a.copy() * 128.0
# # out_img[:,:,2] = b.copy() * 128.0
# # print(output1)
# # print(skimage.color.rgb2lab(orig)[:, :, 1] )
# # plot.imshow(orig)
# # plot.show()
# # plot.imshow(skimage.color.lab2rgb(out_img))
# # plot.show()
# if i == 1000:
# break
if i == 600:
break
CnRvcmNoLm1hbnVhbF9zZWVkKDEzMSkKY3VkYSA9IHRvcmNoLmRldmljZSgnY3VkYScpCgpjbGFzcyBNT0RFTChubi5Nb2R1bGUpOgoKCWRlZiBfX2luaXRfXyhzZWxmKToKCQlzdXBlcihNT0RFTCwgc2VsZikuX19pbml0X18oKQoJCXNlbGYuY29udjEgPSBubi5Db252MmQoMSwgOCwgMykKCQlzZWxmLmNvbnYyID0gbm4uQ29udjJkKDgsIDE2LCAzKQoJCXNlbGYuY29udjMgPSBubi5Db252MmQoMTYsIDMyLCAzKQoJCXNlbGYuY29udjQgPSBubi5Db252MmQoMzIsIDE2LCAzKQoJCXNlbGYuY29udjUgPSBubi5Db252MmQoMTYsIDEsIDMpCgoJCQoJCXNlbGYuYWN0aXYgPSBGRi5yZWx1CgkJc2VsZi5wYWRkaW5nID0gbm4uQ29uc3RhbnRQYWQyZCgxLCAyNTUuMCkKCQlzZWxmLnBvb2wgPSBubi5BdmdQb29sMmQoMywgc3RyaWRlPTEpCgoKCWRlZiBmb3J3YXJkKHNlbGYsIHgpOgoJCXggPSBzZWxmLmFjdGl2KCBzZWxmLmNvbnYxKHNlbGYucGFkZGluZyh4KSApICkKCQl4ID0gc2VsZi5hY3RpdiggKCBzZWxmLmNvbnYyKCBzZWxmLnBhZGRpbmcoeCkgKSApICkKCQl4ID0gc2VsZi5wb29sKHNlbGYucGFkZGluZyh4KSkKCQl4ID0gc2VsZi5hY3RpdiggKCBzZWxmLmNvbnYzKCBzZWxmLnBhZGRpbmcoeCkgKSApICkKCQl4ID0gc2VsZi5hY3RpdiggKCBzZWxmLmNvbnY0KCBzZWxmLnBhZGRpbmcoeCkgKSApICkKCQl4ID0gc2VsZi5wb29sKCBzZWxmLnBhZGRpbmcoeCkgKQoJCXggPSBzZWxmLmFjdGl2KCAoIHNlbGYuY29udjUoIHNlbGYucGFkZGluZyh4KSApICkgKQoJCXggPSB0b3JjaC50YW5oKHgpCgoJCXJldHVybiB4CgoKbW9kZWxfYSA9IE1PREVMKCkKbW9kZWxfYiA9IE1PREVMKCkKbW9kZWxfYS50byhjdWRhKQptb2RlbF9iLnRvKGN1ZGEpCgpsZWFybmluZ19yYXRlID0gMC4wMQoKbG9zc19hbGdvID0gbm4uTVNFTG9zcygpCm9wdGltaXplcl9BID0gb3B0aW0uUk1TcHJvcChtb2RlbF9hLnBhcmFtZXRlcnMoKSwgbHIgPSBsZWFybmluZ19yYXRlLCBtb21lbnR1bT0wLjkpCm9wdGltaXplcl9CID0gb3B0aW0uUk1TcHJvcChtb2RlbF9iLnBhcmFtZXRlcnMoKSwgbHIgPSBsZWFybmluZ19yYXRlLCBtb21lbnR1bT0wLjkpCgpmb3IgZXBvY2ggaW4gcmFuZ2UoMSk6CglpID0gMAoJZm9yIGRhdGEgaW4gb3MubGlzdGRpcihyb290X2Rpcik6CgkJaW1hZ2UgPSBpby5pbXJlYWQob3MucGF0aC5qb2luKHJvb3RfZGlyLCBkYXRhKSkKCQlvcmlnID0gaW1hZ2UuY29weSgpCgkJaW1hZ2UgPSBza2ltYWdlLmNvbG9yLnJnYjJsYWIoaW1hZ2UpCgkJeGcgPSBpbWFnZVs6LDosMF0KCQl4YSA9IGltYWdlWzosOiwxXQoJCXhiID0gaW1hZ2VbOiw6LDJdCgkJIyBpbWFnZS5wZXJtdXRlKDIsIDAsIDEpCgoJCXhnX3RlbnNvciA9IHRvcmNoLnRlbnNvcih4ZywgZGV2aWNlPWN1ZGEpCgkJeGFfdGVuc29yID0gdG9yY2gudGVuc29yKHhhLCBkZXZpY2U9Y3VkYSkKCQl4Yl90ZW5zb3IgPSB0b3JjaC50ZW5zb3IoeGIsIGRldmljZT1jdWRhKQoKCQl4Z190ZW5zb3IgPSB4Z190ZW5zb3IudmlldygxLDEseGdfdGVuc29yLnNpemUoKVswXSwgeGdfdGVuc29yLnNpemUoKVsxXSkKCQl4YV90ZW5zb3IgPSB4YV90ZW5zb3IudmlldygxLDEseGFfdGVuc29yLnNpemUoKVswXSx4YV90ZW5zb3Iuc2l6ZSgpWzFdKQoJCXhiX3RlbnNvciA9IHhiX3RlbnNvci52aWV3KDEsMSx4Yl90ZW5zb3Iuc2l6ZSgpWzBdLHhiX3RlbnNvci5zaXplKClbMV0pCgoJCXhnX3RlbnNvciA9IHhnX3RlbnNvci5mbG9hdCgpCgkJeGFfdGVuc29yID0geGFfdGVuc29yLmZsb2F0KCkKCQl4Yl90ZW5zb3IgPSB4Yl90ZW5zb3IuZmxvYXQoKQoKCQlvcHRpbWl6ZXJfQS56ZXJvX2dyYWQoKQoJCW91dHB1dF8xID0gbW9kZWxfYSh4Z190ZW5zb3IpCgkJbG9zc19hID0gbG9zc19hbGdvKHRvcmNoLm11bChvdXRwdXRfMSwxMjguMCksIHhhX3RlbnNvcikKCQlsb3NzX2EuYmFja3dhcmQoKQoJCW9wdGltaXplcl9BLnN0ZXAoKQoKCQlvcHRpbWl6ZXJfQi56ZXJvX2dyYWQoKQoJCW91dHB1dF8yID0gbW9kZWxfYih4Yl90ZW5zb3IpCgkJbG9zc19iID0gbG9zc19hbGdvKHRvcmNoLm11bChvdXRwdXRfMiwxMjguMCksIHhiX3RlbnNvcikKCQlsb3NzX2IuYmFja3dhcmQoKQoJCW9wdGltaXplcl9CLnN0ZXAoKQoKCQkjIHByaW50KCJjdXJyZW50IGxvc3MgPSAiICsgc3RyKGxvc3NfYSkpCgkJcHJpbnQoImN1cnJlbnQgbG9zcyA9ICIgKyBzdHIobG9zc19hKSArICIgICIgKyBzdHIobG9zc19iKSkKCQlpICs9IDEKCQkjIGlmKGkgPT0gMiBvciBpID09IDEwMDAgb3IgVHJ1ZSk6CgkJIyAJIyBvdXRwdXQxID0gb3V0cHV0XzEudmlldygyNTYsMjU2KQoJCSMgCSMgb3V0cHV0MiA9IG91dHB1dF8yLnZpZXcoMjU2LDI1NikKCQkjIAkjIGEgPSBvdXRwdXQxLmRldGFjaCgpLmNwdSgpLm51bXB5KCkKCQkjIAkjIGIgPSBvdXRwdXQyLmRldGFjaCgpLmNwdSgpLm51bXB5KCkKCQkjIAkjIG91dF9pbWcgPSBpbWFnZS5jb3B5KCkKCQkjIAkjIG91dF9pbWdbOiw6LDFdID0gYS5jb3B5KCkgKiAxMjguMAoJCSMgCSMgb3V0X2ltZ1s6LDosMl0gPSBiLmNvcHkoKSAqIDEyOC4wCgkJIyAJIyBwcmludChvdXRwdXQxKQoJCSMgCSMgcHJpbnQoc2tpbWFnZS5jb2xvci5yZ2IybGFiKG9yaWcpWzosIDosIDFdICkKCQkjIAkjIHBsb3QuaW1zaG93KG9yaWcpCgkJIyAJIyBwbG90LnNob3coKQoJCSMgCSMgcGxvdC5pbXNob3coc2tpbWFnZS5jb2xvci5sYWIycmdiKG91dF9pbWcpKQoJCSMgCSMgcGxvdC5zaG93KCkKCQkjIAlpZiBpID09IDEwMDA6CgkJIyAJCWJyZWFrCgkJaWYgaSA9PSA2MDA6CgkJCWJyZWFr