xautodl/lib/models/cell_infers/cells.py

53 lines
1.8 KiB
Python
Raw Normal View History

2020-02-23 00:30:37 +01:00
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
2019-11-08 10:06:12 +01:00
import torch.nn as nn
from copy import deepcopy
from ..cell_operations import OPS
# Cell for NAS-Bench-201
2019-11-08 10:06:12 +01:00
class InferCell(nn.Module):
def __init__(self, genotype, C_in, C_out, stride):
super(InferCell, self).__init__()
self.layers = nn.ModuleList()
self.node_IN = []
self.node_IX = []
self.genotype = deepcopy(genotype)
for i in range(1, len(genotype)):
node_info = genotype[i-1]
cur_index = []
cur_innod = []
for (op_name, op_in) in node_info:
if op_in == 0:
layer = OPS[op_name](C_in , C_out, stride, True, True)
2019-11-08 10:06:12 +01:00
else:
layer = OPS[op_name](C_out, C_out, 1, True, True)
2019-11-08 10:06:12 +01:00
cur_index.append( len(self.layers) )
cur_innod.append( op_in )
self.layers.append( layer )
self.node_IX.append( cur_index )
self.node_IN.append( cur_innod )
self.nodes = len(genotype)
self.in_dim = C_in
self.out_dim = C_out
def extra_repr(self):
string = 'info :: nodes={nodes}, inC={in_dim}, outC={out_dim}'.format(**self.__dict__)
laystr = []
for i, (node_layers, node_innods) in enumerate(zip(self.node_IX,self.node_IN)):
y = ['I{:}-L{:}'.format(_ii, _il) for _il, _ii in zip(node_layers, node_innods)]
x = '{:}<-({:})'.format(i+1, ','.join(y))
laystr.append( x )
return string + ', [{:}]'.format( ' | '.join(laystr) ) + ', {:}'.format(self.genotype.tostr())
def forward(self, inputs):
nodes = [inputs]
for i, (node_layers, node_innods) in enumerate(zip(self.node_IX,self.node_IN)):
node_feature = sum( self.layers[_il](nodes[_ii]) for _il, _ii in zip(node_layers, node_innods) )
nodes.append( node_feature )
return nodes[-1]