diff --git a/lib/models/cell_searchs/search_model_gdas.py b/lib/models/cell_searchs/search_model_gdas.py index 84ddcce..1d2731d 100644 --- a/lib/models/cell_searchs/search_model_gdas.py +++ b/lib/models/cell_searchs/search_model_gdas.py @@ -88,7 +88,9 @@ class TinyNetworkGDAS(nn.Module): index = probs.max(-1, keepdim=True)[1] one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0) hardwts = one_h - probs.detach() + probs - if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()): continue + if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()): + continue + else: break feature = self.stem(inputs) for i, cell in enumerate(self.cells): diff --git a/lib/models/l2s_cell_searchs/__init__.py b/lib/models/l2s_cell_searchs/__init__.py deleted file mode 100644 index 2133795..0000000 --- a/lib/models/l2s_cell_searchs/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -################################################## -# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # -################################################## -from .search_model_darts_v1 import TinyNetworkDartsV1 -from .search_model_darts_v2 import TinyNetworkDartsV2 -from .search_model_gdas import TinyNetworkGDAS -from .search_model_setn import TinyNetworkSETN -from .search_model_enas import TinyNetworkENAS -from .search_model_random import TinyNetworkRANDOM -from .genotypes import Structure as CellStructure, architectures as CellArchitectures - -nas_super_nets = {'DARTS-V1': TinyNetworkDartsV1, - 'DARTS-V2': TinyNetworkDartsV2, - 'GDAS' : TinyNetworkGDAS, - 'SETN' : TinyNetworkSETN, - 'ENAS' : TinyNetworkENAS, - 'RANDOM' : TinyNetworkRANDOM} diff --git a/lib/models/l2s_cell_searchs/_test_module.py b/lib/models/l2s_cell_searchs/_test_module.py deleted file mode 100644 index c603ba6..0000000 --- a/lib/models/l2s_cell_searchs/_test_module.py +++ /dev/null @@ -1,12 +0,0 @@ -################################################## -# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # -################################################## -import torch -from search_model_enas_utils import Controller - -def main(): - controller = Controller(6, 4) - predictions = controller() - -if __name__ == '__main__': - main() diff --git a/lib/models/l2s_cell_searchs/genotypes.py b/lib/models/l2s_cell_searchs/genotypes.py deleted file mode 100644 index e0f2e2e..0000000 --- a/lib/models/l2s_cell_searchs/genotypes.py +++ /dev/null @@ -1,197 +0,0 @@ -################################################## -# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # -################################################## -from copy import deepcopy - - - -def get_combination(space, num): - combs = [] - for i in range(num): - if i == 0: - for func in space: - combs.append( [(func, i)] ) - else: - new_combs = [] - for string in combs: - for func in space: - xstring = string + [(func, i)] - new_combs.append( xstring ) - combs = new_combs - return combs - - - -class Structure: - - def __init__(self, genotype): - assert isinstance(genotype, list) or isinstance(genotype, tuple), 'invalid class of genotype : {:}'.format(type(genotype)) - self.node_num = len(genotype) + 1 - self.nodes = [] - self.node_N = [] - for idx, node_info in enumerate(genotype): - assert isinstance(node_info, list) or isinstance(node_info, tuple), 'invalid class of node_info : {:}'.format(type(node_info)) - assert len(node_info) >= 1, 'invalid length : {:}'.format(len(node_info)) - for node_in in node_info: - assert isinstance(node_in, list) or isinstance(node_in, tuple), 'invalid class of in-node : {:}'.format(type(node_in)) - assert len(node_in) == 2 and node_in[1] <= idx, 'invalid in-node : {:}'.format(node_in) - self.node_N.append( len(node_info) ) - self.nodes.append( tuple(deepcopy(node_info)) ) - - def tolist(self, remove_str): - # convert this class to the list, if remove_str is 'none', then remove the 'none' operation. - # note that we re-order the input node in this function - # return the-genotype-list and success [if unsuccess, it is not a connectivity] - genotypes = [] - for node_info in self.nodes: - node_info = list( node_info ) - node_info = sorted(node_info, key=lambda x: (x[1], x[0])) - node_info = tuple(filter(lambda x: x[0] != remove_str, node_info)) - if len(node_info) == 0: return None, False - genotypes.append( node_info ) - return genotypes, True - - def node(self, index): - assert index > 0 and index <= len(self), 'invalid index={:} < {:}'.format(index, len(self)) - return self.nodes[index] - - def tostr(self): - strings = [] - for node_info in self.nodes: - string = '|'.join([x[0]+'~{:}'.format(x[1]) for x in node_info]) - string = '|{:}|'.format(string) - strings.append( string ) - return '+'.join(strings) - - def check_valid(self): - nodes = {0: True} - for i, node_info in enumerate(self.nodes): - sums = [] - for op, xin in node_info: - if op == 'none' or nodes[xin] == False: x = False - else: x = True - sums.append( x ) - nodes[i+1] = sum(sums) > 0 - return nodes[len(self.nodes)] - - def to_unique_str(self, consider_zero=False): - # this is used to identify the isomorphic cell, which rerquires the prior knowledge of operation - # two operations are special, i.e., none and skip_connect - nodes = {0: '0'} - for i_node, node_info in enumerate(self.nodes): - cur_node = [] - for op, xin in node_info: - if consider_zero: - if op == 'none' or nodes[xin] == '#': x = '#' # zero - elif op == 'skip_connect': x = nodes[xin] - else: x = '('+nodes[xin]+')' + '@{:}'.format(op) - else: - if op == 'skip_connect': x = nodes[xin] - else: x = '('+nodes[xin]+')' + '@{:}'.format(op) - cur_node.append(x) - nodes[i_node+1] = '+'.join( sorted(cur_node) ) - return nodes[ len(self.nodes) ] - - def check_valid_op(self, op_names): - for node_info in self.nodes: - for inode_edge in node_info: - #assert inode_edge[0] in op_names, 'invalid op-name : {:}'.format(inode_edge[0]) - if inode_edge[0] not in op_names: return False - return True - - def __repr__(self): - return ('{name}({node_num} nodes with {node_info})'.format(name=self.__class__.__name__, node_info=self.tostr(), **self.__dict__)) - - def __len__(self): - return len(self.nodes) + 1 - - def __getitem__(self, index): - return self.nodes[index] - - @staticmethod - def str2structure(xstr): - assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr)) - nodestrs = xstr.split('+') - genotypes = [] - for i, node_str in enumerate(nodestrs): - inputs = list(filter(lambda x: x != '', node_str.split('|'))) - for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput) - inputs = ( xi.split('~') for xi in inputs ) - input_infos = tuple( (op, int(IDX)) for (op, IDX) in inputs) - genotypes.append( input_infos ) - return Structure( genotypes ) - - @staticmethod - def str2fullstructure(xstr, default_name='none'): - assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr)) - nodestrs = xstr.split('+') - genotypes = [] - for i, node_str in enumerate(nodestrs): - inputs = list(filter(lambda x: x != '', node_str.split('|'))) - for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput) - inputs = ( xi.split('~') for xi in inputs ) - input_infos = list( (op, int(IDX)) for (op, IDX) in inputs) - all_in_nodes= list(x[1] for x in input_infos) - for j in range(i): - if j not in all_in_nodes: input_infos.append((default_name, j)) - node_info = sorted(input_infos, key=lambda x: (x[1], x[0])) - genotypes.append( tuple(node_info) ) - return Structure( genotypes ) - - @staticmethod - def gen_all(search_space, num, return_ori): - assert isinstance(search_space, list) or isinstance(search_space, tuple), 'invalid class of search-space : {:}'.format(type(search_space)) - assert num >= 2, 'There should be at least two nodes in a neural cell instead of {:}'.format(num) - all_archs = get_combination(search_space, 1) - for i, arch in enumerate(all_archs): - all_archs[i] = [ tuple(arch) ] - - for inode in range(2, num): - cur_nodes = get_combination(search_space, inode) - new_all_archs = [] - for previous_arch in all_archs: - for cur_node in cur_nodes: - new_all_archs.append( previous_arch + [tuple(cur_node)] ) - all_archs = new_all_archs - if return_ori: - return all_archs - else: - return [Structure(x) for x in all_archs] - - - -ResNet_CODE = Structure( - [(('nor_conv_3x3', 0), ), # node-1 - (('nor_conv_3x3', 1), ), # node-2 - (('skip_connect', 0), ('skip_connect', 2))] # node-3 - ) - -AllConv3x3_CODE = Structure( - [(('nor_conv_3x3', 0), ), # node-1 - (('nor_conv_3x3', 0), ('nor_conv_3x3', 1)), # node-2 - (('nor_conv_3x3', 0), ('nor_conv_3x3', 1), ('nor_conv_3x3', 2))] # node-3 - ) - -AllFull_CODE = Structure( - [(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0)), # node-1 - (('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0), ('skip_connect', 1), ('nor_conv_1x1', 1), ('nor_conv_3x3', 1), ('avg_pool_3x3', 1)), # node-2 - (('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0), ('skip_connect', 1), ('nor_conv_1x1', 1), ('nor_conv_3x3', 1), ('avg_pool_3x3', 1), ('skip_connect', 2), ('nor_conv_1x1', 2), ('nor_conv_3x3', 2), ('avg_pool_3x3', 2))] # node-3 - ) - -AllConv1x1_CODE = Structure( - [(('nor_conv_1x1', 0), ), # node-1 - (('nor_conv_1x1', 0), ('nor_conv_1x1', 1)), # node-2 - (('nor_conv_1x1', 0), ('nor_conv_1x1', 1), ('nor_conv_1x1', 2))] # node-3 - ) - -AllIdentity_CODE = Structure( - [(('skip_connect', 0), ), # node-1 - (('skip_connect', 0), ('skip_connect', 1)), # node-2 - (('skip_connect', 0), ('skip_connect', 1), ('skip_connect', 2))] # node-3 - ) - -architectures = {'resnet' : ResNet_CODE, - 'all_c3x3': AllConv3x3_CODE, - 'all_c1x1': AllConv1x1_CODE, - 'all_idnt': AllIdentity_CODE, - 'all_full': AllFull_CODE} diff --git a/lib/models/l2s_cell_searchs/search_cells.py b/lib/models/l2s_cell_searchs/search_cells.py deleted file mode 100644 index fba750f..0000000 --- a/lib/models/l2s_cell_searchs/search_cells.py +++ /dev/null @@ -1,148 +0,0 @@ -################################################## -# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # -################################################## -import math, random, torch -import warnings -import torch.nn as nn -import torch.nn.functional as F -from copy import deepcopy -from ..cell_operations import OPS - - -class SearchCell(nn.Module): - - def __init__(self, C_in, C_out, stride, max_nodes, op_names, n_piece): - super(SearchCell, self).__init__() - - self.op_names = deepcopy(op_names) - self.max_nodes = max_nodes - self.in_dim = C_in - self.out_dim = C_out - self.n_piece = n_piece - self.multi_edges = nn.ModuleList() - for i_piece in range(n_piece): - edges = nn.ModuleDict() - for i in range(1, max_nodes): - for j in range(i): - node_str = '{:}<-{:}'.format(i, j) - if j == 0: xlists = [OPS[op_name](C_in , C_out, stride) for op_name in op_names] - else : xlists = [OPS[op_name](C_in , C_out, 1) for op_name in op_names] - edges[ node_str ] = nn.ModuleList( xlists ) - self.multi_edges.append( edges ) - - self.edge_keys = sorted(list(edges.keys())) - self.edge2index = {key:i for i, key in enumerate(self.edge_keys)} - self.num_edges = len(edges) - - def extra_repr(self): - string = 'info :: {max_nodes} nodes, inC={in_dim}, outC={out_dim}, nP={n_piece}'.format(**self.__dict__) - return string - - def forward(self, inputs, weightss): - nodes = [inputs] - with torch.no_grad(): - xmod, xid, argmax = 1, 0, weightss.argmax(dim=1).cpu().tolist() - for i, x in enumerate(argmax): - xid += x * (xmod % self.n_piece) - xmod = (xmod * len(self.op_names)) % self.n_piece - xid = xid % self.n_piece - edges = self.multi_edges[xid] - for i in range(1, self.max_nodes): - inter_nodes = [] - for j in range(i): - node_str = '{:}<-{:}'.format(i, j) - weights = weightss[ self.edge2index[node_str] ] - inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(edges[node_str], weights) ) ) - nodes.append( sum(inter_nodes) ) - return nodes[-1] - - # GDAS - def forward_gdas(self, inputs, alphas, _tau): - avoid_zero = 0 - while True: - gumbels = -torch.empty_like(alphas).exponential_().log() - logits = (alphas.log_softmax(dim=1) + gumbels) / _tau - probs = nn.functional.softmax(logits, dim=1) - index = probs.max(-1, keepdim=True)[1] - one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0) - hardwts = one_h - probs.detach() + probs - if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()): - continue # avoid the numerical error - nodes = [inputs] - for i in range(1, self.max_nodes): - inter_nodes = [] - for j in range(i): - node_str = '{:}<-{:}'.format(i, j) - weights = hardwts[ self.edge2index[node_str] ] - argmaxs = index[ self.edge2index[node_str] ].item() - weigsum = sum( weights[_ie] * edge(nodes[j]) if _ie == argmaxs else weights[_ie] for _ie, edge in enumerate(self.edges[node_str]) ) - inter_nodes.append( weigsum ) - nodes.append( sum(inter_nodes) ) - avoid_zero += 1 - if nodes[-1].sum().item() == 0: - if avoid_zero < 10: continue - else: - warnings.warn('get zero outputs with avoid_zero={:}'.format(avoid_zero)) - break - else: - break - return nodes[-1] - - # joint - def forward_joint(self, inputs, weightss): - nodes = [inputs] - for i in range(1, self.max_nodes): - inter_nodes = [] - for j in range(i): - node_str = '{:}<-{:}'.format(i, j) - weights = weightss[ self.edge2index[node_str] ] - #aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) / weights.numel() - aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) - inter_nodes.append( aggregation ) - nodes.append( sum(inter_nodes) ) - return nodes[-1] - - # uniform random sampling per iteration - def forward_urs(self, inputs): - nodes = [inputs] - for i in range(1, self.max_nodes): - while True: # to avoid select zero for all ops - sops, has_non_zero = [], False - for j in range(i): - node_str = '{:}<-{:}'.format(i, j) - candidates = self.edges[node_str] - select_op = random.choice(candidates) - sops.append( select_op ) - if not hasattr(select_op, 'is_zero') or select_op.is_zero == False: has_non_zero=True - if has_non_zero: break - inter_nodes = [] - for j, select_op in enumerate(sops): - inter_nodes.append( select_op(nodes[j]) ) - nodes.append( sum(inter_nodes) ) - return nodes[-1] - - # select the argmax - def forward_select(self, inputs, weightss): - nodes = [inputs] - for i in range(1, self.max_nodes): - inter_nodes = [] - for j in range(i): - node_str = '{:}<-{:}'.format(i, j) - weights = weightss[ self.edge2index[node_str] ] - inter_nodes.append( self.edges[node_str][ weights.argmax().item() ]( nodes[j] ) ) - #inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) ) - nodes.append( sum(inter_nodes) ) - return nodes[-1] - - # forward with a specific structure - def forward_dynamic(self, inputs, structure): - nodes = [inputs] - for i in range(1, self.max_nodes): - cur_op_node = structure.nodes[i-1] - inter_nodes = [] - for op_name, j in cur_op_node: - node_str = '{:}<-{:}'.format(i, j) - op_index = self.op_names.index( op_name ) - inter_nodes.append( self.edges[node_str][op_index]( nodes[j] ) ) - nodes.append( sum(inter_nodes) ) - return nodes[-1] diff --git a/lib/models/l2s_cell_searchs/search_model_darts_v1.py b/lib/models/l2s_cell_searchs/search_model_darts_v1.py deleted file mode 100644 index ffc381e..0000000 --- a/lib/models/l2s_cell_searchs/search_model_darts_v1.py +++ /dev/null @@ -1,93 +0,0 @@ -################################################## -# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # -######################################################## -# DARTS: Differentiable Architecture Search, ICLR 2019 # -######################################################## -import torch -import torch.nn as nn -from copy import deepcopy -from ..cell_operations import ResNetBasicblock -from .search_cells import SearchCell -from .genotypes import Structure - - -class TinyNetworkDartsV1(nn.Module): - - def __init__(self, C, N, max_nodes, num_classes, search_space, n_piece): - super(TinyNetworkDartsV1, self).__init__() - self._C = C - self._layerN = N - self.max_nodes = max_nodes - self.stem = nn.Sequential( - nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), - nn.BatchNorm2d(C)) - - layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N - layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N - - C_prev, num_edge, edge2index = C, None, None - self.cells = nn.ModuleList() - for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): - if reduction: - cell = ResNetBasicblock(C_prev, C_curr, 2) - else: - cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, n_piece) - if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index - else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) - self.cells.append( cell ) - C_prev = cell.out_dim - self.op_names = deepcopy( search_space ) - self._Layer = len(self.cells) - self.edge2index = edge2index - self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) - self.global_pooling = nn.AdaptiveAvgPool2d(1) - self.classifier = nn.Linear(C_prev, num_classes) - self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) - - def get_weights(self): - xlist = list( self.stem.parameters() ) + list( self.cells.parameters() ) - xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() ) - xlist+= list( self.classifier.parameters() ) - return xlist - - def get_alphas(self): - return [self.arch_parameters] - - def get_message(self): - string = self.extra_repr() - for i, cell in enumerate(self.cells): - string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) - return string - - def extra_repr(self): - return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) - - def genotype(self): - genotypes = [] - for i in range(1, self.max_nodes): - xlist = [] - for j in range(i): - node_str = '{:}<-{:}'.format(i, j) - with torch.no_grad(): - weights = self.arch_parameters[ self.edge2index[node_str] ] - op_name = self.op_names[ weights.argmax().item() ] - xlist.append((op_name, j)) - genotypes.append( tuple(xlist) ) - return Structure( genotypes ) - - def forward(self, inputs): - alphas = nn.functional.softmax(self.arch_parameters, dim=-1) - - feature = self.stem(inputs) - for i, cell in enumerate(self.cells): - if isinstance(cell, SearchCell): - feature = cell(feature, alphas) - else: - feature = cell(feature) - - out = self.lastact(feature) - out = self.global_pooling( out ) - out = out.view(out.size(0), -1) - logits = self.classifier(out) - - return out, logits diff --git a/lib/models/l2s_cell_searchs/search_model_darts_v2.py b/lib/models/l2s_cell_searchs/search_model_darts_v2.py deleted file mode 100644 index cb996ff..0000000 --- a/lib/models/l2s_cell_searchs/search_model_darts_v2.py +++ /dev/null @@ -1,93 +0,0 @@ -################################################## -# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # -######################################################## -# DARTS: Differentiable Architecture Search, ICLR 2019 # -######################################################## -import torch -import torch.nn as nn -from copy import deepcopy -from ..cell_operations import ResNetBasicblock -from .search_cells import SearchCell -from .genotypes import Structure - - -class TinyNetworkDartsV2(nn.Module): - - def __init__(self, C, N, max_nodes, num_classes, search_space): - super(TinyNetworkDartsV2, self).__init__() - self._C = C - self._layerN = N - self.max_nodes = max_nodes - self.stem = nn.Sequential( - nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), - nn.BatchNorm2d(C)) - - layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N - layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N - - C_prev, num_edge, edge2index = C, None, None - self.cells = nn.ModuleList() - for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): - if reduction: - cell = ResNetBasicblock(C_prev, C_curr, 2) - else: - cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space) - if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index - else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) - self.cells.append( cell ) - C_prev = cell.out_dim - self.op_names = deepcopy( search_space ) - self._Layer = len(self.cells) - self.edge2index = edge2index - self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) - self.global_pooling = nn.AdaptiveAvgPool2d(1) - self.classifier = nn.Linear(C_prev, num_classes) - self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) - - def get_weights(self): - xlist = list( self.stem.parameters() ) + list( self.cells.parameters() ) - xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() ) - xlist+= list( self.classifier.parameters() ) - return xlist - - def get_alphas(self): - return [self.arch_parameters] - - def get_message(self): - string = self.extra_repr() - for i, cell in enumerate(self.cells): - string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) - return string - - def extra_repr(self): - return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) - - def genotype(self): - genotypes = [] - for i in range(1, self.max_nodes): - xlist = [] - for j in range(i): - node_str = '{:}<-{:}'.format(i, j) - with torch.no_grad(): - weights = self.arch_parameters[ self.edge2index[node_str] ] - op_name = self.op_names[ weights.argmax().item() ] - xlist.append((op_name, j)) - genotypes.append( tuple(xlist) ) - return Structure( genotypes ) - - def forward(self, inputs): - alphas = nn.functional.softmax(self.arch_parameters, dim=-1) - - feature = self.stem(inputs) - for i, cell in enumerate(self.cells): - if isinstance(cell, SearchCell): - feature = cell(feature, alphas) - else: - feature = cell(feature) - - out = self.lastact(feature) - out = self.global_pooling( out ) - out = out.view(out.size(0), -1) - logits = self.classifier(out) - - return out, logits diff --git a/lib/models/l2s_cell_searchs/search_model_enas.py b/lib/models/l2s_cell_searchs/search_model_enas.py deleted file mode 100644 index 2422b52..0000000 --- a/lib/models/l2s_cell_searchs/search_model_enas.py +++ /dev/null @@ -1,94 +0,0 @@ -################################################## -# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # -########################################################################## -# Efficient Neural Architecture Search via Parameters Sharing, ICML 2018 # -########################################################################## -import torch -import torch.nn as nn -from copy import deepcopy -from ..cell_operations import ResNetBasicblock -from .search_cells import SearchCell -from .genotypes import Structure -from .search_model_enas_utils import Controller - - -class TinyNetworkENAS(nn.Module): - - def __init__(self, C, N, max_nodes, num_classes, search_space): - super(TinyNetworkENAS, self).__init__() - self._C = C - self._layerN = N - self.max_nodes = max_nodes - self.stem = nn.Sequential( - nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), - nn.BatchNorm2d(C)) - - layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N - layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N - - C_prev, num_edge, edge2index = C, None, None - self.cells = nn.ModuleList() - for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): - if reduction: - cell = ResNetBasicblock(C_prev, C_curr, 2) - else: - cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space) - if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index - else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) - self.cells.append( cell ) - C_prev = cell.out_dim - self.op_names = deepcopy( search_space ) - self._Layer = len(self.cells) - self.edge2index = edge2index - self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) - self.global_pooling = nn.AdaptiveAvgPool2d(1) - self.classifier = nn.Linear(C_prev, num_classes) - # to maintain the sampled architecture - self.sampled_arch = None - - def update_arch(self, _arch): - if _arch is None: - self.sampled_arch = None - elif isinstance(_arch, Structure): - self.sampled_arch = _arch - elif isinstance(_arch, (list, tuple)): - genotypes = [] - for i in range(1, self.max_nodes): - xlist = [] - for j in range(i): - node_str = '{:}<-{:}'.format(i, j) - op_index = _arch[ self.edge2index[node_str] ] - op_name = self.op_names[ op_index ] - xlist.append((op_name, j)) - genotypes.append( tuple(xlist) ) - self.sampled_arch = Structure(genotypes) - else: - raise ValueError('invalid type of input architecture : {:}'.format(_arch)) - return self.sampled_arch - - def create_controller(self): - return Controller(len(self.edge2index), len(self.op_names)) - - def get_message(self): - string = self.extra_repr() - for i, cell in enumerate(self.cells): - string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) - return string - - def extra_repr(self): - return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) - - def forward(self, inputs): - - feature = self.stem(inputs) - for i, cell in enumerate(self.cells): - if isinstance(cell, SearchCell): - feature = cell.forward_dynamic(feature, self.sampled_arch) - else: feature = cell(feature) - - out = self.lastact(feature) - out = self.global_pooling( out ) - out = out.view(out.size(0), -1) - logits = self.classifier(out) - - return out, logits diff --git a/lib/models/l2s_cell_searchs/search_model_enas_utils.py b/lib/models/l2s_cell_searchs/search_model_enas_utils.py deleted file mode 100644 index e03f57b..0000000 --- a/lib/models/l2s_cell_searchs/search_model_enas_utils.py +++ /dev/null @@ -1,55 +0,0 @@ -################################################## -# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # -########################################################################## -# Efficient Neural Architecture Search via Parameters Sharing, ICML 2018 # -########################################################################## -import torch -import torch.nn as nn -from torch.distributions.categorical import Categorical - -class Controller(nn.Module): - # we refer to https://github.com/TDeVries/enas_pytorch/blob/master/models/controller.py - def __init__(self, num_edge, num_ops, lstm_size=32, lstm_num_layers=2, tanh_constant=2.5, temperature=5.0): - super(Controller, self).__init__() - # assign the attributes - self.num_edge = num_edge - self.num_ops = num_ops - self.lstm_size = lstm_size - self.lstm_N = lstm_num_layers - self.tanh_constant = tanh_constant - self.temperature = temperature - # create parameters - self.register_parameter('input_vars', nn.Parameter(torch.Tensor(1, 1, lstm_size))) - self.w_lstm = nn.LSTM(input_size=self.lstm_size, hidden_size=self.lstm_size, num_layers=self.lstm_N) - self.w_embd = nn.Embedding(self.num_ops, self.lstm_size) - self.w_pred = nn.Linear(self.lstm_size, self.num_ops) - - nn.init.uniform_(self.input_vars , -0.1, 0.1) - nn.init.uniform_(self.w_lstm.weight_hh_l0, -0.1, 0.1) - nn.init.uniform_(self.w_lstm.weight_ih_l0, -0.1, 0.1) - nn.init.uniform_(self.w_embd.weight , -0.1, 0.1) - nn.init.uniform_(self.w_pred.weight , -0.1, 0.1) - - def forward(self): - - inputs, h0 = self.input_vars, None - log_probs, entropys, sampled_arch = [], [], [] - for iedge in range(self.num_edge): - outputs, h0 = self.w_lstm(inputs, h0) - - logits = self.w_pred(outputs) - logits = logits / self.temperature - logits = self.tanh_constant * torch.tanh(logits) - # distribution - op_distribution = Categorical(logits=logits) - op_index = op_distribution.sample() - sampled_arch.append( op_index.item() ) - - op_log_prob = op_distribution.log_prob(op_index) - log_probs.append( op_log_prob.view(-1) ) - op_entropy = op_distribution.entropy() - entropys.append( op_entropy.view(-1) ) - - # obtain the input embedding for the next step - inputs = self.w_embd(op_index) - return torch.sum(torch.cat(log_probs)), torch.sum(torch.cat(entropys)), sampled_arch diff --git a/lib/models/l2s_cell_searchs/search_model_gdas.py b/lib/models/l2s_cell_searchs/search_model_gdas.py deleted file mode 100644 index 6a4dd4e..0000000 --- a/lib/models/l2s_cell_searchs/search_model_gdas.py +++ /dev/null @@ -1,96 +0,0 @@ -########################################################################### -# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 # -########################################################################### -import torch -import torch.nn as nn -from copy import deepcopy -from ..cell_operations import ResNetBasicblock -from .search_cells import SearchCell -from .genotypes import Structure - - -class TinyNetworkGDAS(nn.Module): - - def __init__(self, C, N, max_nodes, num_classes, search_space): - super(TinyNetworkGDAS, self).__init__() - self._C = C - self._layerN = N - self.max_nodes = max_nodes - self.stem = nn.Sequential( - nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), - nn.BatchNorm2d(C)) - - layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N - layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N - - C_prev, num_edge, edge2index = C, None, None - self.cells = nn.ModuleList() - for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): - if reduction: - cell = ResNetBasicblock(C_prev, C_curr, 2) - else: - cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space) - if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index - else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) - self.cells.append( cell ) - C_prev = cell.out_dim - self.op_names = deepcopy( search_space ) - self._Layer = len(self.cells) - self.edge2index = edge2index - self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) - self.global_pooling = nn.AdaptiveAvgPool2d(1) - self.classifier = nn.Linear(C_prev, num_classes) - self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) - self.tau = 10 - - def get_weights(self): - xlist = list( self.stem.parameters() ) + list( self.cells.parameters() ) - xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() ) - xlist+= list( self.classifier.parameters() ) - return xlist - - def set_tau(self, tau): - self.tau = tau - - def get_tau(self): - return self.tau - - def get_alphas(self): - return [self.arch_parameters] - - def get_message(self): - string = self.extra_repr() - for i, cell in enumerate(self.cells): - string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) - return string - - def extra_repr(self): - return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) - - def genotype(self): - genotypes = [] - for i in range(1, self.max_nodes): - xlist = [] - for j in range(i): - node_str = '{:}<-{:}'.format(i, j) - with torch.no_grad(): - weights = self.arch_parameters[ self.edge2index[node_str] ] - op_name = self.op_names[ weights.argmax().item() ] - xlist.append((op_name, j)) - genotypes.append( tuple(xlist) ) - return Structure( genotypes ) - - def forward(self, inputs): - feature = self.stem(inputs) - for i, cell in enumerate(self.cells): - if isinstance(cell, SearchCell): - feature = cell.forward_gdas(feature, self.arch_parameters, self.tau) - else: - feature = cell(feature) - - out = self.lastact(feature) - out = self.global_pooling( out ) - out = out.view(out.size(0), -1) - logits = self.classifier(out) - - return out, logits diff --git a/lib/models/l2s_cell_searchs/search_model_random.py b/lib/models/l2s_cell_searchs/search_model_random.py deleted file mode 100644 index c2f83f9..0000000 --- a/lib/models/l2s_cell_searchs/search_model_random.py +++ /dev/null @@ -1,81 +0,0 @@ -################################################## -# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # -############################################################################## -# Random Search and Reproducibility for Neural Architecture Search, UAI 2019 # -############################################################################## -import torch, random -import torch.nn as nn -from copy import deepcopy -from ..cell_operations import ResNetBasicblock -from .search_cells import SearchCell -from .genotypes import Structure - - -class TinyNetworkRANDOM(nn.Module): - - def __init__(self, C, N, max_nodes, num_classes, search_space): - super(TinyNetworkRANDOM, self).__init__() - self._C = C - self._layerN = N - self.max_nodes = max_nodes - self.stem = nn.Sequential( - nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), - nn.BatchNorm2d(C)) - - layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N - layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N - - C_prev, num_edge, edge2index = C, None, None - self.cells = nn.ModuleList() - for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): - if reduction: - cell = ResNetBasicblock(C_prev, C_curr, 2) - else: - cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space) - if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index - else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) - self.cells.append( cell ) - C_prev = cell.out_dim - self.op_names = deepcopy( search_space ) - self._Layer = len(self.cells) - self.edge2index = edge2index - self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) - self.global_pooling = nn.AdaptiveAvgPool2d(1) - self.classifier = nn.Linear(C_prev, num_classes) - self.arch_cache = None - - def get_message(self): - string = self.extra_repr() - for i, cell in enumerate(self.cells): - string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) - return string - - def extra_repr(self): - return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) - - def random_genotype(self, set_cache): - genotypes = [] - for i in range(1, self.max_nodes): - xlist = [] - for j in range(i): - node_str = '{:}<-{:}'.format(i, j) - op_name = random.choice( self.op_names ) - xlist.append((op_name, j)) - genotypes.append( tuple(xlist) ) - arch = Structure( genotypes ) - if set_cache: self.arch_cache = arch - return arch - - def forward(self, inputs): - - feature = self.stem(inputs) - for i, cell in enumerate(self.cells): - if isinstance(cell, SearchCell): - feature = cell.forward_dynamic(feature, self.arch_cache) - else: feature = cell(feature) - - out = self.lastact(feature) - out = self.global_pooling( out ) - out = out.view(out.size(0), -1) - logits = self.classifier(out) - return out, logits diff --git a/lib/models/l2s_cell_searchs/search_model_setn.py b/lib/models/l2s_cell_searchs/search_model_setn.py deleted file mode 100644 index 5864f32..0000000 --- a/lib/models/l2s_cell_searchs/search_model_setn.py +++ /dev/null @@ -1,152 +0,0 @@ -################################################## -# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # -###################################################################################### -# One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 # -###################################################################################### -import torch, random -import torch.nn as nn -from copy import deepcopy -from ..cell_operations import ResNetBasicblock -from .search_cells import SearchCell -from .genotypes import Structure - - -class TinyNetworkSETN(nn.Module): - - def __init__(self, C, N, max_nodes, num_classes, search_space): - super(TinyNetworkSETN, self).__init__() - self._C = C - self._layerN = N - self.max_nodes = max_nodes - self.stem = nn.Sequential( - nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), - nn.BatchNorm2d(C)) - - layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N - layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N - - C_prev, num_edge, edge2index = C, None, None - self.cells = nn.ModuleList() - for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): - if reduction: - cell = ResNetBasicblock(C_prev, C_curr, 2) - else: - cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space) - if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index - else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) - self.cells.append( cell ) - C_prev = cell.out_dim - self.op_names = deepcopy( search_space ) - self._Layer = len(self.cells) - self.edge2index = edge2index - self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) - self.global_pooling = nn.AdaptiveAvgPool2d(1) - self.classifier = nn.Linear(C_prev, num_classes) - self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) - self.mode = 'urs' - self.dynamic_cell = None - - def set_cal_mode(self, mode, dynamic_cell=None): - assert mode in ['urs', 'joint', 'select', 'dynamic'] - self.mode = mode - if mode == 'dynamic': self.dynamic_cell = deepcopy( dynamic_cell ) - else : self.dynamic_cell = None - - def get_cal_mode(self): - return self.mode - - def get_weights(self): - xlist = list( self.stem.parameters() ) + list( self.cells.parameters() ) - xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() ) - xlist+= list( self.classifier.parameters() ) - return xlist - - def get_alphas(self): - return [self.arch_parameters] - - def get_message(self): - string = self.extra_repr() - for i, cell in enumerate(self.cells): - string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) - return string - - def extra_repr(self): - return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) - - def genotype(self): - genotypes = [] - for i in range(1, self.max_nodes): - xlist = [] - for j in range(i): - node_str = '{:}<-{:}'.format(i, j) - with torch.no_grad(): - weights = self.arch_parameters[ self.edge2index[node_str] ] - op_name = self.op_names[ weights.argmax().item() ] - xlist.append((op_name, j)) - genotypes.append( tuple(xlist) ) - return Structure( genotypes ) - - def dync_genotype(self, use_random=False): - genotypes = [] - with torch.no_grad(): - alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1) - for i in range(1, self.max_nodes): - xlist = [] - for j in range(i): - node_str = '{:}<-{:}'.format(i, j) - if use_random: - op_name = random.choice(self.op_names) - else: - weights = alphas_cpu[ self.edge2index[node_str] ] - op_index = torch.multinomial(weights, 1).item() - op_name = self.op_names[ op_index ] - xlist.append((op_name, j)) - genotypes.append( tuple(xlist) ) - return Structure( genotypes ) - - def get_log_prob(self, arch): - with torch.no_grad(): - logits = nn.functional.log_softmax(self.arch_parameters, dim=-1) - select_logits = [] - for i, node_info in enumerate(arch.nodes): - for op, xin in node_info: - node_str = '{:}<-{:}'.format(i+1, xin) - op_index = self.op_names.index(op) - select_logits.append( logits[self.edge2index[node_str], op_index] ) - return sum(select_logits).item() - - - def return_topK(self, K): - archs = Structure.gen_all(self.op_names, self.max_nodes, False) - pairs = [(self.get_log_prob(arch), arch) for arch in archs] - if K < 0 or K >= len(archs): K = len(archs) - sorted_pairs = sorted(pairs, key=lambda x: -x[0]) - return_pairs = [sorted_pairs[_][1] for _ in range(K)] - return return_pairs - - - def forward(self, inputs): - alphas = nn.functional.softmax(self.arch_parameters, dim=-1) - with torch.no_grad(): - alphas_cpu = alphas.detach().cpu() - - feature = self.stem(inputs) - for i, cell in enumerate(self.cells): - if isinstance(cell, SearchCell): - if self.mode == 'urs': - feature = cell.forward_urs(feature) - elif self.mode == 'select': - feature = cell.forward_select(feature, alphas_cpu) - elif self.mode == 'joint': - feature = cell.forward_joint(feature, alphas) - elif self.mode == 'dynamic': - feature = cell.forward_dynamic(feature, self.dynamic_cell) - else: raise ValueError('invalid mode={:}'.format(self.mode)) - else: feature = cell(feature) - - out = self.lastact(feature) - out = self.global_pooling( out ) - out = out.view(out.size(0), -1) - logits = self.classifier(out) - - return out, logits