class Block(nn.Module):
def __init__(self, in_channels, out_channels, identity_downsample=None, stride=1):
super(Block, self).__init__()
self.expansion = 1
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, stride=1, padding=0)
self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
self.relu = nn.ReLU()
self.identity_downsample = identity_downsample
def forward(self, x):
identity = x
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.conv3(x)
x = self.bn3(x)
if self.identity_downsample is not None:
identity = self.identity_downsample(identity)
x += identity
x = self.relu(x)
return x
class ResNet18(nn.Module):
def __init__(self, block):
super(ResNet18, self).__init__()
#input_dim = 784(??????) + 42
image_channels=3
num_classes=1000
self.label_embedding = nn.Embedding(42, 42)
self.in_channels = 64
self.conv1 = nn.Conv2d(image_channels, 64, kernel_size=7, stride=2, padding=3)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# ResNetLayers
self.layer1 = self.make_layers(num_layers, block, 2, intermediate_channels=64, stride=1)
self.layer2 = self.make_layers(num_layers, block, 2, intermediate_channels=128, stride=2)
self.layer3 = self.make_layers(num_layers, block, 2, intermediate_channels=256, stride=2)
self.layer4 = self.make_layers(num_layers, block, 2, intermediate_channels=512, stride=2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * self.expansion, num_classes)
def forward(self, x, labels):
c = self.label_embedding(labels)
x = torch.cat([x, c], 1)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.reshape(x.shape[0], -1)
x = self.fc(x)
return x
def make_layers(self, block, num_residual_blocks, intermediate_channels, stride):
layers = []
identity_downsample = nn.Sequential(nn.Conv2d(self.in_channels, intermediate_channels*self.expansion, kernel_size=1, stride=stride),
nn.BatchNorm2d(intermediate_channels*self.expansion))
layers.append(block(self.in_channels, intermediate_channels, identity_downsample, stride))
self.in_channels = intermediate_channels * self.expansion # 256
for i in range(num_residual_blocks - 1):
layers.append(block(self.in_channels, intermediate_channels)) # 256 -> 64, 64*4 (256) again
return nn.Sequential(*layers)
Y2xhc3MgQmxvY2sobm4uTW9kdWxlKToKICAgIGRlZiBfX2luaXRfXyhzZWxmLCBpbl9jaGFubmVscywgb3V0X2NoYW5uZWxzLCBpZGVudGl0eV9kb3duc2FtcGxlPU5vbmUsIHN0cmlkZT0xKToKICAgICAgICBzdXBlcihCbG9jaywgc2VsZikuX19pbml0X18oKQogICAgICAgIHNlbGYuZXhwYW5zaW9uID0gMQogICAgICAgIHNlbGYuY29udjEgPSBubi5Db252MmQoaW5fY2hhbm5lbHMsIG91dF9jaGFubmVscywga2VybmVsX3NpemU9MSwgc3RyaWRlPTEsIHBhZGRpbmc9MCkKICAgICAgICBzZWxmLmJuMSA9IG5uLkJhdGNoTm9ybTJkKG91dF9jaGFubmVscykKICAgICAgICBzZWxmLmNvbnYyID0gbm4uQ29udjJkKGluX2NoYW5uZWxzLCBvdXRfY2hhbm5lbHMsIGtlcm5lbF9zaXplPTMsIHN0cmlkZT1zdHJpZGUsIHBhZGRpbmc9MSkKICAgICAgICBzZWxmLmJuMiA9IG5uLkJhdGNoTm9ybTJkKG91dF9jaGFubmVscykKICAgICAgICBzZWxmLmNvbnYzID0gbm4uQ29udjJkKG91dF9jaGFubmVscywgb3V0X2NoYW5uZWxzICogc2VsZi5leHBhbnNpb24sIGtlcm5lbF9zaXplPTEsIHN0cmlkZT0xLCBwYWRkaW5nPTApCiAgICAgICAgc2VsZi5ibjMgPSBubi5CYXRjaE5vcm0yZChvdXRfY2hhbm5lbHMgKiBzZWxmLmV4cGFuc2lvbikKICAgICAgICBzZWxmLnJlbHUgPSBubi5SZUxVKCkKICAgICAgICBzZWxmLmlkZW50aXR5X2Rvd25zYW1wbGUgPSBpZGVudGl0eV9kb3duc2FtcGxlCgogICAgZGVmIGZvcndhcmQoc2VsZiwgeCk6CiAgICAgICAgaWRlbnRpdHkgPSB4CgogICAgICAgIHggPSBzZWxmLmNvbnYyKHgpCiAgICAgICAgeCA9IHNlbGYuYm4yKHgpCiAgICAgICAgeCA9IHNlbGYucmVsdSh4KQogICAgICAgIHggPSBzZWxmLmNvbnYzKHgpCiAgICAgICAgeCA9IHNlbGYuYm4zKHgpCgogICAgICAgIGlmIHNlbGYuaWRlbnRpdHlfZG93bnNhbXBsZSBpcyBub3QgTm9uZToKICAgICAgICAgICAgaWRlbnRpdHkgPSBzZWxmLmlkZW50aXR5X2Rvd25zYW1wbGUoaWRlbnRpdHkpCgogICAgICAgIHggKz0gaWRlbnRpdHkKICAgICAgICB4ID0gc2VsZi5yZWx1KHgpCiAgICAgICAgcmV0dXJuIHgKICAgICAgICAKICAgICAgICAKY2xhc3MgUmVzTmV0MTgobm4uTW9kdWxlKToKICAgIGRlZiBfX2luaXRfXyhzZWxmLCBibG9jayk6CgogICAgICAgIHN1cGVyKFJlc05ldDE4LCBzZWxmKS5fX2luaXRfXygpCiAgICAgICAgCiAgICAgICAgI2lucHV0X2RpbSA9IDc4NCg/Pz8/Pz8pICsgNDIgCiAgICAgICAgCiAgICAgICAgaW1hZ2VfY2hhbm5lbHM9MyAKICAgICAgICBudW1fY2xhc3Nlcz0xMDAwCiAgICAgICAgCiAgICAgICAgc2VsZi5sYWJlbF9lbWJlZGRpbmcgPSBubi5FbWJlZGRpbmcoNDIsIDQyKQogICAgIAogICAgICAgIHNlbGYuaW5fY2hhbm5lbHMgPSA2NAogICAgICAgIHNlbGYuY29udjEgPSBubi5Db252MmQoaW1hZ2VfY2hhbm5lbHMsIDY0LCBrZXJuZWxfc2l6ZT03LCBzdHJpZGU9MiwgcGFkZGluZz0zKQogICAgICAgIHNlbGYuYm4xID0gbm4uQmF0Y2hOb3JtMmQoNjQpCiAgICAgICAgc2VsZi5yZWx1ID0gbm4uUmVMVSgpCiAgICAgICAgc2VsZi5tYXhwb29sID0gbm4uTWF4UG9vbDJkKGtlcm5lbF9zaXplPTMsIHN0cmlkZT0yLCBwYWRkaW5nPTEpCgogICAgICAgICMgUmVzTmV0TGF5ZXJzCiAgICAgICAgc2VsZi5sYXllcjEgPSBzZWxmLm1ha2VfbGF5ZXJzKG51bV9sYXllcnMsIGJsb2NrLCAyLCBpbnRlcm1lZGlhdGVfY2hhbm5lbHM9NjQsIHN0cmlkZT0xKQogICAgICAgIHNlbGYubGF5ZXIyID0gc2VsZi5tYWtlX2xheWVycyhudW1fbGF5ZXJzLCBibG9jaywgMiwgaW50ZXJtZWRpYXRlX2NoYW5uZWxzPTEyOCwgc3RyaWRlPTIpCiAgICAgICAgc2VsZi5sYXllcjMgPSBzZWxmLm1ha2VfbGF5ZXJzKG51bV9sYXllcnMsIGJsb2NrLCAyLCBpbnRlcm1lZGlhdGVfY2hhbm5lbHM9MjU2LCBzdHJpZGU9MikKICAgICAgICBzZWxmLmxheWVyNCA9IHNlbGYubWFrZV9sYXllcnMobnVtX2xheWVycywgYmxvY2ssIDIsIGludGVybWVkaWF0ZV9jaGFubmVscz01MTIsIHN0cmlkZT0yKQoKICAgICAgICBzZWxmLmF2Z3Bvb2wgPSBubi5BZGFwdGl2ZUF2Z1Bvb2wyZCgoMSwgMSkpCiAgICAgICAgc2VsZi5mYyA9IG5uLkxpbmVhcig1MTIgKiBzZWxmLmV4cGFuc2lvbiwgbnVtX2NsYXNzZXMpCgogICAgZGVmIGZvcndhcmQoc2VsZiwgeCwgbGFiZWxzKToKICAgICAgICAKICAgICAgICBjID0gc2VsZi5sYWJlbF9lbWJlZGRpbmcobGFiZWxzKQogICAgICAgIHggPSB0b3JjaC5jYXQoW3gsIGNdLCAxKQogICAgICAgIAogICAgICAgIHggPSBzZWxmLmNvbnYxKHgpCiAgICAgICAgeCA9IHNlbGYuYm4xKHgpCiAgICAgICAgeCA9IHNlbGYucmVsdSh4KQogICAgICAgIHggPSBzZWxmLm1heHBvb2woeCkKCiAgICAgICAgeCA9IHNlbGYubGF5ZXIxKHgpCiAgICAgICAgeCA9IHNlbGYubGF5ZXIyKHgpCiAgICAgICAgeCA9IHNlbGYubGF5ZXIzKHgpCiAgICAgICAgeCA9IHNlbGYubGF5ZXI0KHgpCgogICAgICAgIHggPSBzZWxmLmF2Z3Bvb2woeCkKICAgICAgICB4ID0geC5yZXNoYXBlKHguc2hhcGVbMF0sIC0xKQogICAgICAgIHggPSBzZWxmLmZjKHgpCiAgICAgICAgcmV0dXJuIHgKCiAgICBkZWYgbWFrZV9sYXllcnMoc2VsZiwgYmxvY2ssIG51bV9yZXNpZHVhbF9ibG9ja3MsIGludGVybWVkaWF0ZV9jaGFubmVscywgc3RyaWRlKToKICAgICAgICBsYXllcnMgPSBbXQoKICAgICAgICBpZGVudGl0eV9kb3duc2FtcGxlID0gbm4uU2VxdWVudGlhbChubi5Db252MmQoc2VsZi5pbl9jaGFubmVscywgaW50ZXJtZWRpYXRlX2NoYW5uZWxzKnNlbGYuZXhwYW5zaW9uLCBrZXJuZWxfc2l6ZT0xLCBzdHJpZGU9c3RyaWRlKSwKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICBubi5CYXRjaE5vcm0yZChpbnRlcm1lZGlhdGVfY2hhbm5lbHMqc2VsZi5leHBhbnNpb24pKQogICAgICAgIGxheWVycy5hcHBlbmQoYmxvY2soc2VsZi5pbl9jaGFubmVscywgaW50ZXJtZWRpYXRlX2NoYW5uZWxzLCBpZGVudGl0eV9kb3duc2FtcGxlLCBzdHJpZGUpKQogICAgICAgIHNlbGYuaW5fY2hhbm5lbHMgPSBpbnRlcm1lZGlhdGVfY2hhbm5lbHMgKiBzZWxmLmV4cGFuc2lvbiAjIDI1NgogICAgICAgIGZvciBpIGluIHJhbmdlKG51bV9yZXNpZHVhbF9ibG9ja3MgLSAxKToKICAgICAgICAgICAgbGF5ZXJzLmFwcGVuZChibG9jayhzZWxmLmluX2NoYW5uZWxzLCBpbnRlcm1lZGlhdGVfY2hhbm5lbHMpKSAjIDI1NiAtPiA2NCwgNjQqNCAoMjU2KSBhZ2FpbgogICAgICAgIHJldHVybiBubi5TZXF1ZW50aWFsKCpsYXllcnMp