| 
									
										
										
										
											2019-11-15 17:15:07 +11:00
										 |  |  | ################################################## | 
					
						
							|  |  |  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | 
					
						
							|  |  |  | ################################################## | 
					
						
							| 
									
										
										
										
											2019-11-05 23:35:28 +11:00
										 |  |  | import math, random, torch | 
					
						
							|  |  |  | import warnings | 
					
						
							|  |  |  | import torch.nn as nn | 
					
						
							|  |  |  | import torch.nn.functional as F | 
					
						
							|  |  |  | from copy import deepcopy | 
					
						
							|  |  |  | from ..cell_operations import OPS | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-01-15 00:52:06 +11:00
										 |  |  | # This module is used for NAS-Bench-201, represents a small search space with a complete DAG | 
					
						
							|  |  |  | class NAS201SearchCell(nn.Module): | 
					
						
							| 
									
										
										
										
											2019-11-05 23:35:28 +11:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-12-23 13:32:20 +11:00
										 |  |  |   def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False, track_running_stats=True): | 
					
						
							| 
									
										
										
										
											2020-01-15 00:52:06 +11:00
										 |  |  |     super(NAS201SearchCell, self).__init__() | 
					
						
							| 
									
										
										
										
											2019-11-05 23:35:28 +11:00
										 |  |  | 
 | 
					
						
							|  |  |  |     self.op_names  = deepcopy(op_names) | 
					
						
							|  |  |  |     self.edges     = nn.ModuleDict() | 
					
						
							|  |  |  |     self.max_nodes = max_nodes | 
					
						
							|  |  |  |     self.in_dim    = C_in | 
					
						
							|  |  |  |     self.out_dim   = C_out | 
					
						
							|  |  |  |     for i in range(1, max_nodes): | 
					
						
							|  |  |  |       for j in range(i): | 
					
						
							|  |  |  |         node_str = '{:}<-{:}'.format(i, j) | 
					
						
							|  |  |  |         if j == 0: | 
					
						
							| 
									
										
										
										
											2019-12-23 13:32:20 +11:00
										 |  |  |           xlists = [OPS[op_name](C_in , C_out, stride, affine, track_running_stats) for op_name in op_names] | 
					
						
							| 
									
										
										
										
											2019-11-05 23:35:28 +11:00
										 |  |  |         else: | 
					
						
							| 
									
										
										
										
											2019-12-23 13:32:20 +11:00
										 |  |  |           xlists = [OPS[op_name](C_in , C_out,      1, affine, track_running_stats) for op_name in op_names] | 
					
						
							| 
									
										
										
										
											2019-11-05 23:35:28 +11:00
										 |  |  |         self.edges[ node_str ] = nn.ModuleList( xlists ) | 
					
						
							|  |  |  |     self.edge_keys  = sorted(list(self.edges.keys())) | 
					
						
							|  |  |  |     self.edge2index = {key:i for i, key in enumerate(self.edge_keys)} | 
					
						
							|  |  |  |     self.num_edges  = len(self.edges) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   def extra_repr(self): | 
					
						
							|  |  |  |     string = 'info :: {max_nodes} nodes, inC={in_dim}, outC={out_dim}'.format(**self.__dict__) | 
					
						
							|  |  |  |     return string | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   def forward(self, inputs, weightss): | 
					
						
							|  |  |  |     nodes = [inputs] | 
					
						
							|  |  |  |     for i in range(1, self.max_nodes): | 
					
						
							|  |  |  |       inter_nodes = [] | 
					
						
							|  |  |  |       for j in range(i): | 
					
						
							|  |  |  |         node_str = '{:}<-{:}'.format(i, j) | 
					
						
							|  |  |  |         weights  = weightss[ self.edge2index[node_str] ] | 
					
						
							|  |  |  |         inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) ) | 
					
						
							|  |  |  |       nodes.append( sum(inter_nodes) ) | 
					
						
							|  |  |  |     return nodes[-1] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   # GDAS | 
					
						
							| 
									
										
										
										
											2019-11-19 11:58:04 +11:00
										 |  |  |   def forward_gdas(self, inputs, hardwts, index): | 
					
						
							|  |  |  |     nodes   = [inputs] | 
					
						
							|  |  |  |     for i in range(1, self.max_nodes): | 
					
						
							|  |  |  |       inter_nodes = [] | 
					
						
							|  |  |  |       for j in range(i): | 
					
						
							|  |  |  |         node_str = '{:}<-{:}'.format(i, j) | 
					
						
							|  |  |  |         weights  = hardwts[ self.edge2index[node_str] ] | 
					
						
							|  |  |  |         argmaxs  = index[ self.edge2index[node_str] ].item() | 
					
						
							|  |  |  |         weigsum  = sum( weights[_ie] * edge(nodes[j]) if _ie == argmaxs else weights[_ie] for _ie, edge in enumerate(self.edges[node_str]) ) | 
					
						
							|  |  |  |         inter_nodes.append( weigsum ) | 
					
						
							|  |  |  |       nodes.append( sum(inter_nodes) ) | 
					
						
							| 
									
										
										
										
											2019-11-05 23:35:28 +11:00
										 |  |  |     return nodes[-1] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   # joint | 
					
						
							|  |  |  |   def forward_joint(self, inputs, weightss): | 
					
						
							|  |  |  |     nodes = [inputs] | 
					
						
							|  |  |  |     for i in range(1, self.max_nodes): | 
					
						
							|  |  |  |       inter_nodes = [] | 
					
						
							|  |  |  |       for j in range(i): | 
					
						
							|  |  |  |         node_str = '{:}<-{:}'.format(i, j) | 
					
						
							|  |  |  |         weights  = weightss[ self.edge2index[node_str] ] | 
					
						
							| 
									
										
										
										
											2019-11-09 16:50:13 +11:00
										 |  |  |         #aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) / weights.numel() | 
					
						
							|  |  |  |         aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) | 
					
						
							| 
									
										
										
										
											2019-11-05 23:35:28 +11:00
										 |  |  |         inter_nodes.append( aggregation ) | 
					
						
							|  |  |  |       nodes.append( sum(inter_nodes) ) | 
					
						
							|  |  |  |     return nodes[-1] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-01-12 01:42:17 +11:00
										 |  |  |   # uniform random sampling per iteration, SETN | 
					
						
							| 
									
										
										
										
											2019-11-05 23:35:28 +11:00
										 |  |  |   def forward_urs(self, inputs): | 
					
						
							|  |  |  |     nodes = [inputs] | 
					
						
							|  |  |  |     for i in range(1, self.max_nodes): | 
					
						
							|  |  |  |       while True: # to avoid select zero for all ops | 
					
						
							|  |  |  |         sops, has_non_zero = [], False | 
					
						
							|  |  |  |         for j in range(i): | 
					
						
							|  |  |  |           node_str   = '{:}<-{:}'.format(i, j) | 
					
						
							|  |  |  |           candidates = self.edges[node_str] | 
					
						
							|  |  |  |           select_op  = random.choice(candidates) | 
					
						
							|  |  |  |           sops.append( select_op ) | 
					
						
							| 
									
										
										
										
											2020-01-09 22:26:23 +11:00
										 |  |  |           if not hasattr(select_op, 'is_zero') or select_op.is_zero is False: has_non_zero=True | 
					
						
							| 
									
										
										
										
											2019-11-05 23:35:28 +11:00
										 |  |  |         if has_non_zero: break | 
					
						
							|  |  |  |       inter_nodes = [] | 
					
						
							|  |  |  |       for j, select_op in enumerate(sops): | 
					
						
							|  |  |  |         inter_nodes.append( select_op(nodes[j]) ) | 
					
						
							|  |  |  |       nodes.append( sum(inter_nodes) ) | 
					
						
							|  |  |  |     return nodes[-1] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   # select the argmax | 
					
						
							|  |  |  |   def forward_select(self, inputs, weightss): | 
					
						
							|  |  |  |     nodes = [inputs] | 
					
						
							|  |  |  |     for i in range(1, self.max_nodes): | 
					
						
							|  |  |  |       inter_nodes = [] | 
					
						
							|  |  |  |       for j in range(i): | 
					
						
							|  |  |  |         node_str = '{:}<-{:}'.format(i, j) | 
					
						
							|  |  |  |         weights  = weightss[ self.edge2index[node_str] ] | 
					
						
							|  |  |  |         inter_nodes.append( self.edges[node_str][ weights.argmax().item() ]( nodes[j] ) ) | 
					
						
							|  |  |  |         #inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) ) | 
					
						
							|  |  |  |       nodes.append( sum(inter_nodes) ) | 
					
						
							|  |  |  |     return nodes[-1] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   # forward with a specific structure | 
					
						
							|  |  |  |   def forward_dynamic(self, inputs, structure): | 
					
						
							|  |  |  |     nodes = [inputs] | 
					
						
							|  |  |  |     for i in range(1, self.max_nodes): | 
					
						
							|  |  |  |       cur_op_node = structure.nodes[i-1] | 
					
						
							|  |  |  |       inter_nodes = [] | 
					
						
							|  |  |  |       for op_name, j in cur_op_node: | 
					
						
							|  |  |  |         node_str = '{:}<-{:}'.format(i, j) | 
					
						
							|  |  |  |         op_index = self.op_names.index( op_name ) | 
					
						
							|  |  |  |         inter_nodes.append( self.edges[node_str][op_index]( nodes[j] ) ) | 
					
						
							|  |  |  |       nodes.append( sum(inter_nodes) ) | 
					
						
							|  |  |  |     return nodes[-1] | 
					
						
							| 
									
										
										
										
											2020-01-12 01:42:17 +11:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class MixedOp(nn.Module): | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   def __init__(self, space, C, stride, affine, track_running_stats): | 
					
						
							|  |  |  |     super(MixedOp, self).__init__() | 
					
						
							|  |  |  |     self._ops = nn.ModuleList() | 
					
						
							|  |  |  |     for primitive in space: | 
					
						
							|  |  |  |       op = OPS[primitive](C, C, stride, affine, track_running_stats) | 
					
						
							|  |  |  |       self._ops.append(op) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-01-17 22:14:47 +11:00
										 |  |  |   def forward_gdas(self, x, weights, index): | 
					
						
							| 
									
										
										
										
											2020-01-12 01:42:17 +11:00
										 |  |  |     return self._ops[index](x) * weights[index] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-01-17 22:14:47 +11:00
										 |  |  |   def forward_darts(self, x, weights): | 
					
						
							|  |  |  |     return sum(w * op(x) for w, op in zip(weights, self._ops)) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-01-12 01:42:17 +11:00
										 |  |  | 
 | 
					
						
							|  |  |  | # Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018 | 
					
						
							|  |  |  | class NASNetSearchCell(nn.Module): | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   def __init__(self, space, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev, affine, track_running_stats): | 
					
						
							|  |  |  |     super(NASNetSearchCell, self).__init__() | 
					
						
							|  |  |  |     self.reduction = reduction | 
					
						
							|  |  |  |     self.op_names  = deepcopy(space) | 
					
						
							|  |  |  |     if reduction_prev: self.preprocess0 = OPS['skip_connect'](C_prev_prev, C, 2, affine, track_running_stats) | 
					
						
							|  |  |  |     else             : self.preprocess0 = OPS['nor_conv_1x1'](C_prev_prev, C, 1, affine, track_running_stats) | 
					
						
							|  |  |  |     self.preprocess1 = OPS['nor_conv_1x1'](C_prev, C, 1, affine, track_running_stats) | 
					
						
							|  |  |  |     self._steps = steps | 
					
						
							|  |  |  |     self._multiplier = multiplier | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     self._ops = nn.ModuleList() | 
					
						
							|  |  |  |     self.edges     = nn.ModuleDict() | 
					
						
							|  |  |  |     for i in range(self._steps): | 
					
						
							|  |  |  |       for j in range(2+i): | 
					
						
							|  |  |  |         node_str = '{:}<-{:}'.format(i, j) | 
					
						
							|  |  |  |         stride = 2 if reduction and j < 2 else 1 | 
					
						
							|  |  |  |         op = MixedOp(space, C, stride, affine, track_running_stats) | 
					
						
							|  |  |  |         self.edges[ node_str ] = op | 
					
						
							|  |  |  |     self.edge_keys  = sorted(list(self.edges.keys())) | 
					
						
							|  |  |  |     self.edge2index = {key:i for i, key in enumerate(self.edge_keys)} | 
					
						
							|  |  |  |     self.num_edges  = len(self.edges) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   def forward_gdas(self, s0, s1, weightss, indexs): | 
					
						
							|  |  |  |     s0 = self.preprocess0(s0) | 
					
						
							|  |  |  |     s1 = self.preprocess1(s1) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     states = [s0, s1] | 
					
						
							|  |  |  |     for i in range(self._steps): | 
					
						
							|  |  |  |       clist = [] | 
					
						
							|  |  |  |       for j, h in enumerate(states): | 
					
						
							|  |  |  |         node_str = '{:}<-{:}'.format(i, j) | 
					
						
							|  |  |  |         op = self.edges[ node_str ] | 
					
						
							|  |  |  |         weights = weightss[ self.edge2index[node_str] ] | 
					
						
							|  |  |  |         index   = indexs[ self.edge2index[node_str] ].item() | 
					
						
							| 
									
										
										
										
											2020-01-17 22:14:47 +11:00
										 |  |  |         clist.append( op.forward_gdas(h, weights, index) ) | 
					
						
							|  |  |  |       states.append( sum(clist) ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return torch.cat(states[-self._multiplier:], dim=1) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   def forward_darts(self, s0, s1, weightss): | 
					
						
							|  |  |  |     s0 = self.preprocess0(s0) | 
					
						
							|  |  |  |     s1 = self.preprocess1(s1) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     states = [s0, s1] | 
					
						
							|  |  |  |     for i in range(self._steps): | 
					
						
							|  |  |  |       clist = [] | 
					
						
							|  |  |  |       for j, h in enumerate(states): | 
					
						
							|  |  |  |         node_str = '{:}<-{:}'.format(i, j) | 
					
						
							|  |  |  |         op = self.edges[ node_str ] | 
					
						
							|  |  |  |         weights = weightss[ self.edge2index[node_str] ] | 
					
						
							|  |  |  |         clist.append( op.forward_darts(h, weights) ) | 
					
						
							| 
									
										
										
										
											2020-01-12 01:42:17 +11:00
										 |  |  |       states.append( sum(clist) ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return torch.cat(states[-self._multiplier:], dim=1) |