118 lines
4.0 KiB
Python
118 lines
4.0 KiB
Python
##################################################
|
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
|
##################################################
|
|
import math, torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from .initialization import initialize_resnet
|
|
|
|
|
|
class Bottleneck(nn.Module):
|
|
def __init__(self, nChannels, growthRate):
|
|
super(Bottleneck, self).__init__()
|
|
interChannels = 4 * growthRate
|
|
self.bn1 = nn.BatchNorm2d(nChannels)
|
|
self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1, bias=False)
|
|
self.bn2 = nn.BatchNorm2d(interChannels)
|
|
self.conv2 = nn.Conv2d(
|
|
interChannels, growthRate, kernel_size=3, padding=1, bias=False
|
|
)
|
|
|
|
def forward(self, x):
|
|
out = self.conv1(F.relu(self.bn1(x)))
|
|
out = self.conv2(F.relu(self.bn2(out)))
|
|
out = torch.cat((x, out), 1)
|
|
return out
|
|
|
|
|
|
class SingleLayer(nn.Module):
|
|
def __init__(self, nChannels, growthRate):
|
|
super(SingleLayer, self).__init__()
|
|
self.bn1 = nn.BatchNorm2d(nChannels)
|
|
self.conv1 = nn.Conv2d(
|
|
nChannels, growthRate, kernel_size=3, padding=1, bias=False
|
|
)
|
|
|
|
def forward(self, x):
|
|
out = self.conv1(F.relu(self.bn1(x)))
|
|
out = torch.cat((x, out), 1)
|
|
return out
|
|
|
|
|
|
class Transition(nn.Module):
|
|
def __init__(self, nChannels, nOutChannels):
|
|
super(Transition, self).__init__()
|
|
self.bn1 = nn.BatchNorm2d(nChannels)
|
|
self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1, bias=False)
|
|
|
|
def forward(self, x):
|
|
out = self.conv1(F.relu(self.bn1(x)))
|
|
out = F.avg_pool2d(out, 2)
|
|
return out
|
|
|
|
|
|
class DenseNet(nn.Module):
|
|
def __init__(self, growthRate, depth, reduction, nClasses, bottleneck):
|
|
super(DenseNet, self).__init__()
|
|
|
|
if bottleneck:
|
|
nDenseBlocks = int((depth - 4) / 6)
|
|
else:
|
|
nDenseBlocks = int((depth - 4) / 3)
|
|
|
|
self.message = "CifarDenseNet : block : {:}, depth : {:}, reduction : {:}, growth-rate = {:}, class = {:}".format(
|
|
"bottleneck" if bottleneck else "basic",
|
|
depth,
|
|
reduction,
|
|
growthRate,
|
|
nClasses,
|
|
)
|
|
|
|
nChannels = 2 * growthRate
|
|
self.conv1 = nn.Conv2d(3, nChannels, kernel_size=3, padding=1, bias=False)
|
|
|
|
self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck)
|
|
nChannels += nDenseBlocks * growthRate
|
|
nOutChannels = int(math.floor(nChannels * reduction))
|
|
self.trans1 = Transition(nChannels, nOutChannels)
|
|
|
|
nChannels = nOutChannels
|
|
self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck)
|
|
nChannels += nDenseBlocks * growthRate
|
|
nOutChannels = int(math.floor(nChannels * reduction))
|
|
self.trans2 = Transition(nChannels, nOutChannels)
|
|
|
|
nChannels = nOutChannels
|
|
self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck)
|
|
nChannels += nDenseBlocks * growthRate
|
|
|
|
self.act = nn.Sequential(
|
|
nn.BatchNorm2d(nChannels), nn.ReLU(inplace=True), nn.AvgPool2d(8)
|
|
)
|
|
self.fc = nn.Linear(nChannels, nClasses)
|
|
|
|
self.apply(initialize_resnet)
|
|
|
|
def get_message(self):
|
|
return self.message
|
|
|
|
def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck):
|
|
layers = []
|
|
for i in range(int(nDenseBlocks)):
|
|
if bottleneck:
|
|
layers.append(Bottleneck(nChannels, growthRate))
|
|
else:
|
|
layers.append(SingleLayer(nChannels, growthRate))
|
|
nChannels += growthRate
|
|
return nn.Sequential(*layers)
|
|
|
|
def forward(self, inputs):
|
|
out = self.conv1(inputs)
|
|
out = self.trans1(self.dense1(out))
|
|
out = self.trans2(self.dense2(out))
|
|
out = self.dense3(out)
|
|
features = self.act(out)
|
|
features = features.view(features.size(0), -1)
|
|
out = self.fc(features)
|
|
return features, out
|