autodl-projects/lib/models/CifarDenseNet.py

118 lines
4.0 KiB
Python
Raw Normal View History

2019-11-15 07:40:15 +01:00
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import math, torch
import torch.nn as nn
import torch.nn.functional as F
from .initialization import initialize_resnet
class Bottleneck(nn.Module):
2021-05-09 19:02:38 +02:00
def __init__(self, nChannels, growthRate):
super(Bottleneck, self).__init__()
interChannels = 4 * growthRate
self.bn1 = nn.BatchNorm2d(nChannels)
self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1, bias=False)
self.bn2 = nn.BatchNorm2d(interChannels)
self.conv2 = nn.Conv2d(
interChannels, growthRate, kernel_size=3, padding=1, bias=False
)
def forward(self, x):
out = self.conv1(F.relu(self.bn1(x)))
out = self.conv2(F.relu(self.bn2(out)))
out = torch.cat((x, out), 1)
return out
2019-11-15 07:40:15 +01:00
class SingleLayer(nn.Module):
2021-05-09 19:02:38 +02:00
def __init__(self, nChannels, growthRate):
super(SingleLayer, self).__init__()
self.bn1 = nn.BatchNorm2d(nChannels)
self.conv1 = nn.Conv2d(
nChannels, growthRate, kernel_size=3, padding=1, bias=False
)
2019-11-15 07:40:15 +01:00
2021-05-09 19:02:38 +02:00
def forward(self, x):
out = self.conv1(F.relu(self.bn1(x)))
out = torch.cat((x, out), 1)
return out
2019-11-15 07:40:15 +01:00
class Transition(nn.Module):
2021-05-09 19:02:38 +02:00
def __init__(self, nChannels, nOutChannels):
super(Transition, self).__init__()
self.bn1 = nn.BatchNorm2d(nChannels)
self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1, bias=False)
2019-11-15 07:40:15 +01:00
2021-05-09 19:02:38 +02:00
def forward(self, x):
out = self.conv1(F.relu(self.bn1(x)))
out = F.avg_pool2d(out, 2)
return out
2019-11-15 07:40:15 +01:00
class DenseNet(nn.Module):
2021-05-09 19:02:38 +02:00
def __init__(self, growthRate, depth, reduction, nClasses, bottleneck):
super(DenseNet, self).__init__()
if bottleneck:
nDenseBlocks = int((depth - 4) / 6)
else:
nDenseBlocks = int((depth - 4) / 3)
self.message = "CifarDenseNet : block : {:}, depth : {:}, reduction : {:}, growth-rate = {:}, class = {:}".format(
"bottleneck" if bottleneck else "basic",
depth,
reduction,
growthRate,
nClasses,
)
nChannels = 2 * growthRate
self.conv1 = nn.Conv2d(3, nChannels, kernel_size=3, padding=1, bias=False)
self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck)
nChannels += nDenseBlocks * growthRate
nOutChannels = int(math.floor(nChannels * reduction))
self.trans1 = Transition(nChannels, nOutChannels)
nChannels = nOutChannels
self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck)
nChannels += nDenseBlocks * growthRate
nOutChannels = int(math.floor(nChannels * reduction))
self.trans2 = Transition(nChannels, nOutChannels)
nChannels = nOutChannels
self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck)
nChannels += nDenseBlocks * growthRate
self.act = nn.Sequential(
nn.BatchNorm2d(nChannels), nn.ReLU(inplace=True), nn.AvgPool2d(8)
)
self.fc = nn.Linear(nChannels, nClasses)
self.apply(initialize_resnet)
def get_message(self):
return self.message
def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck):
layers = []
for i in range(int(nDenseBlocks)):
if bottleneck:
layers.append(Bottleneck(nChannels, growthRate))
else:
layers.append(SingleLayer(nChannels, growthRate))
nChannels += growthRate
return nn.Sequential(*layers)
def forward(self, inputs):
out = self.conv1(inputs)
out = self.trans1(self.dense1(out))
out = self.trans2(self.dense2(out))
out = self.dense3(out)
features = self.act(out)
features = features.view(features.size(0), -1)
out = self.fc(features)
return features, out