Update REA, REINFORCE, and RANDOM
This commit is contained in:
		| @@ -141,9 +141,12 @@ class NASBench201API(NASBenchMetaAPI): | ||||
|   # `is_random` | ||||
|   #   When is_random=True, the performance of a random architecture will be returned | ||||
|   #   When is_random=False, the performanceo of all trials will be averaged. | ||||
|   def get_more_info(self, index: int, dataset, iepoch=None, hp='12', is_random=True): | ||||
|   def get_more_info(self, index, dataset, iepoch=None, hp='12', is_random=True): | ||||
|     if self.verbose: | ||||
|       print('Call the get_more_info function with index={:}, dataset={:}, iepoch={:}, hp={:}, and is_random={:}.'.format(index, dataset, iepoch, hp, is_random)) | ||||
|     index = self.query_index_by_arch(index)  # To avoid the input is a string or an instance of a arch object | ||||
|     if index not in self.arch2infos_dict: | ||||
|       raise ValueError('Did not find {:} from arch2infos_dict.'.format(index)) | ||||
|     archresult = self.arch2infos_dict[index][str(hp)] | ||||
|     # if randomly select one trial, select the seed at first | ||||
|     if isinstance(is_random, bool) and is_random: | ||||
|   | ||||
| @@ -131,7 +131,7 @@ class NASBench301API(NASBenchMetaAPI): | ||||
|       print('Call query_info_str_by_arch with arch={:} and hp={:}'.format(arch, hp)) | ||||
|     return self._query_info_str_by_arch(arch, hp, print_information) | ||||
|  | ||||
|   def get_more_info(self, index: int, dataset: Text, iepoch=None, hp='12', is_random=True): | ||||
|   def get_more_info(self, index, dataset: Text, iepoch=None, hp='12', is_random=True): | ||||
|     """This function will return the metric for the `index`-th architecture | ||||
|        `dataset` indicates the dataset: | ||||
|           'cifar10-valid'  : using the proposed train set of CIFAR-10 as the training set | ||||
| @@ -151,6 +151,9 @@ class NASBench301API(NASBenchMetaAPI): | ||||
|     """ | ||||
|     if self.verbose: | ||||
|       print('Call the get_more_info function with index={:}, dataset={:}, iepoch={:}, hp={:}, and is_random={:}.'.format(index, dataset, iepoch, hp, is_random)) | ||||
|     index = self.query_index_by_arch(index)  # To avoid the input is a string or an instance of a arch object | ||||
|     if index not in self.arch2infos_dict: | ||||
|       raise ValueError('Did not find {:} from arch2infos_dict.'.format(index)) | ||||
|     archresult = self.arch2infos_dict[index][str(hp)] | ||||
|     # if randomly select one trial, select the seed at first | ||||
|     if isinstance(is_random, bool) and is_random: | ||||
|   | ||||
| @@ -68,7 +68,7 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta): | ||||
|   def reset_time(self): | ||||
|     self._used_time = 0 | ||||
|  | ||||
|   def simulate_train_eval(self, arch, dataset, hp='12'): | ||||
|   def simulate_train_eval(self, arch, dataset, hp='12', account_time=True): | ||||
|     index = self.query_index_by_arch(arch) | ||||
|     all_names = ('cifar10', 'cifar100', 'ImageNet16-120') | ||||
|     assert dataset in all_names, 'Invalid dataset name : {:} vs {:}'.format(dataset, all_names) | ||||
| @@ -77,8 +77,10 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta): | ||||
|     else: | ||||
|       info = self.get_more_info(index, dataset, iepoch=None, hp=hp, is_random=True) | ||||
|     valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time'] | ||||
|     self._used_time += time_cost | ||||
|     return valid_acc, time_cost, self._used_time | ||||
|     latency = self.get_latency(index, dataset) | ||||
|     if account_time: | ||||
|       self._used_time += time_cost | ||||
|     return valid_acc, latency, time_cost, self._used_time | ||||
|  | ||||
|   def random(self): | ||||
|     """Return a random index of all architectures.""" | ||||
|   | ||||
| @@ -8,7 +8,9 @@ import torch.nn as nn | ||||
| from models import CellStructure | ||||
| from log_utils import time_string | ||||
|  | ||||
|  | ||||
| def evaluate_one_shot(model, xloader, api, cal_mode, seed=111): | ||||
|   print ('This is an old version of codes to use NAS-Bench-API, and should be modified to align with the new version. Please contact me for more details if you use this function.') | ||||
|   weights = deepcopy(model.state_dict()) | ||||
|   model.train(cal_mode) | ||||
|   with torch.no_grad(): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user