From d28826793d2f27f5a5ae17d6e70c5a51d79ac34f Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Wed, 16 Oct 2019 16:29:57 +1100 Subject: [PATCH] update GDAS (TO-FINISH) --- README.md | 4 +- lib/models/cell_searchs/__init__.py | 0 lib/models/cell_searchs/cells.py | 196 +++++++++++++++++++ lib/models/cell_searchs/operations.py | 113 +++++++++++ lib/models/cell_searchs/search_model_gdas.py | 117 +++++++++++ lib/models/searchs/SoftSelect.py | 3 +- 6 files changed, 429 insertions(+), 4 deletions(-) create mode 100644 lib/models/cell_searchs/__init__.py create mode 100644 lib/models/cell_searchs/cells.py create mode 100644 lib/models/cell_searchs/operations.py create mode 100644 lib/models/cell_searchs/search_model_gdas.py diff --git a/README.md b/README.md index b093065..b6e1606 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,7 @@ CUDA_VISIBLE_DEVICES=0,1 bash ./scripts-search/search-cifar.sh cifar10 ResNet56 args: `cifar10` indicates the dataset name, `ResNet56` indicates the basemodel name, `CIFARX` indicates the searching hyper-parameters, `0.47/0.57` indicates the expected FLOP ratio, `-1` indicates the random seed. -## One-Shot Neural Architecture Search via Self-Evaluated Template Network +## [One-Shot Neural Architecture Search via Self-Evaluated Template Network](https://arxiv.org/abs/1910.05733) @@ -67,7 +67,7 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./scripts/nas-infer-train.sh imagenet-1k SETN Searching codes come soon! -## [Searching for A Robust Neural Architecture in Four GPU Hours](http://openaccess.thecvf.com/content_CVPR_2019/papers/Dong_Searching_for_a_Robust_Neural_Architecture_in_Four_GPU_Hours_CVPR_2019_paper.pdf) +## [Searching for A Robust Neural Architecture in Four GPU Hours](https://arxiv.org/abs/1910.04465) diff --git a/lib/models/cell_searchs/__init__.py b/lib/models/cell_searchs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lib/models/cell_searchs/cells.py b/lib/models/cell_searchs/cells.py new file mode 100644 index 0000000..f569cfd --- /dev/null +++ b/lib/models/cell_searchs/cells.py @@ -0,0 +1,196 @@ +import math, torch +import torch.nn as nn +import torch.nn.functional as F +from copy import deepcopy +from .operations import OPS, ReLUConvBN + + +class SearchCell(nn.Module): + + def __init__(self, C_in, C_out, stride, max_nodes, op_names): + super(SearchCell, self).__init__() + + self.op_names = deepcopy(op_names) + self.edges = nn.ModuleDict() + self.max_nodes = max_nodes + self.in_dim = C_in + self.out_dim = C_out + for i in range(1, max_nodes): + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + if j == 0: + xlists = [OPS[op_name](C_in , C_out, stride) for op_name in op_names] + else: + xlists = [OPS[op_name](C_in , C_out, 1) for op_name in op_names] + self.edges[ node_str ] = nn.ModuleList( xlists ) + self.edge_keys = sorted(list(self.edges.keys())) + self.edge2index = {key:i for i, key in enumerate(self.edge_keys)} + self.num_edges = len(self.edges) + + def extra_repr(self): + string = 'info :: {max_nodes} nodes, inC={in_dim}, outC={out_dim}'.format(**self.__dict__) + return string + + def forward(self, inputs, weightss): + nodes = [inputs] + for i in range(1, self.max_nodes): + inter_nodes = [] + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + weights = weightss[ self.edge2index[node_str] ] + inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) ) + nodes.append( sum(inter_nodes) ) + return nodes[-1] + + # GDAS + def forward_acc(self, inputs, weightss, indexess): + nodes = [inputs] + for i in range(1, self.max_nodes): + inter_nodes = [] + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + weights = weightss[ self.edge2index[node_str] ] + indexes = indexess[ self.edge2index[node_str] ].item() + import pdb; pdb.set_trace() # to-do + #inter_nodes.append( self.edges[node_str][indexes](nodes[j]) * weights[indexes] ) + nodes.append( sum(inter_nodes) ) + return nodes[-1] + + # joint + def forward_joint(self, inputs, weightss): + nodes = [inputs] + for i in range(1, self.max_nodes): + inter_nodes = [] + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + weights = weightss[ self.edge2index[node_str] ] + aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) / weights.numel() + inter_nodes.append( aggregation ) + nodes.append( sum(inter_nodes) ) + return nodes[-1] + + # uniform random sampling per iteration + def forward_urs(self, inputs): + nodes = [inputs] + for i in range(1, self.max_nodes): + while True: # to avoid select zero for all ops + sops, has_non_zero = [], False + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + candidates = self.edges[node_str] + select_op = random.choice(candidates) + sops.append( select_op ) + if not hasattr(select_op, 'is_zero') or select_op.is_zero == False: has_non_zero=True + if has_non_zero: break + inter_nodes = [] + for j, select_op in enumerate(sops): + inter_nodes.append( select_op(nodes[j]) ) + nodes.append( sum(inter_nodes) ) + return nodes[-1] + + # select the argmax + def forward_select(self, inputs, weightss): + nodes = [inputs] + for i in range(1, self.max_nodes): + inter_nodes = [] + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + weights = weightss[ self.edge2index[node_str] ] + inter_nodes.append( self.edges[node_str][ weights.argmax().item() ]( nodes[j] ) ) + #inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) ) + nodes.append( sum(inter_nodes) ) + return nodes[-1] + + # select the argmax + def forward_dynamic(self, inputs, structure): + nodes = [inputs] + for i in range(1, self.max_nodes): + cur_op_node = structure.nodes[i-1] + inter_nodes = [] + for op_name, j in cur_op_node: + node_str = '{:}<-{:}'.format(i, j) + op_index = self.op_names.index( op_name ) + inter_nodes.append( self.edges[node_str][op_index]( nodes[j] ) ) + nodes.append( sum(inter_nodes) ) + return nodes[-1] + + +class InferCell(nn.Module): + + def __init__(self, genotype, C_in, C_out, stride): + super(InferCell, self).__init__() + + self.layers = nn.ModuleList() + self.node_IN = [] + self.node_IX = [] + self.genotype = deepcopy(genotype) + for i in range(1, len(genotype)): + node_info = genotype[i-1] + cur_index = [] + cur_innod = [] + for (op_name, op_in) in node_info: + if op_in == 0: + layer = OPS[op_name](C_in , C_out, stride) + else: + layer = OPS[op_name](C_out, C_out, 1) + cur_index.append( len(self.layers) ) + cur_innod.append( op_in ) + self.layers.append( layer ) + self.node_IX.append( cur_index ) + self.node_IN.append( cur_innod ) + self.nodes = len(genotype) + self.in_dim = C_in + self.out_dim = C_out + + def extra_repr(self): + string = 'info :: nodes={nodes}, inC={in_dim}, outC={out_dim}'.format(**self.__dict__) + laystr = [] + for i, (node_layers, node_innods) in enumerate(zip(self.node_IX,self.node_IN)): + y = ['I{:}-L{:}'.format(_ii, _il) for _il, _ii in zip(node_layers, node_innods)] + x = '{:}<-({:})'.format(i+1, ','.join(y)) + laystr.append( x ) + return string + ', [{:}]'.format( ' | '.join(laystr) ) + ', {:}'.format(self.genotype.tostr()) + + def forward(self, inputs): + nodes = [inputs] + for i, (node_layers, node_innods) in enumerate(zip(self.node_IX,self.node_IN)): + node_feature = sum( self.layers[_il](nodes[_ii]) for _il, _ii in zip(node_layers, node_innods) ) + nodes.append( node_feature ) + return nodes[-1] + + + +class ResNetBasicblock(nn.Module): + + def __init__(self, inplanes, planes, stride): + super(ResNetBasicblock, self).__init__() + assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) + self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, 1, 1) + self.conv_b = ReLUConvBN( planes, planes, 3, 1, 1, 1) + if stride == 2: + self.downsample = nn.Sequential( + nn.AvgPool2d(kernel_size=2, stride=2, padding=0), + nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False)) + elif inplanes != planes: + self.downsample = ReLUConvBN(inplanes, planes, 1, 1, 0, 1) + else: + self.downsample = None + self.in_dim = inplanes + self.out_dim = planes + self.stride = stride + self.num_conv = 2 + + def extra_repr(self): + string = '{name}(inC={in_dim}, outC={out_dim}, stride={stride})'.format(name=self.__class__.__name__, **self.__dict__) + return string + + def forward(self, inputs): + + basicblock = self.conv_a(inputs) + basicblock = self.conv_b(basicblock) + + if self.downsample is not None: + residual = self.downsample(inputs) + else: + residual = inputs + return residual + basicblock diff --git a/lib/models/cell_searchs/operations.py b/lib/models/cell_searchs/operations.py new file mode 100644 index 0000000..85c5253 --- /dev/null +++ b/lib/models/cell_searchs/operations.py @@ -0,0 +1,113 @@ +import torch +import torch.nn as nn + +__all__ = ['OPS', 'ReLUConvBN', 'SearchSpaceNames'] + +OPS = { + 'none' : lambda C_in, C_out, stride: Zero(C_in, C_out, stride), + 'avg_pool_3x3' : lambda C_in, C_out, stride: POOLING(C_in, C_out, stride, 'avg'), + 'max_pool_3x3' : lambda C_in, C_out, stride: POOLING(C_in, C_out, stride, 'max'), + 'nor_conv_7x7' : lambda C_in, C_out, stride: ReLUConvBN(C_in, C_out, (7,7), (stride,stride), (3,3), (1,1)), + 'nor_conv_3x3' : lambda C_in, C_out, stride: ReLUConvBN(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1)), + 'nor_conv_1x1' : lambda C_in, C_out, stride: ReLUConvBN(C_in, C_out, (1,1), (stride,stride), (0,0), (1,1)), + 'skip_connect' : lambda C_in, C_out, stride: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride), +} + +CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3'] + +SearchSpaceNames = {'connect-nas' : CONNECT_NAS_BENCHMARK} + + +class POOLING(nn.Module): + + def __init__(self, C_in, C_out, stride, mode): + super(POOLING, self).__init__() + if C_in == C_out: + self.preprocess = None + else: + self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0) + if mode == 'avg' : self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False) + elif mode == 'max': self.op = nn.MaxPool2d(3, stride=stride, padding=1) + else : raise ValueError('Invalid mode={:} in POOLING'.format(mode)) + + def forward(self, inputs): + if self.preprocess: x = self.preprocess(inputs) + else : x = inputs + return self.op(x) + + +class ReLUConvBN(nn.Module): + + def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation): + super(ReLUConvBN, self).__init__() + self.op = nn.Sequential( + nn.ReLU(inplace=False), + nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False), + nn.BatchNorm2d(C_out) + ) + + def forward(self, x): + return self.op(x) + + +class Identity(nn.Module): + + def __init__(self): + super(Identity, self).__init__() + + def forward(self, x): + return x + + +class Zero(nn.Module): + + def __init__(self, C_in, C_out, stride): + super(Zero, self).__init__() + self.C_in = C_in + self.C_out = C_out + self.stride = stride + self.is_zero = True + + def forward(self, x): + if self.C_in == self.C_out: + if self.stride == 1: return x.mul(0.) + else : return x[:,:,::self.stride,::self.stride].mul(0.) + else: + shape = list(x.shape) + shape[1] = self.C_out + zeros = x.new_zeros(shape, dtype=x.dtype, device=x.device) + return zeros + + def extra_repr(self): + return 'C_in={C_in}, C_out={C_out}, stride={stride}'.format(**self.__dict__) + + +class FactorizedReduce(nn.Module): + + def __init__(self, C_in, C_out, stride): + super(FactorizedReduce, self).__init__() + self.stride = stride + self.C_in = C_in + self.C_out = C_out + self.relu = nn.ReLU(inplace=False) + if stride == 2: + #assert C_out % 2 == 0, 'C_out : {:}'.format(C_out) + C_outs = [C_out // 2, C_out - C_out // 2] + self.convs = nn.ModuleList() + for i in range(2): + self.convs.append( nn.Conv2d(C_in, C_outs[i], 1, stride=stride, padding=0, bias=False) ) + self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0) + else: + raise ValueError('Invalid stride : {:}'.format(stride)) + + self.bn = nn.BatchNorm2d(C_out) + + def forward(self, x): + x = self.relu(x) + y = self.pad(x) + out = torch.cat([self.convs[0](x), self.convs[1](y[:,:,1:,1:])], dim=1) + out = self.bn(out) + return out + + def extra_repr(self): + return 'C_in={C_in}, C_out={C_out}, stride={stride}'.format(**self.__dict__) diff --git a/lib/models/cell_searchs/search_model_gdas.py b/lib/models/cell_searchs/search_model_gdas.py new file mode 100644 index 0000000..6455d11 --- /dev/null +++ b/lib/models/cell_searchs/search_model_gdas.py @@ -0,0 +1,117 @@ +################################################## +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # +########################################################################### +# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 # +########################################################################### +import torch +import torch.nn as nn +from copy import deepcopy +from .infer_cells import ResNetBasicblock +from .search_cells import SearchCell +from .genotypes import Structure + + +class TinyNetworkGDAS(nn.Module): + + def __init__(self, C, N, max_nodes, num_classes, search_space): + super(TinyNetworkGDAS, self).__init__() + self._C = C + self._layerN = N + self.max_nodes = max_nodes + self.stem = nn.Sequential( + nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(C)) + + layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N + layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N + + C_prev, num_edge, edge2index = C, None, None + self.cells = nn.ModuleList() + for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): + if reduction: + cell = ResNetBasicblock(C_prev, C_curr, 2) + else: + cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space) + if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index + else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) + self.cells.append( cell ) + C_prev = cell.out_dim + self.op_names = deepcopy( search_space ) + self._Layer = len(self.cells) + self.edge2index = edge2index + self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) + self.global_pooling = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Linear(C_prev, num_classes) + self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) + self.tau = 10 + self.nan_count = 0 + + def get_weights(self): + xlist = list( self.stem.parameters() ) + list( self.cells.parameters() ) + xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() ) + xlist+= list( self.classifier.parameters() ) + return xlist + + def set_tau(self, tau, _nan_count=0): + self.tau = tau + self.nan_count = _nan_count + + def get_tau(self): + return self.tau + + def get_alphas(self): + return [self.arch_parameters] + + def get_message(self): + string = self.extra_repr() + for i, cell in enumerate(self.cells): + string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) + return string + + def extra_repr(self): + return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) + + def genotype(self): + genotypes = [] + for i in range(1, self.max_nodes): + xlist = [] + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + with torch.no_grad(): + weights = self.arch_parameters[ self.edge2index[node_str] ] + op_name = self.op_names[ weights.argmax().item() ] + xlist.append((op_name, j)) + genotypes.append( tuple(xlist) ) + return Structure( genotypes ) + + def forward(self, inputs): + def gumbel_softmax(_logits, _tau): + while True: # a trick to avoid the gumbels bug + gumbels = -torch.empty_like(_logits).exponential_().log() + new_logits = (_logits.log_softmax(dim=1) + gumbels) / _tau + probs = nn.functional.softmax(new_logits, dim=1) + index = probs.max(-1, keepdim=True)[1] + if index[0].item() == self.op_names.index('none') and index[3].item() == self.op_names.index('none') and index[5].item() == self.op_names.index('none'): continue + if index[1].item() == self.op_names.index('none') and index[2].item() == self.op_names.index('none') and index[3].item() == self.op_names.index('none') and index[4].item() == self.op_names.index('none'): continue + if index[3].item() == self.op_names.index('none') and index[4].item() == self.op_names.index('none') and index[5].item() == self.op_names.index('none'): continue + if index[3].item() == self.op_names.index('none') and index[0].item() == self.op_names.index('none') and index[1].item() == self.op_names.index('none'): continue + one_h = torch.zeros_like(_logits).scatter_(-1, index, 1.0) + xres = one_h - probs.detach() + probs + if (not torch.isinf(gumbels).any()) and (not torch.isinf(probs).any()) and (not torch.isnan(probs).any()): break + self.nan_count += 1 + return xres, index + + feature = self.stem(inputs) + for i, cell in enumerate(self.cells): + if isinstance(cell, SearchCell): + alphas, IDX = gumbel_softmax(self.arch_parameters, self.tau) + feature = cell.forward_gdas(feature, alphas, IDX.cpu()) + else: + feature = cell(feature) + + out = self.lastact(feature) + out = self.global_pooling( out ) + out = out.view(out.size(0), -1) + logits = self.classifier(out) + + return out, logits diff --git a/lib/models/searchs/SoftSelect.py b/lib/models/searchs/SoftSelect.py index 84f2ad8..b120c8b 100644 --- a/lib/models/searchs/SoftSelect.py +++ b/lib/models/searchs/SoftSelect.py @@ -9,8 +9,7 @@ def select2withP(logits, tau, just_prob=False, num=2, eps=1e-7): else : while True: # a trick to avoid the gumbels bug gumbels = -torch.empty_like(logits).exponential_().log() - new_logits = (logits + gumbels) / tau - #new_logits = (logits.log_softmax(dim=1) + gumbels) / tau + new_logits = (logits.log_softmax(dim=1) + gumbels) / tau probs = nn.functional.softmax(new_logits, dim=1) if (not torch.isinf(gumbels).any()) and (not torch.isinf(probs).any()) and (not torch.isnan(probs).any()): break