| 
									
										
										
										
											2020-02-23 10:30:37 +11:00
										 |  |  | ##################################################### | 
					
						
							|  |  |  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | 
					
						
							|  |  |  | ##################################################### | 
					
						
							| 
									
										
										
										
											2020-01-15 00:52:06 +11:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-08 20:06:12 +11:00
										 |  |  | import torch.nn as nn | 
					
						
							|  |  |  | from copy import deepcopy | 
					
						
							|  |  |  | from ..cell_operations import OPS | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-01-15 00:52:06 +11:00
										 |  |  | # Cell for NAS-Bench-201 | 
					
						
							| 
									
										
										
										
											2019-11-08 20:06:12 +11:00
										 |  |  | class InferCell(nn.Module): | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   def __init__(self, genotype, C_in, C_out, stride): | 
					
						
							|  |  |  |     super(InferCell, self).__init__() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     self.layers  = nn.ModuleList() | 
					
						
							|  |  |  |     self.node_IN = [] | 
					
						
							|  |  |  |     self.node_IX = [] | 
					
						
							|  |  |  |     self.genotype = deepcopy(genotype) | 
					
						
							|  |  |  |     for i in range(1, len(genotype)): | 
					
						
							|  |  |  |       node_info = genotype[i-1] | 
					
						
							|  |  |  |       cur_index = [] | 
					
						
							|  |  |  |       cur_innod = [] | 
					
						
							|  |  |  |       for (op_name, op_in) in node_info: | 
					
						
							|  |  |  |         if op_in == 0: | 
					
						
							| 
									
										
										
										
											2019-12-23 13:32:20 +11:00
										 |  |  |           layer = OPS[op_name](C_in , C_out, stride, True, True) | 
					
						
							| 
									
										
										
										
											2019-11-08 20:06:12 +11:00
										 |  |  |         else: | 
					
						
							| 
									
										
										
										
											2019-12-23 13:32:20 +11:00
										 |  |  |           layer = OPS[op_name](C_out, C_out,      1, True, True) | 
					
						
							| 
									
										
										
										
											2019-11-08 20:06:12 +11:00
										 |  |  |         cur_index.append( len(self.layers) ) | 
					
						
							|  |  |  |         cur_innod.append( op_in ) | 
					
						
							|  |  |  |         self.layers.append( layer ) | 
					
						
							|  |  |  |       self.node_IX.append( cur_index ) | 
					
						
							|  |  |  |       self.node_IN.append( cur_innod ) | 
					
						
							|  |  |  |     self.nodes   = len(genotype) | 
					
						
							|  |  |  |     self.in_dim  = C_in | 
					
						
							|  |  |  |     self.out_dim = C_out | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   def extra_repr(self): | 
					
						
							|  |  |  |     string = 'info :: nodes={nodes}, inC={in_dim}, outC={out_dim}'.format(**self.__dict__) | 
					
						
							|  |  |  |     laystr = [] | 
					
						
							|  |  |  |     for i, (node_layers, node_innods) in enumerate(zip(self.node_IX,self.node_IN)): | 
					
						
							|  |  |  |       y = ['I{:}-L{:}'.format(_ii, _il) for _il, _ii in zip(node_layers, node_innods)] | 
					
						
							|  |  |  |       x = '{:}<-({:})'.format(i+1, ','.join(y)) | 
					
						
							|  |  |  |       laystr.append( x ) | 
					
						
							|  |  |  |     return string + ', [{:}]'.format( ' | '.join(laystr) ) + ', {:}'.format(self.genotype.tostr()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   def forward(self, inputs): | 
					
						
							|  |  |  |     nodes = [inputs] | 
					
						
							|  |  |  |     for i, (node_layers, node_innods) in enumerate(zip(self.node_IX,self.node_IN)): | 
					
						
							|  |  |  |       node_feature = sum( self.layers[_il](nodes[_ii]) for _il, _ii in zip(node_layers, node_innods) ) | 
					
						
							|  |  |  |       nodes.append( node_feature ) | 
					
						
							|  |  |  |     return nodes[-1] |