update API
This commit is contained in:
		| @@ -170,10 +170,28 @@ class NASBench201API(object): | ||||
|     return archresult.get_comput_costs(dataset) | ||||
|  | ||||
|   # obtain the metric for the `index`-th architecture | ||||
|   # `dataset` indicates the dataset: | ||||
|   #   'cifar10-valid'  : using the proposed train set of CIFAR-10 as the training set | ||||
|   #   'cifar10'        : using the proposed train+valid set of CIFAR-10 as the training set | ||||
|   #   'cifar100'       : using the proposed train set of CIFAR-100 as the training set | ||||
|   #   'ImageNet16-120' : using the proposed train set of ImageNet-16-120 as the training set | ||||
|   # `iepoch` indicates the index of training epochs from 0 to 11/199. | ||||
|   #   When iepoch=None, it will return the metric for the last training epoch | ||||
|   #   When iepoch=11, it will return the metric for the 11-th training epoch (starting from 0) | ||||
|   # `use_12epochs_result` indicates different hyper-parameters for training | ||||
|   #   When use_12epochs_result=True, it trains the network with 12 epochs and the LR decayed from 0.1 to 0 within 12 epochs | ||||
|   #   When use_12epochs_result=False, it trains the network with 200 epochs and the LR decayed from 0.1 to 0 within 200 epochs | ||||
|   # `is_random` | ||||
|   #   When is_random=True, the performance of a random architecture will be returned | ||||
|   #   When is_random=False, the performanceo of all trials will be averaged. | ||||
|   def get_more_info(self, index, dataset, iepoch=None, use_12epochs_result=False, is_random=True): | ||||
|     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less | ||||
|     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full | ||||
|     archresult = arch2infos[index] | ||||
|     # if randomly select one trial, select the seed at first | ||||
|     if isinstance(is_random, bool) and is_random: | ||||
|       seeds = archresult.get_dataset_seeds(dataset) | ||||
|       is_random = random.choice(seeds) | ||||
|     if dataset == 'cifar10-valid': | ||||
|       train_info = archresult.get_metrics(dataset, 'train'   , iepoch=iepoch, is_random=is_random) | ||||
|       valid_info = archresult.get_metrics(dataset, 'x-valid' , iepoch=iepoch, is_random=is_random) | ||||
| @@ -202,7 +220,7 @@ class NASBench201API(object): | ||||
|         else: | ||||
|           test__info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random) | ||||
|       except: | ||||
|         valid_info = None | ||||
|         test__info = None | ||||
|       try: | ||||
|         valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random) | ||||
|       except: | ||||
| @@ -213,7 +231,7 @@ class NASBench201API(object): | ||||
|         est_valid_info = None | ||||
|       xifo = {'train-loss'    : train_info['loss'], | ||||
|               'train-accuracy': train_info['accuracy']} | ||||
|       if valid_info is not None: | ||||
|       if test__info is not None: | ||||
|         xifo['test-loss'] = test__info['loss'], | ||||
|         xifo['test-accuracy'] = test__info['accuracy'] | ||||
|       if valid_info is not None: | ||||
| @@ -347,14 +365,20 @@ class ArchResults(object): | ||||
|         info = result.get_eval(setname, iepoch) | ||||
|       for key, value in info.items(): infos[key].append( value ) | ||||
|     return_info = dict() | ||||
|     if is_random: | ||||
|     if isinstance(is_random, bool) and is_random: # randomly select one | ||||
|       index = random.randint(0, len(results)-1) | ||||
|       for key, value in infos.items(): return_info[key] = value[index] | ||||
|     else: | ||||
|     elif isinstance(is_random, bool) and not is_random: # average | ||||
|       for key, value in infos.items(): | ||||
|         if len(value) > 0 and value[0] is not None: | ||||
|           return_info[key] = np.mean(value) | ||||
|         else: return_info[key] = None | ||||
|     elif isinstance(is_random, int): # specify the seed | ||||
|       if is_random not in x_seeds: raise ValueError('can not find random seed ({:}) from {:}'.format(is_random, x_seeds)) | ||||
|       index = x_seeds.index(is_random) | ||||
|       for key, value in infos.items(): return_info[key] = value[index] | ||||
|     else: | ||||
|       raise ValueError('invalid value for is_random: {:}'.format(is_random)) | ||||
|     return return_info | ||||
|  | ||||
|   def show(self, is_print=False): | ||||
| @@ -363,6 +387,9 @@ class ArchResults(object): | ||||
|   def get_dataset_names(self): | ||||
|     return list(self.dataset_seed.keys()) | ||||
|  | ||||
|   def get_dataset_seeds(self, dataset): | ||||
|     return copy.deepcopy( self.dataset_seed[dataset] ) | ||||
|  | ||||
|   def get_net_param(self, dataset, seed=None): | ||||
|     if seed is None: | ||||
|       x_seeds = self.dataset_seed[dataset] | ||||
|   | ||||
		Reference in New Issue
	
	Block a user