update GDAS
This commit is contained in:
parent
09d68c6375
commit
b6c0828382
@ -88,7 +88,9 @@ class TinyNetworkGDAS(nn.Module):
|
|||||||
index = probs.max(-1, keepdim=True)[1]
|
index = probs.max(-1, keepdim=True)[1]
|
||||||
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
|
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
|
||||||
hardwts = one_h - probs.detach() + probs
|
hardwts = one_h - probs.detach() + probs
|
||||||
if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()): continue
|
if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()):
|
||||||
|
continue
|
||||||
|
else: break
|
||||||
|
|
||||||
feature = self.stem(inputs)
|
feature = self.stem(inputs)
|
||||||
for i, cell in enumerate(self.cells):
|
for i, cell in enumerate(self.cells):
|
||||||
|
@ -1,17 +0,0 @@
|
|||||||
##################################################
|
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
|
||||||
##################################################
|
|
||||||
from .search_model_darts_v1 import TinyNetworkDartsV1
|
|
||||||
from .search_model_darts_v2 import TinyNetworkDartsV2
|
|
||||||
from .search_model_gdas import TinyNetworkGDAS
|
|
||||||
from .search_model_setn import TinyNetworkSETN
|
|
||||||
from .search_model_enas import TinyNetworkENAS
|
|
||||||
from .search_model_random import TinyNetworkRANDOM
|
|
||||||
from .genotypes import Structure as CellStructure, architectures as CellArchitectures
|
|
||||||
|
|
||||||
nas_super_nets = {'DARTS-V1': TinyNetworkDartsV1,
|
|
||||||
'DARTS-V2': TinyNetworkDartsV2,
|
|
||||||
'GDAS' : TinyNetworkGDAS,
|
|
||||||
'SETN' : TinyNetworkSETN,
|
|
||||||
'ENAS' : TinyNetworkENAS,
|
|
||||||
'RANDOM' : TinyNetworkRANDOM}
|
|
@ -1,12 +0,0 @@
|
|||||||
##################################################
|
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
|
||||||
##################################################
|
|
||||||
import torch
|
|
||||||
from search_model_enas_utils import Controller
|
|
||||||
|
|
||||||
def main():
|
|
||||||
controller = Controller(6, 4)
|
|
||||||
predictions = controller()
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
@ -1,197 +0,0 @@
|
|||||||
##################################################
|
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
|
||||||
##################################################
|
|
||||||
from copy import deepcopy
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_combination(space, num):
|
|
||||||
combs = []
|
|
||||||
for i in range(num):
|
|
||||||
if i == 0:
|
|
||||||
for func in space:
|
|
||||||
combs.append( [(func, i)] )
|
|
||||||
else:
|
|
||||||
new_combs = []
|
|
||||||
for string in combs:
|
|
||||||
for func in space:
|
|
||||||
xstring = string + [(func, i)]
|
|
||||||
new_combs.append( xstring )
|
|
||||||
combs = new_combs
|
|
||||||
return combs
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Structure:
|
|
||||||
|
|
||||||
def __init__(self, genotype):
|
|
||||||
assert isinstance(genotype, list) or isinstance(genotype, tuple), 'invalid class of genotype : {:}'.format(type(genotype))
|
|
||||||
self.node_num = len(genotype) + 1
|
|
||||||
self.nodes = []
|
|
||||||
self.node_N = []
|
|
||||||
for idx, node_info in enumerate(genotype):
|
|
||||||
assert isinstance(node_info, list) or isinstance(node_info, tuple), 'invalid class of node_info : {:}'.format(type(node_info))
|
|
||||||
assert len(node_info) >= 1, 'invalid length : {:}'.format(len(node_info))
|
|
||||||
for node_in in node_info:
|
|
||||||
assert isinstance(node_in, list) or isinstance(node_in, tuple), 'invalid class of in-node : {:}'.format(type(node_in))
|
|
||||||
assert len(node_in) == 2 and node_in[1] <= idx, 'invalid in-node : {:}'.format(node_in)
|
|
||||||
self.node_N.append( len(node_info) )
|
|
||||||
self.nodes.append( tuple(deepcopy(node_info)) )
|
|
||||||
|
|
||||||
def tolist(self, remove_str):
|
|
||||||
# convert this class to the list, if remove_str is 'none', then remove the 'none' operation.
|
|
||||||
# note that we re-order the input node in this function
|
|
||||||
# return the-genotype-list and success [if unsuccess, it is not a connectivity]
|
|
||||||
genotypes = []
|
|
||||||
for node_info in self.nodes:
|
|
||||||
node_info = list( node_info )
|
|
||||||
node_info = sorted(node_info, key=lambda x: (x[1], x[0]))
|
|
||||||
node_info = tuple(filter(lambda x: x[0] != remove_str, node_info))
|
|
||||||
if len(node_info) == 0: return None, False
|
|
||||||
genotypes.append( node_info )
|
|
||||||
return genotypes, True
|
|
||||||
|
|
||||||
def node(self, index):
|
|
||||||
assert index > 0 and index <= len(self), 'invalid index={:} < {:}'.format(index, len(self))
|
|
||||||
return self.nodes[index]
|
|
||||||
|
|
||||||
def tostr(self):
|
|
||||||
strings = []
|
|
||||||
for node_info in self.nodes:
|
|
||||||
string = '|'.join([x[0]+'~{:}'.format(x[1]) for x in node_info])
|
|
||||||
string = '|{:}|'.format(string)
|
|
||||||
strings.append( string )
|
|
||||||
return '+'.join(strings)
|
|
||||||
|
|
||||||
def check_valid(self):
|
|
||||||
nodes = {0: True}
|
|
||||||
for i, node_info in enumerate(self.nodes):
|
|
||||||
sums = []
|
|
||||||
for op, xin in node_info:
|
|
||||||
if op == 'none' or nodes[xin] == False: x = False
|
|
||||||
else: x = True
|
|
||||||
sums.append( x )
|
|
||||||
nodes[i+1] = sum(sums) > 0
|
|
||||||
return nodes[len(self.nodes)]
|
|
||||||
|
|
||||||
def to_unique_str(self, consider_zero=False):
|
|
||||||
# this is used to identify the isomorphic cell, which rerquires the prior knowledge of operation
|
|
||||||
# two operations are special, i.e., none and skip_connect
|
|
||||||
nodes = {0: '0'}
|
|
||||||
for i_node, node_info in enumerate(self.nodes):
|
|
||||||
cur_node = []
|
|
||||||
for op, xin in node_info:
|
|
||||||
if consider_zero:
|
|
||||||
if op == 'none' or nodes[xin] == '#': x = '#' # zero
|
|
||||||
elif op == 'skip_connect': x = nodes[xin]
|
|
||||||
else: x = '('+nodes[xin]+')' + '@{:}'.format(op)
|
|
||||||
else:
|
|
||||||
if op == 'skip_connect': x = nodes[xin]
|
|
||||||
else: x = '('+nodes[xin]+')' + '@{:}'.format(op)
|
|
||||||
cur_node.append(x)
|
|
||||||
nodes[i_node+1] = '+'.join( sorted(cur_node) )
|
|
||||||
return nodes[ len(self.nodes) ]
|
|
||||||
|
|
||||||
def check_valid_op(self, op_names):
|
|
||||||
for node_info in self.nodes:
|
|
||||||
for inode_edge in node_info:
|
|
||||||
#assert inode_edge[0] in op_names, 'invalid op-name : {:}'.format(inode_edge[0])
|
|
||||||
if inode_edge[0] not in op_names: return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return ('{name}({node_num} nodes with {node_info})'.format(name=self.__class__.__name__, node_info=self.tostr(), **self.__dict__))
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.nodes) + 1
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
return self.nodes[index]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def str2structure(xstr):
|
|
||||||
assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr))
|
|
||||||
nodestrs = xstr.split('+')
|
|
||||||
genotypes = []
|
|
||||||
for i, node_str in enumerate(nodestrs):
|
|
||||||
inputs = list(filter(lambda x: x != '', node_str.split('|')))
|
|
||||||
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
|
|
||||||
inputs = ( xi.split('~') for xi in inputs )
|
|
||||||
input_infos = tuple( (op, int(IDX)) for (op, IDX) in inputs)
|
|
||||||
genotypes.append( input_infos )
|
|
||||||
return Structure( genotypes )
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def str2fullstructure(xstr, default_name='none'):
|
|
||||||
assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr))
|
|
||||||
nodestrs = xstr.split('+')
|
|
||||||
genotypes = []
|
|
||||||
for i, node_str in enumerate(nodestrs):
|
|
||||||
inputs = list(filter(lambda x: x != '', node_str.split('|')))
|
|
||||||
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
|
|
||||||
inputs = ( xi.split('~') for xi in inputs )
|
|
||||||
input_infos = list( (op, int(IDX)) for (op, IDX) in inputs)
|
|
||||||
all_in_nodes= list(x[1] for x in input_infos)
|
|
||||||
for j in range(i):
|
|
||||||
if j not in all_in_nodes: input_infos.append((default_name, j))
|
|
||||||
node_info = sorted(input_infos, key=lambda x: (x[1], x[0]))
|
|
||||||
genotypes.append( tuple(node_info) )
|
|
||||||
return Structure( genotypes )
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def gen_all(search_space, num, return_ori):
|
|
||||||
assert isinstance(search_space, list) or isinstance(search_space, tuple), 'invalid class of search-space : {:}'.format(type(search_space))
|
|
||||||
assert num >= 2, 'There should be at least two nodes in a neural cell instead of {:}'.format(num)
|
|
||||||
all_archs = get_combination(search_space, 1)
|
|
||||||
for i, arch in enumerate(all_archs):
|
|
||||||
all_archs[i] = [ tuple(arch) ]
|
|
||||||
|
|
||||||
for inode in range(2, num):
|
|
||||||
cur_nodes = get_combination(search_space, inode)
|
|
||||||
new_all_archs = []
|
|
||||||
for previous_arch in all_archs:
|
|
||||||
for cur_node in cur_nodes:
|
|
||||||
new_all_archs.append( previous_arch + [tuple(cur_node)] )
|
|
||||||
all_archs = new_all_archs
|
|
||||||
if return_ori:
|
|
||||||
return all_archs
|
|
||||||
else:
|
|
||||||
return [Structure(x) for x in all_archs]
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
ResNet_CODE = Structure(
|
|
||||||
[(('nor_conv_3x3', 0), ), # node-1
|
|
||||||
(('nor_conv_3x3', 1), ), # node-2
|
|
||||||
(('skip_connect', 0), ('skip_connect', 2))] # node-3
|
|
||||||
)
|
|
||||||
|
|
||||||
AllConv3x3_CODE = Structure(
|
|
||||||
[(('nor_conv_3x3', 0), ), # node-1
|
|
||||||
(('nor_conv_3x3', 0), ('nor_conv_3x3', 1)), # node-2
|
|
||||||
(('nor_conv_3x3', 0), ('nor_conv_3x3', 1), ('nor_conv_3x3', 2))] # node-3
|
|
||||||
)
|
|
||||||
|
|
||||||
AllFull_CODE = Structure(
|
|
||||||
[(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0)), # node-1
|
|
||||||
(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0), ('skip_connect', 1), ('nor_conv_1x1', 1), ('nor_conv_3x3', 1), ('avg_pool_3x3', 1)), # node-2
|
|
||||||
(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0), ('skip_connect', 1), ('nor_conv_1x1', 1), ('nor_conv_3x3', 1), ('avg_pool_3x3', 1), ('skip_connect', 2), ('nor_conv_1x1', 2), ('nor_conv_3x3', 2), ('avg_pool_3x3', 2))] # node-3
|
|
||||||
)
|
|
||||||
|
|
||||||
AllConv1x1_CODE = Structure(
|
|
||||||
[(('nor_conv_1x1', 0), ), # node-1
|
|
||||||
(('nor_conv_1x1', 0), ('nor_conv_1x1', 1)), # node-2
|
|
||||||
(('nor_conv_1x1', 0), ('nor_conv_1x1', 1), ('nor_conv_1x1', 2))] # node-3
|
|
||||||
)
|
|
||||||
|
|
||||||
AllIdentity_CODE = Structure(
|
|
||||||
[(('skip_connect', 0), ), # node-1
|
|
||||||
(('skip_connect', 0), ('skip_connect', 1)), # node-2
|
|
||||||
(('skip_connect', 0), ('skip_connect', 1), ('skip_connect', 2))] # node-3
|
|
||||||
)
|
|
||||||
|
|
||||||
architectures = {'resnet' : ResNet_CODE,
|
|
||||||
'all_c3x3': AllConv3x3_CODE,
|
|
||||||
'all_c1x1': AllConv1x1_CODE,
|
|
||||||
'all_idnt': AllIdentity_CODE,
|
|
||||||
'all_full': AllFull_CODE}
|
|
@ -1,148 +0,0 @@
|
|||||||
##################################################
|
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
|
||||||
##################################################
|
|
||||||
import math, random, torch
|
|
||||||
import warnings
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from copy import deepcopy
|
|
||||||
from ..cell_operations import OPS
|
|
||||||
|
|
||||||
|
|
||||||
class SearchCell(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, C_in, C_out, stride, max_nodes, op_names, n_piece):
|
|
||||||
super(SearchCell, self).__init__()
|
|
||||||
|
|
||||||
self.op_names = deepcopy(op_names)
|
|
||||||
self.max_nodes = max_nodes
|
|
||||||
self.in_dim = C_in
|
|
||||||
self.out_dim = C_out
|
|
||||||
self.n_piece = n_piece
|
|
||||||
self.multi_edges = nn.ModuleList()
|
|
||||||
for i_piece in range(n_piece):
|
|
||||||
edges = nn.ModuleDict()
|
|
||||||
for i in range(1, max_nodes):
|
|
||||||
for j in range(i):
|
|
||||||
node_str = '{:}<-{:}'.format(i, j)
|
|
||||||
if j == 0: xlists = [OPS[op_name](C_in , C_out, stride) for op_name in op_names]
|
|
||||||
else : xlists = [OPS[op_name](C_in , C_out, 1) for op_name in op_names]
|
|
||||||
edges[ node_str ] = nn.ModuleList( xlists )
|
|
||||||
self.multi_edges.append( edges )
|
|
||||||
|
|
||||||
self.edge_keys = sorted(list(edges.keys()))
|
|
||||||
self.edge2index = {key:i for i, key in enumerate(self.edge_keys)}
|
|
||||||
self.num_edges = len(edges)
|
|
||||||
|
|
||||||
def extra_repr(self):
|
|
||||||
string = 'info :: {max_nodes} nodes, inC={in_dim}, outC={out_dim}, nP={n_piece}'.format(**self.__dict__)
|
|
||||||
return string
|
|
||||||
|
|
||||||
def forward(self, inputs, weightss):
|
|
||||||
nodes = [inputs]
|
|
||||||
with torch.no_grad():
|
|
||||||
xmod, xid, argmax = 1, 0, weightss.argmax(dim=1).cpu().tolist()
|
|
||||||
for i, x in enumerate(argmax):
|
|
||||||
xid += x * (xmod % self.n_piece)
|
|
||||||
xmod = (xmod * len(self.op_names)) % self.n_piece
|
|
||||||
xid = xid % self.n_piece
|
|
||||||
edges = self.multi_edges[xid]
|
|
||||||
for i in range(1, self.max_nodes):
|
|
||||||
inter_nodes = []
|
|
||||||
for j in range(i):
|
|
||||||
node_str = '{:}<-{:}'.format(i, j)
|
|
||||||
weights = weightss[ self.edge2index[node_str] ]
|
|
||||||
inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(edges[node_str], weights) ) )
|
|
||||||
nodes.append( sum(inter_nodes) )
|
|
||||||
return nodes[-1]
|
|
||||||
|
|
||||||
# GDAS
|
|
||||||
def forward_gdas(self, inputs, alphas, _tau):
|
|
||||||
avoid_zero = 0
|
|
||||||
while True:
|
|
||||||
gumbels = -torch.empty_like(alphas).exponential_().log()
|
|
||||||
logits = (alphas.log_softmax(dim=1) + gumbels) / _tau
|
|
||||||
probs = nn.functional.softmax(logits, dim=1)
|
|
||||||
index = probs.max(-1, keepdim=True)[1]
|
|
||||||
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
|
|
||||||
hardwts = one_h - probs.detach() + probs
|
|
||||||
if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()):
|
|
||||||
continue # avoid the numerical error
|
|
||||||
nodes = [inputs]
|
|
||||||
for i in range(1, self.max_nodes):
|
|
||||||
inter_nodes = []
|
|
||||||
for j in range(i):
|
|
||||||
node_str = '{:}<-{:}'.format(i, j)
|
|
||||||
weights = hardwts[ self.edge2index[node_str] ]
|
|
||||||
argmaxs = index[ self.edge2index[node_str] ].item()
|
|
||||||
weigsum = sum( weights[_ie] * edge(nodes[j]) if _ie == argmaxs else weights[_ie] for _ie, edge in enumerate(self.edges[node_str]) )
|
|
||||||
inter_nodes.append( weigsum )
|
|
||||||
nodes.append( sum(inter_nodes) )
|
|
||||||
avoid_zero += 1
|
|
||||||
if nodes[-1].sum().item() == 0:
|
|
||||||
if avoid_zero < 10: continue
|
|
||||||
else:
|
|
||||||
warnings.warn('get zero outputs with avoid_zero={:}'.format(avoid_zero))
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
return nodes[-1]
|
|
||||||
|
|
||||||
# joint
|
|
||||||
def forward_joint(self, inputs, weightss):
|
|
||||||
nodes = [inputs]
|
|
||||||
for i in range(1, self.max_nodes):
|
|
||||||
inter_nodes = []
|
|
||||||
for j in range(i):
|
|
||||||
node_str = '{:}<-{:}'.format(i, j)
|
|
||||||
weights = weightss[ self.edge2index[node_str] ]
|
|
||||||
#aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) / weights.numel()
|
|
||||||
aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) )
|
|
||||||
inter_nodes.append( aggregation )
|
|
||||||
nodes.append( sum(inter_nodes) )
|
|
||||||
return nodes[-1]
|
|
||||||
|
|
||||||
# uniform random sampling per iteration
|
|
||||||
def forward_urs(self, inputs):
|
|
||||||
nodes = [inputs]
|
|
||||||
for i in range(1, self.max_nodes):
|
|
||||||
while True: # to avoid select zero for all ops
|
|
||||||
sops, has_non_zero = [], False
|
|
||||||
for j in range(i):
|
|
||||||
node_str = '{:}<-{:}'.format(i, j)
|
|
||||||
candidates = self.edges[node_str]
|
|
||||||
select_op = random.choice(candidates)
|
|
||||||
sops.append( select_op )
|
|
||||||
if not hasattr(select_op, 'is_zero') or select_op.is_zero == False: has_non_zero=True
|
|
||||||
if has_non_zero: break
|
|
||||||
inter_nodes = []
|
|
||||||
for j, select_op in enumerate(sops):
|
|
||||||
inter_nodes.append( select_op(nodes[j]) )
|
|
||||||
nodes.append( sum(inter_nodes) )
|
|
||||||
return nodes[-1]
|
|
||||||
|
|
||||||
# select the argmax
|
|
||||||
def forward_select(self, inputs, weightss):
|
|
||||||
nodes = [inputs]
|
|
||||||
for i in range(1, self.max_nodes):
|
|
||||||
inter_nodes = []
|
|
||||||
for j in range(i):
|
|
||||||
node_str = '{:}<-{:}'.format(i, j)
|
|
||||||
weights = weightss[ self.edge2index[node_str] ]
|
|
||||||
inter_nodes.append( self.edges[node_str][ weights.argmax().item() ]( nodes[j] ) )
|
|
||||||
#inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) )
|
|
||||||
nodes.append( sum(inter_nodes) )
|
|
||||||
return nodes[-1]
|
|
||||||
|
|
||||||
# forward with a specific structure
|
|
||||||
def forward_dynamic(self, inputs, structure):
|
|
||||||
nodes = [inputs]
|
|
||||||
for i in range(1, self.max_nodes):
|
|
||||||
cur_op_node = structure.nodes[i-1]
|
|
||||||
inter_nodes = []
|
|
||||||
for op_name, j in cur_op_node:
|
|
||||||
node_str = '{:}<-{:}'.format(i, j)
|
|
||||||
op_index = self.op_names.index( op_name )
|
|
||||||
inter_nodes.append( self.edges[node_str][op_index]( nodes[j] ) )
|
|
||||||
nodes.append( sum(inter_nodes) )
|
|
||||||
return nodes[-1]
|
|
@ -1,93 +0,0 @@
|
|||||||
##################################################
|
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
|
||||||
########################################################
|
|
||||||
# DARTS: Differentiable Architecture Search, ICLR 2019 #
|
|
||||||
########################################################
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from copy import deepcopy
|
|
||||||
from ..cell_operations import ResNetBasicblock
|
|
||||||
from .search_cells import SearchCell
|
|
||||||
from .genotypes import Structure
|
|
||||||
|
|
||||||
|
|
||||||
class TinyNetworkDartsV1(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, C, N, max_nodes, num_classes, search_space, n_piece):
|
|
||||||
super(TinyNetworkDartsV1, self).__init__()
|
|
||||||
self._C = C
|
|
||||||
self._layerN = N
|
|
||||||
self.max_nodes = max_nodes
|
|
||||||
self.stem = nn.Sequential(
|
|
||||||
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
|
|
||||||
nn.BatchNorm2d(C))
|
|
||||||
|
|
||||||
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
|
|
||||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
|
||||||
|
|
||||||
C_prev, num_edge, edge2index = C, None, None
|
|
||||||
self.cells = nn.ModuleList()
|
|
||||||
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
|
|
||||||
if reduction:
|
|
||||||
cell = ResNetBasicblock(C_prev, C_curr, 2)
|
|
||||||
else:
|
|
||||||
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, n_piece)
|
|
||||||
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
|
|
||||||
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
|
|
||||||
self.cells.append( cell )
|
|
||||||
C_prev = cell.out_dim
|
|
||||||
self.op_names = deepcopy( search_space )
|
|
||||||
self._Layer = len(self.cells)
|
|
||||||
self.edge2index = edge2index
|
|
||||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
|
||||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
|
||||||
self.classifier = nn.Linear(C_prev, num_classes)
|
|
||||||
self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
|
|
||||||
|
|
||||||
def get_weights(self):
|
|
||||||
xlist = list( self.stem.parameters() ) + list( self.cells.parameters() )
|
|
||||||
xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() )
|
|
||||||
xlist+= list( self.classifier.parameters() )
|
|
||||||
return xlist
|
|
||||||
|
|
||||||
def get_alphas(self):
|
|
||||||
return [self.arch_parameters]
|
|
||||||
|
|
||||||
def get_message(self):
|
|
||||||
string = self.extra_repr()
|
|
||||||
for i, cell in enumerate(self.cells):
|
|
||||||
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
|
|
||||||
return string
|
|
||||||
|
|
||||||
def extra_repr(self):
|
|
||||||
return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
|
|
||||||
|
|
||||||
def genotype(self):
|
|
||||||
genotypes = []
|
|
||||||
for i in range(1, self.max_nodes):
|
|
||||||
xlist = []
|
|
||||||
for j in range(i):
|
|
||||||
node_str = '{:}<-{:}'.format(i, j)
|
|
||||||
with torch.no_grad():
|
|
||||||
weights = self.arch_parameters[ self.edge2index[node_str] ]
|
|
||||||
op_name = self.op_names[ weights.argmax().item() ]
|
|
||||||
xlist.append((op_name, j))
|
|
||||||
genotypes.append( tuple(xlist) )
|
|
||||||
return Structure( genotypes )
|
|
||||||
|
|
||||||
def forward(self, inputs):
|
|
||||||
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
|
|
||||||
|
|
||||||
feature = self.stem(inputs)
|
|
||||||
for i, cell in enumerate(self.cells):
|
|
||||||
if isinstance(cell, SearchCell):
|
|
||||||
feature = cell(feature, alphas)
|
|
||||||
else:
|
|
||||||
feature = cell(feature)
|
|
||||||
|
|
||||||
out = self.lastact(feature)
|
|
||||||
out = self.global_pooling( out )
|
|
||||||
out = out.view(out.size(0), -1)
|
|
||||||
logits = self.classifier(out)
|
|
||||||
|
|
||||||
return out, logits
|
|
@ -1,93 +0,0 @@
|
|||||||
##################################################
|
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
|
||||||
########################################################
|
|
||||||
# DARTS: Differentiable Architecture Search, ICLR 2019 #
|
|
||||||
########################################################
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from copy import deepcopy
|
|
||||||
from ..cell_operations import ResNetBasicblock
|
|
||||||
from .search_cells import SearchCell
|
|
||||||
from .genotypes import Structure
|
|
||||||
|
|
||||||
|
|
||||||
class TinyNetworkDartsV2(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, C, N, max_nodes, num_classes, search_space):
|
|
||||||
super(TinyNetworkDartsV2, self).__init__()
|
|
||||||
self._C = C
|
|
||||||
self._layerN = N
|
|
||||||
self.max_nodes = max_nodes
|
|
||||||
self.stem = nn.Sequential(
|
|
||||||
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
|
|
||||||
nn.BatchNorm2d(C))
|
|
||||||
|
|
||||||
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
|
|
||||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
|
||||||
|
|
||||||
C_prev, num_edge, edge2index = C, None, None
|
|
||||||
self.cells = nn.ModuleList()
|
|
||||||
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
|
|
||||||
if reduction:
|
|
||||||
cell = ResNetBasicblock(C_prev, C_curr, 2)
|
|
||||||
else:
|
|
||||||
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space)
|
|
||||||
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
|
|
||||||
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
|
|
||||||
self.cells.append( cell )
|
|
||||||
C_prev = cell.out_dim
|
|
||||||
self.op_names = deepcopy( search_space )
|
|
||||||
self._Layer = len(self.cells)
|
|
||||||
self.edge2index = edge2index
|
|
||||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
|
||||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
|
||||||
self.classifier = nn.Linear(C_prev, num_classes)
|
|
||||||
self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
|
|
||||||
|
|
||||||
def get_weights(self):
|
|
||||||
xlist = list( self.stem.parameters() ) + list( self.cells.parameters() )
|
|
||||||
xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() )
|
|
||||||
xlist+= list( self.classifier.parameters() )
|
|
||||||
return xlist
|
|
||||||
|
|
||||||
def get_alphas(self):
|
|
||||||
return [self.arch_parameters]
|
|
||||||
|
|
||||||
def get_message(self):
|
|
||||||
string = self.extra_repr()
|
|
||||||
for i, cell in enumerate(self.cells):
|
|
||||||
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
|
|
||||||
return string
|
|
||||||
|
|
||||||
def extra_repr(self):
|
|
||||||
return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
|
|
||||||
|
|
||||||
def genotype(self):
|
|
||||||
genotypes = []
|
|
||||||
for i in range(1, self.max_nodes):
|
|
||||||
xlist = []
|
|
||||||
for j in range(i):
|
|
||||||
node_str = '{:}<-{:}'.format(i, j)
|
|
||||||
with torch.no_grad():
|
|
||||||
weights = self.arch_parameters[ self.edge2index[node_str] ]
|
|
||||||
op_name = self.op_names[ weights.argmax().item() ]
|
|
||||||
xlist.append((op_name, j))
|
|
||||||
genotypes.append( tuple(xlist) )
|
|
||||||
return Structure( genotypes )
|
|
||||||
|
|
||||||
def forward(self, inputs):
|
|
||||||
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
|
|
||||||
|
|
||||||
feature = self.stem(inputs)
|
|
||||||
for i, cell in enumerate(self.cells):
|
|
||||||
if isinstance(cell, SearchCell):
|
|
||||||
feature = cell(feature, alphas)
|
|
||||||
else:
|
|
||||||
feature = cell(feature)
|
|
||||||
|
|
||||||
out = self.lastact(feature)
|
|
||||||
out = self.global_pooling( out )
|
|
||||||
out = out.view(out.size(0), -1)
|
|
||||||
logits = self.classifier(out)
|
|
||||||
|
|
||||||
return out, logits
|
|
@ -1,94 +0,0 @@
|
|||||||
##################################################
|
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
|
||||||
##########################################################################
|
|
||||||
# Efficient Neural Architecture Search via Parameters Sharing, ICML 2018 #
|
|
||||||
##########################################################################
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from copy import deepcopy
|
|
||||||
from ..cell_operations import ResNetBasicblock
|
|
||||||
from .search_cells import SearchCell
|
|
||||||
from .genotypes import Structure
|
|
||||||
from .search_model_enas_utils import Controller
|
|
||||||
|
|
||||||
|
|
||||||
class TinyNetworkENAS(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, C, N, max_nodes, num_classes, search_space):
|
|
||||||
super(TinyNetworkENAS, self).__init__()
|
|
||||||
self._C = C
|
|
||||||
self._layerN = N
|
|
||||||
self.max_nodes = max_nodes
|
|
||||||
self.stem = nn.Sequential(
|
|
||||||
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
|
|
||||||
nn.BatchNorm2d(C))
|
|
||||||
|
|
||||||
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
|
|
||||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
|
||||||
|
|
||||||
C_prev, num_edge, edge2index = C, None, None
|
|
||||||
self.cells = nn.ModuleList()
|
|
||||||
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
|
|
||||||
if reduction:
|
|
||||||
cell = ResNetBasicblock(C_prev, C_curr, 2)
|
|
||||||
else:
|
|
||||||
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space)
|
|
||||||
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
|
|
||||||
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
|
|
||||||
self.cells.append( cell )
|
|
||||||
C_prev = cell.out_dim
|
|
||||||
self.op_names = deepcopy( search_space )
|
|
||||||
self._Layer = len(self.cells)
|
|
||||||
self.edge2index = edge2index
|
|
||||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
|
||||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
|
||||||
self.classifier = nn.Linear(C_prev, num_classes)
|
|
||||||
# to maintain the sampled architecture
|
|
||||||
self.sampled_arch = None
|
|
||||||
|
|
||||||
def update_arch(self, _arch):
|
|
||||||
if _arch is None:
|
|
||||||
self.sampled_arch = None
|
|
||||||
elif isinstance(_arch, Structure):
|
|
||||||
self.sampled_arch = _arch
|
|
||||||
elif isinstance(_arch, (list, tuple)):
|
|
||||||
genotypes = []
|
|
||||||
for i in range(1, self.max_nodes):
|
|
||||||
xlist = []
|
|
||||||
for j in range(i):
|
|
||||||
node_str = '{:}<-{:}'.format(i, j)
|
|
||||||
op_index = _arch[ self.edge2index[node_str] ]
|
|
||||||
op_name = self.op_names[ op_index ]
|
|
||||||
xlist.append((op_name, j))
|
|
||||||
genotypes.append( tuple(xlist) )
|
|
||||||
self.sampled_arch = Structure(genotypes)
|
|
||||||
else:
|
|
||||||
raise ValueError('invalid type of input architecture : {:}'.format(_arch))
|
|
||||||
return self.sampled_arch
|
|
||||||
|
|
||||||
def create_controller(self):
|
|
||||||
return Controller(len(self.edge2index), len(self.op_names))
|
|
||||||
|
|
||||||
def get_message(self):
|
|
||||||
string = self.extra_repr()
|
|
||||||
for i, cell in enumerate(self.cells):
|
|
||||||
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
|
|
||||||
return string
|
|
||||||
|
|
||||||
def extra_repr(self):
|
|
||||||
return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
|
|
||||||
|
|
||||||
def forward(self, inputs):
|
|
||||||
|
|
||||||
feature = self.stem(inputs)
|
|
||||||
for i, cell in enumerate(self.cells):
|
|
||||||
if isinstance(cell, SearchCell):
|
|
||||||
feature = cell.forward_dynamic(feature, self.sampled_arch)
|
|
||||||
else: feature = cell(feature)
|
|
||||||
|
|
||||||
out = self.lastact(feature)
|
|
||||||
out = self.global_pooling( out )
|
|
||||||
out = out.view(out.size(0), -1)
|
|
||||||
logits = self.classifier(out)
|
|
||||||
|
|
||||||
return out, logits
|
|
@ -1,55 +0,0 @@
|
|||||||
##################################################
|
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
|
||||||
##########################################################################
|
|
||||||
# Efficient Neural Architecture Search via Parameters Sharing, ICML 2018 #
|
|
||||||
##########################################################################
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.distributions.categorical import Categorical
|
|
||||||
|
|
||||||
class Controller(nn.Module):
|
|
||||||
# we refer to https://github.com/TDeVries/enas_pytorch/blob/master/models/controller.py
|
|
||||||
def __init__(self, num_edge, num_ops, lstm_size=32, lstm_num_layers=2, tanh_constant=2.5, temperature=5.0):
|
|
||||||
super(Controller, self).__init__()
|
|
||||||
# assign the attributes
|
|
||||||
self.num_edge = num_edge
|
|
||||||
self.num_ops = num_ops
|
|
||||||
self.lstm_size = lstm_size
|
|
||||||
self.lstm_N = lstm_num_layers
|
|
||||||
self.tanh_constant = tanh_constant
|
|
||||||
self.temperature = temperature
|
|
||||||
# create parameters
|
|
||||||
self.register_parameter('input_vars', nn.Parameter(torch.Tensor(1, 1, lstm_size)))
|
|
||||||
self.w_lstm = nn.LSTM(input_size=self.lstm_size, hidden_size=self.lstm_size, num_layers=self.lstm_N)
|
|
||||||
self.w_embd = nn.Embedding(self.num_ops, self.lstm_size)
|
|
||||||
self.w_pred = nn.Linear(self.lstm_size, self.num_ops)
|
|
||||||
|
|
||||||
nn.init.uniform_(self.input_vars , -0.1, 0.1)
|
|
||||||
nn.init.uniform_(self.w_lstm.weight_hh_l0, -0.1, 0.1)
|
|
||||||
nn.init.uniform_(self.w_lstm.weight_ih_l0, -0.1, 0.1)
|
|
||||||
nn.init.uniform_(self.w_embd.weight , -0.1, 0.1)
|
|
||||||
nn.init.uniform_(self.w_pred.weight , -0.1, 0.1)
|
|
||||||
|
|
||||||
def forward(self):
|
|
||||||
|
|
||||||
inputs, h0 = self.input_vars, None
|
|
||||||
log_probs, entropys, sampled_arch = [], [], []
|
|
||||||
for iedge in range(self.num_edge):
|
|
||||||
outputs, h0 = self.w_lstm(inputs, h0)
|
|
||||||
|
|
||||||
logits = self.w_pred(outputs)
|
|
||||||
logits = logits / self.temperature
|
|
||||||
logits = self.tanh_constant * torch.tanh(logits)
|
|
||||||
# distribution
|
|
||||||
op_distribution = Categorical(logits=logits)
|
|
||||||
op_index = op_distribution.sample()
|
|
||||||
sampled_arch.append( op_index.item() )
|
|
||||||
|
|
||||||
op_log_prob = op_distribution.log_prob(op_index)
|
|
||||||
log_probs.append( op_log_prob.view(-1) )
|
|
||||||
op_entropy = op_distribution.entropy()
|
|
||||||
entropys.append( op_entropy.view(-1) )
|
|
||||||
|
|
||||||
# obtain the input embedding for the next step
|
|
||||||
inputs = self.w_embd(op_index)
|
|
||||||
return torch.sum(torch.cat(log_probs)), torch.sum(torch.cat(entropys)), sampled_arch
|
|
@ -1,96 +0,0 @@
|
|||||||
###########################################################################
|
|
||||||
# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 #
|
|
||||||
###########################################################################
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from copy import deepcopy
|
|
||||||
from ..cell_operations import ResNetBasicblock
|
|
||||||
from .search_cells import SearchCell
|
|
||||||
from .genotypes import Structure
|
|
||||||
|
|
||||||
|
|
||||||
class TinyNetworkGDAS(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, C, N, max_nodes, num_classes, search_space):
|
|
||||||
super(TinyNetworkGDAS, self).__init__()
|
|
||||||
self._C = C
|
|
||||||
self._layerN = N
|
|
||||||
self.max_nodes = max_nodes
|
|
||||||
self.stem = nn.Sequential(
|
|
||||||
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
|
|
||||||
nn.BatchNorm2d(C))
|
|
||||||
|
|
||||||
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
|
|
||||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
|
||||||
|
|
||||||
C_prev, num_edge, edge2index = C, None, None
|
|
||||||
self.cells = nn.ModuleList()
|
|
||||||
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
|
|
||||||
if reduction:
|
|
||||||
cell = ResNetBasicblock(C_prev, C_curr, 2)
|
|
||||||
else:
|
|
||||||
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space)
|
|
||||||
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
|
|
||||||
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
|
|
||||||
self.cells.append( cell )
|
|
||||||
C_prev = cell.out_dim
|
|
||||||
self.op_names = deepcopy( search_space )
|
|
||||||
self._Layer = len(self.cells)
|
|
||||||
self.edge2index = edge2index
|
|
||||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
|
||||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
|
||||||
self.classifier = nn.Linear(C_prev, num_classes)
|
|
||||||
self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
|
|
||||||
self.tau = 10
|
|
||||||
|
|
||||||
def get_weights(self):
|
|
||||||
xlist = list( self.stem.parameters() ) + list( self.cells.parameters() )
|
|
||||||
xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() )
|
|
||||||
xlist+= list( self.classifier.parameters() )
|
|
||||||
return xlist
|
|
||||||
|
|
||||||
def set_tau(self, tau):
|
|
||||||
self.tau = tau
|
|
||||||
|
|
||||||
def get_tau(self):
|
|
||||||
return self.tau
|
|
||||||
|
|
||||||
def get_alphas(self):
|
|
||||||
return [self.arch_parameters]
|
|
||||||
|
|
||||||
def get_message(self):
|
|
||||||
string = self.extra_repr()
|
|
||||||
for i, cell in enumerate(self.cells):
|
|
||||||
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
|
|
||||||
return string
|
|
||||||
|
|
||||||
def extra_repr(self):
|
|
||||||
return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
|
|
||||||
|
|
||||||
def genotype(self):
|
|
||||||
genotypes = []
|
|
||||||
for i in range(1, self.max_nodes):
|
|
||||||
xlist = []
|
|
||||||
for j in range(i):
|
|
||||||
node_str = '{:}<-{:}'.format(i, j)
|
|
||||||
with torch.no_grad():
|
|
||||||
weights = self.arch_parameters[ self.edge2index[node_str] ]
|
|
||||||
op_name = self.op_names[ weights.argmax().item() ]
|
|
||||||
xlist.append((op_name, j))
|
|
||||||
genotypes.append( tuple(xlist) )
|
|
||||||
return Structure( genotypes )
|
|
||||||
|
|
||||||
def forward(self, inputs):
|
|
||||||
feature = self.stem(inputs)
|
|
||||||
for i, cell in enumerate(self.cells):
|
|
||||||
if isinstance(cell, SearchCell):
|
|
||||||
feature = cell.forward_gdas(feature, self.arch_parameters, self.tau)
|
|
||||||
else:
|
|
||||||
feature = cell(feature)
|
|
||||||
|
|
||||||
out = self.lastact(feature)
|
|
||||||
out = self.global_pooling( out )
|
|
||||||
out = out.view(out.size(0), -1)
|
|
||||||
logits = self.classifier(out)
|
|
||||||
|
|
||||||
return out, logits
|
|
@ -1,81 +0,0 @@
|
|||||||
##################################################
|
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
|
||||||
##############################################################################
|
|
||||||
# Random Search and Reproducibility for Neural Architecture Search, UAI 2019 #
|
|
||||||
##############################################################################
|
|
||||||
import torch, random
|
|
||||||
import torch.nn as nn
|
|
||||||
from copy import deepcopy
|
|
||||||
from ..cell_operations import ResNetBasicblock
|
|
||||||
from .search_cells import SearchCell
|
|
||||||
from .genotypes import Structure
|
|
||||||
|
|
||||||
|
|
||||||
class TinyNetworkRANDOM(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, C, N, max_nodes, num_classes, search_space):
|
|
||||||
super(TinyNetworkRANDOM, self).__init__()
|
|
||||||
self._C = C
|
|
||||||
self._layerN = N
|
|
||||||
self.max_nodes = max_nodes
|
|
||||||
self.stem = nn.Sequential(
|
|
||||||
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
|
|
||||||
nn.BatchNorm2d(C))
|
|
||||||
|
|
||||||
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
|
|
||||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
|
||||||
|
|
||||||
C_prev, num_edge, edge2index = C, None, None
|
|
||||||
self.cells = nn.ModuleList()
|
|
||||||
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
|
|
||||||
if reduction:
|
|
||||||
cell = ResNetBasicblock(C_prev, C_curr, 2)
|
|
||||||
else:
|
|
||||||
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space)
|
|
||||||
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
|
|
||||||
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
|
|
||||||
self.cells.append( cell )
|
|
||||||
C_prev = cell.out_dim
|
|
||||||
self.op_names = deepcopy( search_space )
|
|
||||||
self._Layer = len(self.cells)
|
|
||||||
self.edge2index = edge2index
|
|
||||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
|
||||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
|
||||||
self.classifier = nn.Linear(C_prev, num_classes)
|
|
||||||
self.arch_cache = None
|
|
||||||
|
|
||||||
def get_message(self):
|
|
||||||
string = self.extra_repr()
|
|
||||||
for i, cell in enumerate(self.cells):
|
|
||||||
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
|
|
||||||
return string
|
|
||||||
|
|
||||||
def extra_repr(self):
|
|
||||||
return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
|
|
||||||
|
|
||||||
def random_genotype(self, set_cache):
|
|
||||||
genotypes = []
|
|
||||||
for i in range(1, self.max_nodes):
|
|
||||||
xlist = []
|
|
||||||
for j in range(i):
|
|
||||||
node_str = '{:}<-{:}'.format(i, j)
|
|
||||||
op_name = random.choice( self.op_names )
|
|
||||||
xlist.append((op_name, j))
|
|
||||||
genotypes.append( tuple(xlist) )
|
|
||||||
arch = Structure( genotypes )
|
|
||||||
if set_cache: self.arch_cache = arch
|
|
||||||
return arch
|
|
||||||
|
|
||||||
def forward(self, inputs):
|
|
||||||
|
|
||||||
feature = self.stem(inputs)
|
|
||||||
for i, cell in enumerate(self.cells):
|
|
||||||
if isinstance(cell, SearchCell):
|
|
||||||
feature = cell.forward_dynamic(feature, self.arch_cache)
|
|
||||||
else: feature = cell(feature)
|
|
||||||
|
|
||||||
out = self.lastact(feature)
|
|
||||||
out = self.global_pooling( out )
|
|
||||||
out = out.view(out.size(0), -1)
|
|
||||||
logits = self.classifier(out)
|
|
||||||
return out, logits
|
|
@ -1,152 +0,0 @@
|
|||||||
##################################################
|
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
|
||||||
######################################################################################
|
|
||||||
# One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 #
|
|
||||||
######################################################################################
|
|
||||||
import torch, random
|
|
||||||
import torch.nn as nn
|
|
||||||
from copy import deepcopy
|
|
||||||
from ..cell_operations import ResNetBasicblock
|
|
||||||
from .search_cells import SearchCell
|
|
||||||
from .genotypes import Structure
|
|
||||||
|
|
||||||
|
|
||||||
class TinyNetworkSETN(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, C, N, max_nodes, num_classes, search_space):
|
|
||||||
super(TinyNetworkSETN, self).__init__()
|
|
||||||
self._C = C
|
|
||||||
self._layerN = N
|
|
||||||
self.max_nodes = max_nodes
|
|
||||||
self.stem = nn.Sequential(
|
|
||||||
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
|
|
||||||
nn.BatchNorm2d(C))
|
|
||||||
|
|
||||||
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
|
|
||||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
|
||||||
|
|
||||||
C_prev, num_edge, edge2index = C, None, None
|
|
||||||
self.cells = nn.ModuleList()
|
|
||||||
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
|
|
||||||
if reduction:
|
|
||||||
cell = ResNetBasicblock(C_prev, C_curr, 2)
|
|
||||||
else:
|
|
||||||
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space)
|
|
||||||
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
|
|
||||||
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
|
|
||||||
self.cells.append( cell )
|
|
||||||
C_prev = cell.out_dim
|
|
||||||
self.op_names = deepcopy( search_space )
|
|
||||||
self._Layer = len(self.cells)
|
|
||||||
self.edge2index = edge2index
|
|
||||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
|
||||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
|
||||||
self.classifier = nn.Linear(C_prev, num_classes)
|
|
||||||
self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
|
|
||||||
self.mode = 'urs'
|
|
||||||
self.dynamic_cell = None
|
|
||||||
|
|
||||||
def set_cal_mode(self, mode, dynamic_cell=None):
|
|
||||||
assert mode in ['urs', 'joint', 'select', 'dynamic']
|
|
||||||
self.mode = mode
|
|
||||||
if mode == 'dynamic': self.dynamic_cell = deepcopy( dynamic_cell )
|
|
||||||
else : self.dynamic_cell = None
|
|
||||||
|
|
||||||
def get_cal_mode(self):
|
|
||||||
return self.mode
|
|
||||||
|
|
||||||
def get_weights(self):
|
|
||||||
xlist = list( self.stem.parameters() ) + list( self.cells.parameters() )
|
|
||||||
xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() )
|
|
||||||
xlist+= list( self.classifier.parameters() )
|
|
||||||
return xlist
|
|
||||||
|
|
||||||
def get_alphas(self):
|
|
||||||
return [self.arch_parameters]
|
|
||||||
|
|
||||||
def get_message(self):
|
|
||||||
string = self.extra_repr()
|
|
||||||
for i, cell in enumerate(self.cells):
|
|
||||||
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
|
|
||||||
return string
|
|
||||||
|
|
||||||
def extra_repr(self):
|
|
||||||
return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
|
|
||||||
|
|
||||||
def genotype(self):
|
|
||||||
genotypes = []
|
|
||||||
for i in range(1, self.max_nodes):
|
|
||||||
xlist = []
|
|
||||||
for j in range(i):
|
|
||||||
node_str = '{:}<-{:}'.format(i, j)
|
|
||||||
with torch.no_grad():
|
|
||||||
weights = self.arch_parameters[ self.edge2index[node_str] ]
|
|
||||||
op_name = self.op_names[ weights.argmax().item() ]
|
|
||||||
xlist.append((op_name, j))
|
|
||||||
genotypes.append( tuple(xlist) )
|
|
||||||
return Structure( genotypes )
|
|
||||||
|
|
||||||
def dync_genotype(self, use_random=False):
|
|
||||||
genotypes = []
|
|
||||||
with torch.no_grad():
|
|
||||||
alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1)
|
|
||||||
for i in range(1, self.max_nodes):
|
|
||||||
xlist = []
|
|
||||||
for j in range(i):
|
|
||||||
node_str = '{:}<-{:}'.format(i, j)
|
|
||||||
if use_random:
|
|
||||||
op_name = random.choice(self.op_names)
|
|
||||||
else:
|
|
||||||
weights = alphas_cpu[ self.edge2index[node_str] ]
|
|
||||||
op_index = torch.multinomial(weights, 1).item()
|
|
||||||
op_name = self.op_names[ op_index ]
|
|
||||||
xlist.append((op_name, j))
|
|
||||||
genotypes.append( tuple(xlist) )
|
|
||||||
return Structure( genotypes )
|
|
||||||
|
|
||||||
def get_log_prob(self, arch):
|
|
||||||
with torch.no_grad():
|
|
||||||
logits = nn.functional.log_softmax(self.arch_parameters, dim=-1)
|
|
||||||
select_logits = []
|
|
||||||
for i, node_info in enumerate(arch.nodes):
|
|
||||||
for op, xin in node_info:
|
|
||||||
node_str = '{:}<-{:}'.format(i+1, xin)
|
|
||||||
op_index = self.op_names.index(op)
|
|
||||||
select_logits.append( logits[self.edge2index[node_str], op_index] )
|
|
||||||
return sum(select_logits).item()
|
|
||||||
|
|
||||||
|
|
||||||
def return_topK(self, K):
|
|
||||||
archs = Structure.gen_all(self.op_names, self.max_nodes, False)
|
|
||||||
pairs = [(self.get_log_prob(arch), arch) for arch in archs]
|
|
||||||
if K < 0 or K >= len(archs): K = len(archs)
|
|
||||||
sorted_pairs = sorted(pairs, key=lambda x: -x[0])
|
|
||||||
return_pairs = [sorted_pairs[_][1] for _ in range(K)]
|
|
||||||
return return_pairs
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, inputs):
|
|
||||||
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
|
|
||||||
with torch.no_grad():
|
|
||||||
alphas_cpu = alphas.detach().cpu()
|
|
||||||
|
|
||||||
feature = self.stem(inputs)
|
|
||||||
for i, cell in enumerate(self.cells):
|
|
||||||
if isinstance(cell, SearchCell):
|
|
||||||
if self.mode == 'urs':
|
|
||||||
feature = cell.forward_urs(feature)
|
|
||||||
elif self.mode == 'select':
|
|
||||||
feature = cell.forward_select(feature, alphas_cpu)
|
|
||||||
elif self.mode == 'joint':
|
|
||||||
feature = cell.forward_joint(feature, alphas)
|
|
||||||
elif self.mode == 'dynamic':
|
|
||||||
feature = cell.forward_dynamic(feature, self.dynamic_cell)
|
|
||||||
else: raise ValueError('invalid mode={:}'.format(self.mode))
|
|
||||||
else: feature = cell(feature)
|
|
||||||
|
|
||||||
out = self.lastact(feature)
|
|
||||||
out = self.global_pooling( out )
|
|
||||||
out = out.view(out.size(0), -1)
|
|
||||||
logits = self.classifier(out)
|
|
||||||
|
|
||||||
return out, logits
|
|
Loading…
Reference in New Issue
Block a user