fix NAS-Bench-201 comments and add more
This commit is contained in:
		| @@ -292,6 +292,11 @@ class NASBench201API(object): | ||||
|         xifo['est-valid-accuracy'] = est_valid_info['accuracy'] | ||||
|       return xifo | ||||
|  | ||||
|   """ | ||||
|   This function will print the information of a specific (or all) architecture(s). | ||||
|   If the index < 0: it will loop for all architectures and print their information one by one. | ||||
|   else: it will print the information of the 'index'-th archiitecture. | ||||
|   """ | ||||
|   def show(self, index=-1): | ||||
|     if index < 0: # show all architectures | ||||
|       print(self) | ||||
| @@ -299,10 +304,10 @@ class NASBench201API(object): | ||||
|         print('\n' + '-' * 10 + ' The ({:5d}/{:5d}) {:06d}-th architecture! '.format(i, len(self.evaluated_indexes), idx) + '-'*10) | ||||
|         print('arch : {:}'.format(self.meta_archs[idx])) | ||||
|         strings = print_information(self.arch2infos_full[idx]) | ||||
|         print('>' * 40 + ' 200 epochs ' + '>' * 40) | ||||
|         print('>' * 40 + ' {:03d} epochs '.format(self.arch2infos_full[idx].get_total_epoch()) + '>' * 40) | ||||
|         print('\n'.join(strings)) | ||||
|         strings = print_information(self.arch2infos_less[idx]) | ||||
|         print('>' * 40 + '  12 epochs ' + '>' * 40) | ||||
|         print('>' * 40 + ' {:03d} epochs '.format(self.arch2infos_less[idx].get_total_epoch()) + '>' * 40) | ||||
|         print('\n'.join(strings)) | ||||
|         print('<' * 40 + '------------' + '<' * 40) | ||||
|     else: | ||||
| @@ -310,10 +315,10 @@ class NASBench201API(object): | ||||
|         if index not in self.evaluated_indexes: print('The {:}-th architecture has not been evaluated or not saved.'.format(index)) | ||||
|         else: | ||||
|           strings = print_information(self.arch2infos_full[index]) | ||||
|           print('>' * 40 + ' 200 epochs ' + '>' * 40) | ||||
|           print('>' * 40 + ' {:03d} epochs '.format(self.arch2infos_full[index].get_total_epoch()) + '>' * 40) | ||||
|           print('\n'.join(strings)) | ||||
|           strings = print_information(self.arch2infos_less[index]) | ||||
|           print('>' * 40 + '  12 epochs ' + '>' * 40) | ||||
|           print('>' * 40 + ' {:03d} epochs '.format(self.arch2infos_less[index].get_total_epoch()) + '>' * 40) | ||||
|           print('\n'.join(strings)) | ||||
|           print('<' * 40 + '------------' + '<' * 40) | ||||
|       else: | ||||
| @@ -419,7 +424,7 @@ class ArchResults(object): | ||||
|     -- When dataset = cifar10-valid, you can use 'train', 'x-valid', 'ori-test' | ||||
|     ------ 'train' : the metric on the training set. | ||||
|     ------ 'x-valid' : the metric on the validation set. | ||||
|     ------ 'ori-test' : the metric on the validation + test set. | ||||
|     ------ 'ori-test' : the metric on the test set. | ||||
|     -- When dataset = cifar10, you can use 'train', 'ori-test'. | ||||
|     ------ 'train' : the metric on the training + validation set. | ||||
|     ------ 'ori-test' : the metric on the test set. | ||||
| @@ -472,6 +477,11 @@ class ArchResults(object): | ||||
|   def get_dataset_seeds(self, dataset): | ||||
|     return copy.deepcopy( self.dataset_seed[dataset] ) | ||||
|  | ||||
|   """ | ||||
|   This function will return the trained network's weights on the 'dataset'. | ||||
|   When the 'seed' is None, it will return the weights for every run trial in the form of a dict. | ||||
|   When the  | ||||
|   """ | ||||
|   def get_net_param(self, dataset, seed=None): | ||||
|     if seed is None: | ||||
|       x_seeds = self.dataset_seed[dataset] | ||||
| @@ -479,6 +489,21 @@ class ArchResults(object): | ||||
|     else: | ||||
|       return self.all_results[(dataset, seed)].get_net_param() | ||||
|  | ||||
|   # get the total number of training epochs | ||||
|   def get_total_epoch(self, dataset=None): | ||||
|     if dataset is None: | ||||
|       epochss = [] | ||||
|       for xdata, x_seeds in self.dataset_seed.items(): | ||||
|         epochss += [self.all_results[(xdata, seed)].get_total_epoch() for seed in x_seeds] | ||||
|     elif isinstance(dataset, str): | ||||
|       x_seeds = self.dataset_seed[dataset] | ||||
|       epochss = [self.all_results[(dataset, seed)].get_total_epoch() for seed in x_seeds] | ||||
|     else: | ||||
|       raise ValueError('invalid dataset={:}'.format(dataset)) | ||||
|     if len(set(epochss)) > 1: raise ValueError('Each trial mush have the same number of training epochs : {:}'.format(epochss)) | ||||
|     return epochss[-1] | ||||
|  | ||||
|   # return the ResultsCount object (containing all information of a single trial) for 'dataset' and 'seed' | ||||
|   def query(self, dataset, seed=None): | ||||
|     if seed is None: | ||||
|       x_seeds = self.dataset_seed[dataset] | ||||
| @@ -537,6 +562,8 @@ class ArchResults(object): | ||||
|     x.load_state_dict(state_dict) | ||||
|     return x | ||||
|  | ||||
|   # This function is used to clear the weights saved in each 'result' | ||||
|   # This can help reduce the memory footprint. | ||||
|   def clear_params(self): | ||||
|     for key, result in self.all_results.items(): | ||||
|       result.net_state_dict = None | ||||
| @@ -547,6 +574,11 @@ class ArchResults(object): | ||||
|      | ||||
|  | ||||
|  | ||||
| """ | ||||
| This class (ResultsCount) is used to save the information of one trial for a single architecture. | ||||
| I did not write much comment for this class, because it is the lowest-level class in NAS-Bench-201 API, which will be rarely called. | ||||
| If you have any question regarding this class, please open an issue or email me. | ||||
| """ | ||||
| class ResultsCount(object): | ||||
|  | ||||
|   def __init__(self, name, state_dict, train_accs, train_losses, params, flop, arch_config, seed, epochs, latency): | ||||
| @@ -604,10 +636,17 @@ class ResultsCount(object): | ||||
|     set_name = '[' + ', '.join(self.eval_names) + ']' | ||||
|     return ('{name}({xname}, arch={arch}, FLOP={flop:.2f}M, Param={param:.3f}MB, seed={seed}, {num_eval} eval-sets: {set_name})'.format(name=self.__class__.__name__, xname=self.name, arch=self.arch_config['arch_str'], flop=self.flop, param=self.params, seed=self.seed, num_eval=num_eval, set_name=set_name)) | ||||
|  | ||||
|   # get the total number of training epochs | ||||
|   def get_total_epoch(self): | ||||
|     return copy.deepcopy(self.epochs) | ||||
|    | ||||
|   # get the latency | ||||
|   # -1 represents not avaliable ; otherwise it should be a float value | ||||
|   def get_latency(self): | ||||
|     if self.latency is None: return -1 | ||||
|     else: return sum(self.latency) / len(self.latency) | ||||
|  | ||||
|   # get the information regarding time | ||||
|   def get_times(self): | ||||
|     if self.train_times is not None and isinstance(self.train_times, dict): | ||||
|       train_times = list( self.train_times.values() ) | ||||
| @@ -626,6 +665,7 @@ class ResultsCount(object): | ||||
|   def get_eval_set(self): | ||||
|     return self.eval_names | ||||
|  | ||||
|   # get the training information | ||||
|   def get_train(self, iepoch=None): | ||||
|     if iepoch is None: iepoch = self.epochs-1 | ||||
|     assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs) | ||||
| @@ -639,6 +679,7 @@ class ResultsCount(object): | ||||
|             'cur_time': xtime, | ||||
|             'all_time': atime} | ||||
|  | ||||
|   # get the evaluation information ; there could be multiple evaluation sets (identified by the 'name' argument). | ||||
|   def get_eval(self, name, iepoch=None): | ||||
|     if iepoch is None: iepoch = self.epochs-1 | ||||
|     assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user