| 
									
										
										
										
											2020-02-23 10:30:37 +11:00
										 |  |  | ##################################################### | 
					
						
							|  |  |  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | 
					
						
							|  |  |  | ##################################################### | 
					
						
							| 
									
										
										
										
											2020-01-15 00:52:06 +11:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-06 19:29:07 +11:00
										 |  |  | import torch | 
					
						
							| 
									
										
										
										
											2019-11-08 20:06:12 +11:00
										 |  |  | import torch.nn as nn | 
					
						
							|  |  |  | from copy import deepcopy | 
					
						
							| 
									
										
										
										
											2020-07-24 12:56:34 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | from models.cell_operations import OPS | 
					
						
							| 
									
										
										
										
											2019-11-08 20:06:12 +11:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-01-15 00:52:06 +11:00
										 |  |  | # Cell for NAS-Bench-201 | 
					
						
							| 
									
										
										
										
											2019-11-08 20:06:12 +11:00
										 |  |  | class InferCell(nn.Module): | 
					
						
							| 
									
										
										
										
											2021-05-12 16:28:05 +08:00
										 |  |  |     def __init__( | 
					
						
							|  |  |  |         self, genotype, C_in, C_out, stride, affine=True, track_running_stats=True | 
					
						
							|  |  |  |     ): | 
					
						
							|  |  |  |         super(InferCell, self).__init__() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.layers = nn.ModuleList() | 
					
						
							|  |  |  |         self.node_IN = [] | 
					
						
							|  |  |  |         self.node_IX = [] | 
					
						
							|  |  |  |         self.genotype = deepcopy(genotype) | 
					
						
							|  |  |  |         for i in range(1, len(genotype)): | 
					
						
							|  |  |  |             node_info = genotype[i - 1] | 
					
						
							|  |  |  |             cur_index = [] | 
					
						
							|  |  |  |             cur_innod = [] | 
					
						
							|  |  |  |             for (op_name, op_in) in node_info: | 
					
						
							|  |  |  |                 if op_in == 0: | 
					
						
							|  |  |  |                     layer = OPS[op_name]( | 
					
						
							|  |  |  |                         C_in, C_out, stride, affine, track_running_stats | 
					
						
							|  |  |  |                     ) | 
					
						
							|  |  |  |                 else: | 
					
						
							|  |  |  |                     layer = OPS[op_name](C_out, C_out, 1, affine, track_running_stats) | 
					
						
							|  |  |  |                 cur_index.append(len(self.layers)) | 
					
						
							|  |  |  |                 cur_innod.append(op_in) | 
					
						
							|  |  |  |                 self.layers.append(layer) | 
					
						
							|  |  |  |             self.node_IX.append(cur_index) | 
					
						
							|  |  |  |             self.node_IN.append(cur_innod) | 
					
						
							|  |  |  |         self.nodes = len(genotype) | 
					
						
							|  |  |  |         self.in_dim = C_in | 
					
						
							|  |  |  |         self.out_dim = C_out | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def extra_repr(self): | 
					
						
							|  |  |  |         string = "info :: nodes={nodes}, inC={in_dim}, outC={out_dim}".format( | 
					
						
							|  |  |  |             **self.__dict__ | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         laystr = [] | 
					
						
							|  |  |  |         for i, (node_layers, node_innods) in enumerate(zip(self.node_IX, self.node_IN)): | 
					
						
							|  |  |  |             y = [ | 
					
						
							|  |  |  |                 "I{:}-L{:}".format(_ii, _il) | 
					
						
							|  |  |  |                 for _il, _ii in zip(node_layers, node_innods) | 
					
						
							|  |  |  |             ] | 
					
						
							|  |  |  |             x = "{:}<-({:})".format(i + 1, ",".join(y)) | 
					
						
							|  |  |  |             laystr.append(x) | 
					
						
							|  |  |  |         return ( | 
					
						
							|  |  |  |             string | 
					
						
							|  |  |  |             + ", [{:}]".format(" | ".join(laystr)) | 
					
						
							|  |  |  |             + ", {:}".format(self.genotype.tostr()) | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def forward(self, inputs): | 
					
						
							|  |  |  |         nodes = [inputs] | 
					
						
							|  |  |  |         for i, (node_layers, node_innods) in enumerate(zip(self.node_IX, self.node_IN)): | 
					
						
							|  |  |  |             node_feature = sum( | 
					
						
							|  |  |  |                 self.layers[_il](nodes[_ii]) | 
					
						
							|  |  |  |                 for _il, _ii in zip(node_layers, node_innods) | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             nodes.append(node_feature) | 
					
						
							|  |  |  |         return nodes[-1] | 
					
						
							| 
									
										
										
										
											2020-03-06 19:29:07 +11:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018 | 
					
						
							|  |  |  | class NASNetInferCell(nn.Module): | 
					
						
							| 
									
										
										
										
											2021-05-12 16:28:05 +08:00
										 |  |  |     def __init__( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         genotype, | 
					
						
							|  |  |  |         C_prev_prev, | 
					
						
							|  |  |  |         C_prev, | 
					
						
							|  |  |  |         C, | 
					
						
							|  |  |  |         reduction, | 
					
						
							|  |  |  |         reduction_prev, | 
					
						
							|  |  |  |         affine, | 
					
						
							|  |  |  |         track_running_stats, | 
					
						
							|  |  |  |     ): | 
					
						
							|  |  |  |         super(NASNetInferCell, self).__init__() | 
					
						
							|  |  |  |         self.reduction = reduction | 
					
						
							|  |  |  |         if reduction_prev: | 
					
						
							|  |  |  |             self.preprocess0 = OPS["skip_connect"]( | 
					
						
							|  |  |  |                 C_prev_prev, C, 2, affine, track_running_stats | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             self.preprocess0 = OPS["nor_conv_1x1"]( | 
					
						
							|  |  |  |                 C_prev_prev, C, 1, affine, track_running_stats | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |         self.preprocess1 = OPS["nor_conv_1x1"]( | 
					
						
							|  |  |  |             C_prev, C, 1, affine, track_running_stats | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if not reduction: | 
					
						
							|  |  |  |             nodes, concats = genotype["normal"], genotype["normal_concat"] | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             nodes, concats = genotype["reduce"], genotype["reduce_concat"] | 
					
						
							|  |  |  |         self._multiplier = len(concats) | 
					
						
							|  |  |  |         self._concats = concats | 
					
						
							|  |  |  |         self._steps = len(nodes) | 
					
						
							|  |  |  |         self._nodes = nodes | 
					
						
							|  |  |  |         self.edges = nn.ModuleDict() | 
					
						
							|  |  |  |         for i, node in enumerate(nodes): | 
					
						
							|  |  |  |             for in_node in node: | 
					
						
							|  |  |  |                 name, j = in_node[0], in_node[1] | 
					
						
							|  |  |  |                 stride = 2 if reduction and j < 2 else 1 | 
					
						
							|  |  |  |                 node_str = "{:}<-{:}".format(i + 2, j) | 
					
						
							|  |  |  |                 self.edges[node_str] = OPS[name]( | 
					
						
							|  |  |  |                     C, C, stride, affine, track_running_stats | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # [TODO] to support drop_prob in this function.. | 
					
						
							|  |  |  |     def forward(self, s0, s1, unused_drop_prob): | 
					
						
							|  |  |  |         s0 = self.preprocess0(s0) | 
					
						
							|  |  |  |         s1 = self.preprocess1(s1) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         states = [s0, s1] | 
					
						
							|  |  |  |         for i, node in enumerate(self._nodes): | 
					
						
							|  |  |  |             clist = [] | 
					
						
							|  |  |  |             for in_node in node: | 
					
						
							|  |  |  |                 name, j = in_node[0], in_node[1] | 
					
						
							|  |  |  |                 node_str = "{:}<-{:}".format(i + 2, j) | 
					
						
							|  |  |  |                 op = self.edges[node_str] | 
					
						
							|  |  |  |                 clist.append(op(states[j])) | 
					
						
							|  |  |  |             states.append(sum(clist)) | 
					
						
							|  |  |  |         return torch.cat([states[x] for x in self._concats], dim=1) | 
					
						
							| 
									
										
										
										
											2020-03-06 19:29:07 +11:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class AuxiliaryHeadCIFAR(nn.Module): | 
					
						
							| 
									
										
										
										
											2021-05-12 16:28:05 +08:00
										 |  |  |     def __init__(self, C, num_classes): | 
					
						
							|  |  |  |         """assuming input size 8x8""" | 
					
						
							|  |  |  |         super(AuxiliaryHeadCIFAR, self).__init__() | 
					
						
							|  |  |  |         self.features = nn.Sequential( | 
					
						
							|  |  |  |             nn.ReLU(inplace=True), | 
					
						
							|  |  |  |             nn.AvgPool2d( | 
					
						
							|  |  |  |                 5, stride=3, padding=0, count_include_pad=False | 
					
						
							|  |  |  |             ),  # image size = 2 x 2 | 
					
						
							|  |  |  |             nn.Conv2d(C, 128, 1, bias=False), | 
					
						
							|  |  |  |             nn.BatchNorm2d(128), | 
					
						
							|  |  |  |             nn.ReLU(inplace=True), | 
					
						
							|  |  |  |             nn.Conv2d(128, 768, 2, bias=False), | 
					
						
							|  |  |  |             nn.BatchNorm2d(768), | 
					
						
							|  |  |  |             nn.ReLU(inplace=True), | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         self.classifier = nn.Linear(768, num_classes) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def forward(self, x): | 
					
						
							|  |  |  |         x = self.features(x) | 
					
						
							|  |  |  |         x = self.classifier(x.view(x.size(0), -1)) | 
					
						
							|  |  |  |         return x |