import torch import torch.nn as nn class ImageNetHEAD(nn.Sequential): def __init__(self, C, stride=2): super(ImageNetHEAD, self).__init__() self.add_module( "conv1", nn.Conv2d(3, C // 2, kernel_size=3, stride=2, padding=1, bias=False), ) self.add_module("bn1", nn.BatchNorm2d(C // 2)) self.add_module("relu1", nn.ReLU(inplace=True)) self.add_module( "conv2", nn.Conv2d(C // 2, C, kernel_size=3, stride=stride, padding=1, bias=False), ) self.add_module("bn2", nn.BatchNorm2d(C)) class CifarHEAD(nn.Sequential): def __init__(self, C): super(CifarHEAD, self).__init__() self.add_module("conv", nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False)) self.add_module("bn", nn.BatchNorm2d(C)) class AuxiliaryHeadCIFAR(nn.Module): def __init__(self, C, num_classes): """assuming input size 8x8""" super(AuxiliaryHeadCIFAR, self).__init__() self.features = nn.Sequential( nn.ReLU(inplace=True), nn.AvgPool2d( 5, stride=3, padding=0, count_include_pad=False ), # image size = 2 x 2 nn.Conv2d(C, 128, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.Conv2d(128, 768, 2, bias=False), nn.BatchNorm2d(768), nn.ReLU(inplace=True), ) self.classifier = nn.Linear(768, num_classes) def forward(self, x): x = self.features(x) x = self.classifier(x.view(x.size(0), -1)) return x class AuxiliaryHeadImageNet(nn.Module): def __init__(self, C, num_classes): """assuming input size 14x14""" super(AuxiliaryHeadImageNet, self).__init__() self.features = nn.Sequential( nn.ReLU(inplace=True), nn.AvgPool2d(5, stride=2, padding=0, count_include_pad=False), nn.Conv2d(C, 128, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.Conv2d(128, 768, 2, bias=False), nn.BatchNorm2d(768), nn.ReLU(inplace=True), ) self.classifier = nn.Linear(768, num_classes) def forward(self, x): x = self.features(x) x = self.classifier(x.view(x.size(0), -1)) return x