This commit is contained in:
D-X-Y 2019-11-08 20:06:12 +11:00
parent 5e44189d7e
commit 1da5b49018
9 changed files with 256 additions and 2 deletions

View File

@ -6,6 +6,7 @@ This project contains the following neural architecture search algorithms, imple
- One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019
- Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019
- Auto-ReID: Searching for a Part-Aware ConvNet for Person Re-Identification, ICCV 2019
- several typical classification models, e.g., ResNet and DenseNet (see BASELINE.md)
## Requirements and Preparation

124
exps/AA_functions.py Normal file
View File

@ -0,0 +1,124 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import os, sys, time, torch
from procedures import prepare_seed, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy
from config_utils import dict2config
from log_utils import AverageMeter, time_string, convert_secs2time
from models import get_cell_based_tiny_net
__all__ = ['evaluate_for_seed', 'pure_evaluate']
def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()):
data_time, batch_time, batch = AverageMeter(), AverageMeter(), None
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
latencies = []
network.eval()
with torch.no_grad():
end = time.time()
for i, (inputs, targets) in enumerate(xloader):
targets = targets.cuda(non_blocking=True)
inputs = inputs.cuda(non_blocking=True)
data_time.update(time.time() - end)
# forward
features, logits = network(inputs)
loss = criterion(logits, targets)
batch_time.update(time.time() - end)
if batch is None or batch == inputs.size(0):
batch = inputs.size(0)
latencies.append( batch_time.val - data_time.val )
# record loss and accuracy
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1.update (prec1.item(), inputs.size(0))
top5.update (prec5.item(), inputs.size(0))
end = time.time()
if len(latencies) > 2: latencies = latencies[1:]
return losses.avg, top1.avg, top5.avg, latencies
def procedure(xloader, network, criterion, scheduler, optimizer, mode):
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
if mode == 'train' : network.train()
elif mode == 'valid': network.eval()
else: raise ValueError("The mode is not right : {:}".format(mode))
for i, (inputs, targets) in enumerate(xloader):
if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader))
targets = targets.cuda(non_blocking=True)
if mode == 'train': optimizer.zero_grad()
# forward
features, logits = network(inputs)
loss = criterion(logits, targets)
# backward
if mode == 'train':
loss.backward()
optimizer.step()
# record loss and accuracy
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1.update (prec1.item(), inputs.size(0))
top5.update (prec5.item(), inputs.size(0))
return losses.avg, top1.avg, top5.avg
def evaluate_for_seed(arch_config, config, arch, train_loader, valid_loader, seed, logger):
prepare_seed(seed) # random seed
net = get_cell_based_tiny_net(dict2config({'name': 'infer.tiny',
'C': arch_config['channel'], 'N': arch_config['num_cells'],
'genotype': arch, 'num_classes': config.class_num}
, None)
)
#net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num)
flop, param = get_model_infos(net, config.xshape)
logger.log('Network : {:}'.format(net.get_message()), False)
logger.log('Seed-------------------------- {:} --------------------------'.format(seed))
logger.log('FLOP = {:} MB, Param = {:} MB'.format(flop, param))
# train and valid
optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), config)
network, criterion = torch.nn.DataParallel(net).cuda(), criterion.cuda()
# start training
start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup
train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es = {}, {}, {}, {}, {}, {}
for epoch in range(total_epoch):
scheduler.update(epoch, 0.0)
train_loss, train_acc1, train_acc5 = procedure(train_loader, network, criterion, scheduler, optimizer, 'train')
with torch.no_grad():
valid_loss, valid_acc1, valid_acc5 = procedure(valid_loader, network, criterion, None, None, 'valid')
train_losses[epoch] = train_loss
train_acc1es[epoch] = train_acc1
train_acc5es[epoch] = train_acc5
valid_losses[epoch] = valid_loss
valid_acc1es[epoch] = valid_acc1
valid_acc5es[epoch] = valid_acc5
# measure elapsed time
epoch_time.update(time.time() - start_time)
start_time = time.time()
need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.avg * (total_epoch-epoch-1), True) )
logger.log('{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%]'.format(time_string(), need_time, epoch, total_epoch, train_loss, train_acc1, train_acc5, valid_loss, valid_acc1, valid_acc5))
info_seed = {'flop' : flop,
'param': param,
'channel' : arch_config['channel'],
'num_cells' : arch_config['num_cells'],
'config' : config._asdict(),
'total_epoch' : total_epoch ,
'train_losses': train_losses,
'train_acc1es': train_acc1es,
'train_acc5es': train_acc5es,
'valid_losses': valid_losses,
'valid_acc1es': valid_acc1es,
'valid_acc5es': valid_acc5es,
'net_state_dict': net.state_dict(),
'net_string' : '{:}'.format(net),
'finish-train': True
}
return info_seed

View File

@ -3,10 +3,16 @@
##################################################
import torch
from os import path as osp
__all__ = ['change_key', 'get_cell_based_tiny_net', 'get_search_spaces', 'get_cifar_models', 'get_imagenet_models', \
'obtain_model', 'obtain_search_model', 'load_net_from_checkpoint', \
'CellStructure', 'CellArchitectures'
]
# useful modules
from config_utils import dict2config
from .SharedUtils import change_key
from .clone_weights import init_from_model
from .cell_searchs import CellStructure, CellArchitectures
# Cell-based NAS Models
def get_cell_based_tiny_net(config):
@ -22,9 +28,13 @@ def get_cell_based_tiny_net(config):
elif config.name == 'SETN':
from .cell_searchs import TinyNetworkSETN
return TinyNetworkSETN(config.C, config.N, config.max_nodes, config.num_classes, config.space)
elif config.name == 'infer.tiny':
from .cell_infers import TinyNetwork
return TinyNetwork(config.C, config.N, config.genotype, config.num_classes)
else:
raise ValueError('invalid network name : {:}'.format(config.name))
# obtain the search space, i.e., a dict mapping the operation name into a python-function for this op
def get_search_spaces(xtype, name):
if xtype == 'cell':

View File

@ -0,0 +1 @@
from .tiny_network import TinyNetwork

View File

@ -0,0 +1,51 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import torch
import torch.nn as nn
from copy import deepcopy
from ..cell_operations import OPS
class InferCell(nn.Module):
def __init__(self, genotype, C_in, C_out, stride):
super(InferCell, self).__init__()
self.layers = nn.ModuleList()
self.node_IN = []
self.node_IX = []
self.genotype = deepcopy(genotype)
for i in range(1, len(genotype)):
node_info = genotype[i-1]
cur_index = []
cur_innod = []
for (op_name, op_in) in node_info:
if op_in == 0:
layer = OPS[op_name](C_in , C_out, stride)
else:
layer = OPS[op_name](C_out, C_out, 1)
cur_index.append( len(self.layers) )
cur_innod.append( op_in )
self.layers.append( layer )
self.node_IX.append( cur_index )
self.node_IN.append( cur_innod )
self.nodes = len(genotype)
self.in_dim = C_in
self.out_dim = C_out
def extra_repr(self):
string = 'info :: nodes={nodes}, inC={in_dim}, outC={out_dim}'.format(**self.__dict__)
laystr = []
for i, (node_layers, node_innods) in enumerate(zip(self.node_IX,self.node_IN)):
y = ['I{:}-L{:}'.format(_ii, _il) for _il, _ii in zip(node_layers, node_innods)]
x = '{:}<-({:})'.format(i+1, ','.join(y))
laystr.append( x )
return string + ', [{:}]'.format( ' | '.join(laystr) ) + ', {:}'.format(self.genotype.tostr())
def forward(self, inputs):
nodes = [inputs]
for i, (node_layers, node_innods) in enumerate(zip(self.node_IX,self.node_IN)):
node_feature = sum( self.layers[_il](nodes[_ii]) for _il, _ii in zip(node_layers, node_innods) )
nodes.append( node_feature )
return nodes[-1]

View File

@ -0,0 +1,58 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import torch
import torch.nn as nn
from ..cell_operations import ResNetBasicblock
from .cells import InferCell
class TinyNetwork(nn.Module):
def __init__(self, C, N, genotype, num_classes):
super(TinyNetwork, self).__init__()
self._C = C
self._layerN = N
self.stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C))
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev = C
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
cell = InferCell(genotype, C_prev, C_curr, 1)
self.cells.append( cell )
C_prev = cell.out_dim
self._Layer= len(self.cells)
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
return string
def extra_repr(self):
return ('{name}(C={_C}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
def forward(self, inputs):
feature = self.stem(inputs)
for i, cell in enumerate(self.cells):
feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling( out )
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@ -17,7 +17,8 @@ CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3']
AA_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
SearchSpaceNames = {'connect-nas' : CONNECT_NAS_BENCHMARK,
'aa-nas' : AA_NAS_BENCHMARK}
'aa-nas' : AA_NAS_BENCHMARK,
'full' : sorted(list(OPS.keys()))}
class ReLUConvBN(nn.Module):

View File

@ -2,3 +2,4 @@ from .search_model_darts_v1 import TinyNetworkDartsV1
from .search_model_darts_v2 import TinyNetworkDartsV2
from .search_model_gdas import TinyNetworkGDAS
from .search_model_setn import TinyNetworkSETN
from .genotypes import Structure as CellStructure, architectures as CellArchitectures

View File

@ -60,6 +60,13 @@ class Structure:
strings.append( string )
return '+'.join(strings)
def check_valid_op(self, op_names):
for node_info in self.nodes:
for inode_edge in node_info:
#assert inode_edge[0] in op_names, 'invalid op-name : {:}'.format(inode_edge[0])
if inode_edge[0] not in op_names: return False
return True
def __repr__(self):
return ('{name}({node_num} nodes with {node_info})'.format(name=self.__class__.__name__, node_info=self.tostr(), **self.__dict__))