update README

This commit is contained in:
D-X-Y 2019-11-15 17:40:15 +11:00
parent 0630867505
commit c3672648d7
3 changed files with 116 additions and 3 deletions

View File

@ -5,7 +5,7 @@ This project contains the following neural architecture search algorithms, imple
- Network Pruning via Transformable Architecture Search, NeurIPS 2019
- One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019
- Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019
- several typical classification models, e.g., ResNet and DenseNet (see BASELINE.md)
- several typical classification models, e.g., ResNet and DenseNet (see [BASELINE.md](https://github.com/D-X-Y/NAS-Projects/blob/master/BASELINE.md))
## Requirements and Preparation

105
lib/models/CifarDenseNet.py Normal file
View File

@ -0,0 +1,105 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import math, torch
import torch.nn as nn
import torch.nn.functional as F
from .initialization import initialize_resnet
class Bottleneck(nn.Module):
def __init__(self, nChannels, growthRate):
super(Bottleneck, self).__init__()
interChannels = 4*growthRate
self.bn1 = nn.BatchNorm2d(nChannels)
self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1, bias=False)
self.bn2 = nn.BatchNorm2d(interChannels)
self.conv2 = nn.Conv2d(interChannels, growthRate, kernel_size=3, padding=1, bias=False)
def forward(self, x):
out = self.conv1(F.relu(self.bn1(x)))
out = self.conv2(F.relu(self.bn2(out)))
out = torch.cat((x, out), 1)
return out
class SingleLayer(nn.Module):
def __init__(self, nChannels, growthRate):
super(SingleLayer, self).__init__()
self.bn1 = nn.BatchNorm2d(nChannels)
self.conv1 = nn.Conv2d(nChannels, growthRate, kernel_size=3, padding=1, bias=False)
def forward(self, x):
out = self.conv1(F.relu(self.bn1(x)))
out = torch.cat((x, out), 1)
return out
class Transition(nn.Module):
def __init__(self, nChannels, nOutChannels):
super(Transition, self).__init__()
self.bn1 = nn.BatchNorm2d(nChannels)
self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1, bias=False)
def forward(self, x):
out = self.conv1(F.relu(self.bn1(x)))
out = F.avg_pool2d(out, 2)
return out
class DenseNet(nn.Module):
def __init__(self, growthRate, depth, reduction, nClasses, bottleneck):
super(DenseNet, self).__init__()
if bottleneck: nDenseBlocks = int( (depth-4) / 6 )
else : nDenseBlocks = int( (depth-4) / 3 )
self.message = 'CifarDenseNet : block : {:}, depth : {:}, reduction : {:}, growth-rate = {:}, class = {:}'.format('bottleneck' if bottleneck else 'basic', depth, reduction, growthRate, nClasses)
nChannels = 2*growthRate
self.conv1 = nn.Conv2d(3, nChannels, kernel_size=3, padding=1, bias=False)
self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck)
nChannels += nDenseBlocks*growthRate
nOutChannels = int(math.floor(nChannels*reduction))
self.trans1 = Transition(nChannels, nOutChannels)
nChannels = nOutChannels
self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck)
nChannels += nDenseBlocks*growthRate
nOutChannels = int(math.floor(nChannels*reduction))
self.trans2 = Transition(nChannels, nOutChannels)
nChannels = nOutChannels
self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck)
nChannels += nDenseBlocks*growthRate
self.act = nn.Sequential(
nn.BatchNorm2d(nChannels), nn.ReLU(inplace=True),
nn.AvgPool2d(8))
self.fc = nn.Linear(nChannels, nClasses)
self.apply(initialize_resnet)
def get_message(self):
return self.message
def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck):
layers = []
for i in range(int(nDenseBlocks)):
if bottleneck:
layers.append(Bottleneck(nChannels, growthRate))
else:
layers.append(SingleLayer(nChannels, growthRate))
nChannels += growthRate
return nn.Sequential(*layers)
def forward(self, inputs):
out = self.conv1( inputs )
out = self.trans1(self.dense1(out))
out = self.trans2(self.dense2(out))
out = self.dense3(out)
features = self.act(out)
features = features.view(features.size(0), -1)
out = self.fc(features)
return features, out

View File

@ -38,12 +38,15 @@ def get_search_spaces(xtype, name):
def get_cifar_models(config):
from .CifarResNet import CifarResNet
from .CifarDenseNet import DenseNet
from .CifarWideResNet import CifarWideResNet
super_type = getattr(config, 'super_type', 'basic')
if super_type == 'basic':
if config.arch == 'resnet':
return CifarResNet(config.module, config.depth, config.class_num, config.zero_init_residual)
elif config.arch == 'densenet':
return DenseNet(config.growthRate, config.depth, config.reduction, config.class_num, config.bottleneck)
elif config.arch == 'wideresnet':
return CifarWideResNet(config.depth, config.wide_factor, config.class_num, config.dropout)
else:
@ -68,8 +71,13 @@ def get_cifar_models(config):
def get_imagenet_models(config):
super_type = getattr(config, 'super_type', 'basic')
# NAS searched architecture
if super_type.startswith('infer'):
if super_type == 'basic':
from .ImagenetResNet import ResNet
if config.arch == 'resnet':
return ResNet(config.block_name, config.layers, config.deep_stem, config.class_num, config.zero_init_residual, config.groups, config.width_per_group)
else:
raise ValueError('invalid arch : {:}'.format( config.arch ))
elif super_type.startswith('infer'): # NAS searched architecture
assert len(super_type.split('-')) == 2, 'invalid super_type : {:}'.format(super_type)
infer_mode = super_type.split('-')[1]
if infer_mode == 'shape':