update GDAS
This commit is contained in:
		@@ -88,7 +88,9 @@ class TinyNetworkGDAS(nn.Module):
 | 
			
		||||
      index   = probs.max(-1, keepdim=True)[1]
 | 
			
		||||
      one_h   = torch.zeros_like(logits).scatter_(-1, index, 1.0)
 | 
			
		||||
      hardwts = one_h - probs.detach() + probs
 | 
			
		||||
      if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()): continue
 | 
			
		||||
      if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()):
 | 
			
		||||
        continue
 | 
			
		||||
      else: break
 | 
			
		||||
 | 
			
		||||
    feature = self.stem(inputs)
 | 
			
		||||
    for i, cell in enumerate(self.cells):
 | 
			
		||||
 
 | 
			
		||||
@@ -1,17 +0,0 @@
 | 
			
		||||
##################################################
 | 
			
		||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
 | 
			
		||||
##################################################
 | 
			
		||||
from .search_model_darts_v1 import TinyNetworkDartsV1
 | 
			
		||||
from .search_model_darts_v2 import TinyNetworkDartsV2
 | 
			
		||||
from .search_model_gdas     import TinyNetworkGDAS
 | 
			
		||||
from .search_model_setn     import TinyNetworkSETN
 | 
			
		||||
from .search_model_enas     import TinyNetworkENAS
 | 
			
		||||
from .search_model_random   import TinyNetworkRANDOM
 | 
			
		||||
from .genotypes             import Structure as CellStructure, architectures as CellArchitectures
 | 
			
		||||
 | 
			
		||||
nas_super_nets = {'DARTS-V1': TinyNetworkDartsV1,
 | 
			
		||||
                  'DARTS-V2': TinyNetworkDartsV2,
 | 
			
		||||
                  'GDAS'    : TinyNetworkGDAS,
 | 
			
		||||
                  'SETN'    : TinyNetworkSETN,
 | 
			
		||||
                  'ENAS'    : TinyNetworkENAS,
 | 
			
		||||
                  'RANDOM'  : TinyNetworkRANDOM}
 | 
			
		||||
@@ -1,12 +0,0 @@
 | 
			
		||||
##################################################
 | 
			
		||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
 | 
			
		||||
##################################################
 | 
			
		||||
import torch
 | 
			
		||||
from search_model_enas_utils import Controller
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
  controller = Controller(6, 4)
 | 
			
		||||
  predictions = controller()
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
  main()
 | 
			
		||||
@@ -1,197 +0,0 @@
 | 
			
		||||
##################################################
 | 
			
		||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
 | 
			
		||||
##################################################
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_combination(space, num):
 | 
			
		||||
  combs = []
 | 
			
		||||
  for i in range(num):
 | 
			
		||||
    if i == 0:
 | 
			
		||||
      for func in space:
 | 
			
		||||
        combs.append( [(func, i)] )
 | 
			
		||||
    else:
 | 
			
		||||
      new_combs = []
 | 
			
		||||
      for string in combs:
 | 
			
		||||
        for func in space:
 | 
			
		||||
          xstring = string + [(func, i)]
 | 
			
		||||
          new_combs.append( xstring )
 | 
			
		||||
      combs = new_combs
 | 
			
		||||
  return combs
 | 
			
		||||
  
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Structure:
 | 
			
		||||
 | 
			
		||||
  def __init__(self, genotype):
 | 
			
		||||
    assert isinstance(genotype, list) or isinstance(genotype, tuple), 'invalid class of genotype : {:}'.format(type(genotype))
 | 
			
		||||
    self.node_num = len(genotype) + 1
 | 
			
		||||
    self.nodes    = []
 | 
			
		||||
    self.node_N   = []
 | 
			
		||||
    for idx, node_info in enumerate(genotype):
 | 
			
		||||
      assert isinstance(node_info, list) or isinstance(node_info, tuple), 'invalid class of node_info : {:}'.format(type(node_info))
 | 
			
		||||
      assert len(node_info) >= 1, 'invalid length : {:}'.format(len(node_info))
 | 
			
		||||
      for node_in in node_info:
 | 
			
		||||
        assert isinstance(node_in, list) or isinstance(node_in, tuple), 'invalid class of in-node : {:}'.format(type(node_in))
 | 
			
		||||
        assert len(node_in) == 2 and node_in[1] <= idx, 'invalid in-node : {:}'.format(node_in)
 | 
			
		||||
      self.node_N.append( len(node_info) )
 | 
			
		||||
      self.nodes.append( tuple(deepcopy(node_info)) )
 | 
			
		||||
 | 
			
		||||
  def tolist(self, remove_str):
 | 
			
		||||
    # convert this class to the list, if remove_str is 'none', then remove the 'none' operation.
 | 
			
		||||
    # note that we re-order the input node in this function
 | 
			
		||||
    # return the-genotype-list and success [if unsuccess, it is not a connectivity]
 | 
			
		||||
    genotypes = []
 | 
			
		||||
    for node_info in self.nodes:
 | 
			
		||||
      node_info = list( node_info )
 | 
			
		||||
      node_info = sorted(node_info, key=lambda x: (x[1], x[0]))
 | 
			
		||||
      node_info = tuple(filter(lambda x: x[0] != remove_str, node_info))
 | 
			
		||||
      if len(node_info) == 0: return None, False
 | 
			
		||||
      genotypes.append( node_info )
 | 
			
		||||
    return genotypes, True
 | 
			
		||||
 | 
			
		||||
  def node(self, index):
 | 
			
		||||
    assert index > 0 and index <= len(self), 'invalid index={:} < {:}'.format(index, len(self))
 | 
			
		||||
    return self.nodes[index]
 | 
			
		||||
 | 
			
		||||
  def tostr(self):
 | 
			
		||||
    strings = []
 | 
			
		||||
    for node_info in self.nodes:
 | 
			
		||||
      string = '|'.join([x[0]+'~{:}'.format(x[1]) for x in node_info])
 | 
			
		||||
      string = '|{:}|'.format(string)
 | 
			
		||||
      strings.append( string )
 | 
			
		||||
    return '+'.join(strings)
 | 
			
		||||
 | 
			
		||||
  def check_valid(self):
 | 
			
		||||
    nodes = {0: True}
 | 
			
		||||
    for i, node_info in enumerate(self.nodes):
 | 
			
		||||
      sums = []
 | 
			
		||||
      for op, xin in node_info:
 | 
			
		||||
        if op == 'none' or nodes[xin] == False: x = False
 | 
			
		||||
        else: x = True
 | 
			
		||||
        sums.append( x )
 | 
			
		||||
      nodes[i+1] = sum(sums) > 0
 | 
			
		||||
    return nodes[len(self.nodes)]
 | 
			
		||||
 | 
			
		||||
  def to_unique_str(self, consider_zero=False):
 | 
			
		||||
    # this is used to identify the isomorphic cell, which rerquires the prior knowledge of operation
 | 
			
		||||
    # two operations are special, i.e., none and skip_connect
 | 
			
		||||
    nodes = {0: '0'}
 | 
			
		||||
    for i_node, node_info in enumerate(self.nodes):
 | 
			
		||||
      cur_node = []
 | 
			
		||||
      for op, xin in node_info:
 | 
			
		||||
        if consider_zero:
 | 
			
		||||
          if op == 'none' or nodes[xin] == '#': x = '#' # zero
 | 
			
		||||
          elif op == 'skip_connect': x = nodes[xin]
 | 
			
		||||
          else: x = '('+nodes[xin]+')' + '@{:}'.format(op)
 | 
			
		||||
        else:
 | 
			
		||||
          if op == 'skip_connect': x = nodes[xin]
 | 
			
		||||
          else: x = '('+nodes[xin]+')' + '@{:}'.format(op)
 | 
			
		||||
        cur_node.append(x)
 | 
			
		||||
      nodes[i_node+1] = '+'.join( sorted(cur_node) )
 | 
			
		||||
    return nodes[ len(self.nodes) ]
 | 
			
		||||
 | 
			
		||||
  def check_valid_op(self, op_names):
 | 
			
		||||
    for node_info in self.nodes:
 | 
			
		||||
      for inode_edge in node_info:
 | 
			
		||||
        #assert inode_edge[0] in op_names, 'invalid op-name : {:}'.format(inode_edge[0])
 | 
			
		||||
        if inode_edge[0] not in op_names: return False
 | 
			
		||||
    return True
 | 
			
		||||
 | 
			
		||||
  def __repr__(self):
 | 
			
		||||
    return ('{name}({node_num} nodes with {node_info})'.format(name=self.__class__.__name__, node_info=self.tostr(), **self.__dict__))
 | 
			
		||||
 | 
			
		||||
  def __len__(self):
 | 
			
		||||
    return len(self.nodes) + 1
 | 
			
		||||
 | 
			
		||||
  def __getitem__(self, index):
 | 
			
		||||
    return self.nodes[index]
 | 
			
		||||
 | 
			
		||||
  @staticmethod
 | 
			
		||||
  def str2structure(xstr):
 | 
			
		||||
    assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr))
 | 
			
		||||
    nodestrs = xstr.split('+')
 | 
			
		||||
    genotypes = []
 | 
			
		||||
    for i, node_str in enumerate(nodestrs):
 | 
			
		||||
      inputs = list(filter(lambda x: x != '', node_str.split('|')))
 | 
			
		||||
      for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
 | 
			
		||||
      inputs = ( xi.split('~') for xi in inputs )
 | 
			
		||||
      input_infos = tuple( (op, int(IDX)) for (op, IDX) in inputs)
 | 
			
		||||
      genotypes.append( input_infos )
 | 
			
		||||
    return Structure( genotypes )
 | 
			
		||||
 | 
			
		||||
  @staticmethod
 | 
			
		||||
  def str2fullstructure(xstr, default_name='none'):
 | 
			
		||||
    assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr))
 | 
			
		||||
    nodestrs = xstr.split('+')
 | 
			
		||||
    genotypes = []
 | 
			
		||||
    for i, node_str in enumerate(nodestrs):
 | 
			
		||||
      inputs = list(filter(lambda x: x != '', node_str.split('|')))
 | 
			
		||||
      for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
 | 
			
		||||
      inputs = ( xi.split('~') for xi in inputs )
 | 
			
		||||
      input_infos = list( (op, int(IDX)) for (op, IDX) in inputs)
 | 
			
		||||
      all_in_nodes= list(x[1] for x in input_infos)
 | 
			
		||||
      for j in range(i):
 | 
			
		||||
        if j not in all_in_nodes: input_infos.append((default_name, j))
 | 
			
		||||
      node_info = sorted(input_infos, key=lambda x: (x[1], x[0]))
 | 
			
		||||
      genotypes.append( tuple(node_info) )
 | 
			
		||||
    return Structure( genotypes )
 | 
			
		||||
 | 
			
		||||
  @staticmethod
 | 
			
		||||
  def gen_all(search_space, num, return_ori):
 | 
			
		||||
    assert isinstance(search_space, list) or isinstance(search_space, tuple), 'invalid class of search-space : {:}'.format(type(search_space))
 | 
			
		||||
    assert num >= 2, 'There should be at least two nodes in a neural cell instead of {:}'.format(num)
 | 
			
		||||
    all_archs = get_combination(search_space, 1)
 | 
			
		||||
    for i, arch in enumerate(all_archs):
 | 
			
		||||
      all_archs[i] = [ tuple(arch) ]
 | 
			
		||||
  
 | 
			
		||||
    for inode in range(2, num):
 | 
			
		||||
      cur_nodes = get_combination(search_space, inode)
 | 
			
		||||
      new_all_archs = []
 | 
			
		||||
      for previous_arch in all_archs:
 | 
			
		||||
        for cur_node in cur_nodes:
 | 
			
		||||
          new_all_archs.append( previous_arch + [tuple(cur_node)] )
 | 
			
		||||
      all_archs = new_all_archs
 | 
			
		||||
    if return_ori:
 | 
			
		||||
      return all_archs
 | 
			
		||||
    else:
 | 
			
		||||
      return [Structure(x) for x in all_archs]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
ResNet_CODE = Structure(
 | 
			
		||||
  [(('nor_conv_3x3', 0), ), # node-1 
 | 
			
		||||
   (('nor_conv_3x3', 1), ), # node-2
 | 
			
		||||
   (('skip_connect', 0), ('skip_connect', 2))] # node-3
 | 
			
		||||
  )
 | 
			
		||||
 | 
			
		||||
AllConv3x3_CODE = Structure(
 | 
			
		||||
  [(('nor_conv_3x3', 0), ), # node-1 
 | 
			
		||||
   (('nor_conv_3x3', 0), ('nor_conv_3x3', 1)), # node-2
 | 
			
		||||
   (('nor_conv_3x3', 0), ('nor_conv_3x3', 1), ('nor_conv_3x3', 2))] # node-3
 | 
			
		||||
  )
 | 
			
		||||
 | 
			
		||||
AllFull_CODE = Structure(
 | 
			
		||||
  [(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0)), # node-1 
 | 
			
		||||
   (('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0), ('skip_connect', 1), ('nor_conv_1x1', 1), ('nor_conv_3x3', 1), ('avg_pool_3x3', 1)), # node-2
 | 
			
		||||
   (('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0), ('skip_connect', 1), ('nor_conv_1x1', 1), ('nor_conv_3x3', 1), ('avg_pool_3x3', 1), ('skip_connect', 2), ('nor_conv_1x1', 2), ('nor_conv_3x3', 2), ('avg_pool_3x3', 2))] # node-3
 | 
			
		||||
  )
 | 
			
		||||
 | 
			
		||||
AllConv1x1_CODE = Structure(
 | 
			
		||||
  [(('nor_conv_1x1', 0), ), # node-1 
 | 
			
		||||
   (('nor_conv_1x1', 0), ('nor_conv_1x1', 1)), # node-2
 | 
			
		||||
   (('nor_conv_1x1', 0), ('nor_conv_1x1', 1), ('nor_conv_1x1', 2))] # node-3
 | 
			
		||||
  )
 | 
			
		||||
 | 
			
		||||
AllIdentity_CODE = Structure(
 | 
			
		||||
  [(('skip_connect', 0), ), # node-1 
 | 
			
		||||
   (('skip_connect', 0), ('skip_connect', 1)), # node-2
 | 
			
		||||
   (('skip_connect', 0), ('skip_connect', 1), ('skip_connect', 2))] # node-3
 | 
			
		||||
  )
 | 
			
		||||
 | 
			
		||||
architectures = {'resnet'  : ResNet_CODE,
 | 
			
		||||
                 'all_c3x3': AllConv3x3_CODE,
 | 
			
		||||
                 'all_c1x1': AllConv1x1_CODE,
 | 
			
		||||
                 'all_idnt': AllIdentity_CODE,
 | 
			
		||||
                 'all_full': AllFull_CODE}
 | 
			
		||||
@@ -1,148 +0,0 @@
 | 
			
		||||
##################################################
 | 
			
		||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
 | 
			
		||||
##################################################
 | 
			
		||||
import math, random, torch
 | 
			
		||||
import warnings
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
from ..cell_operations import OPS
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SearchCell(nn.Module):
 | 
			
		||||
 | 
			
		||||
  def __init__(self, C_in, C_out, stride, max_nodes, op_names, n_piece):
 | 
			
		||||
    super(SearchCell, self).__init__()
 | 
			
		||||
 | 
			
		||||
    self.op_names  = deepcopy(op_names)
 | 
			
		||||
    self.max_nodes = max_nodes
 | 
			
		||||
    self.in_dim    = C_in
 | 
			
		||||
    self.out_dim   = C_out
 | 
			
		||||
    self.n_piece   = n_piece
 | 
			
		||||
    self.multi_edges = nn.ModuleList()
 | 
			
		||||
    for i_piece in range(n_piece):
 | 
			
		||||
      edges          = nn.ModuleDict()
 | 
			
		||||
      for i in range(1, max_nodes):
 | 
			
		||||
        for j in range(i):
 | 
			
		||||
          node_str = '{:}<-{:}'.format(i, j)
 | 
			
		||||
          if j == 0: xlists = [OPS[op_name](C_in , C_out, stride) for op_name in op_names]
 | 
			
		||||
          else     : xlists = [OPS[op_name](C_in , C_out,      1) for op_name in op_names]
 | 
			
		||||
          edges[ node_str ] = nn.ModuleList( xlists )
 | 
			
		||||
      self.multi_edges.append( edges )
 | 
			
		||||
      
 | 
			
		||||
    self.edge_keys  = sorted(list(edges.keys()))
 | 
			
		||||
    self.edge2index = {key:i for i, key in enumerate(self.edge_keys)}
 | 
			
		||||
    self.num_edges  = len(edges)
 | 
			
		||||
 | 
			
		||||
  def extra_repr(self):
 | 
			
		||||
    string = 'info :: {max_nodes} nodes, inC={in_dim}, outC={out_dim}, nP={n_piece}'.format(**self.__dict__)
 | 
			
		||||
    return string
 | 
			
		||||
 | 
			
		||||
  def forward(self, inputs, weightss):
 | 
			
		||||
    nodes = [inputs]
 | 
			
		||||
    with torch.no_grad():
 | 
			
		||||
      xmod, xid, argmax = 1, 0, weightss.argmax(dim=1).cpu().tolist()
 | 
			
		||||
      for i, x in enumerate(argmax):
 | 
			
		||||
        xid += x * (xmod % self.n_piece)
 | 
			
		||||
        xmod = (xmod * len(self.op_names)) % self.n_piece
 | 
			
		||||
      xid = xid % self.n_piece
 | 
			
		||||
    edges = self.multi_edges[xid]
 | 
			
		||||
    for i in range(1, self.max_nodes):
 | 
			
		||||
      inter_nodes = []
 | 
			
		||||
      for j in range(i):
 | 
			
		||||
        node_str = '{:}<-{:}'.format(i, j)
 | 
			
		||||
        weights  = weightss[ self.edge2index[node_str] ]
 | 
			
		||||
        inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(edges[node_str], weights) ) )
 | 
			
		||||
      nodes.append( sum(inter_nodes) )
 | 
			
		||||
    return nodes[-1]
 | 
			
		||||
 | 
			
		||||
  # GDAS
 | 
			
		||||
  def forward_gdas(self, inputs, alphas, _tau):
 | 
			
		||||
    avoid_zero = 0
 | 
			
		||||
    while True:
 | 
			
		||||
      gumbels = -torch.empty_like(alphas).exponential_().log()
 | 
			
		||||
      logits  = (alphas.log_softmax(dim=1) + gumbels) / _tau
 | 
			
		||||
      probs   = nn.functional.softmax(logits, dim=1)
 | 
			
		||||
      index   = probs.max(-1, keepdim=True)[1]
 | 
			
		||||
      one_h   = torch.zeros_like(logits).scatter_(-1, index, 1.0)
 | 
			
		||||
      hardwts = one_h - probs.detach() + probs
 | 
			
		||||
      if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()):
 | 
			
		||||
        continue # avoid the numerical error
 | 
			
		||||
      nodes   = [inputs]
 | 
			
		||||
      for i in range(1, self.max_nodes):
 | 
			
		||||
        inter_nodes = []
 | 
			
		||||
        for j in range(i):
 | 
			
		||||
          node_str = '{:}<-{:}'.format(i, j)
 | 
			
		||||
          weights  = hardwts[ self.edge2index[node_str] ]
 | 
			
		||||
          argmaxs  = index[ self.edge2index[node_str] ].item()
 | 
			
		||||
          weigsum  = sum( weights[_ie] * edge(nodes[j]) if _ie == argmaxs else weights[_ie] for _ie, edge in enumerate(self.edges[node_str]) )
 | 
			
		||||
          inter_nodes.append( weigsum )
 | 
			
		||||
        nodes.append( sum(inter_nodes) )
 | 
			
		||||
      avoid_zero += 1
 | 
			
		||||
      if nodes[-1].sum().item() == 0:
 | 
			
		||||
        if avoid_zero < 10: continue
 | 
			
		||||
        else:
 | 
			
		||||
          warnings.warn('get zero outputs with avoid_zero={:}'.format(avoid_zero))
 | 
			
		||||
          break
 | 
			
		||||
      else:
 | 
			
		||||
        break
 | 
			
		||||
    return nodes[-1]
 | 
			
		||||
 | 
			
		||||
  # joint
 | 
			
		||||
  def forward_joint(self, inputs, weightss):
 | 
			
		||||
    nodes = [inputs]
 | 
			
		||||
    for i in range(1, self.max_nodes):
 | 
			
		||||
      inter_nodes = []
 | 
			
		||||
      for j in range(i):
 | 
			
		||||
        node_str = '{:}<-{:}'.format(i, j)
 | 
			
		||||
        weights  = weightss[ self.edge2index[node_str] ]
 | 
			
		||||
        #aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) / weights.numel()
 | 
			
		||||
        aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) )
 | 
			
		||||
        inter_nodes.append( aggregation )
 | 
			
		||||
      nodes.append( sum(inter_nodes) )
 | 
			
		||||
    return nodes[-1]
 | 
			
		||||
 | 
			
		||||
  # uniform random sampling per iteration
 | 
			
		||||
  def forward_urs(self, inputs):
 | 
			
		||||
    nodes = [inputs]
 | 
			
		||||
    for i in range(1, self.max_nodes):
 | 
			
		||||
      while True: # to avoid select zero for all ops
 | 
			
		||||
        sops, has_non_zero = [], False
 | 
			
		||||
        for j in range(i):
 | 
			
		||||
          node_str   = '{:}<-{:}'.format(i, j)
 | 
			
		||||
          candidates = self.edges[node_str]
 | 
			
		||||
          select_op  = random.choice(candidates)
 | 
			
		||||
          sops.append( select_op )
 | 
			
		||||
          if not hasattr(select_op, 'is_zero') or select_op.is_zero == False: has_non_zero=True
 | 
			
		||||
        if has_non_zero: break
 | 
			
		||||
      inter_nodes = []
 | 
			
		||||
      for j, select_op in enumerate(sops):
 | 
			
		||||
        inter_nodes.append( select_op(nodes[j]) )
 | 
			
		||||
      nodes.append( sum(inter_nodes) )
 | 
			
		||||
    return nodes[-1]
 | 
			
		||||
 | 
			
		||||
  # select the argmax
 | 
			
		||||
  def forward_select(self, inputs, weightss):
 | 
			
		||||
    nodes = [inputs]
 | 
			
		||||
    for i in range(1, self.max_nodes):
 | 
			
		||||
      inter_nodes = []
 | 
			
		||||
      for j in range(i):
 | 
			
		||||
        node_str = '{:}<-{:}'.format(i, j)
 | 
			
		||||
        weights  = weightss[ self.edge2index[node_str] ]
 | 
			
		||||
        inter_nodes.append( self.edges[node_str][ weights.argmax().item() ]( nodes[j] ) )
 | 
			
		||||
        #inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) )
 | 
			
		||||
      nodes.append( sum(inter_nodes) )
 | 
			
		||||
    return nodes[-1]
 | 
			
		||||
 | 
			
		||||
  # forward with a specific structure
 | 
			
		||||
  def forward_dynamic(self, inputs, structure):
 | 
			
		||||
    nodes = [inputs]
 | 
			
		||||
    for i in range(1, self.max_nodes):
 | 
			
		||||
      cur_op_node = structure.nodes[i-1]
 | 
			
		||||
      inter_nodes = []
 | 
			
		||||
      for op_name, j in cur_op_node:
 | 
			
		||||
        node_str = '{:}<-{:}'.format(i, j)
 | 
			
		||||
        op_index = self.op_names.index( op_name )
 | 
			
		||||
        inter_nodes.append( self.edges[node_str][op_index]( nodes[j] ) )
 | 
			
		||||
      nodes.append( sum(inter_nodes) )
 | 
			
		||||
    return nodes[-1]
 | 
			
		||||
@@ -1,93 +0,0 @@
 | 
			
		||||
##################################################
 | 
			
		||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
 | 
			
		||||
########################################################
 | 
			
		||||
# DARTS: Differentiable Architecture Search, ICLR 2019 #
 | 
			
		||||
########################################################
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
from ..cell_operations import ResNetBasicblock
 | 
			
		||||
from .search_cells     import SearchCell
 | 
			
		||||
from .genotypes        import Structure
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TinyNetworkDartsV1(nn.Module):
 | 
			
		||||
 | 
			
		||||
  def __init__(self, C, N, max_nodes, num_classes, search_space, n_piece):
 | 
			
		||||
    super(TinyNetworkDartsV1, self).__init__()
 | 
			
		||||
    self._C        = C
 | 
			
		||||
    self._layerN   = N
 | 
			
		||||
    self.max_nodes = max_nodes
 | 
			
		||||
    self.stem = nn.Sequential(
 | 
			
		||||
                    nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
 | 
			
		||||
                    nn.BatchNorm2d(C))
 | 
			
		||||
  
 | 
			
		||||
    layer_channels   = [C    ] * N + [C*2 ] + [C*2  ] * N + [C*4 ] + [C*4  ] * N    
 | 
			
		||||
    layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
 | 
			
		||||
 | 
			
		||||
    C_prev, num_edge, edge2index = C, None, None
 | 
			
		||||
    self.cells = nn.ModuleList()
 | 
			
		||||
    for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
 | 
			
		||||
      if reduction:
 | 
			
		||||
        cell = ResNetBasicblock(C_prev, C_curr, 2)
 | 
			
		||||
      else:
 | 
			
		||||
        cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, n_piece)
 | 
			
		||||
        if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
 | 
			
		||||
        else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
 | 
			
		||||
      self.cells.append( cell )
 | 
			
		||||
      C_prev = cell.out_dim
 | 
			
		||||
    self.op_names   = deepcopy( search_space )
 | 
			
		||||
    self._Layer     = len(self.cells)
 | 
			
		||||
    self.edge2index = edge2index
 | 
			
		||||
    self.lastact    = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
 | 
			
		||||
    self.global_pooling = nn.AdaptiveAvgPool2d(1)
 | 
			
		||||
    self.classifier = nn.Linear(C_prev, num_classes)
 | 
			
		||||
    self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
 | 
			
		||||
 | 
			
		||||
  def get_weights(self):
 | 
			
		||||
    xlist = list( self.stem.parameters() ) + list( self.cells.parameters() )
 | 
			
		||||
    xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() )
 | 
			
		||||
    xlist+= list( self.classifier.parameters() )
 | 
			
		||||
    return xlist
 | 
			
		||||
 | 
			
		||||
  def get_alphas(self):
 | 
			
		||||
    return [self.arch_parameters]
 | 
			
		||||
 | 
			
		||||
  def get_message(self):
 | 
			
		||||
    string = self.extra_repr()
 | 
			
		||||
    for i, cell in enumerate(self.cells):
 | 
			
		||||
      string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
 | 
			
		||||
    return string
 | 
			
		||||
 | 
			
		||||
  def extra_repr(self):
 | 
			
		||||
    return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
 | 
			
		||||
 | 
			
		||||
  def genotype(self):
 | 
			
		||||
    genotypes = []
 | 
			
		||||
    for i in range(1, self.max_nodes):
 | 
			
		||||
      xlist = []
 | 
			
		||||
      for j in range(i):
 | 
			
		||||
        node_str = '{:}<-{:}'.format(i, j)
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
          weights = self.arch_parameters[ self.edge2index[node_str] ]
 | 
			
		||||
          op_name = self.op_names[ weights.argmax().item() ]
 | 
			
		||||
        xlist.append((op_name, j))
 | 
			
		||||
      genotypes.append( tuple(xlist) )
 | 
			
		||||
    return Structure( genotypes )
 | 
			
		||||
 | 
			
		||||
  def forward(self, inputs):
 | 
			
		||||
    alphas  = nn.functional.softmax(self.arch_parameters, dim=-1)
 | 
			
		||||
 | 
			
		||||
    feature = self.stem(inputs)
 | 
			
		||||
    for i, cell in enumerate(self.cells):
 | 
			
		||||
      if isinstance(cell, SearchCell):
 | 
			
		||||
        feature = cell(feature, alphas)
 | 
			
		||||
      else:
 | 
			
		||||
        feature = cell(feature)
 | 
			
		||||
 | 
			
		||||
    out = self.lastact(feature)
 | 
			
		||||
    out = self.global_pooling( out )
 | 
			
		||||
    out = out.view(out.size(0), -1)
 | 
			
		||||
    logits = self.classifier(out)
 | 
			
		||||
 | 
			
		||||
    return out, logits
 | 
			
		||||
@@ -1,93 +0,0 @@
 | 
			
		||||
##################################################
 | 
			
		||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
 | 
			
		||||
########################################################
 | 
			
		||||
# DARTS: Differentiable Architecture Search, ICLR 2019 #
 | 
			
		||||
########################################################
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
from ..cell_operations import ResNetBasicblock
 | 
			
		||||
from .search_cells     import SearchCell
 | 
			
		||||
from .genotypes        import Structure
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TinyNetworkDartsV2(nn.Module):
 | 
			
		||||
 | 
			
		||||
  def __init__(self, C, N, max_nodes, num_classes, search_space):
 | 
			
		||||
    super(TinyNetworkDartsV2, self).__init__()
 | 
			
		||||
    self._C        = C
 | 
			
		||||
    self._layerN   = N
 | 
			
		||||
    self.max_nodes = max_nodes
 | 
			
		||||
    self.stem = nn.Sequential(
 | 
			
		||||
                    nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
 | 
			
		||||
                    nn.BatchNorm2d(C))
 | 
			
		||||
  
 | 
			
		||||
    layer_channels   = [C    ] * N + [C*2 ] + [C*2  ] * N + [C*4 ] + [C*4  ] * N    
 | 
			
		||||
    layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
 | 
			
		||||
 | 
			
		||||
    C_prev, num_edge, edge2index = C, None, None
 | 
			
		||||
    self.cells = nn.ModuleList()
 | 
			
		||||
    for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
 | 
			
		||||
      if reduction:
 | 
			
		||||
        cell = ResNetBasicblock(C_prev, C_curr, 2)
 | 
			
		||||
      else:
 | 
			
		||||
        cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space)
 | 
			
		||||
        if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
 | 
			
		||||
        else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
 | 
			
		||||
      self.cells.append( cell )
 | 
			
		||||
      C_prev = cell.out_dim
 | 
			
		||||
    self.op_names   = deepcopy( search_space )
 | 
			
		||||
    self._Layer     = len(self.cells)
 | 
			
		||||
    self.edge2index = edge2index
 | 
			
		||||
    self.lastact    = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
 | 
			
		||||
    self.global_pooling = nn.AdaptiveAvgPool2d(1)
 | 
			
		||||
    self.classifier = nn.Linear(C_prev, num_classes)
 | 
			
		||||
    self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
 | 
			
		||||
 | 
			
		||||
  def get_weights(self):
 | 
			
		||||
    xlist = list( self.stem.parameters() ) + list( self.cells.parameters() )
 | 
			
		||||
    xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() )
 | 
			
		||||
    xlist+= list( self.classifier.parameters() )
 | 
			
		||||
    return xlist
 | 
			
		||||
 | 
			
		||||
  def get_alphas(self):
 | 
			
		||||
    return [self.arch_parameters]
 | 
			
		||||
 | 
			
		||||
  def get_message(self):
 | 
			
		||||
    string = self.extra_repr()
 | 
			
		||||
    for i, cell in enumerate(self.cells):
 | 
			
		||||
      string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
 | 
			
		||||
    return string
 | 
			
		||||
 | 
			
		||||
  def extra_repr(self):
 | 
			
		||||
    return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
 | 
			
		||||
 | 
			
		||||
  def genotype(self):
 | 
			
		||||
    genotypes = []
 | 
			
		||||
    for i in range(1, self.max_nodes):
 | 
			
		||||
      xlist = []
 | 
			
		||||
      for j in range(i):
 | 
			
		||||
        node_str = '{:}<-{:}'.format(i, j)
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
          weights = self.arch_parameters[ self.edge2index[node_str] ]
 | 
			
		||||
          op_name = self.op_names[ weights.argmax().item() ]
 | 
			
		||||
        xlist.append((op_name, j))
 | 
			
		||||
      genotypes.append( tuple(xlist) )
 | 
			
		||||
    return Structure( genotypes )
 | 
			
		||||
 | 
			
		||||
  def forward(self, inputs):
 | 
			
		||||
    alphas  = nn.functional.softmax(self.arch_parameters, dim=-1)
 | 
			
		||||
 | 
			
		||||
    feature = self.stem(inputs)
 | 
			
		||||
    for i, cell in enumerate(self.cells):
 | 
			
		||||
      if isinstance(cell, SearchCell):
 | 
			
		||||
        feature = cell(feature, alphas)
 | 
			
		||||
      else:
 | 
			
		||||
        feature = cell(feature)
 | 
			
		||||
 | 
			
		||||
    out = self.lastact(feature)
 | 
			
		||||
    out = self.global_pooling( out )
 | 
			
		||||
    out = out.view(out.size(0), -1)
 | 
			
		||||
    logits = self.classifier(out)
 | 
			
		||||
 | 
			
		||||
    return out, logits
 | 
			
		||||
@@ -1,94 +0,0 @@
 | 
			
		||||
##################################################
 | 
			
		||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
 | 
			
		||||
##########################################################################
 | 
			
		||||
# Efficient Neural Architecture Search via Parameters Sharing, ICML 2018 #
 | 
			
		||||
##########################################################################
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
from ..cell_operations import ResNetBasicblock
 | 
			
		||||
from .search_cells     import SearchCell
 | 
			
		||||
from .genotypes        import Structure
 | 
			
		||||
from .search_model_enas_utils import Controller
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TinyNetworkENAS(nn.Module):
 | 
			
		||||
 | 
			
		||||
  def __init__(self, C, N, max_nodes, num_classes, search_space):
 | 
			
		||||
    super(TinyNetworkENAS, self).__init__()
 | 
			
		||||
    self._C        = C
 | 
			
		||||
    self._layerN   = N
 | 
			
		||||
    self.max_nodes = max_nodes
 | 
			
		||||
    self.stem = nn.Sequential(
 | 
			
		||||
                    nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
 | 
			
		||||
                    nn.BatchNorm2d(C))
 | 
			
		||||
  
 | 
			
		||||
    layer_channels   = [C    ] * N + [C*2 ] + [C*2  ] * N + [C*4 ] + [C*4  ] * N    
 | 
			
		||||
    layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
 | 
			
		||||
 | 
			
		||||
    C_prev, num_edge, edge2index = C, None, None
 | 
			
		||||
    self.cells = nn.ModuleList()
 | 
			
		||||
    for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
 | 
			
		||||
      if reduction:
 | 
			
		||||
        cell = ResNetBasicblock(C_prev, C_curr, 2)
 | 
			
		||||
      else:
 | 
			
		||||
        cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space)
 | 
			
		||||
        if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
 | 
			
		||||
        else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
 | 
			
		||||
      self.cells.append( cell )
 | 
			
		||||
      C_prev = cell.out_dim
 | 
			
		||||
    self.op_names   = deepcopy( search_space )
 | 
			
		||||
    self._Layer     = len(self.cells)
 | 
			
		||||
    self.edge2index = edge2index
 | 
			
		||||
    self.lastact    = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
 | 
			
		||||
    self.global_pooling = nn.AdaptiveAvgPool2d(1)
 | 
			
		||||
    self.classifier = nn.Linear(C_prev, num_classes)
 | 
			
		||||
    # to maintain the sampled architecture
 | 
			
		||||
    self.sampled_arch = None
 | 
			
		||||
 | 
			
		||||
  def update_arch(self, _arch):
 | 
			
		||||
    if _arch is None:
 | 
			
		||||
      self.sampled_arch = None
 | 
			
		||||
    elif isinstance(_arch, Structure):
 | 
			
		||||
      self.sampled_arch = _arch
 | 
			
		||||
    elif isinstance(_arch, (list, tuple)):
 | 
			
		||||
      genotypes = []
 | 
			
		||||
      for i in range(1, self.max_nodes):
 | 
			
		||||
        xlist = []
 | 
			
		||||
        for j in range(i):
 | 
			
		||||
          node_str = '{:}<-{:}'.format(i, j)
 | 
			
		||||
          op_index = _arch[ self.edge2index[node_str] ]
 | 
			
		||||
          op_name  = self.op_names[ op_index ]
 | 
			
		||||
          xlist.append((op_name, j))
 | 
			
		||||
        genotypes.append( tuple(xlist) )
 | 
			
		||||
      self.sampled_arch = Structure(genotypes)
 | 
			
		||||
    else:
 | 
			
		||||
      raise ValueError('invalid type of input architecture : {:}'.format(_arch))
 | 
			
		||||
    return self.sampled_arch
 | 
			
		||||
    
 | 
			
		||||
  def create_controller(self):
 | 
			
		||||
    return Controller(len(self.edge2index), len(self.op_names))
 | 
			
		||||
 | 
			
		||||
  def get_message(self):
 | 
			
		||||
    string = self.extra_repr()
 | 
			
		||||
    for i, cell in enumerate(self.cells):
 | 
			
		||||
      string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
 | 
			
		||||
    return string
 | 
			
		||||
 | 
			
		||||
  def extra_repr(self):
 | 
			
		||||
    return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
 | 
			
		||||
 | 
			
		||||
  def forward(self, inputs):
 | 
			
		||||
 | 
			
		||||
    feature = self.stem(inputs)
 | 
			
		||||
    for i, cell in enumerate(self.cells):
 | 
			
		||||
      if isinstance(cell, SearchCell):
 | 
			
		||||
        feature = cell.forward_dynamic(feature, self.sampled_arch)
 | 
			
		||||
      else: feature = cell(feature)
 | 
			
		||||
 | 
			
		||||
    out = self.lastact(feature)
 | 
			
		||||
    out = self.global_pooling( out )
 | 
			
		||||
    out = out.view(out.size(0), -1)
 | 
			
		||||
    logits = self.classifier(out)
 | 
			
		||||
 | 
			
		||||
    return out, logits
 | 
			
		||||
@@ -1,55 +0,0 @@
 | 
			
		||||
##################################################
 | 
			
		||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
 | 
			
		||||
##########################################################################
 | 
			
		||||
# Efficient Neural Architecture Search via Parameters Sharing, ICML 2018 #
 | 
			
		||||
##########################################################################
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
from torch.distributions.categorical import Categorical
 | 
			
		||||
 | 
			
		||||
class Controller(nn.Module):
 | 
			
		||||
  # we refer to https://github.com/TDeVries/enas_pytorch/blob/master/models/controller.py
 | 
			
		||||
  def __init__(self, num_edge, num_ops, lstm_size=32, lstm_num_layers=2, tanh_constant=2.5, temperature=5.0):
 | 
			
		||||
    super(Controller, self).__init__()
 | 
			
		||||
    # assign the attributes
 | 
			
		||||
    self.num_edge  = num_edge
 | 
			
		||||
    self.num_ops   = num_ops
 | 
			
		||||
    self.lstm_size = lstm_size
 | 
			
		||||
    self.lstm_N    = lstm_num_layers
 | 
			
		||||
    self.tanh_constant = tanh_constant
 | 
			
		||||
    self.temperature   = temperature
 | 
			
		||||
    # create parameters
 | 
			
		||||
    self.register_parameter('input_vars', nn.Parameter(torch.Tensor(1, 1, lstm_size)))
 | 
			
		||||
    self.w_lstm = nn.LSTM(input_size=self.lstm_size, hidden_size=self.lstm_size, num_layers=self.lstm_N)
 | 
			
		||||
    self.w_embd = nn.Embedding(self.num_ops, self.lstm_size)
 | 
			
		||||
    self.w_pred = nn.Linear(self.lstm_size, self.num_ops)
 | 
			
		||||
 | 
			
		||||
    nn.init.uniform_(self.input_vars         , -0.1, 0.1)
 | 
			
		||||
    nn.init.uniform_(self.w_lstm.weight_hh_l0, -0.1, 0.1)
 | 
			
		||||
    nn.init.uniform_(self.w_lstm.weight_ih_l0, -0.1, 0.1)
 | 
			
		||||
    nn.init.uniform_(self.w_embd.weight      , -0.1, 0.1)
 | 
			
		||||
    nn.init.uniform_(self.w_pred.weight      , -0.1, 0.1)
 | 
			
		||||
 | 
			
		||||
  def forward(self):
 | 
			
		||||
 | 
			
		||||
    inputs, h0 = self.input_vars, None
 | 
			
		||||
    log_probs, entropys, sampled_arch = [], [], []
 | 
			
		||||
    for iedge in range(self.num_edge):
 | 
			
		||||
      outputs, h0 = self.w_lstm(inputs, h0)
 | 
			
		||||
      
 | 
			
		||||
      logits = self.w_pred(outputs)
 | 
			
		||||
      logits = logits / self.temperature
 | 
			
		||||
      logits = self.tanh_constant * torch.tanh(logits)
 | 
			
		||||
      # distribution
 | 
			
		||||
      op_distribution = Categorical(logits=logits)
 | 
			
		||||
      op_index    = op_distribution.sample()
 | 
			
		||||
      sampled_arch.append( op_index.item() )
 | 
			
		||||
 | 
			
		||||
      op_log_prob = op_distribution.log_prob(op_index)
 | 
			
		||||
      log_probs.append( op_log_prob.view(-1) )
 | 
			
		||||
      op_entropy  = op_distribution.entropy()
 | 
			
		||||
      entropys.append( op_entropy.view(-1) )
 | 
			
		||||
      
 | 
			
		||||
      # obtain the input embedding for the next step
 | 
			
		||||
      inputs = self.w_embd(op_index)
 | 
			
		||||
    return torch.sum(torch.cat(log_probs)), torch.sum(torch.cat(entropys)), sampled_arch
 | 
			
		||||
@@ -1,96 +0,0 @@
 | 
			
		||||
###########################################################################
 | 
			
		||||
# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 #
 | 
			
		||||
###########################################################################
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
from ..cell_operations import ResNetBasicblock
 | 
			
		||||
from .search_cells     import SearchCell
 | 
			
		||||
from .genotypes        import Structure
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TinyNetworkGDAS(nn.Module):
 | 
			
		||||
 | 
			
		||||
  def __init__(self, C, N, max_nodes, num_classes, search_space):
 | 
			
		||||
    super(TinyNetworkGDAS, self).__init__()
 | 
			
		||||
    self._C        = C
 | 
			
		||||
    self._layerN   = N
 | 
			
		||||
    self.max_nodes = max_nodes
 | 
			
		||||
    self.stem = nn.Sequential(
 | 
			
		||||
                    nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
 | 
			
		||||
                    nn.BatchNorm2d(C))
 | 
			
		||||
  
 | 
			
		||||
    layer_channels   = [C    ] * N + [C*2 ] + [C*2  ] * N + [C*4 ] + [C*4  ] * N    
 | 
			
		||||
    layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
 | 
			
		||||
 | 
			
		||||
    C_prev, num_edge, edge2index = C, None, None
 | 
			
		||||
    self.cells = nn.ModuleList()
 | 
			
		||||
    for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
 | 
			
		||||
      if reduction:
 | 
			
		||||
        cell = ResNetBasicblock(C_prev, C_curr, 2)
 | 
			
		||||
      else:
 | 
			
		||||
        cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space)
 | 
			
		||||
        if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
 | 
			
		||||
        else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
 | 
			
		||||
      self.cells.append( cell )
 | 
			
		||||
      C_prev = cell.out_dim
 | 
			
		||||
    self.op_names   = deepcopy( search_space )
 | 
			
		||||
    self._Layer     = len(self.cells)
 | 
			
		||||
    self.edge2index = edge2index
 | 
			
		||||
    self.lastact    = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
 | 
			
		||||
    self.global_pooling = nn.AdaptiveAvgPool2d(1)
 | 
			
		||||
    self.classifier = nn.Linear(C_prev, num_classes)
 | 
			
		||||
    self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
 | 
			
		||||
    self.tau        = 10
 | 
			
		||||
 | 
			
		||||
  def get_weights(self):
 | 
			
		||||
    xlist = list( self.stem.parameters() ) + list( self.cells.parameters() )
 | 
			
		||||
    xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() )
 | 
			
		||||
    xlist+= list( self.classifier.parameters() )
 | 
			
		||||
    return xlist
 | 
			
		||||
 | 
			
		||||
  def set_tau(self, tau):
 | 
			
		||||
    self.tau = tau
 | 
			
		||||
 | 
			
		||||
  def get_tau(self):
 | 
			
		||||
    return self.tau
 | 
			
		||||
 | 
			
		||||
  def get_alphas(self):
 | 
			
		||||
    return [self.arch_parameters]
 | 
			
		||||
 | 
			
		||||
  def get_message(self):
 | 
			
		||||
    string = self.extra_repr()
 | 
			
		||||
    for i, cell in enumerate(self.cells):
 | 
			
		||||
      string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
 | 
			
		||||
    return string
 | 
			
		||||
 | 
			
		||||
  def extra_repr(self):
 | 
			
		||||
    return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
 | 
			
		||||
 | 
			
		||||
  def genotype(self):
 | 
			
		||||
    genotypes = []
 | 
			
		||||
    for i in range(1, self.max_nodes):
 | 
			
		||||
      xlist = []
 | 
			
		||||
      for j in range(i):
 | 
			
		||||
        node_str = '{:}<-{:}'.format(i, j)
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
          weights = self.arch_parameters[ self.edge2index[node_str] ]
 | 
			
		||||
          op_name = self.op_names[ weights.argmax().item() ]
 | 
			
		||||
        xlist.append((op_name, j))
 | 
			
		||||
      genotypes.append( tuple(xlist) )
 | 
			
		||||
    return Structure( genotypes )
 | 
			
		||||
 | 
			
		||||
  def forward(self, inputs):
 | 
			
		||||
    feature = self.stem(inputs)
 | 
			
		||||
    for i, cell in enumerate(self.cells):
 | 
			
		||||
      if isinstance(cell, SearchCell):
 | 
			
		||||
        feature = cell.forward_gdas(feature, self.arch_parameters, self.tau)
 | 
			
		||||
      else:
 | 
			
		||||
        feature = cell(feature)
 | 
			
		||||
 | 
			
		||||
    out = self.lastact(feature)
 | 
			
		||||
    out = self.global_pooling( out )
 | 
			
		||||
    out = out.view(out.size(0), -1)
 | 
			
		||||
    logits = self.classifier(out)
 | 
			
		||||
 | 
			
		||||
    return out, logits
 | 
			
		||||
@@ -1,81 +0,0 @@
 | 
			
		||||
##################################################
 | 
			
		||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
 | 
			
		||||
##############################################################################
 | 
			
		||||
# Random Search and Reproducibility for Neural Architecture Search, UAI 2019 # 
 | 
			
		||||
##############################################################################
 | 
			
		||||
import torch, random
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
from ..cell_operations import ResNetBasicblock
 | 
			
		||||
from .search_cells     import SearchCell
 | 
			
		||||
from .genotypes        import Structure
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TinyNetworkRANDOM(nn.Module):
 | 
			
		||||
 | 
			
		||||
  def __init__(self, C, N, max_nodes, num_classes, search_space):
 | 
			
		||||
    super(TinyNetworkRANDOM, self).__init__()
 | 
			
		||||
    self._C        = C
 | 
			
		||||
    self._layerN   = N
 | 
			
		||||
    self.max_nodes = max_nodes
 | 
			
		||||
    self.stem = nn.Sequential(
 | 
			
		||||
                    nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
 | 
			
		||||
                    nn.BatchNorm2d(C))
 | 
			
		||||
  
 | 
			
		||||
    layer_channels   = [C    ] * N + [C*2 ] + [C*2  ] * N + [C*4 ] + [C*4  ] * N    
 | 
			
		||||
    layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
 | 
			
		||||
 | 
			
		||||
    C_prev, num_edge, edge2index = C, None, None
 | 
			
		||||
    self.cells = nn.ModuleList()
 | 
			
		||||
    for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
 | 
			
		||||
      if reduction:
 | 
			
		||||
        cell = ResNetBasicblock(C_prev, C_curr, 2)
 | 
			
		||||
      else:
 | 
			
		||||
        cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space)
 | 
			
		||||
        if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
 | 
			
		||||
        else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
 | 
			
		||||
      self.cells.append( cell )
 | 
			
		||||
      C_prev = cell.out_dim
 | 
			
		||||
    self.op_names   = deepcopy( search_space )
 | 
			
		||||
    self._Layer     = len(self.cells)
 | 
			
		||||
    self.edge2index = edge2index
 | 
			
		||||
    self.lastact    = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
 | 
			
		||||
    self.global_pooling = nn.AdaptiveAvgPool2d(1)
 | 
			
		||||
    self.classifier = nn.Linear(C_prev, num_classes)
 | 
			
		||||
    self.arch_cache = None
 | 
			
		||||
    
 | 
			
		||||
  def get_message(self):
 | 
			
		||||
    string = self.extra_repr()
 | 
			
		||||
    for i, cell in enumerate(self.cells):
 | 
			
		||||
      string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
 | 
			
		||||
    return string
 | 
			
		||||
 | 
			
		||||
  def extra_repr(self):
 | 
			
		||||
    return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
 | 
			
		||||
 | 
			
		||||
  def random_genotype(self, set_cache):
 | 
			
		||||
    genotypes = []
 | 
			
		||||
    for i in range(1, self.max_nodes):
 | 
			
		||||
      xlist = []
 | 
			
		||||
      for j in range(i):
 | 
			
		||||
        node_str = '{:}<-{:}'.format(i, j)
 | 
			
		||||
        op_name  = random.choice( self.op_names )
 | 
			
		||||
        xlist.append((op_name, j))
 | 
			
		||||
      genotypes.append( tuple(xlist) )
 | 
			
		||||
    arch = Structure( genotypes )
 | 
			
		||||
    if set_cache: self.arch_cache = arch
 | 
			
		||||
    return arch
 | 
			
		||||
 | 
			
		||||
  def forward(self, inputs):
 | 
			
		||||
 | 
			
		||||
    feature = self.stem(inputs)
 | 
			
		||||
    for i, cell in enumerate(self.cells):
 | 
			
		||||
      if isinstance(cell, SearchCell):
 | 
			
		||||
        feature = cell.forward_dynamic(feature, self.arch_cache)
 | 
			
		||||
      else: feature = cell(feature)
 | 
			
		||||
 | 
			
		||||
    out = self.lastact(feature)
 | 
			
		||||
    out = self.global_pooling( out )
 | 
			
		||||
    out = out.view(out.size(0), -1)
 | 
			
		||||
    logits = self.classifier(out)
 | 
			
		||||
    return out, logits
 | 
			
		||||
@@ -1,152 +0,0 @@
 | 
			
		||||
##################################################
 | 
			
		||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
 | 
			
		||||
######################################################################################
 | 
			
		||||
# One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 #
 | 
			
		||||
######################################################################################
 | 
			
		||||
import torch, random
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
from ..cell_operations import ResNetBasicblock
 | 
			
		||||
from .search_cells     import SearchCell
 | 
			
		||||
from .genotypes        import Structure
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TinyNetworkSETN(nn.Module):
 | 
			
		||||
 | 
			
		||||
  def __init__(self, C, N, max_nodes, num_classes, search_space):
 | 
			
		||||
    super(TinyNetworkSETN, self).__init__()
 | 
			
		||||
    self._C        = C
 | 
			
		||||
    self._layerN   = N
 | 
			
		||||
    self.max_nodes = max_nodes
 | 
			
		||||
    self.stem = nn.Sequential(
 | 
			
		||||
                    nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
 | 
			
		||||
                    nn.BatchNorm2d(C))
 | 
			
		||||
  
 | 
			
		||||
    layer_channels   = [C    ] * N + [C*2 ] + [C*2  ] * N + [C*4 ] + [C*4  ] * N    
 | 
			
		||||
    layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
 | 
			
		||||
 | 
			
		||||
    C_prev, num_edge, edge2index = C, None, None
 | 
			
		||||
    self.cells = nn.ModuleList()
 | 
			
		||||
    for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
 | 
			
		||||
      if reduction:
 | 
			
		||||
        cell = ResNetBasicblock(C_prev, C_curr, 2)
 | 
			
		||||
      else:
 | 
			
		||||
        cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space)
 | 
			
		||||
        if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
 | 
			
		||||
        else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
 | 
			
		||||
      self.cells.append( cell )
 | 
			
		||||
      C_prev = cell.out_dim
 | 
			
		||||
    self.op_names   = deepcopy( search_space )
 | 
			
		||||
    self._Layer     = len(self.cells)
 | 
			
		||||
    self.edge2index = edge2index
 | 
			
		||||
    self.lastact    = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
 | 
			
		||||
    self.global_pooling = nn.AdaptiveAvgPool2d(1)
 | 
			
		||||
    self.classifier = nn.Linear(C_prev, num_classes)
 | 
			
		||||
    self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
 | 
			
		||||
    self.mode       = 'urs'
 | 
			
		||||
    self.dynamic_cell = None
 | 
			
		||||
    
 | 
			
		||||
  def set_cal_mode(self, mode, dynamic_cell=None):
 | 
			
		||||
    assert mode in ['urs', 'joint', 'select', 'dynamic']
 | 
			
		||||
    self.mode = mode
 | 
			
		||||
    if mode == 'dynamic': self.dynamic_cell = deepcopy( dynamic_cell )
 | 
			
		||||
    else                : self.dynamic_cell = None
 | 
			
		||||
 | 
			
		||||
  def get_cal_mode(self):
 | 
			
		||||
    return self.mode
 | 
			
		||||
 | 
			
		||||
  def get_weights(self):
 | 
			
		||||
    xlist = list( self.stem.parameters() ) + list( self.cells.parameters() )
 | 
			
		||||
    xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() )
 | 
			
		||||
    xlist+= list( self.classifier.parameters() )
 | 
			
		||||
    return xlist
 | 
			
		||||
 | 
			
		||||
  def get_alphas(self):
 | 
			
		||||
    return [self.arch_parameters]
 | 
			
		||||
 | 
			
		||||
  def get_message(self):
 | 
			
		||||
    string = self.extra_repr()
 | 
			
		||||
    for i, cell in enumerate(self.cells):
 | 
			
		||||
      string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
 | 
			
		||||
    return string
 | 
			
		||||
 | 
			
		||||
  def extra_repr(self):
 | 
			
		||||
    return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
 | 
			
		||||
 | 
			
		||||
  def genotype(self):
 | 
			
		||||
    genotypes = []
 | 
			
		||||
    for i in range(1, self.max_nodes):
 | 
			
		||||
      xlist = []
 | 
			
		||||
      for j in range(i):
 | 
			
		||||
        node_str = '{:}<-{:}'.format(i, j)
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
          weights = self.arch_parameters[ self.edge2index[node_str] ]
 | 
			
		||||
          op_name = self.op_names[ weights.argmax().item() ]
 | 
			
		||||
        xlist.append((op_name, j))
 | 
			
		||||
      genotypes.append( tuple(xlist) )
 | 
			
		||||
    return Structure( genotypes )
 | 
			
		||||
 | 
			
		||||
  def dync_genotype(self, use_random=False):
 | 
			
		||||
    genotypes = []
 | 
			
		||||
    with torch.no_grad():
 | 
			
		||||
      alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1)
 | 
			
		||||
    for i in range(1, self.max_nodes):
 | 
			
		||||
      xlist = []
 | 
			
		||||
      for j in range(i):
 | 
			
		||||
        node_str = '{:}<-{:}'.format(i, j)
 | 
			
		||||
        if use_random:
 | 
			
		||||
          op_name  = random.choice(self.op_names)
 | 
			
		||||
        else:
 | 
			
		||||
          weights  = alphas_cpu[ self.edge2index[node_str] ]
 | 
			
		||||
          op_index = torch.multinomial(weights, 1).item()
 | 
			
		||||
          op_name  = self.op_names[ op_index ]
 | 
			
		||||
        xlist.append((op_name, j))
 | 
			
		||||
      genotypes.append( tuple(xlist) )
 | 
			
		||||
    return Structure( genotypes )
 | 
			
		||||
 | 
			
		||||
  def get_log_prob(self, arch):
 | 
			
		||||
    with torch.no_grad():
 | 
			
		||||
      logits = nn.functional.log_softmax(self.arch_parameters, dim=-1)
 | 
			
		||||
    select_logits = []
 | 
			
		||||
    for i, node_info in enumerate(arch.nodes):
 | 
			
		||||
      for op, xin in node_info:
 | 
			
		||||
        node_str = '{:}<-{:}'.format(i+1, xin)
 | 
			
		||||
        op_index = self.op_names.index(op)
 | 
			
		||||
        select_logits.append( logits[self.edge2index[node_str], op_index] )
 | 
			
		||||
    return sum(select_logits).item()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
  def return_topK(self, K):
 | 
			
		||||
    archs = Structure.gen_all(self.op_names, self.max_nodes, False)
 | 
			
		||||
    pairs = [(self.get_log_prob(arch), arch) for arch in archs]
 | 
			
		||||
    if K < 0 or K >= len(archs): K = len(archs)
 | 
			
		||||
    sorted_pairs = sorted(pairs, key=lambda x: -x[0])
 | 
			
		||||
    return_pairs = [sorted_pairs[_][1] for _ in range(K)]
 | 
			
		||||
    return return_pairs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
  def forward(self, inputs):
 | 
			
		||||
    alphas  = nn.functional.softmax(self.arch_parameters, dim=-1)
 | 
			
		||||
    with torch.no_grad():
 | 
			
		||||
      alphas_cpu = alphas.detach().cpu()
 | 
			
		||||
 | 
			
		||||
    feature = self.stem(inputs)
 | 
			
		||||
    for i, cell in enumerate(self.cells):
 | 
			
		||||
      if isinstance(cell, SearchCell):
 | 
			
		||||
        if self.mode == 'urs':
 | 
			
		||||
          feature = cell.forward_urs(feature)
 | 
			
		||||
        elif self.mode == 'select':
 | 
			
		||||
          feature = cell.forward_select(feature, alphas_cpu)
 | 
			
		||||
        elif self.mode == 'joint':
 | 
			
		||||
          feature = cell.forward_joint(feature, alphas)
 | 
			
		||||
        elif self.mode == 'dynamic':
 | 
			
		||||
          feature = cell.forward_dynamic(feature, self.dynamic_cell)
 | 
			
		||||
        else: raise ValueError('invalid mode={:}'.format(self.mode))
 | 
			
		||||
      else: feature = cell(feature)
 | 
			
		||||
 | 
			
		||||
    out = self.lastact(feature)
 | 
			
		||||
    out = self.global_pooling( out )
 | 
			
		||||
    out = out.view(out.size(0), -1)
 | 
			
		||||
    logits = self.classifier(out)
 | 
			
		||||
 | 
			
		||||
    return out, logits
 | 
			
		||||
		Reference in New Issue
	
	Block a user