update
This commit is contained in:
		
							
								
								
									
										110
									
								
								nasbench201/DownsampledImageNet.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										110
									
								
								nasbench201/DownsampledImageNet.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,110 @@ | ||||
| import os, sys, hashlib, torch | ||||
| import numpy as np | ||||
| from PIL import Image | ||||
| import torch.utils.data as data | ||||
| import pickle | ||||
|  | ||||
|  | ||||
| def calculate_md5(fpath, chunk_size=1024 * 1024): | ||||
|   md5 = hashlib.md5() | ||||
|   with open(fpath, 'rb') as f: | ||||
|     for chunk in iter(lambda: f.read(chunk_size), b''): | ||||
|       md5.update(chunk) | ||||
|   return md5.hexdigest() | ||||
|  | ||||
|  | ||||
| def check_md5(fpath, md5, **kwargs): | ||||
|   return md5 == calculate_md5(fpath, **kwargs) | ||||
|  | ||||
|  | ||||
| def check_integrity(fpath, md5=None): | ||||
|   print(fpath) | ||||
|   if not os.path.isfile(fpath): return False | ||||
|   if md5 is None: return True | ||||
|   else          : return check_md5(fpath, md5) | ||||
|  | ||||
|  | ||||
| class ImageNet16(data.Dataset): | ||||
|   # http://image-net.org/download-images | ||||
|   # A Downsampled Variant of ImageNet as an Alternative to the CIFAR datasets | ||||
|   # https://arxiv.org/pdf/1707.08819.pdf | ||||
|    | ||||
|   train_list = [ | ||||
|         ['train_data_batch_1', '27846dcaa50de8e21a7d1a35f30f0e91'], | ||||
|         ['train_data_batch_2', 'c7254a054e0e795c69120a5727050e3f'], | ||||
|         ['train_data_batch_3', '4333d3df2e5ffb114b05d2ffc19b1e87'], | ||||
|         ['train_data_batch_4', '1620cdf193304f4a92677b695d70d10f'], | ||||
|         ['train_data_batch_5', '348b3c2fdbb3940c4e9e834affd3b18d'], | ||||
|         ['train_data_batch_6', '6e765307c242a1b3d7d5ef9139b48945'], | ||||
|         ['train_data_batch_7', '564926d8cbf8fc4818ba23d2faac7564'], | ||||
|         ['train_data_batch_8', 'f4755871f718ccb653440b9dd0ebac66'], | ||||
|         ['train_data_batch_9', 'bb6dd660c38c58552125b1a92f86b5d4'], | ||||
|         ['train_data_batch_10','8f03f34ac4b42271a294f91bf480f29b'], | ||||
|     ] | ||||
|   valid_list = [ | ||||
|         ['val_data', '3410e3017fdaefba8d5073aaa65e4bd6'], | ||||
|     ] | ||||
|  | ||||
|   def __init__(self, root, train, transform, use_num_of_class_only=None): | ||||
|     self.root      = root | ||||
|     self.transform = transform | ||||
|     self.train     = train  # training set or valid set | ||||
|     if not self._check_integrity(): raise RuntimeError('Dataset not found or corrupted.') | ||||
|  | ||||
|     if self.train: downloaded_list = self.train_list | ||||
|     else         : downloaded_list = self.valid_list | ||||
|     self.data    = [] | ||||
|     self.targets = [] | ||||
|    | ||||
|     # now load the picked numpy arrays | ||||
|     for i, (file_name, checksum) in enumerate(downloaded_list): | ||||
|       file_path = os.path.join(self.root, file_name) | ||||
|       #print ('Load {:}/{:02d}-th : {:}'.format(i, len(downloaded_list), file_path)) | ||||
|       with open(file_path, 'rb') as f: | ||||
|         if sys.version_info[0] == 2: | ||||
|           entry = pickle.load(f) | ||||
|         else: | ||||
|           entry = pickle.load(f, encoding='latin1') | ||||
|         self.data.append(entry['data']) | ||||
|         self.targets.extend(entry['labels']) | ||||
|     self.data = np.vstack(self.data).reshape(-1, 3, 16, 16) | ||||
|     self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWC | ||||
|     if use_num_of_class_only is not None: | ||||
|       assert isinstance(use_num_of_class_only, int) and use_num_of_class_only > 0 and use_num_of_class_only < 1000, 'invalid use_num_of_class_only : {:}'.format(use_num_of_class_only) | ||||
|       new_data, new_targets = [], [] | ||||
|       for I, L in zip(self.data, self.targets): | ||||
|         if 1 <= L <= use_num_of_class_only: | ||||
|           new_data.append( I ) | ||||
|           new_targets.append( L ) | ||||
|       self.data    = new_data | ||||
|       self.targets = new_targets | ||||
|     #    self.mean.append(entry['mean']) | ||||
|     #self.mean = np.vstack(self.mean).reshape(-1, 3, 16, 16) | ||||
|     #self.mean = np.mean(np.mean(np.mean(self.mean, axis=0), axis=1), axis=1) | ||||
|     #print ('Mean : {:}'.format(self.mean)) | ||||
|     #temp      = self.data - np.reshape(self.mean, (1, 1, 1, 3)) | ||||
|     #std_data  = np.std(temp, axis=0) | ||||
|     #std_data  = np.mean(np.mean(std_data, axis=0), axis=0) | ||||
|     #print ('Std  : {:}'.format(std_data)) | ||||
|  | ||||
|   def __getitem__(self, index): | ||||
|     img, target = self.data[index], self.targets[index] - 1 | ||||
|  | ||||
|     img = Image.fromarray(img) | ||||
|  | ||||
|     if self.transform is not None: | ||||
|       img = self.transform(img) | ||||
|  | ||||
|     return img, target | ||||
|  | ||||
|   def __len__(self): | ||||
|     return len(self.data) | ||||
|  | ||||
|   def _check_integrity(self): | ||||
|     root = self.root | ||||
|     for fentry in (self.train_list + self.valid_list): | ||||
|       filename, md5 = fentry[0], fentry[1] | ||||
|       fpath = os.path.join(root, filename) | ||||
|       if not check_integrity(fpath, md5): | ||||
|         return False | ||||
|     return True | ||||
							
								
								
									
										52
									
								
								nasbench201/architect_ig.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								nasbench201/architect_ig.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,52 @@ | ||||
| import torch | ||||
|  | ||||
|  | ||||
| class Architect(object): | ||||
|     def __init__(self, model, args): | ||||
|         self.network_momentum = args.momentum | ||||
|         self.network_weight_decay = args.weight_decay | ||||
|         self.model = model | ||||
|         self.optimizer = torch.optim.Adam(self.model.arch_parameters(), | ||||
|                                         lr=args.arch_learning_rate, betas=(0.5, 0.999), | ||||
|                                         weight_decay=args.arch_weight_decay) | ||||
|  | ||||
|         self._init_arch_parameters = [] | ||||
|         for alpha in self.model.arch_parameters(): | ||||
|             alpha_init = torch.zeros_like(alpha) | ||||
|             alpha_init.data.copy_(alpha) | ||||
|             self._init_arch_parameters.append(alpha_init) | ||||
|  | ||||
|         #### mode | ||||
|         if args.method in ['darts', 'darts-proj', 'sdarts', 'sdarts-proj']: | ||||
|             self.method = 'fo' # first order update | ||||
|         elif 'so' in args.method: | ||||
|             print('ERROR: PLEASE USE architect.py for second order darts') | ||||
|         elif args.method in ['blank', 'blank-proj']: | ||||
|             self.method = 'blank' | ||||
|         else: | ||||
|             print('ERROR: WRONG ARCH UPDATE METHOD', args.method); exit(0) | ||||
|  | ||||
|     def reset_arch_parameters(self): | ||||
|         for alpha, alpha_init in zip(self.model.arch_parameters(), self._init_arch_parameters): | ||||
|             alpha.data.copy_(alpha_init.data) | ||||
|  | ||||
|     def step(self, input_train, target_train, input_valid, target_valid, *args, **kwargs): | ||||
|         if self.method == 'fo': | ||||
|             shared = self._step_fo(input_train, target_train, input_valid, target_valid) | ||||
|         elif self.method == 'so': | ||||
|             raise NotImplementedError | ||||
|         elif self.method == 'blank': ## do not update alpha | ||||
|             shared = None | ||||
|  | ||||
|         return shared | ||||
|  | ||||
|     #### first order | ||||
|     def _step_fo(self, input_train, target_train, input_valid, target_valid): | ||||
|         loss = self.model._loss(input_valid, target_valid) | ||||
|         loss.backward() | ||||
|         self.optimizer.step() | ||||
|         return None | ||||
|  | ||||
|     #### darts 2nd order | ||||
|     def _step_darts_so(self, input_train, target_train, input_valid, target_valid, eta, model_optimizer): | ||||
|         raise NotImplementedError | ||||
							
								
								
									
										120
									
								
								nasbench201/cell_infers/cells.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										120
									
								
								nasbench201/cell_infers/cells.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,120 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
|  | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from copy import deepcopy | ||||
| from ..cell_operations import OPS | ||||
|  | ||||
|  | ||||
| # Cell for NAS-Bench-201 | ||||
| class InferCell(nn.Module): | ||||
|  | ||||
|   def __init__(self, genotype, C_in, C_out, stride): | ||||
|     super(InferCell, self).__init__() | ||||
|  | ||||
|     self.layers  = nn.ModuleList() | ||||
|     self.node_IN = [] | ||||
|     self.node_IX = [] | ||||
|     self.genotype = deepcopy(genotype) | ||||
|     for i in range(1, len(genotype)): | ||||
|       node_info = genotype[i-1] | ||||
|       cur_index = [] | ||||
|       cur_innod = [] | ||||
|       for (op_name, op_in) in node_info: | ||||
|         if op_in == 0: | ||||
|           layer = OPS[op_name](C_in , C_out, stride, True, True) | ||||
|         else: | ||||
|           layer = OPS[op_name](C_out, C_out,      1, True, True) | ||||
|         cur_index.append( len(self.layers) ) | ||||
|         cur_innod.append( op_in ) | ||||
|         self.layers.append( layer ) | ||||
|       self.node_IX.append( cur_index ) | ||||
|       self.node_IN.append( cur_innod ) | ||||
|     self.nodes   = len(genotype) | ||||
|     self.in_dim  = C_in | ||||
|     self.out_dim = C_out | ||||
|  | ||||
|   def extra_repr(self): | ||||
|     string = 'info :: nodes={nodes}, inC={in_dim}, outC={out_dim}'.format(**self.__dict__) | ||||
|     laystr = [] | ||||
|     for i, (node_layers, node_innods) in enumerate(zip(self.node_IX,self.node_IN)): | ||||
|       y = ['I{:}-L{:}'.format(_ii, _il) for _il, _ii in zip(node_layers, node_innods)] | ||||
|       x = '{:}<-({:})'.format(i+1, ','.join(y)) | ||||
|       laystr.append( x ) | ||||
|     return string + ', [{:}]'.format( ' | '.join(laystr) ) + ', {:}'.format(self.genotype.tostr()) | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     nodes = [inputs] | ||||
|     for i, (node_layers, node_innods) in enumerate(zip(self.node_IX,self.node_IN)): | ||||
|       node_feature = sum( self.layers[_il](nodes[_ii]) for _il, _ii in zip(node_layers, node_innods) ) | ||||
|       nodes.append( node_feature ) | ||||
|     return nodes[-1] | ||||
|  | ||||
|  | ||||
|  | ||||
| # Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018 | ||||
| class NASNetInferCell(nn.Module): | ||||
|  | ||||
|   def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev, affine, track_running_stats): | ||||
|     super(NASNetInferCell, self).__init__() | ||||
|     self.reduction = reduction | ||||
|     if reduction_prev: self.preprocess0 = OPS['skip_connect'](C_prev_prev, C, 2, affine, track_running_stats) | ||||
|     else             : self.preprocess0 = OPS['nor_conv_1x1'](C_prev_prev, C, 1, affine, track_running_stats) | ||||
|     self.preprocess1 = OPS['nor_conv_1x1'](C_prev, C, 1, affine, track_running_stats) | ||||
|  | ||||
|     if not reduction: | ||||
|       nodes, concats = genotype['normal'], genotype['normal_concat'] | ||||
|     else: | ||||
|       nodes, concats = genotype['reduce'], genotype['reduce_concat'] | ||||
|     self._multiplier = len(concats) | ||||
|     self._concats = concats | ||||
|     self._steps = len(nodes) | ||||
|     self._nodes = nodes | ||||
|     self.edges = nn.ModuleDict() | ||||
|     for i, node in enumerate(nodes): | ||||
|       for in_node in node: | ||||
|         name, j = in_node[0], in_node[1] | ||||
|         stride = 2 if reduction and j < 2 else 1 | ||||
|         node_str = '{:}<-{:}'.format(i+2, j) | ||||
|         self.edges[node_str] = OPS[name](C, C, stride, affine, track_running_stats) | ||||
|  | ||||
|   # [TODO] to support drop_prob in this function.. | ||||
|   def forward(self, s0, s1, unused_drop_prob): | ||||
|     s0 = self.preprocess0(s0) | ||||
|     s1 = self.preprocess1(s1) | ||||
|  | ||||
|     states = [s0, s1] | ||||
|     for i, node in enumerate(self._nodes): | ||||
|       clist = [] | ||||
|       for in_node in node: | ||||
|         name, j = in_node[0], in_node[1] | ||||
|         node_str = '{:}<-{:}'.format(i+2, j) | ||||
|         op = self.edges[ node_str ] | ||||
|         clist.append( op(states[j]) ) | ||||
|       states.append( sum(clist) ) | ||||
|     return torch.cat([states[x] for x in self._concats], dim=1) | ||||
|  | ||||
|  | ||||
| class AuxiliaryHeadCIFAR(nn.Module): | ||||
|  | ||||
|   def __init__(self, C, num_classes): | ||||
|     """assuming input size 8x8""" | ||||
|     super(AuxiliaryHeadCIFAR, self).__init__() | ||||
|     self.features = nn.Sequential( | ||||
|       nn.ReLU(inplace=True), | ||||
|       nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2 | ||||
|       nn.Conv2d(C, 128, 1, bias=False), | ||||
|       nn.BatchNorm2d(128), | ||||
|       nn.ReLU(inplace=True), | ||||
|       nn.Conv2d(128, 768, 2, bias=False), | ||||
|       nn.BatchNorm2d(768), | ||||
|       nn.ReLU(inplace=True) | ||||
|     ) | ||||
|     self.classifier = nn.Linear(768, num_classes) | ||||
|  | ||||
|   def forward(self, x): | ||||
|     x = self.features(x) | ||||
|     x = self.classifier(x.view(x.size(0),-1)) | ||||
|     return x | ||||
							
								
								
									
										82
									
								
								nasbench201/cell_infers/tiny_network.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										82
									
								
								nasbench201/cell_infers/tiny_network.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,82 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||
| ##################################################### | ||||
| import torch.nn as nn | ||||
| from ..cell_operations import ResNetBasicblock | ||||
| from .cells import InferCell | ||||
|  | ||||
|  | ||||
| # The macro structure for architectures in NAS-Bench-201 | ||||
| class TinyNetwork(nn.Module): | ||||
|  | ||||
|   def __init__(self, C, N, genotype, num_classes): | ||||
|     super(TinyNetwork, self).__init__() | ||||
|     self._C               = C | ||||
|     self._layerN          = N | ||||
|  | ||||
|     self.stem = nn.Sequential( | ||||
|                     nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), | ||||
|                     nn.BatchNorm2d(C)) | ||||
|    | ||||
|     layer_channels   = [C    ] * N + [C*2 ] + [C*2  ] * N + [C*4 ] + [C*4  ] * N     | ||||
|     layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N | ||||
|  | ||||
|     C_prev = C | ||||
|     self.cells = nn.ModuleList() | ||||
|     for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): | ||||
|       if reduction: | ||||
|         cell = ResNetBasicblock(C_prev, C_curr, 2, True) | ||||
|       else: | ||||
|         cell = InferCell(genotype, C_prev, C_curr, 1) | ||||
|       self.cells.append( cell ) | ||||
|       C_prev = cell.out_dim | ||||
|     self._Layer= len(self.cells) | ||||
|  | ||||
|     self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) | ||||
|     self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||
|     self.classifier = nn.Linear(C_prev, num_classes) | ||||
|      | ||||
|     self.requires_feature = True | ||||
|  | ||||
|   def get_message(self): | ||||
|     string = self.extra_repr() | ||||
|     for i, cell in enumerate(self.cells): | ||||
|       string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) | ||||
|     return string | ||||
|  | ||||
|   def extra_repr(self): | ||||
|     return ('{name}(C={_C}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     feature = self.stem(inputs) | ||||
|     for i, cell in enumerate(self.cells): | ||||
|       feature = cell(feature) | ||||
|  | ||||
|     out = self.lastact(feature) | ||||
|     out = self.global_pooling( out ) | ||||
|     out = out.view(out.size(0), -1) | ||||
|     logits = self.classifier(out) | ||||
|  | ||||
|     if self.requires_feature: | ||||
|       return logits, out | ||||
|     else: | ||||
|       return logits | ||||
|    | ||||
|   def _loss(self, input, target, return_logits=False): | ||||
|     logits, _ = self(input) | ||||
|     loss = self._criterion(logits, target) | ||||
|      | ||||
|     return (loss, logits) if return_logits else loss | ||||
|  | ||||
|   def step(self, input, target, args, shared=None, return_grad=False): | ||||
|     Lt, logit_t = self._loss(input, target, return_logits=True) | ||||
|     Lt.backward() | ||||
|     if args.grad_clip != 0:  | ||||
|       nn.utils.clip_grad_norm_(self.get_weights(), args.grad_clip) | ||||
|     self.optimizer.step() | ||||
|  | ||||
|     if return_grad: | ||||
|       grad = torch.nn.utils.parameters_to_vector([p.grad for p in self.get_weights()]) | ||||
|       return logit_t, Lt, grad | ||||
|     else: | ||||
|       return logit_t, Lt | ||||
							
								
								
									
										289
									
								
								nasbench201/cell_operations.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										289
									
								
								nasbench201/cell_operations.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,289 @@ | ||||
| import sys | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| sys.path.insert(0, '../') | ||||
| from Layers import layers | ||||
| __all__ = ['OPS', 'ResNetBasicblock', 'SearchSpaceNames'] | ||||
|  | ||||
| OPS = { | ||||
|   'noise'        : lambda C_in, C_out, stride, affine, track_running_stats: NoiseOp(stride, 0., 1.), # C_in, C_out not needed | ||||
|   'none'         : lambda C_in, C_out, stride, affine, track_running_stats: Zero(C_in, C_out, stride), | ||||
|   'avg_pool_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: POOLING(C_in, C_out, stride, 'avg', affine, track_running_stats), | ||||
|   'max_pool_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: POOLING(C_in, C_out, stride, 'max', affine, track_running_stats), | ||||
|   'nor_conv_7x7' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (7,7), (stride,stride), (3,3), (1,1), affine, track_running_stats), | ||||
|   'nor_conv_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine, track_running_stats), | ||||
|   'nor_conv_1x1' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (1,1), (stride,stride), (0,0), (1,1), affine, track_running_stats), | ||||
|   'dua_sepc_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine, track_running_stats), | ||||
|   'dua_sepc_5x5' : lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(C_in, C_out, (5,5), (stride,stride), (2,2), (1,1), affine, track_running_stats), | ||||
|   'dil_sepc_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: SepConv(C_in, C_out, (3,3), (stride,stride), (2,2), (2,2), affine, track_running_stats), | ||||
|   'dil_sepc_5x5' : lambda C_in, C_out, stride, affine, track_running_stats: SepConv(C_in, C_out, (5,5), (stride,stride), (4,4), (2,2), affine, track_running_stats), | ||||
|   'skip_connect' : lambda C_in, C_out, stride, affine, track_running_stats: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine, track_running_stats), | ||||
| } | ||||
|  | ||||
| CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3'] | ||||
| NAS_BENCH_201         = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3'] | ||||
| DARTS_SPACE           = ['none', 'skip_connect', 'dua_sepc_3x3', 'dua_sepc_5x5', 'dil_sepc_3x3', 'dil_sepc_5x5', 'avg_pool_3x3', 'max_pool_3x3'] | ||||
| #### wrc modified | ||||
| NAS_BENCH_201_SKIP    = ['none', 'skip_connect', 'nor_conv_1x1_skip', 'nor_conv_3x3_skip', 'avg_pool_3x3'] | ||||
| NAS_BENCH_201_SIMPLE  = ['skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3'] | ||||
| NAS_BENCH_201_S2      = ['skip_connect', 'nor_conv_3x3'] | ||||
| NAS_BENCH_201_S4      = ['noise', 'nor_conv_3x3'] | ||||
| NAS_BENCH_201_S10     = ['none', 'nor_conv_3x3'] | ||||
|  | ||||
| SearchSpaceNames = {'connect-nas'  : CONNECT_NAS_BENCHMARK, | ||||
|                     'nas-bench-201': NAS_BENCH_201, | ||||
|                     'nas-bench-201-simple': NAS_BENCH_201_SIMPLE, | ||||
|                     'nas-bench-201-s2': NAS_BENCH_201_S2, | ||||
|                     'nas-bench-201-s4': NAS_BENCH_201_S4, | ||||
|                     'nas-bench-201-s10': NAS_BENCH_201_S10, | ||||
|                     'darts'        : DARTS_SPACE} | ||||
|  | ||||
| class NoiseOp(nn.Module): | ||||
|     def __init__(self, stride, mean, std): | ||||
|         super(NoiseOp, self).__init__() | ||||
|         self.stride = stride | ||||
|         self.mean = mean | ||||
|         self.std = std | ||||
|  | ||||
|     def forward(self, x, block_input=False): | ||||
|       if block_input: | ||||
|         x = x * 0 | ||||
|       if self.stride != 1: | ||||
|         x_new = x[:,:,::self.stride,::self.stride] | ||||
|       else: | ||||
|         x_new = x | ||||
|       noise = x_new.data.new(x_new.size()).normal_(self.mean, self.std) | ||||
|       return noise | ||||
|  | ||||
| class ReLUConvBN(nn.Module): | ||||
|  | ||||
|   def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True): | ||||
|     super(ReLUConvBN, self).__init__() | ||||
|     self.op = nn.Sequential( | ||||
|       nn.ReLU(inplace=False), | ||||
|       layers.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False), | ||||
|       nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats) | ||||
|     ) | ||||
|  | ||||
|   def forward(self, x, block_input=False): | ||||
|     if block_input: | ||||
|       x = x * 0 | ||||
|     return self.op(x) | ||||
|  | ||||
|   def score(self): | ||||
|     score = 0  | ||||
|     for l in self.op: | ||||
|         if hasattr(l, 'score'): | ||||
|             score += torch.sum(l.score).cpu().numpy() | ||||
|     return score | ||||
|    | ||||
| #### wrc modified | ||||
| class ReLUConvBNSkip(nn.Module): | ||||
|  | ||||
|   def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True): | ||||
|     super(ReLUConvBNSkip, self).__init__() | ||||
|     self.op = nn.Sequential( | ||||
|       nn.ReLU(inplace=False), | ||||
|       layers.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False), | ||||
|       nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats) | ||||
|     ) | ||||
|  | ||||
|   def forward(self, x, block_input=False): | ||||
|     if block_input: | ||||
|       x = x * 0 | ||||
|     return self.op(x) + x | ||||
|    | ||||
|   def score(self): | ||||
|     score = 0  | ||||
|     for l in self.op: | ||||
|         if hasattr(l, 'score'): | ||||
|             score += torch.sum(l.score).cpu().numpy() | ||||
|     return score | ||||
| #### | ||||
|  | ||||
| class SepConv(nn.Module): | ||||
|      | ||||
|   def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True): | ||||
|     super(SepConv, self).__init__() | ||||
|     self.op = nn.Sequential( | ||||
|       nn.ReLU(inplace=False), | ||||
|       layers.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False), | ||||
|       layers.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), | ||||
|       nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats), | ||||
|       ) | ||||
|  | ||||
|   def forward(self, x, block_input=False): | ||||
|     if block_input: | ||||
|       x = x * 0 | ||||
|     return self.op(x) | ||||
|  | ||||
|   def score(self): | ||||
|     score = 0  | ||||
|     for l in self.op: | ||||
|         if hasattr(l, 'score'): | ||||
|             score += torch.sum(l.score).cpu().numpy() | ||||
|     return score | ||||
|  | ||||
|  | ||||
| class DualSepConv(nn.Module): | ||||
|      | ||||
|   def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True): | ||||
|     super(DualSepConv, self).__init__() | ||||
|     self.op_a = SepConv(C_in, C_in , kernel_size, stride, padding, dilation, affine, track_running_stats) | ||||
|     self.op_b = SepConv(C_in, C_out, kernel_size, 1, padding, dilation, affine, track_running_stats) | ||||
|  | ||||
|   def forward(self, x, block_input=False): | ||||
|     if block_input: | ||||
|       x = x * 0 | ||||
|     x = self.op_a(x) | ||||
|     x = self.op_b(x) | ||||
|     return x | ||||
|  | ||||
|   def score(self): | ||||
|     score = self.op_a.score() + self.op_b.score() | ||||
|     return score | ||||
|  | ||||
|  | ||||
| class ResNetBasicblock(nn.Module): | ||||
|  | ||||
|   def __init__(self, inplanes, planes, stride, affine=True): | ||||
|     super(ResNetBasicblock, self).__init__() | ||||
|     assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) | ||||
|     self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, 1, 1, affine) | ||||
|     self.conv_b = ReLUConvBN(  planes, planes, 3,      1, 1, 1, affine) | ||||
|     if stride == 2: | ||||
|       self.downsample = nn.Sequential( | ||||
|                            nn.AvgPool2d(kernel_size=2, stride=2, padding=0), | ||||
|                            nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False)) | ||||
|     elif inplanes != planes: | ||||
|       self.downsample = ReLUConvBN(inplanes, planes, 1, 1, 0, 1, affine) | ||||
|     else: | ||||
|       self.downsample = None | ||||
|     self.in_dim  = inplanes | ||||
|     self.out_dim = planes | ||||
|     self.stride  = stride | ||||
|     self.num_conv = 2 | ||||
|  | ||||
|   def extra_repr(self): | ||||
|     string = '{name}(inC={in_dim}, outC={out_dim}, stride={stride})'.format(name=self.__class__.__name__, **self.__dict__) | ||||
|     return string | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     basicblock = self.conv_a(inputs) | ||||
|     basicblock = self.conv_b(basicblock) | ||||
|  | ||||
|     if self.downsample is not None: | ||||
|       residual = self.downsample(inputs) | ||||
|     else: | ||||
|       residual = inputs | ||||
|     return residual + basicblock | ||||
|    | ||||
|   def score(self): | ||||
|     return self.conv_a.score() + self.conv_b.score() | ||||
|      | ||||
|  | ||||
|  | ||||
|  | ||||
| class POOLING(nn.Module): | ||||
|  | ||||
|   def __init__(self, C_in, C_out, stride, mode, affine=True, track_running_stats=True): | ||||
|     super(POOLING, self).__init__() | ||||
|     if C_in == C_out: | ||||
|       self.preprocess = None | ||||
|     else: | ||||
|       self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0, affine, track_running_stats) | ||||
|     if mode == 'avg'  : self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False) | ||||
|     elif mode == 'max': self.op = nn.MaxPool2d(3, stride=stride, padding=1) | ||||
|     else              : raise ValueError('Invalid mode={:} in POOLING'.format(mode)) | ||||
|  | ||||
|   def forward(self, inputs, block_input=False): | ||||
|     if block_input: | ||||
|       inputs = inputs * 0 | ||||
|     if self.preprocess: x = self.preprocess(inputs) | ||||
|     else              : x = inputs | ||||
|     return self.op(x) | ||||
|    | ||||
|   def score(self): | ||||
|     if self.preprocess : | ||||
|       return self.preprocess.score() | ||||
|     else: | ||||
|       return 0 | ||||
|  | ||||
|  | ||||
| class Identity(nn.Module): | ||||
|  | ||||
|   def __init__(self): | ||||
|     super(Identity, self).__init__() | ||||
|  | ||||
|   def forward(self, x, block_input=False): | ||||
|     if block_input: | ||||
|       x = x * 0 | ||||
|     return x | ||||
|  | ||||
|  | ||||
| class Zero(nn.Module): | ||||
|  | ||||
|   def __init__(self, C_in, C_out, stride): | ||||
|     super(Zero, self).__init__() | ||||
|     self.C_in   = C_in | ||||
|     self.C_out  = C_out | ||||
|     self.stride = stride | ||||
|     self.is_zero = True | ||||
|  | ||||
|   def forward(self, x, block_input=False): | ||||
|     if block_input: | ||||
|       x = x*0 | ||||
|     if self.C_in == self.C_out: | ||||
|       if self.stride == 1: return x.mul(0.) | ||||
|       else               : return x[:,:,::self.stride,::self.stride].mul(0.) | ||||
|     else: ## this is never called in nasbench201 | ||||
|       shape = list(x.shape) | ||||
|       shape[1] = self.C_out | ||||
|       zeros = x.new_zeros(shape, dtype=x.dtype, device=x.device) | ||||
|       return zeros | ||||
|  | ||||
|   def extra_repr(self): | ||||
|     return 'C_in={C_in}, C_out={C_out}, stride={stride}'.format(**self.__dict__) | ||||
|  | ||||
|  | ||||
| class FactorizedReduce(nn.Module): | ||||
|  | ||||
|   def __init__(self, C_in, C_out, stride, affine, track_running_stats): | ||||
|     super(FactorizedReduce, self).__init__() | ||||
|     self.stride = stride | ||||
|     self.C_in   = C_in | ||||
|     self.C_out  = C_out | ||||
|     self.relu   = nn.ReLU(inplace=False) | ||||
|     if stride == 2: | ||||
|       #assert C_out % 2 == 0, 'C_out : {:}'.format(C_out) | ||||
|       C_outs = [C_out // 2, C_out - C_out // 2] | ||||
|       self.convs = nn.ModuleList() | ||||
|       for i in range(2): | ||||
|         self.convs.append(layers.Conv2d(C_in, C_outs[i], 1, stride=stride, padding=0, bias=False) ) | ||||
|       self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0) | ||||
|     elif stride == 1: | ||||
|       self.conv = layers.Conv2d(C_in, C_out, 1, stride=stride, padding=0, bias=False) | ||||
|     else: | ||||
|       raise ValueError('Invalid stride : {:}'.format(stride)) | ||||
|     self.bn = nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats) | ||||
|  | ||||
|   def forward(self, x, block_input=False): | ||||
|     if block_input: | ||||
|       x = x * 0 | ||||
|     if self.stride == 2: | ||||
|       x = self.relu(x) | ||||
|       y = self.pad(x) | ||||
|       out = torch.cat([self.convs[0](x), self.convs[1](y[:,:,1:,1:])], dim=1) | ||||
|     else: | ||||
|       out = self.conv(x) | ||||
|     out = self.bn(out) | ||||
|     return out | ||||
|  | ||||
|   def extra_repr(self): | ||||
|     return 'C_in={C_in}, C_out={C_out}, stride={stride}'.format(**self.__dict__) | ||||
|  | ||||
|   def score(self): | ||||
|     if self.stride == 1: | ||||
|       return self.conv.score() | ||||
|     else: | ||||
|       return self.convs[0].score()+self.convs[1].score() | ||||
							
								
								
									
										194
									
								
								nasbench201/genotypes.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										194
									
								
								nasbench201/genotypes.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,194 @@ | ||||
| from copy import deepcopy | ||||
|  | ||||
|  | ||||
| def get_combination(space, num): | ||||
|   combs = [] | ||||
|   for i in range(num): | ||||
|     if i == 0: | ||||
|       for func in space: | ||||
|         combs.append( [(func, i)] ) | ||||
|     else: | ||||
|       new_combs = [] | ||||
|       for string in combs: | ||||
|         for func in space: | ||||
|           xstring = string + [(func, i)] | ||||
|           new_combs.append( xstring ) | ||||
|       combs = new_combs | ||||
|   return combs | ||||
|    | ||||
|  | ||||
| class Structure: | ||||
|  | ||||
|   def __init__(self, genotype): | ||||
|     assert isinstance(genotype, list) or isinstance(genotype, tuple), 'invalid class of genotype : {:}'.format(type(genotype)) | ||||
|     self.node_num = len(genotype) + 1 | ||||
|     self.nodes    = [] | ||||
|     self.node_N   = [] | ||||
|     for idx, node_info in enumerate(genotype): | ||||
|       assert isinstance(node_info, list) or isinstance(node_info, tuple), 'invalid class of node_info : {:}'.format(type(node_info)) | ||||
|       assert len(node_info) >= 1, 'invalid length : {:}'.format(len(node_info)) | ||||
|       for node_in in node_info: | ||||
|         assert isinstance(node_in, list) or isinstance(node_in, tuple), 'invalid class of in-node : {:}'.format(type(node_in)) | ||||
|         assert len(node_in) == 2 and node_in[1] <= idx, 'invalid in-node : {:}'.format(node_in) | ||||
|       self.node_N.append( len(node_info) ) | ||||
|       self.nodes.append( tuple(deepcopy(node_info)) ) | ||||
|  | ||||
|   def tolist(self, remove_str): | ||||
|     # convert this class to the list, if remove_str is 'none', then remove the 'none' operation. | ||||
|     # note that we re-order the input node in this function | ||||
|     # return the-genotype-list and success [if unsuccess, it is not a connectivity] | ||||
|     genotypes = [] | ||||
|     for node_info in self.nodes: | ||||
|       node_info = list( node_info ) | ||||
|       node_info = sorted(node_info, key=lambda x: (x[1], x[0])) | ||||
|       node_info = tuple(filter(lambda x: x[0] != remove_str, node_info)) | ||||
|       if len(node_info) == 0: return None, False | ||||
|       genotypes.append( node_info ) | ||||
|     return genotypes, True | ||||
|  | ||||
|   def node(self, index): | ||||
|     assert index > 0 and index <= len(self), 'invalid index={:} < {:}'.format(index, len(self)) | ||||
|     return self.nodes[index] | ||||
|  | ||||
|   def tostr(self): | ||||
|     strings = [] | ||||
|     for node_info in self.nodes: | ||||
|       string = '|'.join([x[0]+'~{:}'.format(x[1]) for x in node_info]) | ||||
|       string = '|{:}|'.format(string) | ||||
|       strings.append( string ) | ||||
|     return '+'.join(strings) | ||||
|  | ||||
|   def check_valid(self): | ||||
|     nodes = {0: True} | ||||
|     for i, node_info in enumerate(self.nodes): | ||||
|       sums = [] | ||||
|       for op, xin in node_info: | ||||
|         if op == 'none' or nodes[xin] is False: x = False | ||||
|         else: x = True | ||||
|         sums.append( x ) | ||||
|       nodes[i+1] = sum(sums) > 0 | ||||
|     return nodes[len(self.nodes)] | ||||
|  | ||||
|   def to_unique_str(self, consider_zero=False): | ||||
|     # this is used to identify the isomorphic cell, which rerquires the prior knowledge of operation | ||||
|     # two operations are special, i.e., none and skip_connect | ||||
|     nodes = {0: '0'} | ||||
|     for i_node, node_info in enumerate(self.nodes): | ||||
|       cur_node = [] | ||||
|       for op, xin in node_info: | ||||
|         if consider_zero is None: | ||||
|           x = '('+nodes[xin]+')' + '@{:}'.format(op) | ||||
|         elif consider_zero: | ||||
|           if op == 'none' or nodes[xin] == '#': x = '#' # zero | ||||
|           elif op == 'skip_connect': x = nodes[xin] | ||||
|           else: x = '('+nodes[xin]+')' + '@{:}'.format(op) | ||||
|         else: | ||||
|           if op == 'skip_connect': x = nodes[xin] | ||||
|           else: x = '('+nodes[xin]+')' + '@{:}'.format(op) | ||||
|         cur_node.append(x) | ||||
|       nodes[i_node+1] = '+'.join( sorted(cur_node) ) | ||||
|     return nodes[ len(self.nodes) ] | ||||
|  | ||||
|   def check_valid_op(self, op_names): | ||||
|     for node_info in self.nodes: | ||||
|       for inode_edge in node_info: | ||||
|         #assert inode_edge[0] in op_names, 'invalid op-name : {:}'.format(inode_edge[0]) | ||||
|         if inode_edge[0] not in op_names: return False | ||||
|     return True | ||||
|  | ||||
|   def __repr__(self): | ||||
|     return ('{name}({node_num} nodes with {node_info})'.format(name=self.__class__.__name__, node_info=self.tostr(), **self.__dict__)) | ||||
|  | ||||
|   def __len__(self): | ||||
|     return len(self.nodes) + 1 | ||||
|  | ||||
|   def __getitem__(self, index): | ||||
|     return self.nodes[index] | ||||
|  | ||||
|   @staticmethod | ||||
|   def str2structure(xstr): | ||||
|     assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr)) | ||||
|     nodestrs = xstr.split('+') | ||||
|     genotypes = [] | ||||
|     for i, node_str in enumerate(nodestrs): | ||||
|       inputs = list(filter(lambda x: x != '', node_str.split('|'))) | ||||
|       for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput) | ||||
|       inputs = ( xi.split('~') for xi in inputs ) | ||||
|       input_infos = tuple( (op, int(IDX)) for (op, IDX) in inputs) | ||||
|       genotypes.append( input_infos ) | ||||
|     return Structure( genotypes ) | ||||
|  | ||||
|   @staticmethod | ||||
|   def str2fullstructure(xstr, default_name='none'): | ||||
|     assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr)) | ||||
|     nodestrs = xstr.split('+') | ||||
|     genotypes = [] | ||||
|     for i, node_str in enumerate(nodestrs): | ||||
|       inputs = list(filter(lambda x: x != '', node_str.split('|'))) | ||||
|       for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput) | ||||
|       inputs = ( xi.split('~') for xi in inputs ) | ||||
|       input_infos = list( (op, int(IDX)) for (op, IDX) in inputs) | ||||
|       all_in_nodes= list(x[1] for x in input_infos) | ||||
|       for j in range(i): | ||||
|         if j not in all_in_nodes: input_infos.append((default_name, j)) | ||||
|       node_info = sorted(input_infos, key=lambda x: (x[1], x[0])) | ||||
|       genotypes.append( tuple(node_info) ) | ||||
|     return Structure( genotypes ) | ||||
|  | ||||
|   @staticmethod | ||||
|   def gen_all(search_space, num, return_ori): | ||||
|     assert isinstance(search_space, list) or isinstance(search_space, tuple), 'invalid class of search-space : {:}'.format(type(search_space)) | ||||
|     assert num >= 2, 'There should be at least two nodes in a neural cell instead of {:}'.format(num) | ||||
|     all_archs = get_combination(search_space, 1) | ||||
|     for i, arch in enumerate(all_archs): | ||||
|       all_archs[i] = [ tuple(arch) ] | ||||
|    | ||||
|     for inode in range(2, num): | ||||
|       cur_nodes = get_combination(search_space, inode) | ||||
|       new_all_archs = [] | ||||
|       for previous_arch in all_archs: | ||||
|         for cur_node in cur_nodes: | ||||
|           new_all_archs.append( previous_arch + [tuple(cur_node)] ) | ||||
|       all_archs = new_all_archs | ||||
|     if return_ori: | ||||
|       return all_archs | ||||
|     else: | ||||
|       return [Structure(x) for x in all_archs] | ||||
|  | ||||
|  | ||||
|  | ||||
| ResNet_CODE = Structure( | ||||
|   [(('nor_conv_3x3', 0), ), # node-1  | ||||
|    (('nor_conv_3x3', 1), ), # node-2 | ||||
|    (('skip_connect', 0), ('skip_connect', 2))] # node-3 | ||||
|   ) | ||||
|  | ||||
| AllConv3x3_CODE = Structure( | ||||
|   [(('nor_conv_3x3', 0), ), # node-1  | ||||
|    (('nor_conv_3x3', 0), ('nor_conv_3x3', 1)), # node-2 | ||||
|    (('nor_conv_3x3', 0), ('nor_conv_3x3', 1), ('nor_conv_3x3', 2))] # node-3 | ||||
|   ) | ||||
|  | ||||
| AllFull_CODE = Structure( | ||||
|   [(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0)), # node-1  | ||||
|    (('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0), ('skip_connect', 1), ('nor_conv_1x1', 1), ('nor_conv_3x3', 1), ('avg_pool_3x3', 1)), # node-2 | ||||
|    (('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0), ('skip_connect', 1), ('nor_conv_1x1', 1), ('nor_conv_3x3', 1), ('avg_pool_3x3', 1), ('skip_connect', 2), ('nor_conv_1x1', 2), ('nor_conv_3x3', 2), ('avg_pool_3x3', 2))] # node-3 | ||||
|   ) | ||||
|  | ||||
| AllConv1x1_CODE = Structure( | ||||
|   [(('nor_conv_1x1', 0), ), # node-1  | ||||
|    (('nor_conv_1x1', 0), ('nor_conv_1x1', 1)), # node-2 | ||||
|    (('nor_conv_1x1', 0), ('nor_conv_1x1', 1), ('nor_conv_1x1', 2))] # node-3 | ||||
|   ) | ||||
|  | ||||
| AllIdentity_CODE = Structure( | ||||
|   [(('skip_connect', 0), ), # node-1  | ||||
|    (('skip_connect', 0), ('skip_connect', 1)), # node-2 | ||||
|    (('skip_connect', 0), ('skip_connect', 1), ('skip_connect', 2))] # node-3 | ||||
|   ) | ||||
|  | ||||
| architectures = {'resnet'  : ResNet_CODE, | ||||
|                  'all_c3x3': AllConv3x3_CODE, | ||||
|                  'all_c1x1': AllConv1x1_CODE, | ||||
|                  'all_idnt': AllIdentity_CODE, | ||||
|                  'all_full': AllFull_CODE} | ||||
							
								
								
									
										619
									
								
								nasbench201/init_projection.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										619
									
								
								nasbench201/init_projection.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,619 @@ | ||||
| import os | ||||
| import sys | ||||
| import numpy as np | ||||
| import torch | ||||
| import torch.nn.functional as f | ||||
| sys.path.insert(0, '../') | ||||
| import nasbench201.utils as ig_utils | ||||
| import logging | ||||
| import torch.utils | ||||
| import copy | ||||
| import scipy.stats as ss | ||||
| from collections import OrderedDict | ||||
| from foresight.pruners import * | ||||
| from op_score import Jocab_Score, get_ntk_n | ||||
| import gc | ||||
| from nasbench201.linear_region import Linear_Region_Collector | ||||
|  | ||||
| torch.set_printoptions(precision=4, sci_mode=False) | ||||
| np.set_printoptions(precision=4, suppress=True) | ||||
|  | ||||
| # global-edge-iter: similar toglobal-op-iterbut iteratively selects edge e from E based on the average score of all operations on each edge | ||||
| def global_op_greedy_pt_project(proj_queue, model, args):  | ||||
|     def project(model, args): | ||||
|         ## macros | ||||
|         num_edge, num_op = model.num_edge, model.num_op | ||||
|  | ||||
|         ##get remain eid numbers  | ||||
|         remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0] | ||||
|         compare = lambda x, y : x < y | ||||
|  | ||||
|         crit_extrema = None | ||||
|         best_eid = None | ||||
|         input, target = next(iter(proj_queue)) | ||||
|         for eid in remain_eids: | ||||
|             for opid in range(num_op): | ||||
|                 # projection | ||||
|                 weights = model.get_projected_weights() | ||||
|                 proj_mask = torch.ones_like(weights[eid]) | ||||
|                 proj_mask[opid] = 0 | ||||
|                 weights[eid] = weights[eid] * proj_mask | ||||
|  | ||||
|                 ## proj evaluation | ||||
|                 if args.proj_crit == 'jacob': | ||||
|                     valid_stats = Jocab_Score(model, input, target, weights=weights) | ||||
|                     crit = valid_stats | ||||
|  | ||||
|                 if crit_extrema is None or compare(crit, crit_extrema): | ||||
|                     crit_extrema = crit | ||||
|                     best_opid = opid | ||||
|                     best_eid = eid | ||||
|  | ||||
|         logging.info('best opid %d', best_opid) | ||||
|         return best_eid, best_opid | ||||
|  | ||||
|     tune_epochs = model.arch_parameters()[0].shape[0] | ||||
|  | ||||
|     for epoch in range(tune_epochs): | ||||
|         logging.info('epoch %d', epoch)  | ||||
|         logging.info('project') | ||||
|         selected_eid, best_opid = project(model, args) | ||||
|         model.project_op(selected_eid, best_opid) | ||||
|  | ||||
|     return | ||||
|  | ||||
| # global-edge-iter: similar toglobal-op-oncebut uses the average score of operations on edges to obtain the edge discretization order | ||||
| def global_edge_greedy_pt_project(proj_queue, model, args): | ||||
|     def select_eid(model, args): | ||||
|         ## macros | ||||
|         num_edge, num_op = model.num_edge, model.num_op | ||||
|  | ||||
|         ##get remain eid numbers  | ||||
|         remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0] | ||||
|         compare = lambda x, y : x < y | ||||
|  | ||||
|         crit_extrema = None | ||||
|         best_eid = None | ||||
|         input, target = next(iter(proj_queue)) | ||||
|         for eid in remain_eids: | ||||
|             eid_score = [] | ||||
|             for opid in range(num_op): | ||||
|                 # projection | ||||
|                 weights = model.get_projected_weights() | ||||
|                 proj_mask = torch.ones_like(weights[eid]) | ||||
|                 proj_mask[opid] = 0 | ||||
|                 weights[eid] = weights[eid] * proj_mask | ||||
|  | ||||
|                 ## proj evaluation | ||||
|                 if args.proj_crit == 'jacob': | ||||
|                     valid_stats = Jocab_Score(model, input,  target, weights=weights) | ||||
|                     crit = valid_stats | ||||
|                 eid_score.append(crit) | ||||
|             eid_score = np.mean(eid_score) | ||||
|  | ||||
|             if crit_extrema is None or compare(eid_score, crit_extrema): | ||||
|                 crit_extrema = eid_score | ||||
|                 best_eid = eid | ||||
|         return best_eid | ||||
|      | ||||
|     def project(model, args, selected_eid): | ||||
|         ## macros | ||||
|         num_edge, num_op = model.num_edge, model.num_op | ||||
|  | ||||
|         ## select the best operation | ||||
|         if args.proj_crit == 'jacob': | ||||
|             crit_idx = 3 | ||||
|             compare = lambda x, y: x < y | ||||
|         else: | ||||
|             crit_idx = 4 | ||||
|             compare = lambda x, y: x < y | ||||
|          | ||||
|         best_opid = 0 | ||||
|         crit_list = [] | ||||
|         op_ids = [] | ||||
|         input, target = next(iter(proj_queue)) | ||||
|         for opid in range(num_op): | ||||
|             ## projection | ||||
|             weights = model.get_projected_weights() | ||||
|             proj_mask = torch.ones_like(weights[selected_eid]) | ||||
|             proj_mask[opid] = 0 | ||||
|             weights[selected_eid] = weights[selected_eid] * proj_mask | ||||
|  | ||||
|             ## proj evaluation | ||||
|             if args.proj_crit == 'jacob': | ||||
|                 valid_stats = Jocab_Score(model, input,  target, weights=weights) | ||||
|                 crit = valid_stats | ||||
|             | ||||
|             crit_list.append(crit) | ||||
|             op_ids.append(opid) | ||||
|              | ||||
|         best_opid = op_ids[np.nanargmin(crit_list)] | ||||
|  | ||||
|         logging.info('best opid %d', best_opid) | ||||
|         logging.info(crit_list) | ||||
|         return selected_eid, best_opid | ||||
|  | ||||
|     num_edges = model.arch_parameters()[0].shape[0] | ||||
|  | ||||
|     for epoch in range(num_edges): | ||||
|         logging.info('epoch %d', epoch) | ||||
|          | ||||
|         logging.info('project') | ||||
|         selected_eid = select_eid(model, args) | ||||
|         selected_eid, best_opid = project(model, args, selected_eid) | ||||
|         model.project_op(selected_eid, best_opid) | ||||
|     return | ||||
|  | ||||
| # global-op-once: only evaluates S(A−(e,o)) for all operations once to obtain a ranking order of the operations, and discretizes the edgesEaccording to this order | ||||
| def global_op_once_pt_project(proj_queue, model, args): | ||||
|     def order(model, args): | ||||
|         ## macros | ||||
|         num_edge, num_op = model.num_edge, model.num_op | ||||
|  | ||||
|         ##get remain eid numbers  | ||||
|         remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0] | ||||
|         compare = lambda x, y : x < y | ||||
|  | ||||
|         edge_score = OrderedDict() | ||||
|         input, target = next(iter(proj_queue)) | ||||
|         for eid in remain_eids:        | ||||
|             crit_list = [] | ||||
|             for opid in range(num_op): | ||||
|                 # projection | ||||
|                 weights = model.get_projected_weights() | ||||
|                 proj_mask = torch.ones_like(weights[eid]) | ||||
|                 proj_mask[opid] = 0 | ||||
|                 weights[eid] = weights[eid] * proj_mask | ||||
|  | ||||
|                 ## proj evaluation | ||||
|                 if args.proj_crit == 'jacob': | ||||
|                     valid_stats = Jocab_Score(model, input,  target, weights=weights) | ||||
|                     crit = valid_stats | ||||
|  | ||||
|                 crit_list.append(crit) | ||||
|             edge_score[eid] = np.nanargmin(crit_list) | ||||
|         return edge_score | ||||
|  | ||||
|     def project(model, args, selected_eid): | ||||
|         ## macros | ||||
|         num_edge, num_op = model.num_edge, model.num_op | ||||
|         ## select the best operation | ||||
|         if args.proj_crit == 'jacob': | ||||
|             crit_idx = 3 | ||||
|             compare = lambda x, y: x < y | ||||
|         else: | ||||
|             crit_idx = 4 | ||||
|             compare = lambda x, y: x < y | ||||
|          | ||||
|         best_opid = 0 | ||||
|         crit_list = [] | ||||
|         op_ids = [] | ||||
|         input, target = next(iter(proj_queue)) | ||||
|         for opid in range(num_op): | ||||
|             ## projection | ||||
|             weights = model.get_projected_weights() | ||||
|             proj_mask = torch.ones_like(weights[selected_eid]) | ||||
|             proj_mask[opid] = 0 | ||||
|             weights[selected_eid] = weights[selected_eid] * proj_mask | ||||
|  | ||||
|             ## proj evaluation | ||||
|             if args.proj_crit == 'jacob': | ||||
|                 crit = Jocab_Score(model, input,  target, weights=weights) | ||||
|             crit_list.append(crit) | ||||
|             op_ids.append(opid) | ||||
|              | ||||
|         best_opid = op_ids[np.nanargmin(crit_list)] | ||||
|  | ||||
|         logging.info('best opid %d', best_opid) | ||||
|         logging.info(crit_list) | ||||
|         return selected_eid, best_opid | ||||
|      | ||||
|     num_edges = model.arch_parameters()[0].shape[0] | ||||
|  | ||||
|     eid_order = order(model, args) | ||||
|     for epoch in range(num_edges): | ||||
|         logging.info('epoch %d', epoch) | ||||
|         logging.info('project') | ||||
|         selected_eid, _ = eid_order.popitem() | ||||
|         selected_eid, best_opid = project(model, args, selected_eid) | ||||
|         model.project_op(selected_eid, best_opid) | ||||
|  | ||||
|     return | ||||
|  | ||||
| # global-edge-once: similar toglobal-op-oncebut uses the average score of operations on dges to obtain the edge discretization order | ||||
| def global_edge_once_pt_project(proj_queue, model, args): | ||||
|     def order(model, args): | ||||
|         ## macros | ||||
|         num_edge, num_op = model.num_edge, model.num_op | ||||
|  | ||||
|         ##get remain eid numbers  | ||||
|         remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0] | ||||
|         compare = lambda x, y : x < y | ||||
|  | ||||
|         edge_score = OrderedDict() | ||||
|         crit_extrema = None | ||||
|         best_eid = None | ||||
|         input, target = next(iter(proj_queue)) | ||||
|         for eid in remain_eids:        | ||||
|             crit_list = [] | ||||
|             for opid in range(num_op): | ||||
|                 # projection | ||||
|                 weights = model.get_projected_weights() | ||||
|                 proj_mask = torch.ones_like(weights[eid]) | ||||
|                 proj_mask[opid] = 0 | ||||
|                 weights[eid] = weights[eid] * proj_mask | ||||
|  | ||||
|                 ## proj evaluation | ||||
|                 if args.proj_crit == 'jacob': | ||||
|                     crit = Jocab_Score(model, input,  target, weights=weights) | ||||
|  | ||||
|                 crit_list.append(crit) | ||||
|             edge_score[eid] = np.mean(crit_list) | ||||
|         return edge_score | ||||
|  | ||||
|     def project(model, args, selected_eid): | ||||
|         ## macros | ||||
|         num_edge, num_op = model.num_edge, model.num_op | ||||
|         ## select the best operation | ||||
|         if args.proj_crit == 'jacob': | ||||
|             crit_idx = 3 | ||||
|             compare = lambda x, y: x < y | ||||
|         else: | ||||
|             crit_idx = 4 | ||||
|             compare = lambda x, y: x < y | ||||
|          | ||||
|         best_opid = 0 | ||||
|         crit_extrema = None | ||||
|         crit_list = [] | ||||
|         op_ids = [] | ||||
|         input, target = next(iter(proj_queue)) | ||||
|         for opid in range(num_op): | ||||
|             ## projection | ||||
|             weights = model.get_projected_weights() | ||||
|             proj_mask = torch.ones_like(weights[selected_eid]) | ||||
|             proj_mask[opid] = 0 | ||||
|             weights[selected_eid] = weights[selected_eid] * proj_mask | ||||
|  | ||||
|             ## proj evaluation | ||||
|             if args.proj_crit == 'jacob': | ||||
|                 crit = Jocab_Score(model, input,  target, weights=weights)       | ||||
|             crit_list.append(crit) | ||||
|             op_ids.append(opid) | ||||
|              | ||||
|         best_opid = op_ids[np.nanargmin(crit_list)] | ||||
|  | ||||
|         logging.info('best opid %d', best_opid) | ||||
|         logging.info(crit_list) | ||||
|         return selected_eid, best_opid | ||||
|      | ||||
|     num_edges = model.arch_parameters()[0].shape[0] | ||||
|  | ||||
|     eid_order = order(model, args) | ||||
|     for epoch in range(num_edges): | ||||
|         logging.info('epoch %d', epoch) | ||||
|         logging.info('project') | ||||
|         selected_eid, _ = eid_order.popitem() | ||||
|         selected_eid, best_opid = project(model, args, selected_eid) | ||||
|         model.project_op(selected_eid, best_opid) | ||||
|  | ||||
|     return | ||||
|  | ||||
| # fixed [reverse, order]: discretizes the edges in a fixed order, where in our experiments we discretize from the222input towards the output of the cell struct | ||||
| # random: discretizes the edges in a random order (DARTS-PT) | ||||
| # NOTE: Only this methods allows use other zero-cost proxy metrics  | ||||
| def pt_project(proj_queue, model, args): | ||||
|     def project(model, args): | ||||
|         ## macros,一共6条边,每条边有5个操作 | ||||
|         num_edge, num_op = model.num_edge, model.num_op | ||||
|  | ||||
|         ## select an edge | ||||
|         remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0] | ||||
|         # print('candidate_flags:', model.candidate_flags) | ||||
|         # print(model.candidate_flags) | ||||
|         # 选边的方法 | ||||
|         if args.edge_decision == "random": | ||||
|             # 选出来了一个数组,取其中的一个元素 | ||||
|             selected_eid = np.random.choice(remain_eids, size=1)[0] | ||||
|         elif args.edge_decision == "reverse": | ||||
|             selected_eid = remain_eids[-1] | ||||
|         else: | ||||
|             selected_eid = remain_eids[0] | ||||
|  | ||||
|         ## select the best operation | ||||
|         if args.proj_crit == 'jacob': | ||||
|             crit_idx = 3 | ||||
|             compare = lambda x, y: x < y | ||||
|         else: | ||||
|             crit_idx = 4 | ||||
|             compare = lambda x, y: x < y | ||||
|  | ||||
|         if args.dataset == 'cifar100': | ||||
|             n_classes = 100 | ||||
|         elif args.dataset == 'imagenet16-120': | ||||
|             n_classes = 120 | ||||
|         else: | ||||
|             n_classes = 10 | ||||
|  | ||||
|         best_opid = 0 | ||||
|         crit_extrema = None | ||||
|         crit_list = [] | ||||
|         op_ids = [] | ||||
|         input, target = next(iter(proj_queue)) | ||||
|         for opid in range(num_op): | ||||
|             ## projection | ||||
|             weights = model.get_projected_weights() | ||||
|             proj_mask = torch.ones_like(weights[selected_eid]) | ||||
|             # print(selected_eid, weights[selected_eid]) | ||||
|             proj_mask[opid] = 0 | ||||
|             weights[selected_eid] = weights[selected_eid] * proj_mask | ||||
|  | ||||
|  | ||||
|             ## proj evaluation | ||||
|             if args.proj_crit == 'jacob': | ||||
|                 crit = Jocab_Score(model, input,  target, weights=weights) | ||||
|             else: | ||||
|                 cache_weight = model.proj_weights[selected_eid] | ||||
|                 cache_flag =  model.candidate_flags[selected_eid] | ||||
|  | ||||
|  | ||||
|                 for idx in range(num_op): | ||||
|                     if idx == opid: | ||||
|                         model.proj_weights[selected_eid][opid] = 0 | ||||
|                     else: | ||||
|                         model.proj_weights[selected_eid][idx] = 1.0/num_op | ||||
|  | ||||
|  | ||||
|                 model.candidate_flags[selected_eid] = False | ||||
|                 # print(model.get_projected_weights()) | ||||
|  | ||||
|                 if args.proj_crit == 'comb': | ||||
|                     synflow = predictive.find_measures(model, | ||||
|                                         proj_queue, | ||||
|                                         ('random', 1, n_classes), | ||||
|                                         torch.device("cuda"), | ||||
|                                         measure_names=['synflow']) | ||||
|                     var = predictive.find_measures(model, | ||||
|                                         proj_queue, | ||||
|                                         ('random', 1, n_classes), | ||||
|                                         torch.device("cuda"), | ||||
|                                         measure_names=['var']) | ||||
|                     # print(synflow, var) | ||||
|                     comb = np.log(synflow['synflow'] + 1) / (var['var'] + 0.1) | ||||
|                     measures = {'comb': comb} | ||||
|                 else: | ||||
|                     measures = predictive.find_measures(model, | ||||
|                                              proj_queue, | ||||
|                                              ('random', 1, n_classes), | ||||
|                                              torch.device("cuda"), | ||||
|                                              measure_names=[args.proj_crit]) | ||||
|  | ||||
|                 # print(measures) | ||||
|                 for idx in range(num_op): | ||||
|                     model.proj_weights[selected_eid][idx] = 0 | ||||
|                 model.candidate_flags[selected_eid] = cache_flag | ||||
|                 crit = measures[args.proj_crit] | ||||
|  | ||||
|             crit_list.append(crit) | ||||
|             op_ids.append(opid) | ||||
|  | ||||
|  | ||||
|         best_opid = op_ids[np.nanargmin(crit_list)] | ||||
|         # best_opid = op_ids[np.nanargmax(crit_list)] | ||||
|  | ||||
|         logging.info('best opid %d', best_opid) | ||||
|         logging.info('current edge id %d', selected_eid) | ||||
|         logging.info(crit_list) | ||||
|         return selected_eid, best_opid | ||||
|      | ||||
|     num_edges = model.arch_parameters()[0].shape[0] | ||||
|  | ||||
|     for epoch in range(num_edges): | ||||
|         logging.info('epoch %d', epoch)         | ||||
|         logging.info('project') | ||||
|         selected_eid, best_opid = project(model, args) | ||||
|         model.project_op(selected_eid, best_opid) | ||||
|  | ||||
|     return | ||||
|  | ||||
| def tenas_project(proj_queue, model, model_thin, args): | ||||
|     def project(model, args): | ||||
|         ## macros | ||||
|         num_edge, num_op = model.num_edge, model.num_op | ||||
|  | ||||
|         ##get remain eid numbers  | ||||
|         remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0] | ||||
|         compare = lambda x, y : x < y | ||||
|  | ||||
|         ntks = [] | ||||
|         lrs = [] | ||||
|         edge_op_id = [] | ||||
|         best_eid = None | ||||
|          | ||||
|         if args.proj_crit == 'tenas': | ||||
|             lrc_model = Linear_Region_Collector(input_size=(1000, 1, 3, 3), sample_batch=3, dataset=args.dataset, data_path=args.data, seed=args.seed) | ||||
|         for eid in remain_eids: | ||||
|             for opid in range(num_op): | ||||
|                 # projection | ||||
|                 weights = model.get_projected_weights() | ||||
|                 proj_mask = torch.ones_like(weights[eid]) | ||||
|                 proj_mask[opid] = 0 | ||||
|                 weights[eid] = weights[eid] * proj_mask | ||||
|  | ||||
|                 ## proj evaluation | ||||
|                 if args.proj_crit == 'tenas': | ||||
|                     lrc_model.reinit(ori_models=[model_thin], seed=args.seed, weights=weights) | ||||
|                     lr = lrc_model.forward_batch_sample() | ||||
|                     lrc_model.clear() | ||||
|                     ntk = get_ntk_n(proj_queue, [model], recalbn=0, train_mode=True, num_batch=1, weights=weights) | ||||
|                     ntks.append(ntk) | ||||
|                     lrs.append(lr) | ||||
|                     edge_op_id.append('{}:{}'.format(eid, opid)) | ||||
|         print('ntls', ntks) | ||||
|         print('lrs', lrs) | ||||
|         ntks_ranks = ss.rankdata(ntks) | ||||
|         lrs_ranks = ss.rankdata(lrs) | ||||
|         ntks_ranks = len(ntks_ranks) - ntks_ranks.astype(int) | ||||
|         op_ranks = [] | ||||
|         for i in range(len(edge_op_id)): | ||||
|             op_ranks.append(ntks_ranks[i]+lrs_ranks[i]) | ||||
|          | ||||
|         best_op_index = edge_op_id[np.nanargmin(op_ranks[0:num_op])] | ||||
|         best_eid, best_opid = [int(x) for x in best_op_index.split(':')] | ||||
|  | ||||
|         logging.info(op_ranks) | ||||
|         logging.info('best eid %d', best_eid) | ||||
|         logging.info('best opid %d', best_opid) | ||||
|         return best_eid, best_opid | ||||
|     num_edges = model.arch_parameters()[0].shape[0] | ||||
|  | ||||
|     for epoch in range(num_edges): | ||||
|         logging.info('epoch %d', epoch)         | ||||
|         logging.info('project') | ||||
|         selected_eid, best_opid = project(model, args) | ||||
|         model.project_op(selected_eid, best_opid) | ||||
|  | ||||
|     return | ||||
|  | ||||
| #new methods  | ||||
| #Randomly propose candidate of networks and transfer it to supernet, then perform global op selection in this subspace | ||||
| def shrink_pt_project(proj_queue, model, args): | ||||
|     def project(model, args): | ||||
|         ## macros | ||||
|         num_edge, num_op = model.num_edge, model.num_op | ||||
|  | ||||
|         ## select an edge | ||||
|         remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0] | ||||
|         selected_eid = np.random.choice(remain_eids, size=1)[0] | ||||
|  | ||||
|  | ||||
|         ## select the best operation | ||||
|         if args.proj_crit == 'jacob': | ||||
|             crit_idx = 3 | ||||
|             compare = lambda x, y: x < y | ||||
|         else: | ||||
|             crit_idx = 4 | ||||
|             compare = lambda x, y: x < y | ||||
|  | ||||
|         if args.dataset == 'cifar100': | ||||
|             n_classes = 100 | ||||
|         elif args.dataset == 'imagenet16-120': | ||||
|             n_classes = 120 | ||||
|         else: | ||||
|             n_classes = 10 | ||||
|  | ||||
|         best_opid = 0 | ||||
|         crit_extrema = None | ||||
|         crit_list = [] | ||||
|         op_ids = [] | ||||
|         input, target = next(iter(proj_queue)) | ||||
|         for opid in range(num_op): | ||||
|             ## projection | ||||
|             weights = model.get_projected_weights() | ||||
|             proj_mask = torch.ones_like(weights[selected_eid]) | ||||
|             proj_mask[opid] = 0 | ||||
|             weights[selected_eid] = weights[selected_eid] * proj_mask | ||||
|  | ||||
|             ## proj evaluation | ||||
|             if args.proj_crit == 'jacob': | ||||
|                 crit = Jocab_Score(model, input,  target, weights=weights) | ||||
|             else: | ||||
|                 cache_weight = model.proj_weights[selected_eid] | ||||
|                 cache_flag =  model.candidate_flags[selected_eid] | ||||
|  | ||||
|                 for idx in range(num_op): | ||||
|                     if idx == opid: | ||||
|                         model.proj_weights[selected_eid][opid] = 0 | ||||
|                     else: | ||||
|                         model.proj_weights[selected_eid][idx] = 1.0/num_op | ||||
|                 model.candidate_flags[selected_eid] = False | ||||
|                  | ||||
|                 measures = predictive.find_measures(model, | ||||
|                                     train_queue, | ||||
|                                     ('random', 1, n_classes),  | ||||
|                                     torch.device("cuda"), | ||||
|                                     measure_names=[args.proj_crit]) | ||||
|                 for idx in range(num_op): | ||||
|                     model.proj_weights[selected_eid][idx] = 0 | ||||
|                 model.candidate_flags[selected_eid] = cache_flag | ||||
|                 crit = measures[args.proj_crit] | ||||
|  | ||||
|             crit_list.append(crit) | ||||
|             op_ids.append(opid) | ||||
|              | ||||
|         best_opid = op_ids[np.nanargmin(crit_list)] | ||||
|  | ||||
|         logging.info('best opid %d', best_opid) | ||||
|         logging.info('current edge id %d', selected_eid) | ||||
|         logging.info(crit_list) | ||||
|         return selected_eid, best_opid | ||||
|      | ||||
|     def global_project(model, args): | ||||
|         ## macros | ||||
|         num_edge, num_op = model.num_edge, model.num_op | ||||
|  | ||||
|         ##get remain eid numbers  | ||||
|         remain_eids = torch.nonzero(model.subspace_candidate_flags).cpu().numpy().T[0] | ||||
|         compare = lambda x, y : x < y | ||||
|  | ||||
|         crit_extrema = None | ||||
|         best_eid = None | ||||
|         best_opid = None | ||||
|         input, target = next(iter(proj_queue)) | ||||
|         for eid in remain_eids: | ||||
|             remain_oids = torch.nonzero(model.proj_weights[eid]).cpu().numpy().T[0] | ||||
|             for opid in remain_oids: | ||||
|                 # projection | ||||
|                 weights = model.get_projected_weights() | ||||
|                 proj_mask = torch.ones_like(weights[eid]) | ||||
|                 proj_mask[opid] = 0 | ||||
|                 weights[eid] = weights[eid] * proj_mask | ||||
|                 ## proj evaluation | ||||
|                 if args.proj_crit == 'jacob': | ||||
|                     valid_stats = Jocab_Score(model, input, target, weights=weights) | ||||
|                     crit = valid_stats | ||||
|  | ||||
|                 if crit_extrema is None or compare(crit, crit_extrema): | ||||
|                     crit_extrema = crit | ||||
|                     best_opid = opid | ||||
|                     best_eid = eid | ||||
|  | ||||
|  | ||||
|         logging.info('best eid %d', best_eid) | ||||
|         logging.info('best opid %d', best_opid) | ||||
|         model.subspace_candidate_flags[best_eid] = False | ||||
|         proj_mask = torch.zeros_like(model.proj_weights[best_eid]) | ||||
|         model.proj_weights[best_eid] = model.proj_weights[best_eid] * proj_mask | ||||
|         model.proj_weights[best_eid][best_opid] = 1 | ||||
|         return best_eid, best_opid | ||||
|  | ||||
|     num_edges = model.arch_parameters()[0].shape[0] | ||||
|  | ||||
|     #subspace | ||||
|     logging.info('Start subspace proposal') | ||||
|     subspace = copy.deepcopy(model.proj_weights) | ||||
|     for i in range(20): | ||||
|         model.reset_arch_parameters() | ||||
|         for epoch in range(num_edges): | ||||
|             logging.info('epoch %d', epoch)         | ||||
|             logging.info('project') | ||||
|             selected_eid, best_opid = project(model, args) | ||||
|             model.project_op(selected_eid, best_opid) | ||||
|         subspace += model.proj_weights | ||||
|      | ||||
|     model.reset_arch_parameters() | ||||
|     subspace = torch.gt(subspace, 0).int().float() | ||||
|     subspace = f.normalize(subspace, p=1, dim=1) | ||||
|     model.proj_weights += subspace | ||||
|     for i in range(num_edges): | ||||
|         model.candidate_flags[i] = False | ||||
|     logging.info('Start final search in subspace') | ||||
|     logging.info(subspace) | ||||
|  | ||||
|     model.subspace_candidate_flags = torch.tensor(len(model._arch_parameters) * [True], requires_grad=False, dtype=torch.bool).cuda() | ||||
|     for epoch in range(num_edges): | ||||
|         logging.info('epoch %d', epoch)  | ||||
|         logging.info('project') | ||||
|         selected_eid, best_opid = global_project(model, args) | ||||
|         model.printing(logging) | ||||
|         #model.project_op(selected_eid, best_opid) | ||||
|     return | ||||
							
								
								
									
										270
									
								
								nasbench201/linear_region.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										270
									
								
								nasbench201/linear_region.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,270 @@ | ||||
| import os.path as osp | ||||
| import numpy as np | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torch.utils.data | ||||
| import torchvision.transforms as transforms | ||||
| import torchvision.datasets as dset | ||||
| from pdb import set_trace as bp | ||||
| from operator import mul | ||||
| from functools import reduce | ||||
| import copy | ||||
| Dataset2Class = {'cifar10': 10, | ||||
|                  'cifar100': 100, | ||||
|                  'imagenet-1k-s': 1000, | ||||
|                  'imagenet-1k': 1000, | ||||
| } | ||||
|  | ||||
|  | ||||
| class CUTOUT(object): | ||||
|  | ||||
|     def __init__(self, length): | ||||
|         self.length = length | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return ('{name}(length={length})'.format(name=self.__class__.__name__, **self.__dict__)) | ||||
|  | ||||
|     def __call__(self, img): | ||||
|         h, w = img.size(1), img.size(2) | ||||
|         mask = np.ones((h, w), np.float32) | ||||
|         y = np.random.randint(h) | ||||
|         x = np.random.randint(w) | ||||
|  | ||||
|         y1 = np.clip(y - self.length // 2, 0, h) | ||||
|         y2 = np.clip(y + self.length // 2, 0, h) | ||||
|         x1 = np.clip(x - self.length // 2, 0, w) | ||||
|         x2 = np.clip(x + self.length // 2, 0, w) | ||||
|  | ||||
|         mask[y1: y2, x1: x2] = 0. | ||||
|         mask = torch.from_numpy(mask) | ||||
|         mask = mask.expand_as(img) | ||||
|         img *= mask | ||||
|         return img | ||||
|  | ||||
|  | ||||
| imagenet_pca = { | ||||
|         'eigval': np.asarray([0.2175, 0.0188, 0.0045]), | ||||
|         'eigvec': np.asarray([ | ||||
|                 [-0.5675, 0.7192, 0.4009], | ||||
|                 [-0.5808, -0.0045, -0.8140], | ||||
|                 [-0.5836, -0.6948, 0.4203], | ||||
|         ]) | ||||
| } | ||||
|  | ||||
|  | ||||
| class RandChannel(object): | ||||
|     # randomly pick channels from input | ||||
|     def __init__(self, num_channel): | ||||
|         self.num_channel = num_channel | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return ('{name}(num_channel={num_channel})'.format(name=self.__class__.__name__, **self.__dict__)) | ||||
|  | ||||
|     def __call__(self, img): | ||||
|         channel = img.size(0) | ||||
|         channel_choice = sorted(np.random.choice(list(range(channel)), size=self.num_channel, replace=False)) | ||||
|         return torch.index_select(img, 0, torch.Tensor(channel_choice).long()) | ||||
|  | ||||
|  | ||||
| def get_datasets(name, root, input_size, cutout=-1): | ||||
|     assert len(input_size) in [3, 4] | ||||
|     if len(input_size) == 4: | ||||
|         input_size = input_size[1:] | ||||
|     assert input_size[1] == input_size[2] | ||||
|  | ||||
|     if name == 'cifar10': | ||||
|         mean = [x / 255 for x in [125.3, 123.0, 113.9]] | ||||
|         std  = [x / 255 for x in [63.0, 62.1, 66.7]] | ||||
|     elif name == 'cifar100': | ||||
|         mean = [x / 255 for x in [129.3, 124.1, 112.4]] | ||||
|         std  = [x / 255 for x in [68.2, 65.4, 70.4]] | ||||
|     elif name.startswith('imagenet-1k'): | ||||
|         mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] | ||||
|     elif name.startswith('ImageNet16'): | ||||
|         mean = [x / 255 for x in [122.68, 116.66, 104.01]] | ||||
|         std  = [x / 255 for x in [63.22,  61.26 , 65.09]] | ||||
|     else: | ||||
|         raise TypeError("Unknow dataset : {:}".format(name)) | ||||
|     #ßprint(input_size) | ||||
|     # Data Argumentation | ||||
|     if name == 'cifar10' or name == 'cifar100': | ||||
|         lists = [transforms.RandomCrop(input_size[1], padding=4), transforms.ToTensor(), transforms.Normalize(mean, std), RandChannel(input_size[0])] | ||||
|         if cutout > 0 : lists += [CUTOUT(cutout)] | ||||
|         train_transform = transforms.Compose(lists) | ||||
|         test_transform  = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]) | ||||
|     elif name.startswith('ImageNet16'): | ||||
|         lists = [transforms.RandomCrop(input_size[1], padding=4), transforms.ToTensor(), transforms.Normalize(mean, std), RandChannel(input_size[0])] | ||||
|         if cutout > 0 : lists += [CUTOUT(cutout)] | ||||
|         train_transform = transforms.Compose(lists) | ||||
|         test_transform  = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]) | ||||
|     elif name.startswith('imagenet-1k'): | ||||
|         normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | ||||
|         if name == 'imagenet-1k': | ||||
|             xlists    = [] | ||||
|             xlists.append(transforms.Resize((32, 32), interpolation=2)) | ||||
|             xlists.append(transforms.RandomCrop(input_size[1], padding=0)) | ||||
|         elif name == 'imagenet-1k-s': | ||||
|             xlists = [transforms.RandomResizedCrop(32, scale=(0.2, 1.0))] | ||||
|             xlists = [] | ||||
|         else: raise ValueError('invalid name : {:}'.format(name)) | ||||
|         xlists.append(transforms.ToTensor()) | ||||
|         xlists.append(normalize) | ||||
|         xlists.append(RandChannel(input_size[0])) | ||||
|         train_transform = transforms.Compose(xlists) | ||||
|         test_transform = transforms.Compose([transforms.Resize(40), transforms.CenterCrop(32), transforms.ToTensor(), normalize]) | ||||
|     else: | ||||
|         raise TypeError("Unknow dataset : {:}".format(name)) | ||||
|  | ||||
|     if name == 'cifar10': | ||||
|         train_data = dset.CIFAR10 (root, train=True , transform=train_transform, download=True) | ||||
|         test_data  = dset.CIFAR10 (root, train=False, transform=test_transform , download=True) | ||||
|         assert len(train_data) == 50000 and len(test_data) == 10000 | ||||
|     elif name == 'cifar100': | ||||
|         train_data = dset.CIFAR100(root, train=True , transform=train_transform, download=True) | ||||
|         test_data  = dset.CIFAR100(root, train=False, transform=test_transform , download=True) | ||||
|         assert len(train_data) == 50000 and len(test_data) == 10000 | ||||
|     elif name.startswith('imagenet-1k'): | ||||
|         train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform) | ||||
|         test_data  = dset.ImageFolder(osp.join(root, 'val'),   test_transform) | ||||
|     else: raise TypeError("Unknow dataset : {:}".format(name)) | ||||
|  | ||||
|     class_num = Dataset2Class[name] | ||||
|     return train_data, test_data, class_num | ||||
|  | ||||
|  | ||||
| class LinearRegionCount(object): | ||||
|     """Computes and stores the average and current value""" | ||||
|     def __init__(self, n_samples): | ||||
|         self.ActPattern = {} | ||||
|         self.n_LR = -1 | ||||
|         self.n_samples = n_samples | ||||
|         self.ptr = 0 | ||||
|         self.activations = None | ||||
|  | ||||
|     @torch.no_grad() | ||||
|     def update2D(self, activations): | ||||
|         n_batch = activations.size()[0] | ||||
|         n_neuron = activations.size()[1] | ||||
|         self.n_neuron = n_neuron | ||||
|         if self.activations is None: | ||||
|             self.activations = torch.zeros(self.n_samples, n_neuron).cuda() | ||||
|         self.activations[self.ptr:self.ptr+n_batch] = torch.sign(activations)  # after ReLU | ||||
|         self.ptr += n_batch | ||||
|  | ||||
|     @torch.no_grad() | ||||
|     def calc_LR(self): | ||||
|         res = torch.matmul(self.activations.half(), (1-self.activations).T.half()) # each element in res: A * (1 - B) | ||||
|         res += res.T # make symmetric, each element in res: A * (1 - B) + (1 - A) * B, a non-zero element indicate a pair of two different linear regions | ||||
|         res = 1 - torch.sign(res) # a non-zero element now indicate two linear regions are identical | ||||
|         res = res.sum(1) # for each sample's linear region: how many identical regions from other samples | ||||
|         res = 1. / res.float() # contribution of each redudant (repeated) linear region | ||||
|         self.n_LR = res.sum().item() # sum of unique regions (by aggregating contribution of all regions) | ||||
|         del self.activations, res | ||||
|         self.activations = None | ||||
|         torch.cuda.empty_cache() | ||||
|  | ||||
|     @torch.no_grad() | ||||
|     def update1D(self, activationList): | ||||
|         code_string = '' | ||||
|         for key, value in activationList.items(): | ||||
|             n_neuron = value.size()[0] | ||||
|             for i in range(n_neuron): | ||||
|                 if value[i] > 0: | ||||
|                     code_string += '1' | ||||
|                 else: | ||||
|                     code_string += '0' | ||||
|         if code_string not in self.ActPattern: | ||||
|             self.ActPattern[code_string] = 1 | ||||
|  | ||||
|     def getLinearReginCount(self): | ||||
|         if self.n_LR == -1: | ||||
|             self.calc_LR() | ||||
|         return self.n_LR | ||||
|  | ||||
|  | ||||
| class Linear_Region_Collector: | ||||
|     def __init__(self, models=[], input_size=(64, 3, 32, 32), sample_batch=100, dataset='cifar100', data_path=None, seed=0): | ||||
|         self.models = [] | ||||
|         self.input_size = input_size  # BCHW | ||||
|         self.sample_batch = sample_batch | ||||
|         self.input_numel = reduce(mul, self.input_size, 1) | ||||
|         self.interFeature = [] | ||||
|         self.dataset = dataset | ||||
|         self.data_path = data_path | ||||
|         self.seed = seed | ||||
|         self.reinit(models, input_size, sample_batch, seed) | ||||
|          | ||||
|     def reinit(self, ori_models=None, input_size=None, sample_batch=None, seed=None, weights=None): | ||||
|         models = [] | ||||
|         for network in ori_models: | ||||
|             network = network.cuda() | ||||
|             net = copy.deepcopy(network) | ||||
|             net.proj_weights = weights | ||||
|             num_edge, num_op = net.num_edge, net.num_op | ||||
|             for i in range(num_edge): | ||||
|                 net.candidate_flags[i] = False | ||||
|                 net.eval() | ||||
|             models.append(net) | ||||
|  | ||||
|         if models is not None: | ||||
|             assert isinstance(models, list) | ||||
|             del self.models | ||||
|             self.models = models | ||||
|             for model in self.models: | ||||
|                 self.register_hook(model) | ||||
|                 device = torch.cuda.current_device() | ||||
|                 model = model.cuda(device=device) | ||||
|             self.LRCounts = [LinearRegionCount(self.input_size[0]*self.sample_batch) for _ in range(len(models))] | ||||
|         if input_size is not None or sample_batch is not None: | ||||
|             if input_size is not None: | ||||
|                 self.input_size = input_size  # BCHW | ||||
|                 self.input_numel = reduce(mul, self.input_size, 1) | ||||
|             if sample_batch is not None: | ||||
|                 self.sample_batch = sample_batch | ||||
|             if self.data_path is not None: | ||||
|                 self.train_data, _, class_num = get_datasets(self.dataset, self.data_path, self.input_size, -1) | ||||
|                 self.train_loader = torch.utils.data.DataLoader(self.train_data, batch_size=self.input_size[0], num_workers=16, pin_memory=True, drop_last=True, shuffle=True) | ||||
|                 self.loader = iter(self.train_loader) | ||||
|         if seed is not None and seed != self.seed: | ||||
|             self.seed = seed | ||||
|             torch.manual_seed(seed) | ||||
|             torch.cuda.manual_seed(seed) | ||||
|         del self.interFeature | ||||
|         self.interFeature = [] | ||||
|         torch.cuda.empty_cache() | ||||
|  | ||||
|     def clear(self): | ||||
|         self.LRCounts = [LinearRegionCount(self.input_size[0]*self.sample_batch) for _ in range(len(self.models))] | ||||
|         del self.interFeature | ||||
|         self.interFeature = [] | ||||
|         torch.cuda.empty_cache() | ||||
|  | ||||
|     def register_hook(self, model): | ||||
|         for m in model.modules(): | ||||
|             if isinstance(m, nn.ReLU): | ||||
|                 m.register_forward_hook(hook=self.hook_in_forward) | ||||
|  | ||||
|     def hook_in_forward(self, module, input, output): | ||||
|         if isinstance(input, tuple) and len(input[0].size()) == 4: | ||||
|             self.interFeature.append(output.detach())  # for ReLU | ||||
|  | ||||
|     def forward_batch_sample(self): | ||||
|         for _ in range(self.sample_batch): | ||||
|             try: | ||||
|                 inputs, targets = self.loader.next() | ||||
|             except Exception: | ||||
|                 del self.loader | ||||
|                 self.loader = iter(self.train_loader) | ||||
|                 inputs, targets = self.loader.next() | ||||
|             for model, LRCount in zip(self.models, self.LRCounts): | ||||
|                 self.forward(model, LRCount, inputs) | ||||
|         output = [LRCount.getLinearReginCount() for LRCount in self.LRCounts] | ||||
|         return output | ||||
|  | ||||
|     def forward(self, model, LRCount, input_data): | ||||
|         self.interFeature = [] | ||||
|         with torch.no_grad(): | ||||
|             model.forward(input_data.cuda()) | ||||
|             if len(self.interFeature) == 0: return | ||||
|             feature_data = torch.cat([f.view(input_data.size(0), -1) for f in self.interFeature], 1) | ||||
|             LRCount.update2D(feature_data) | ||||
							
								
								
									
										245
									
								
								nasbench201/networks_proposal.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										245
									
								
								nasbench201/networks_proposal.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,245 @@ | ||||
| import os | ||||
| import sys | ||||
| sys.path.insert(0, '../') | ||||
| import time | ||||
| import glob | ||||
| import json | ||||
| import shutil | ||||
| import logging | ||||
| import argparse | ||||
| import numpy as np | ||||
|  | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torch.utils | ||||
| import torchvision.datasets as dset | ||||
| import torch.backends.cudnn as cudnn | ||||
| from torch.utils.tensorboard import SummaryWriter | ||||
| from torch.autograd import Variable | ||||
|  | ||||
| import nasbench201.utils as ig_utils | ||||
| from nasbench201.search_model_darts_proj import TinyNetworkDartsProj | ||||
| from nasbench201.cell_operations import SearchSpaceNames | ||||
| from nasbench201.init_projection import pt_project, global_op_greedy_pt_project, global_op_once_pt_project, global_edge_greedy_pt_project, global_edge_once_pt_project, shrink_pt_project, tenas_project | ||||
| from nas_201_api import NASBench201API as API | ||||
|  | ||||
| torch.set_printoptions(precision=4, sci_mode=False) | ||||
| np.set_printoptions(precision=4, suppress=True) | ||||
|  | ||||
|  | ||||
| parser = argparse.ArgumentParser("sota") | ||||
| # data related  | ||||
| parser.add_argument('--data', type=str, default='../data', help='location of the data corpus') | ||||
| parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100', 'imagenet16-120'], help='choose dataset') | ||||
| parser.add_argument('--train_portion', type=float, default=0.5, help='portion of training data') | ||||
| parser.add_argument('--batch_size', type=int, default=64, help='batch size for alpha') | ||||
| parser.add_argument('--cutout', action='store_true', default=True, help='use cutout') | ||||
| parser.add_argument('--cutout_length', type=int, default=16, help='cutout length') | ||||
| parser.add_argument('--cutout_prob', type=float, default=1.0, help='cutout probability') | ||||
| parser.add_argument('--seed', type=int, default=2, help='random seed') | ||||
|  | ||||
| #search space setting | ||||
| parser.add_argument('--search_space', type=str, default='nas-bench-201') | ||||
|  | ||||
| parser.add_argument('--pool_size', type=int, default=100, help='number of model to proposed') | ||||
| parser.add_argument('--init_channels', type=int, default=16, help='num of init channels') | ||||
| parser.add_argument('--layers', type=int, default=8, help='total number of layers') | ||||
|  | ||||
| #system configurations | ||||
| parser.add_argument('--gpu', type=str, default='auto', help='gpu device id') | ||||
| parser.add_argument('--save', type=str, default='exp', help='experiment name') | ||||
|  | ||||
| #default opt setting for model | ||||
| parser.add_argument('--learning_rate', type=float, default=0.025, help='init learning rate') | ||||
| parser.add_argument('--learning_rate_min', type=float, default=0.001, help='min learning rate') | ||||
| parser.add_argument('--momentum', type=float, default=0.9, help='momentum') | ||||
| parser.add_argument('--nesterov', action='store_true', default=True, help='using nestrov momentum for SGD') | ||||
| parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay') | ||||
| parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping') | ||||
|  | ||||
| #### common | ||||
| parser.add_argument('--fast', action='store_true', default=True, help='skip loading api which is slow') | ||||
|  | ||||
| #### projection | ||||
| parser.add_argument('--edge_decision', type=str, default='random', choices=['random','reverse', 'order', 'global_op_greedy', 'global_op_once', 'global_edge_greedy', 'global_edge_once', 'shrink_pt_project'], help='which edge to be projected next') | ||||
| parser.add_argument('--proj_crit', type=str, default="comb", choices=['loss', 'acc', 'jacob', 'snip', 'fisher', 'synflow', 'grad_norm', 'grasp', 'jacob_cov','tenas', 'var', 'cor', 'norm', 'comb', 'meco'], help='criteria for projection') | ||||
| args = parser.parse_args() | ||||
|  | ||||
| #### args augment | ||||
| expid = args.save | ||||
| args.save = '../experiments/nas-bench-201/prop-{}-{}-{}'.format(args.save, args.seed, args.pool_size) | ||||
| if not args.dataset == 'cifar10': | ||||
|     args.save += '-' + args.dataset | ||||
| if not args.edge_decision == 'random': | ||||
|     args.save += '-' + args.edge_decision | ||||
| if not args.proj_crit == 'jacob': | ||||
|     args.save += '-' + args.proj_crit | ||||
|  | ||||
| #### logging | ||||
| scripts_to_save = glob.glob('*.py') \ | ||||
|                   # + ['../exp_scripts/{}.sh'.format(expid)] | ||||
| if os.path.exists(args.save): | ||||
|     if input("WARNING: {} exists, override?[y/n]".format(args.save)) == 'y': | ||||
|         print('proceed to override saving directory') | ||||
|         shutil.rmtree(args.save) | ||||
|     else: | ||||
|         exit(0) | ||||
| ig_utils.create_exp_dir(args.save, scripts_to_save=scripts_to_save) | ||||
|  | ||||
| log_format = '%(asctime)s %(message)s' | ||||
| logging.basicConfig(stream=sys.stdout, level=logging.INFO, | ||||
|     format=log_format, datefmt='%m/%d %I:%M:%S %p') | ||||
|  | ||||
| log_file = 'log.txt' | ||||
| log_path = os.path.join(args.save, log_file) | ||||
| logging.info('======> log filename: %s', log_file) | ||||
|  | ||||
| if os.path.exists(log_path): | ||||
|     if input("WARNING: {} exists, override?[y/n]".format(log_file)) == 'y': | ||||
|         print('proceed to override log file directory') | ||||
|     else: | ||||
|         exit(0) | ||||
|  | ||||
| fh = logging.FileHandler(log_path, mode='w') | ||||
| fh.setFormatter(logging.Formatter(log_format)) | ||||
| logging.getLogger().addHandler(fh) | ||||
| writer = SummaryWriter(args.save + '/runs') | ||||
|  | ||||
| #### macros | ||||
| if args.dataset == 'cifar100': | ||||
|     n_classes = 100 | ||||
| elif args.dataset == 'imagenet16-120': | ||||
|     n_classes = 120 | ||||
| else: | ||||
|     n_classes = 10 | ||||
|  | ||||
| def main(): | ||||
|     torch.set_num_threads(3) | ||||
|     if not torch.cuda.is_available(): | ||||
|         logging.info('no gpu device available') | ||||
|         sys.exit(1) | ||||
|  | ||||
|     np.random.seed(args.seed) | ||||
|     gpu = ig_utils.pick_gpu_lowest_memory() if args.gpu == 'auto' else int(args.gpu) | ||||
|     torch.cuda.set_device(gpu) | ||||
|     cudnn.benchmark = True | ||||
|     torch.manual_seed(args.seed) | ||||
|     cudnn.enabled = True | ||||
|     torch.cuda.manual_seed(args.seed) | ||||
|     logging.info("args = %s", args) | ||||
|     logging.info('gpu device = %d' % gpu) | ||||
|  | ||||
|     #### model | ||||
|     criterion = nn.CrossEntropyLoss() | ||||
|     search_space = SearchSpaceNames[args.search_space] | ||||
|  | ||||
|     # 初始化超网络 | ||||
|     model = TinyNetworkDartsProj(C=args.init_channels, N=5, max_nodes=4, num_classes=n_classes, criterion=criterion, search_space=search_space, args=args) | ||||
|     model_thin = TinyNetworkDartsProj(C=args.init_channels, N=5, max_nodes=4, num_classes=n_classes, criterion=criterion, search_space=search_space, args=args, stem_channels=1) | ||||
|     model = model.cuda() | ||||
|     model_thin = model_thin.cuda() | ||||
|     logging.info("param size = %fMB", ig_utils.count_parameters_in_MB(model)) | ||||
|  | ||||
|     #### data | ||||
|     if args.dataset == 'cifar10': | ||||
|         train_transform, valid_transform = ig_utils._data_transforms_cifar10(args) | ||||
|         train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform) | ||||
|         valid_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform) | ||||
|     elif args.dataset == 'cifar100': | ||||
|         train_transform, valid_transform = ig_utils._data_transforms_cifar100(args) | ||||
|         train_data = dset.CIFAR100(root=args.data, train=True, download=True, transform=train_transform) | ||||
|         valid_data = dset.CIFAR100(root=args.data, train=False, download=True, transform=valid_transform) | ||||
|     elif args.dataset == 'imagenet16-120': | ||||
|         import torchvision.transforms as transforms | ||||
|         from nasbench201.DownsampledImageNet import ImageNet16 | ||||
|         mean = [x / 255 for x in [122.68, 116.66, 104.01]] | ||||
|         std = [x / 255 for x in [63.22,  61.26, 65.09]] | ||||
|         lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(16, padding=2), transforms.ToTensor(), transforms.Normalize(mean, std)] | ||||
|         train_transform = transforms.Compose(lists) | ||||
|         train_data = ImageNet16(root=os.path.join(args.data,'imagenet16'), train=True, transform=train_transform, use_num_of_class_only=120) | ||||
|         valid_data = ImageNet16(root=os.path.join(args.data,'imagenet16'), train=False, transform=train_transform, use_num_of_class_only=120) | ||||
|         assert len(train_data) == 151700 | ||||
|  | ||||
|     num_train = len(train_data) | ||||
|     indices = list(range(num_train)) | ||||
|     split = int(np.floor(args.train_portion * num_train)) | ||||
|  | ||||
|     train_queue = torch.utils.data.DataLoader( | ||||
|         train_data, batch_size=args.batch_size, | ||||
|         sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]), | ||||
|         pin_memory=True) | ||||
|  | ||||
|  | ||||
|     #format network pool diction | ||||
|     networks_pool={} | ||||
|     networks_pool['search_space'] = args.search_space | ||||
|     networks_pool['dataset'] = args.dataset | ||||
|     networks_pool['networks'] = [] | ||||
|     networks_pool['pool_size'] = args.pool_size  | ||||
|     #### architecture selection / projection | ||||
|     for i in range(args.pool_size): | ||||
|         network_info={} | ||||
|         logging.info('{} MODEL HAS SEARCHED'.format(i+1)) | ||||
|         if args.edge_decision == 'global_op_greedy': | ||||
|             global_op_greedy_pt_project(train_queue, model, args) | ||||
|         elif args.edge_decision == 'global_op_once':  | ||||
|             global_op_once_pt_project(train_queue, model, args) | ||||
|         elif args.edge_decision == 'global_edge_greedy': | ||||
|             global_edge_greedy_pt_project(train_queue, model, args) | ||||
|         elif args.edge_decision == 'global_edge_once': | ||||
|             global_edge_once_pt_project(train_queue, model, args) | ||||
|         elif args.edge_decision == 'shrink_pt_project': | ||||
|             shrink_pt_project(train_queue, model, args) | ||||
|             api = API('../data/NAS-Bench-201-v1_0-e61699.pth') | ||||
|             cifar10_train, cifar10_test, cifar100_train, cifar100_valid, \ | ||||
|                 cifar100_test, imagenet16_train, imagenet16_valid, imagenet16_test = query(api, model.genotype().tostr(), logging) | ||||
|         else: | ||||
|             if args.proj_crit == 'jacob': | ||||
|                 pt_project(train_queue, model, args) | ||||
|             else: | ||||
|                 pt_project(train_queue, model, args) | ||||
|                 # tenas_project(train_queue, model, model_thin, args) | ||||
|  | ||||
|         network_info['id'] = str(i) | ||||
|         network_info['genotype'] = model.genotype().tostr() | ||||
|         networks_pool['networks'].append(network_info) | ||||
|         model.reset_arch_parameters() | ||||
|      | ||||
|     with open(os.path.join(args.save,'networks_pool.json'), 'w') as save_file: | ||||
|         json.dump(networks_pool, save_file) | ||||
|  | ||||
|  | ||||
| #### util functions | ||||
| def distill(result): | ||||
|     result = result.split('\n') | ||||
|     cifar10 = result[5].replace(' ', '').split(':') | ||||
|     cifar100 = result[7].replace(' ', '').split(':') | ||||
|     imagenet16 = result[9].replace(' ', '').split(':') | ||||
|  | ||||
|     cifar10_train = float(cifar10[1].strip(',test')[-7:-2].strip('=')) | ||||
|     cifar10_test = float(cifar10[2][-7:-2].strip('=')) | ||||
|     cifar100_train = float(cifar100[1].strip(',valid')[-7:-2].strip('=')) | ||||
|     cifar100_valid = float(cifar100[2].strip(',test')[-7:-2].strip('=')) | ||||
|     cifar100_test = float(cifar100[3][-7:-2].strip('=')) | ||||
|     imagenet16_train = float(imagenet16[1].strip(',valid')[-7:-2].strip('=')) | ||||
|     imagenet16_valid = float(imagenet16[2].strip(',test')[-7:-2].strip('=')) | ||||
|     imagenet16_test = float(imagenet16[3][-7:-2].strip('=')) | ||||
|  | ||||
|     return cifar10_train, cifar10_test, cifar100_train, cifar100_valid, \ | ||||
|         cifar100_test, imagenet16_train, imagenet16_valid, imagenet16_test | ||||
|  | ||||
|  | ||||
| def query(api, genotype, logging): | ||||
|     result = api.query_by_arch(genotype, hp='200') | ||||
|     logging.info('{:}'.format(result)) | ||||
|     cifar10_train, cifar10_test, cifar100_train, cifar100_valid, \ | ||||
|         cifar100_test, imagenet16_train, imagenet16_valid, imagenet16_test = distill(result) | ||||
|     logging.info('cifar10 train %f test %f', cifar10_train, cifar10_test) | ||||
|     logging.info('cifar100 train %f valid %f test %f', cifar100_train, cifar100_valid, cifar100_test) | ||||
|     logging.info('imagenet16 train %f valid %f test %f', imagenet16_train, imagenet16_valid, imagenet16_test) | ||||
|     return cifar10_train, cifar10_test, cifar100_train, cifar100_valid, \ | ||||
|            cifar100_test, imagenet16_train, imagenet16_valid, imagenet16_test | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     main() | ||||
							
								
								
									
										113
									
								
								nasbench201/op_score.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										113
									
								
								nasbench201/op_score.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,113 @@ | ||||
| import gc | ||||
| import numpy as np | ||||
| import os | ||||
| import sys | ||||
| import torch | ||||
| import torch.nn.functional as f | ||||
| from operator import mul | ||||
| from functools import reduce | ||||
| import copy | ||||
| sys.path.insert(0, '../') | ||||
|  | ||||
| def Jocab_Score(ori_model, input, target, weights=None): | ||||
|     model = copy.deepcopy(ori_model) | ||||
|     model.eval() | ||||
|     model.proj_weights = weights | ||||
|     num_edge, num_op = model.num_edge, model.num_op | ||||
|     for i in range(num_edge): | ||||
|         model.candidate_flags[i] = False | ||||
|     batch_size = input.shape[0] | ||||
|     model.K = torch.zeros(batch_size, batch_size).cuda() | ||||
|  | ||||
|     def counting_forward_hook(module, inp, out): | ||||
|         try: | ||||
|             if isinstance(inp, tuple): | ||||
|                 inp = inp[0] | ||||
|             inp = inp.view(inp.size(0), -1) | ||||
|             x = (inp > 0).float() | ||||
|             K = x @ x.t() | ||||
|             K2 = (1.-x) @ (1.-x.t()) | ||||
|             model.K = model.K + K + K2 | ||||
|         except: | ||||
|             pass | ||||
|  | ||||
|     for name, module in model.named_modules(): | ||||
|         if 'ReLU' in str(type(module)): | ||||
|             module.register_forward_hook(counting_forward_hook) | ||||
|      | ||||
|     input = input.cuda() | ||||
|     model(input) | ||||
|     score = hooklogdet(model.K.cpu().numpy()) | ||||
|     del model | ||||
|     del input | ||||
|     return score | ||||
|  | ||||
| def hooklogdet(K, labels=None): | ||||
|     s, ld = np.linalg.slogdet(K) | ||||
|     return ld | ||||
|  | ||||
| # NTK | ||||
| #------------------------------------------------------------ | ||||
| #https://github.com/VITA-Group/TENAS/blob/main/lib/procedures/ntk.py | ||||
| # | ||||
| def recal_bn(network, xloader, recalbn, device): | ||||
|     for m in network.modules(): | ||||
|         if isinstance(m, torch.nn.BatchNorm2d): | ||||
|             m.running_mean.data.fill_(0) | ||||
|             m.running_var.data.fill_(0) | ||||
|             m.num_batches_tracked.data.zero_() | ||||
|             m.momentum = None | ||||
|     network.train() | ||||
|     with torch.no_grad(): | ||||
|         for i, (inputs, targets) in enumerate(xloader): | ||||
|             if i >= recalbn: break | ||||
|             inputs = inputs.cuda(device=device, non_blocking=True) | ||||
|             _, _ = network(inputs) | ||||
|     return network | ||||
|  | ||||
| def get_ntk_n(xloader, networks, recalbn=0, train_mode=False, num_batch=-1, weights=None): | ||||
|     device = torch.cuda.current_device() | ||||
|     ntks = [] | ||||
|     copied_networks = [] | ||||
|     for network in networks: | ||||
|         network = network.cuda(device=device) | ||||
|         net = copy.deepcopy(network) | ||||
|         net.proj_weights = weights | ||||
|         num_edge, num_op = net.num_edge, net.num_op | ||||
|         for i in range(num_edge): | ||||
|             net.candidate_flags[i] = False | ||||
|         if train_mode: | ||||
|             net.train() | ||||
|         else: | ||||
|             net.eval() | ||||
|         copied_networks.append(net) | ||||
|     ###### | ||||
|     grads = [[] for _ in range(len(copied_networks))] | ||||
|     for i, (inputs, targets) in enumerate(xloader): | ||||
|         if num_batch > 0 and i >= num_batch: break | ||||
|         inputs = inputs.cuda(device=device, non_blocking=True) | ||||
|         for net_idx, network in enumerate(copied_networks): | ||||
|             network.zero_grad() | ||||
|             inputs_ = inputs.clone().cuda(device=device, non_blocking=True) | ||||
|             logit = network(inputs_) | ||||
|             if isinstance(logit, tuple): | ||||
|                 logit = logit[1]  # 201 networks: return features and logits | ||||
|             for _idx in range(len(inputs_)): | ||||
|                 logit[_idx:_idx+1].backward(torch.ones_like(logit[_idx:_idx+1]), retain_graph=True) | ||||
|                 grad = [] | ||||
|                 for name, W in network.named_parameters(): | ||||
|                     if 'weight' in name and W.grad is not None: | ||||
|                         grad.append(W.grad.view(-1).detach()) | ||||
|                 grads[net_idx].append(torch.cat(grad, -1)) | ||||
|                 network.zero_grad() | ||||
|                 torch.cuda.empty_cache() | ||||
|     ###### | ||||
|     grads = [torch.stack(_grads, 0) for _grads in grads] | ||||
|     ntks = [torch.einsum('nc,mc->nm', [_grads, _grads]) for _grads in grads] | ||||
|     conds = [] | ||||
|     for ntk in ntks: | ||||
|         eigenvalues, _ = torch.symeig(ntk)  # ascending | ||||
|         conds.append(np.nan_to_num((eigenvalues[-1] / eigenvalues[0]).item(), copy=True, nan=100000.0)) | ||||
|      | ||||
|     del copied_networks | ||||
|     return conds | ||||
							
								
								
									
										182
									
								
								nasbench201/search_cells.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										182
									
								
								nasbench201/search_cells.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,182 @@ | ||||
| import math, random, torch | ||||
| import warnings | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
| from copy import deepcopy | ||||
| import sys | ||||
| sys.path.insert(0, '../') | ||||
| from nasbench201.cell_operations import OPS | ||||
|  | ||||
|  | ||||
| # This module is used for NAS-Bench-201, represents a small search space with a complete DAG | ||||
| class NAS201SearchCell(nn.Module): | ||||
|  | ||||
|   def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False, track_running_stats=True): | ||||
|     super(NAS201SearchCell, self).__init__() | ||||
|  | ||||
|     self.op_names  = deepcopy(op_names) | ||||
|     self.edges     = nn.ModuleDict() | ||||
|     self.max_nodes = max_nodes | ||||
|     self.in_dim    = C_in | ||||
|     self.out_dim   = C_out | ||||
|     for i in range(1, max_nodes): | ||||
|       for j in range(i): | ||||
|         node_str = '{:}<-{:}'.format(i, j) | ||||
|         if j == 0: | ||||
|           xlists = [OPS[op_name](C_in , C_out, stride, affine, track_running_stats) for op_name in op_names] | ||||
|         else: | ||||
|           xlists = [OPS[op_name](C_in , C_out,      1, affine, track_running_stats) for op_name in op_names] | ||||
|         self.edges[ node_str ] = nn.ModuleList( xlists ) | ||||
|     self.edge_keys  = sorted(list(self.edges.keys())) | ||||
|     self.edge2index = {key:i for i, key in enumerate(self.edge_keys)} | ||||
|     self.num_edges  = len(self.edges) | ||||
|      | ||||
|   def extra_repr(self): | ||||
|     string = 'info :: {max_nodes} nodes, inC={in_dim}, outC={out_dim}'.format(**self.__dict__) | ||||
|     return string | ||||
|  | ||||
|   def forward(self, inputs, weightss): | ||||
|     return self._forward(inputs, weightss) | ||||
|  | ||||
|   def _forward(self, inputs, weightss): | ||||
|     with torch.autograd.set_detect_anomaly(True): | ||||
|       nodes = [inputs] | ||||
|       for i in range(1, self.max_nodes): | ||||
|         inter_nodes = [] | ||||
|         for j in range(i): | ||||
|           node_str = '{:}<-{:}'.format(i, j) | ||||
|           weights  = weightss[ self.edge2index[node_str] ] | ||||
|           inter_nodes.append(sum(layer(nodes[j], block_input=True)*w if w==0 else layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights)) ) | ||||
|         nodes.append( sum(inter_nodes) ) | ||||
|       return nodes[-1] | ||||
|  | ||||
|   # GDAS | ||||
|   def forward_gdas(self, inputs, hardwts, index): | ||||
|     nodes   = [inputs] | ||||
|     for i in range(1, self.max_nodes): | ||||
|       inter_nodes = [] | ||||
|       for j in range(i): | ||||
|         node_str = '{:}<-{:}'.format(i, j) | ||||
|         weights  = hardwts[ self.edge2index[node_str] ] | ||||
|         argmaxs  = index[ self.edge2index[node_str] ].item() | ||||
|         weigsum  = sum( weights[_ie] * edge(nodes[j]) if _ie == argmaxs else weights[_ie] for _ie, edge in enumerate(self.edges[node_str]) ) | ||||
|         inter_nodes.append( weigsum ) | ||||
|       nodes.append( sum(inter_nodes) ) | ||||
|     return nodes[-1] | ||||
|  | ||||
|   # joint | ||||
|   def forward_joint(self, inputs, weightss): | ||||
|     nodes = [inputs] | ||||
|     for i in range(1, self.max_nodes): | ||||
|       inter_nodes = [] | ||||
|       for j in range(i): | ||||
|         node_str = '{:}<-{:}'.format(i, j) | ||||
|         weights  = weightss[ self.edge2index[node_str] ] | ||||
|         #aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) / weights.numel() | ||||
|         aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) | ||||
|         inter_nodes.append( aggregation ) | ||||
|       nodes.append( sum(inter_nodes) ) | ||||
|     return nodes[-1] | ||||
|  | ||||
|   # uniform random sampling per iteration, SETN | ||||
|   def forward_urs(self, inputs): | ||||
|     nodes = [inputs] | ||||
|     for i in range(1, self.max_nodes): | ||||
|       while True: # to avoid select zero for all ops | ||||
|         sops, has_non_zero = [], False | ||||
|         for j in range(i): | ||||
|           node_str   = '{:}<-{:}'.format(i, j) | ||||
|           candidates = self.edges[node_str] | ||||
|           select_op  = random.choice(candidates) | ||||
|           sops.append( select_op ) | ||||
|           if not hasattr(select_op, 'is_zero') or select_op.is_zero is False: has_non_zero=True | ||||
|         if has_non_zero: break | ||||
|       inter_nodes = [] | ||||
|       for j, select_op in enumerate(sops): | ||||
|         inter_nodes.append( select_op(nodes[j]) ) | ||||
|       nodes.append( sum(inter_nodes) ) | ||||
|     return nodes[-1] | ||||
|  | ||||
|   # select the argmax | ||||
|   def forward_select(self, inputs, weightss): | ||||
|     nodes = [inputs] | ||||
|     for i in range(1, self.max_nodes): | ||||
|       inter_nodes = [] | ||||
|       for j in range(i): | ||||
|         node_str = '{:}<-{:}'.format(i, j) | ||||
|         weights  = weightss[ self.edge2index[node_str] ] | ||||
|         inter_nodes.append( self.edges[node_str][ weights.argmax().item() ]( nodes[j] ) ) | ||||
|         #inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) ) | ||||
|       nodes.append( sum(inter_nodes) ) | ||||
|     return nodes[-1] | ||||
|  | ||||
|   # forward with a specific structure | ||||
|   def forward_dynamic(self, inputs, structure): | ||||
|     nodes = [inputs] | ||||
|     for i in range(1, self.max_nodes): | ||||
|       cur_op_node = structure.nodes[i-1] | ||||
|       inter_nodes = [] | ||||
|       for op_name, j in cur_op_node: | ||||
|         node_str = '{:}<-{:}'.format(i, j) | ||||
|         op_index = self.op_names.index( op_name ) | ||||
|         inter_nodes.append( self.edges[node_str][op_index]( nodes[j] ) ) | ||||
|       nodes.append( sum(inter_nodes) ) | ||||
|     return nodes[-1] | ||||
|  | ||||
|  | ||||
| def channel_shuffle(x, groups): | ||||
|   batchsize, num_channels, height, width = x.data.size() | ||||
|   channels_per_group = num_channels // groups | ||||
|   # reshape | ||||
|   x = x.view(batchsize, groups,  | ||||
|     channels_per_group, height, width) | ||||
|   x = torch.transpose(x, 1, 2).contiguous() | ||||
|   # flatten | ||||
|   x = x.view(batchsize, -1, height, width) | ||||
|   return x | ||||
|  | ||||
|  | ||||
| class NAS201SearchCell_PartialChannel(NAS201SearchCell): | ||||
|  | ||||
|   def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False, track_running_stats=True, k=4): | ||||
|     super(NAS201SearchCell, self).__init__() | ||||
|  | ||||
|     self.k = k | ||||
|     self.op_names  = deepcopy(op_names) | ||||
|     self.edges     = nn.ModuleDict() | ||||
|     self.max_nodes = max_nodes | ||||
|     self.in_dim    = C_in | ||||
|     self.out_dim   = C_out | ||||
|     for i in range(1, max_nodes): | ||||
|       for j in range(i): | ||||
|         node_str = '{:}<-{:}'.format(i, j) | ||||
|         if j == 0: | ||||
|           xlists = [OPS[op_name](C_in//self.k , C_out//self.k, stride, affine, track_running_stats) for op_name in op_names] | ||||
|         else: | ||||
|           xlists = [OPS[op_name](C_in//self.k , C_out//self.k,      1, affine, track_running_stats) for op_name in op_names] | ||||
|         self.edges[ node_str ] = nn.ModuleList( xlists ) | ||||
|     self.edge_keys  = sorted(list(self.edges.keys())) | ||||
|     self.edge2index = {key:i for i, key in enumerate(self.edge_keys)} | ||||
|     self.num_edges  = len(self.edges) | ||||
|    | ||||
|   def MixedOp(self, x, ops, weights): | ||||
|     dim_2 = x.shape[1] | ||||
|     xtemp = x[ : , :  dim_2//self.k, :, :] | ||||
|     xtemp2 = x[ : ,  dim_2//self.k:, :, :] | ||||
|     temp1 = sum(w * op(xtemp) for w, op in zip(weights, ops)) | ||||
|     ans = torch.cat([temp1,xtemp2],dim=1) | ||||
|     ans = channel_shuffle(ans,self.k) | ||||
|     return ans | ||||
|    | ||||
|   def forward(self, inputs, weightss): | ||||
|     nodes = [inputs] | ||||
|     for i in range(1, self.max_nodes): | ||||
|       inter_nodes = [] | ||||
|       for j in range(i): | ||||
|         node_str = '{:}<-{:}'.format(i, j) | ||||
|         weights  = weightss[ self.edge2index[node_str] ] | ||||
|         # inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) ) | ||||
|         inter_nodes.append(self.MixedOp(x=nodes[j], ops=self.edges[node_str], weights=weights)) | ||||
|       nodes.append( sum(inter_nodes) ) | ||||
|     return nodes[-1] | ||||
|  | ||||
							
								
								
									
										202
									
								
								nasbench201/search_model.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										202
									
								
								nasbench201/search_model.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,202 @@ | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
| from copy import deepcopy | ||||
| from .cell_operations import ResNetBasicblock | ||||
| from .search_cells     import NAS201SearchCell as SearchCell | ||||
| from .genotypes        import Structure | ||||
| from torch.autograd import Variable | ||||
|  | ||||
| class TinyNetwork(nn.Module): | ||||
|  | ||||
|   def __init__(self, C, N, max_nodes, num_classes, criterion, search_space, args, affine=False, track_running_stats=True, stem_channels=3): | ||||
|     super(TinyNetwork, self).__init__() | ||||
|     self._C        = C | ||||
|     self._layerN   = N | ||||
|     self.max_nodes = max_nodes | ||||
|     self._num_classes = num_classes | ||||
|     self._criterion = criterion | ||||
|     self._args = args | ||||
|     self._affine = affine | ||||
|     self._track_running_stats = track_running_stats | ||||
|     self.stem = nn.Sequential( | ||||
|                     nn.Conv2d(stem_channels, C, kernel_size=3, padding=1, bias=False), | ||||
|                     nn.BatchNorm2d(C)) | ||||
|  | ||||
|     layer_channels   = [C    ] * N + [C*2 ] + [C*2  ] * N + [C*4 ] + [C*4  ] * N     | ||||
|     layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N | ||||
|  | ||||
|     C_prev, num_edge, edge2index = C, None, None | ||||
|     self.cells = nn.ModuleList() | ||||
|     for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): | ||||
|       if reduction: | ||||
|         cell = ResNetBasicblock(C_prev, C_curr, 2) | ||||
|       else: | ||||
|         cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats) | ||||
|         if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index | ||||
|         else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) | ||||
|       self.cells.append( cell ) | ||||
|       C_prev = cell.out_dim | ||||
|     self.num_edge   = num_edge | ||||
|     self.num_op     = len(search_space) | ||||
|     self.op_names   = deepcopy( search_space ) | ||||
|     self._Layer     = len(self.cells) | ||||
|     self.edge2index = edge2index | ||||
|     self.lastact    = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) | ||||
|     self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||||
|     self.classifier = nn.Linear(C_prev, num_classes) | ||||
|     # self._arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) | ||||
|     self._arch_parameters = Variable(1e-3*torch.randn(num_edge, len(search_space)).cuda(), requires_grad=True) | ||||
|  | ||||
|     ## optimizer | ||||
|     ## 记录的是m在内存中的地址,以示区分 | ||||
|     arch_params = set(id(m) for m in self.arch_parameters()) | ||||
|     self._model_params = [m for m in self.parameters() if id(m) not in arch_params] | ||||
|  | ||||
|     # 模型参数优化器 | ||||
|     self.optimizer = torch.optim.SGD( | ||||
|         self._model_params, | ||||
|         args.learning_rate, | ||||
|         momentum=args.momentum, | ||||
|         weight_decay=args.weight_decay, | ||||
|         nesterov= args.nesterov) | ||||
|  | ||||
|  | ||||
|   def entropy_y_x(self, p_logit): | ||||
|     p = F.softmax(p_logit, dim=1) | ||||
|     return - torch.sum(p * F.log_softmax(p_logit, dim=1)) / p_logit.shape[0] | ||||
|  | ||||
|   def _loss(self, input, target, return_logits=False): | ||||
|     logits = self(input) | ||||
|     loss = self._criterion(logits, target) | ||||
|      | ||||
|     return (loss, logits) if return_logits else loss | ||||
|  | ||||
|   def get_weights(self): | ||||
|     xlist = list( self.stem.parameters() ) + list( self.cells.parameters() ) | ||||
|     xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() ) | ||||
|     xlist+= list( self.classifier.parameters() ) | ||||
|     return xlist | ||||
|  | ||||
|   def arch_parameters(self): | ||||
|     return [self._arch_parameters] | ||||
|  | ||||
|   def get_theta(self): | ||||
|     return nn.functional.softmax(self._arch_parameters, dim=-1).cpu() | ||||
|  | ||||
|   def get_message(self): | ||||
|     string = self.extra_repr() | ||||
|     for i, cell in enumerate(self.cells): | ||||
|       string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) | ||||
|     return string | ||||
|  | ||||
|   def extra_repr(self): | ||||
|     return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) | ||||
|  | ||||
|   def genotype(self): | ||||
|     genotypes = [] | ||||
|     for i in range(1, self.max_nodes): | ||||
|       xlist = [] | ||||
|       for j in range(i): | ||||
|         node_str = '{:}<-{:}'.format(i, j) | ||||
|         with torch.no_grad(): | ||||
|           weights = self._arch_parameters[ self.edge2index[node_str] ] | ||||
|           op_name = self.op_names[ weights.argmax().item() ] | ||||
|         xlist.append((op_name, j)) | ||||
|       genotypes.append( tuple(xlist) ) | ||||
|     return Structure( genotypes ) | ||||
|  | ||||
|   def forward(self, inputs, weights=None): | ||||
|     sim_nn = [] | ||||
|  | ||||
|     weights = nn.functional.softmax(self._arch_parameters, dim=-1) if weights is None else weights | ||||
|      | ||||
|     if self.slim: | ||||
|       weights[1].data.fill_(0) | ||||
|       weights[3].data.fill_(0) | ||||
|       weights[4].data.fill_(0) | ||||
|  | ||||
|     feature = self.stem(inputs) | ||||
|     for i, cell in enumerate(self.cells): | ||||
|       if isinstance(cell, SearchCell): | ||||
|         feature = cell(feature, weights) | ||||
|       else: | ||||
|         feature = cell(feature) | ||||
|  | ||||
|     out = self.lastact(feature) | ||||
|     out = self.global_pooling( out ) | ||||
|     out = out.view(out.size(0), -1) | ||||
|     logits = self.classifier(out) | ||||
|  | ||||
|     return logits | ||||
|  | ||||
|   def _save_arch_parameters(self): | ||||
|     self._saved_arch_parameters = [p.clone() for p in self._arch_parameters] | ||||
|  | ||||
|   def project_arch(self): | ||||
|     self._save_arch_parameters() | ||||
|     for p in self.arch_parameters(): | ||||
|       m, n = p.size() | ||||
|       maxIndexs = p.data.cpu().numpy().argmax(axis=1) | ||||
|       p.data = self.proximal_step(p, maxIndexs) | ||||
|  | ||||
|   def proximal_step(self, var, maxIndexs=None): | ||||
|     values = var.data.cpu().numpy() | ||||
|     m, n = values.shape | ||||
|     alphas = [] | ||||
|     for i in range(m): | ||||
|       for j in range(n): | ||||
|         if j == maxIndexs[i]: | ||||
|           alphas.append(values[i][j].copy()) | ||||
|           values[i][j] = 1 | ||||
|         else: | ||||
|           values[i][j] = 0 | ||||
|     return torch.Tensor(values).cuda() | ||||
|  | ||||
|   def restore_arch_parameters(self): | ||||
|     for i, p in enumerate(self._arch_parameters): | ||||
|       p.data.copy_(self._saved_arch_parameters[i]) | ||||
|     del self._saved_arch_parameters | ||||
|  | ||||
|   def new(self): | ||||
|     model_new = TinyNetwork(self._C, self._layerN, self.max_nodes, self._num_classes, self._criterion, | ||||
|                             self.op_names, self._args, self._affine, self._track_running_stats).cuda() | ||||
|     for x, y in zip(model_new.arch_parameters(), self.arch_parameters()): | ||||
|       x.data.copy_(y.data) | ||||
|  | ||||
|     return model_new | ||||
|  | ||||
|   def step(self, input, target, args, shared=None, return_grad=False): | ||||
|     Lt, logit_t = self._loss(input, target, return_logits=True) | ||||
|     Lt.backward() | ||||
|     if args.grad_clip != 0:  | ||||
|       nn.utils.clip_grad_norm_(self.get_weights(), args.grad_clip) | ||||
|     self.optimizer.step() | ||||
|  | ||||
|     if return_grad: | ||||
|       grad = torch.nn.utils.parameters_to_vector([p.grad for p in self.get_weights()]) | ||||
|       return logit_t, Lt, grad | ||||
|     else: | ||||
|       return logit_t, Lt | ||||
|  | ||||
|   def printing(self, logging): | ||||
|     logging.info(self.get_theta()) | ||||
|    | ||||
|   def set_arch_parameters(self, new_alphas): | ||||
|     for alpha, new_alpha in zip(self.arch_parameters(), new_alphas): | ||||
|         alpha.data.copy_(new_alpha.data) | ||||
|  | ||||
|   def save_arch_parameters(self): | ||||
|     self._saved_arch_parameters = self._arch_parameters.clone() | ||||
|    | ||||
|   def restore_arch_parameters(self): | ||||
|     self.set_arch_parameters(self._saved_arch_parameters) | ||||
|      | ||||
|   def reset_optimizer(self, lr, momentum, weight_decay): | ||||
|     del self.optimizer | ||||
|     self.optimizer = torch.optim.SGD( | ||||
|       self.get_weights(), | ||||
|       lr, | ||||
|       momentum=momentum, | ||||
|       weight_decay=weight_decay, | ||||
|       nesterov= args.nesterov) | ||||
							
								
								
									
										33
									
								
								nasbench201/search_model_darts.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								nasbench201/search_model_darts.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,33 @@ | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from .search_cells import NAS201SearchCell as SearchCell | ||||
| from .search_model import TinyNetwork as TinyNetwork | ||||
|  | ||||
|  | ||||
| class TinyNetworkDarts(TinyNetwork): | ||||
|   def __init__(self, C, N, max_nodes, num_classes, criterion, search_space, args, | ||||
|                affine=False, track_running_stats=True, stem_channels=3): | ||||
|     super(TinyNetworkDarts, self).__init__(C, N, max_nodes, num_classes, criterion, search_space, args, | ||||
|           affine=affine, track_running_stats=track_running_stats, stem_channels=stem_channels) | ||||
|  | ||||
|     self.theta_map = lambda x: torch.softmax(x, dim=-1) | ||||
|    | ||||
|   def get_theta(self): | ||||
|     return self.theta_map(self._arch_parameters).cpu() | ||||
|  | ||||
|   def forward(self, inputs): | ||||
|     weights = self.theta_map(self._arch_parameters) | ||||
|     feature = self.stem(inputs) | ||||
|  | ||||
|     for i, cell in enumerate(self.cells): | ||||
|       if isinstance(cell, SearchCell): | ||||
|         feature = cell(feature, weights) | ||||
|       else: | ||||
|         feature = cell(feature) | ||||
|  | ||||
|     out = self.lastact(feature) | ||||
|     out = self.global_pooling( out ) | ||||
|     out = out.view(out.size(0), -1) | ||||
|     logits = self.classifier(out) | ||||
|  | ||||
|     return logits | ||||
							
								
								
									
										80
									
								
								nasbench201/search_model_darts_proj.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										80
									
								
								nasbench201/search_model_darts_proj.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,80 @@ | ||||
| import torch | ||||
| from .search_cells import NAS201SearchCell as SearchCell | ||||
| from .search_model import TinyNetwork as TinyNetwork | ||||
| from .genotypes        import Structure | ||||
| from torch.autograd import Variable | ||||
|  | ||||
| class TinyNetworkDartsProj(TinyNetwork): | ||||
|   def __init__(self, C, N, max_nodes, num_classes, criterion, search_space, args, | ||||
|                affine=False, track_running_stats=True, stem_channels=3): | ||||
|     super(TinyNetworkDartsProj, self).__init__(C, N, max_nodes, num_classes, criterion, search_space, args, | ||||
|           affine=affine, track_running_stats=track_running_stats, stem_channels=stem_channels) | ||||
|     self.theta_map = lambda x: torch.softmax(x, dim=-1) | ||||
|  | ||||
|     #### for edgewise projection | ||||
|     self.candidate_flags = torch.tensor(len(self._arch_parameters) * [True], requires_grad=False, dtype=torch.bool).cuda() | ||||
|     self.proj_weights = torch.zeros_like(self._arch_parameters) | ||||
|  | ||||
|   def project_op(self, eid, opid): | ||||
|       self.proj_weights[eid][opid] = 1 ## hard by default | ||||
|       self.candidate_flags[eid] = False | ||||
|  | ||||
|   def get_projected_weights(self): | ||||
|       weights = self.theta_map(self._arch_parameters) | ||||
|  | ||||
|       ## proj op | ||||
|       for eid in range(len(self._arch_parameters)): | ||||
|         if not self.candidate_flags[eid]: | ||||
|           weights[eid].data.copy_(self.proj_weights[eid]) | ||||
|  | ||||
|       return weights | ||||
|  | ||||
|   def forward(self, inputs, weights=None): | ||||
|     with torch.autograd.set_detect_anomaly(True): | ||||
|       if weights is None: | ||||
|         weights = self.get_projected_weights() | ||||
|  | ||||
|       feature = self.stem(inputs) | ||||
|       for i, cell in enumerate(self.cells): | ||||
|         if isinstance(cell, SearchCell): | ||||
|           feature = cell(feature, weights) | ||||
|         else: | ||||
|           feature = cell(feature) | ||||
|  | ||||
|       out = self.lastact(feature) | ||||
|       out = self.global_pooling( out ) | ||||
|       out = out.view(out.size(0), -1) | ||||
|       logits = self.classifier(out) | ||||
|  | ||||
|       return logits | ||||
|  | ||||
|   #### utils | ||||
|   def get_theta(self): | ||||
|     return self.get_projected_weights() | ||||
|  | ||||
|   def arch_parameters(self): | ||||
|     return [self._arch_parameters] | ||||
|  | ||||
|   def set_arch_parameters(self, new_alphas): | ||||
|     for eid, alpha in enumerate(self.arch_parameters()): | ||||
|       alpha.data.copy_(new_alphas[eid]) | ||||
|    | ||||
|   def reset_arch_parameters(self): | ||||
|     self._arch_parameters = Variable(1e-3*torch.randn(self.num_edge, len(self.op_names)).cuda(), requires_grad=True) | ||||
|     self.candidate_flags = torch.tensor(len(self._arch_parameters) * [True], requires_grad=False, dtype=torch.bool).cuda() | ||||
|     self.proj_weights = torch.zeros_like(self._arch_parameters) | ||||
|    | ||||
|   def genotype(self): | ||||
|     proj_weights = self.get_projected_weights() | ||||
|  | ||||
|     genotypes = [] | ||||
|     for i in range(1, self.max_nodes): | ||||
|       xlist = [] | ||||
|       for j in range(i): | ||||
|         node_str = '{:}<-{:}'.format(i, j) | ||||
|         with torch.no_grad(): | ||||
|           weights = proj_weights[ self.edge2index[node_str] ] | ||||
|           op_name = self.op_names[ weights.argmax().item() ] | ||||
|         xlist.append((op_name, j)) | ||||
|       genotypes.append( tuple(xlist) ) | ||||
|     return Structure( genotypes ) | ||||
							
								
								
									
										494
									
								
								nasbench201/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										494
									
								
								nasbench201/utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,494 @@ | ||||
| from __future__ import print_function | ||||
|  | ||||
| import numpy as np | ||||
| import os | ||||
| import os.path | ||||
| import sys | ||||
| import shutil | ||||
| import torch | ||||
| import torchvision.transforms as transforms | ||||
|  | ||||
| from PIL import Image | ||||
| from torch.autograd import Variable | ||||
| from torchvision.datasets import VisionDataset | ||||
| from torchvision.datasets import utils | ||||
|  | ||||
| if sys.version_info[0] == 2: | ||||
|     import cPickle as pickle | ||||
| else: | ||||
|     import pickle | ||||
|  | ||||
|  | ||||
| class AvgrageMeter(object): | ||||
|  | ||||
|     def __init__(self): | ||||
|         self.reset() | ||||
|  | ||||
|     def reset(self): | ||||
|         self.avg = 0 | ||||
|         self.sum = 0 | ||||
|         self.cnt = 0 | ||||
|  | ||||
|     def update(self, val, n=1): | ||||
|         self.sum += val * n | ||||
|         self.cnt += n | ||||
|         self.avg = self.sum / self.cnt | ||||
|  | ||||
|  | ||||
| def accuracy(output, target, topk=(1,)): | ||||
|     maxk = max(topk) | ||||
|     batch_size = target.size(0) | ||||
|  | ||||
|     _, pred = output.topk(maxk, 1, True, True) | ||||
|     pred = pred.t() | ||||
|     correct = pred.eq(target.view(1, -1).expand_as(pred)) | ||||
|  | ||||
|     res = [] | ||||
|     for k in topk:         | ||||
|         correct_k = correct[:k].contiguous().view(-1).float().sum(0) | ||||
|         res.append(correct_k.mul_(100.0 / batch_size)) | ||||
|     return res | ||||
|  | ||||
|  | ||||
| class Cutout(object): | ||||
|     def __init__(self, length, prob=1.0): | ||||
|         self.length = length | ||||
|         self.prob = prob | ||||
|  | ||||
|     def __call__(self, img): | ||||
|         if np.random.binomial(1, self.prob): | ||||
|             h, w = img.size(1), img.size(2) | ||||
|             mask = np.ones((h, w), np.float32) | ||||
|             y = np.random.randint(h) | ||||
|             x = np.random.randint(w) | ||||
|  | ||||
|             y1 = np.clip(y - self.length // 2, 0, h) | ||||
|             y2 = np.clip(y + self.length // 2, 0, h) | ||||
|             x1 = np.clip(x - self.length // 2, 0, w) | ||||
|             x2 = np.clip(x + self.length // 2, 0, w) | ||||
|  | ||||
|             mask[y1: y2, x1: x2] = 0. | ||||
|             mask = torch.from_numpy(mask) | ||||
|             mask = mask.expand_as(img) | ||||
|             img *= mask | ||||
|         return img | ||||
|  | ||||
| def _data_transforms_svhn(args): | ||||
|     SVHN_MEAN = [0.4377, 0.4438, 0.4728] | ||||
|     SVHN_STD = [0.1980, 0.2010, 0.1970] | ||||
|  | ||||
|     train_transform = transforms.Compose([ | ||||
|         transforms.RandomCrop(32, padding=4), | ||||
|         transforms.RandomHorizontalFlip(), | ||||
|         transforms.ToTensor(), | ||||
|         transforms.Normalize(SVHN_MEAN, SVHN_STD), | ||||
|     ]) | ||||
|     if args.cutout: | ||||
|         train_transform.transforms.append(Cutout(args.cutout_length, | ||||
|                                           args.cutout_prob)) | ||||
|  | ||||
|     valid_transform = transforms.Compose([ | ||||
|         transforms.ToTensor(), | ||||
|         transforms.Normalize(SVHN_MEAN, SVHN_STD), | ||||
|         ]) | ||||
|     return train_transform, valid_transform | ||||
|  | ||||
|  | ||||
| def _data_transforms_cifar100(args): | ||||
|     CIFAR_MEAN = [0.5071, 0.4865, 0.4409] | ||||
|     CIFAR_STD = [0.2673, 0.2564, 0.2762] | ||||
|  | ||||
|     train_transform = transforms.Compose([ | ||||
|         transforms.RandomCrop(32, padding=4), | ||||
|         transforms.RandomHorizontalFlip(), | ||||
|         transforms.ToTensor(), | ||||
|         transforms.Normalize(CIFAR_MEAN, CIFAR_STD), | ||||
|     ]) | ||||
|     if args.cutout: | ||||
|         train_transform.transforms.append(Cutout(args.cutout_length, | ||||
|                                           args.cutout_prob)) | ||||
|  | ||||
|     valid_transform = transforms.Compose([ | ||||
|         transforms.ToTensor(), | ||||
|         transforms.Normalize(CIFAR_MEAN, CIFAR_STD), | ||||
|         ]) | ||||
|     return train_transform, valid_transform | ||||
|  | ||||
|  | ||||
| def _data_transforms_cifar10(args): | ||||
|     CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124] | ||||
|     CIFAR_STD = [0.24703233, 0.24348505, 0.26158768] | ||||
|  | ||||
|     train_transform = transforms.Compose([ | ||||
|         transforms.RandomCrop(32, padding=4), | ||||
|         transforms.RandomHorizontalFlip(), | ||||
|         transforms.ToTensor(), | ||||
|         transforms.Normalize(CIFAR_MEAN, CIFAR_STD), | ||||
|     ]) | ||||
|     if args.cutout: | ||||
|         train_transform.transforms.append(Cutout(args.cutout_length, | ||||
|                                                  args.cutout_prob)) | ||||
|  | ||||
|     valid_transform = transforms.Compose([ | ||||
|         transforms.ToTensor(), | ||||
|         transforms.Normalize(CIFAR_MEAN, CIFAR_STD), | ||||
|     ]) | ||||
|     return train_transform, valid_transform | ||||
|  | ||||
|  | ||||
| def count_parameters_in_MB(model): | ||||
|     return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name) / 1e6 | ||||
|  | ||||
|  | ||||
| def count_parameters_in_Compact(model): | ||||
|     from sota.cnn.model import Network as CompactModel | ||||
|     genotype = model.genotype() | ||||
|     compact_model = CompactModel(36, model._num_classes, 20, True, genotype) | ||||
|     num_params = count_parameters_in_MB(compact_model) | ||||
|     return num_params | ||||
|  | ||||
|  | ||||
| def save_checkpoint(state, is_best, save, per_epoch=False, prefix=''): | ||||
|     filename = prefix | ||||
|     if per_epoch: | ||||
|         epoch = state['epoch'] | ||||
|         filename += 'checkpoint_{}.pth.tar'.format(epoch) | ||||
|     else: | ||||
|         filename += 'checkpoint.pth.tar' | ||||
|     filename = os.path.join(save, filename) | ||||
|     torch.save(state, filename) | ||||
|     if is_best: | ||||
|         best_filename = os.path.join(save, 'model_best.pth.tar') | ||||
|         shutil.copyfile(filename, best_filename) | ||||
|  | ||||
|  | ||||
| def load_checkpoint(model, optimizer, save, epoch=None): | ||||
|     if epoch is None: | ||||
|         filename = 'checkpoint.pth.tar' | ||||
|     else: | ||||
|         filename = 'checkpoint_{}.pth.tar'.format(epoch) | ||||
|     filename = os.path.join(save, filename) | ||||
|     start_epoch = 0 | ||||
|     if os.path.isfile(filename): | ||||
|         print("=> loading checkpoint '{}'".format(filename)) | ||||
|         checkpoint = torch.load(filename) | ||||
|         start_epoch = checkpoint['epoch'] | ||||
|         best_acc_top1 = checkpoint['best_acc_top1'] | ||||
|         model.load_state_dict(checkpoint['state_dict']) | ||||
|         optimizer.load_state_dict(checkpoint['optimizer']) | ||||
|         print("=> loaded checkpoint '{}' (epoch {})" | ||||
|               .format(filename, checkpoint['epoch'])) | ||||
|     else: | ||||
|         print("=> no checkpoint found at '{}'".format(filename)) | ||||
|      | ||||
|     return model, optimizer, start_epoch, best_acc_top1 | ||||
|  | ||||
|  | ||||
| def save(model, model_path): | ||||
|     torch.save(model.state_dict(), model_path) | ||||
|  | ||||
|  | ||||
| def load(model, model_path): | ||||
|     model.load_state_dict(torch.load(model_path)) | ||||
|  | ||||
|  | ||||
| def drop_path(x, drop_prob): | ||||
|     if drop_prob > 0.: | ||||
|         keep_prob = 1. - drop_prob | ||||
|         mask = Variable(torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob)) | ||||
|         x.div_(keep_prob) | ||||
|         x.mul_(mask) | ||||
|     return x | ||||
|  | ||||
|  | ||||
| def create_exp_dir(path, scripts_to_save=None): | ||||
|     if not os.path.exists(path): | ||||
|         os.makedirs(path) | ||||
|     print('Experiment dir : {}'.format(path)) | ||||
|  | ||||
|     if scripts_to_save is not None: | ||||
|         os.mkdir(os.path.join(path, 'scripts')) | ||||
|         for script in scripts_to_save: | ||||
|             dst_file = os.path.join(path, 'scripts', os.path.basename(script)) | ||||
|             shutil.copyfile(script, dst_file) | ||||
|  | ||||
|  | ||||
| class CIFAR10(VisionDataset): | ||||
|     """`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset. | ||||
|  | ||||
|     Args: | ||||
|         root (string): Root directory of dataset where directory | ||||
|             ``cifar-10-batches-py`` exists or will be saved to if download is set to True. | ||||
|         train (bool, optional): If True, creates dataset from training set, otherwise | ||||
|             creates from test set. | ||||
|         transform (callable, optional): A function/transform that takes in an PIL image | ||||
|             and returns a transformed version. E.g, ``transforms.RandomCrop`` | ||||
|         target_transform (callable, optional): A function/transform that takes in the | ||||
|             target and transforms it. | ||||
|         download (bool, optional): If true, downloads the dataset from the internet and | ||||
|             puts it in root directory. If dataset is already downloaded, it is not | ||||
|             downloaded again. | ||||
|  | ||||
|     """ | ||||
|     base_folder = 'cifar-10-batches-py' | ||||
|     url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" | ||||
|     filename = "cifar-10-python.tar.gz" | ||||
|     tgz_md5 = 'c58f30108f718f92721af3b95e74349a' | ||||
|     train_list = [ | ||||
|         ['data_batch_1', 'c99cafc152244af753f735de768cd75f'], | ||||
|         ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], | ||||
|         ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], | ||||
|         ['data_batch_4', '634d18415352ddfa80567beed471001a'], | ||||
|         #['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'], | ||||
|     ] | ||||
|  | ||||
|     test_list = [ | ||||
|         ['test_batch', '40351d587109b95175f43aff81a1287e'], | ||||
|     ] | ||||
|     meta = { | ||||
|         'filename': 'batches.meta', | ||||
|         'key': 'label_names', | ||||
|         'md5': '5ff9c542aee3614f3951f8cda6e48888', | ||||
|     } | ||||
|  | ||||
|     def __init__(self, root, train=True, transform=None, target_transform=None, | ||||
|                  download=False): | ||||
|  | ||||
|         super(CIFAR10, self).__init__(root, transform=transform, | ||||
|                                       target_transform=target_transform) | ||||
|  | ||||
|         self.train = train  # training set or test set | ||||
|  | ||||
|         if download: | ||||
|             self.download() | ||||
|  | ||||
|         if not self._check_integrity(): | ||||
|             raise RuntimeError('Dataset not found or corrupted.' + | ||||
|                                ' You can use download=True to download it') | ||||
|  | ||||
|         if self.train: | ||||
|             downloaded_list = self.train_list | ||||
|         else: | ||||
|             downloaded_list = self.test_list | ||||
|  | ||||
|         self.data = [] | ||||
|         self.targets = [] | ||||
|  | ||||
|         # now load the picked numpy arrays | ||||
|         for file_name, checksum in downloaded_list: | ||||
|             file_path = os.path.join(self.root, self.base_folder, file_name) | ||||
|             with open(file_path, 'rb') as f: | ||||
|                 if sys.version_info[0] == 2: | ||||
|                     entry = pickle.load(f) | ||||
|                 else: | ||||
|                     entry = pickle.load(f, encoding='latin1') | ||||
|                 self.data.append(entry['data']) | ||||
|                 if 'labels' in entry: | ||||
|                     self.targets.extend(entry['labels']) | ||||
|                 else: | ||||
|                     self.targets.extend(entry['fine_labels']) | ||||
|  | ||||
|         self.data = np.vstack(self.data).reshape(-1, 3, 32, 32) | ||||
|         self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWC | ||||
|  | ||||
|         self._load_meta() | ||||
|  | ||||
|     def _load_meta(self): | ||||
|         path = os.path.join(self.root, self.base_folder, self.meta['filename']) | ||||
|         if not utils.check_integrity(path, self.meta['md5']): | ||||
|             raise RuntimeError('Dataset metadata file not found or corrupted.' + | ||||
|                                ' You can use download=True to download it') | ||||
|         with open(path, 'rb') as infile: | ||||
|             if sys.version_info[0] == 2: | ||||
|                 data = pickle.load(infile) | ||||
|             else: | ||||
|                 data = pickle.load(infile, encoding='latin1') | ||||
|             self.classes = data[self.meta['key']] | ||||
|         self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)} | ||||
|  | ||||
|     def __getitem__(self, index): | ||||
|         """ | ||||
|         Args: | ||||
|             index (int): Index | ||||
|  | ||||
|         Returns: | ||||
|             tuple: (image, target) where target is index of the target class. | ||||
|         """ | ||||
|         img, target = self.data[index], self.targets[index] | ||||
|  | ||||
|         # doing this so that it is consistent with all other datasets | ||||
|         # to return a PIL Image | ||||
|         img = Image.fromarray(img) | ||||
|  | ||||
|         if self.transform is not None: | ||||
|             img = self.transform(img) | ||||
|  | ||||
|         if self.target_transform is not None: | ||||
|             target = self.target_transform(target) | ||||
|  | ||||
|         return img, target | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self.data) | ||||
|  | ||||
|     def _check_integrity(self): | ||||
|         root = self.root | ||||
|         for fentry in (self.train_list + self.test_list): | ||||
|             filename, md5 = fentry[0], fentry[1] | ||||
|             fpath = os.path.join(root, self.base_folder, filename) | ||||
|             if not utils.check_integrity(fpath, md5): | ||||
|                 return False | ||||
|         return True | ||||
|  | ||||
|     def download(self): | ||||
|         if self._check_integrity(): | ||||
|             print('Files already downloaded and verified') | ||||
|             return | ||||
|         utils.download_and_extract_archive(self.url, self.root, | ||||
|                                            filename=self.filename, | ||||
|                                            md5=self.tgz_md5) | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return "Split: {}".format("Train" if self.train is True else "Test") | ||||
|  | ||||
|  | ||||
| def pick_gpu_lowest_memory(): | ||||
|     import gpustat | ||||
|     stats = gpustat.GPUStatCollection.new_query() | ||||
|     ids = map(lambda gpu: int(gpu.entry['index']), stats) | ||||
|     ratios = map(lambda gpu: float(gpu.memory_used)/float(gpu.memory_total), stats) | ||||
|     bestGPU = min(zip(ids, ratios), key=lambda x: x[1])[0] | ||||
|     return bestGPU | ||||
|  | ||||
|  | ||||
| #### early stopping (from RobustNAS) | ||||
| class EVLocalAvg(object): | ||||
|     def __init__(self, window=5, ev_freq=2, total_epochs=50): | ||||
|         """ Keep track of the eigenvalues local average. | ||||
|         Args: | ||||
|             window (int): number of elements used to compute local average. | ||||
|                 Default: 5 | ||||
|             ev_freq (int): frequency used to compute eigenvalues. Default: | ||||
|                 every 2 epochs | ||||
|             total_epochs (int): total number of epochs that DARTS runs. | ||||
|                 Default: 50 | ||||
|         """ | ||||
|         self.window = window | ||||
|         self.ev_freq = ev_freq | ||||
|         self.epochs = total_epochs | ||||
|  | ||||
|         self.stop_search = False | ||||
|         self.stop_epoch = total_epochs - 1 | ||||
|         self.stop_genotype = None | ||||
|         self.stop_numparam = 0 | ||||
|  | ||||
|         self.ev = [] | ||||
|         self.ev_local_avg = [] | ||||
|         self.genotypes = {} | ||||
|         self.numparams = {} | ||||
|         self.la_epochs = {} | ||||
|  | ||||
|         # start and end index of the local average window | ||||
|         self.la_start_idx = 0 | ||||
|         self.la_end_idx = self.window | ||||
|  | ||||
|     def reset(self): | ||||
|         self.ev = [] | ||||
|         self.ev_local_avg = [] | ||||
|         self.genotypes = {} | ||||
|         self.numparams = {} | ||||
|         self.la_epochs = {} | ||||
|  | ||||
|     def update(self, epoch, ev, genotype, numparam=0): | ||||
|         """ Method to update the local average list. | ||||
|  | ||||
|         Args: | ||||
|             epoch (int): current epoch | ||||
|             ev (float): current dominant eigenvalue | ||||
|             genotype (namedtuple): current genotype | ||||
|  | ||||
|         """ | ||||
|         self.ev.append(ev) | ||||
|         self.genotypes.update({epoch: genotype}) | ||||
|         self.numparams.update({epoch: numparam}) | ||||
|         # set the stop_genotype to the current genotype in case the early stop | ||||
|         # procedure decides not to early stop | ||||
|         self.stop_genotype = genotype | ||||
|  | ||||
|         # since the local average computation starts after the dominant | ||||
|         # eigenvalue in the first epoch is already computed we have to wait | ||||
|         # at least until we have 3 eigenvalues in the list. | ||||
|         if (len(self.ev) >= int(np.ceil(self.window/2))) and (epoch < | ||||
|                                                               self.epochs - 1): | ||||
|             # start sliding the window as soon as the number of eigenvalues in | ||||
|             # the list becomes equal to the window size | ||||
|             if len(self.ev) < self.window: | ||||
|                 self.ev_local_avg.append(np.mean(self.ev)) | ||||
|             else: | ||||
|                 assert len(self.ev[self.la_start_idx: self.la_end_idx]) == self.window | ||||
|                 self.ev_local_avg.append(np.mean(self.ev[self.la_start_idx: | ||||
|                                                          self.la_end_idx])) | ||||
|                 self.la_start_idx += 1 | ||||
|                 self.la_end_idx += 1 | ||||
|  | ||||
|             # keep track of the offset between the current epoch and the epoch | ||||
|             # corresponding to the local average. NOTE: in the end the size of | ||||
|             # self.ev and self.ev_local_avg should be equal | ||||
|             self.la_epochs.update({epoch: int(epoch - | ||||
|                                               int(self.ev_freq*np.floor(self.window/2)))}) | ||||
|  | ||||
|         elif len(self.ev) < int(np.ceil(self.window/2)): | ||||
|           self.la_epochs.update({epoch: -1}) | ||||
|  | ||||
|         # since there is an offset between the current epoch and the local | ||||
|         # average epoch, loop in the last epoch to compute the local average of | ||||
|         # these number of elements: window, window - 1, window - 2, ..., ceil(window/2) | ||||
|         elif epoch == self.epochs - 1: | ||||
|             for i in range(int(np.ceil(self.window/2))): | ||||
|                 assert len(self.ev[self.la_start_idx: self.la_end_idx]) == self.window - i | ||||
|                 self.ev_local_avg.append(np.mean(self.ev[self.la_start_idx: | ||||
|                                                          self.la_end_idx + 1])) | ||||
|                 self.la_start_idx += 1 | ||||
|  | ||||
|     def early_stop(self, epoch, factor=1.3, es_start_epoch=10, delta=4, criteria='local_avg'): | ||||
|         """ Early stopping criterion | ||||
|  | ||||
|         Args: | ||||
|             epoch (int): current epoch | ||||
|             factor (float): threshold factor for the ration between the current | ||||
|                 and prefious eigenvalue. Default: 1.3 | ||||
|             es_start_epoch (int): until this epoch do not consider early | ||||
|                 stopping. Default: 20 | ||||
|             delta (int): factor influencing which previous local average we | ||||
|                 consider for early stopping. Default: 2 | ||||
|         """ | ||||
|         if criteria == 'local_avg': | ||||
|             if int(self.la_epochs[epoch] - self.ev_freq*delta) >= es_start_epoch: | ||||
|                 if criteria == 'local_avg': | ||||
|                     current_la = self.ev_local_avg[-1] | ||||
|                     previous_la = self.ev_local_avg[-1 - delta] | ||||
|                     self.stop_search = current_la / previous_la > factor | ||||
|                     if self.stop_search: | ||||
|                         self.stop_epoch = int(self.la_epochs[epoch] - self.ev_freq*delta) | ||||
|                         self.stop_genotype = self.genotypes[self.stop_epoch] | ||||
|                         self.stop_numparam = self.numparams[self.stop_epoch] | ||||
|         elif criteria == 'exact': | ||||
|             if epoch > es_start_epoch: | ||||
|                 current_la = self.ev[-1] | ||||
|                 previous_la = self.ev[-1 - delta] | ||||
|                 self.stop_search = current_la / previous_la > factor | ||||
|                 if self.stop_search: | ||||
|                     self.stop_epoch = epoch - delta | ||||
|                     self.stop_genotype = self.genotypes[self.stop_epoch] | ||||
|                     self.stop_numparam = self.numparams[self.stop_epoch] | ||||
|         else: | ||||
|             print('ERROR IN EARLY STOP: WRONG CRITERIA:', criteria); exit(0) | ||||
|  | ||||
|  | ||||
| def gen_comb(eids): | ||||
|     comb = [] | ||||
|     for r in range(len(eids)): | ||||
|         for c in range(r + 1, len(eids)): | ||||
|             comb.append((eids[r], eids[c])) | ||||
|  | ||||
|     return comb | ||||
		Reference in New Issue
	
	Block a user