update README
This commit is contained in:
		| @@ -147,14 +147,14 @@ class NASBench102API(object): | ||||
|     archresult = arch2infos[index] | ||||
|     return archresult.get_net_param(dataset, seed) | ||||
|  | ||||
|   def get_more_info(self, index, dataset, use_12epochs_result=False): | ||||
|   def get_more_info(self, index, dataset, iepoch=None, use_12epochs_result=False): | ||||
|     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less | ||||
|     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full | ||||
|     archresult = arch2infos[index] | ||||
|     if dataset == 'cifar10-valid': | ||||
|       train_info = archresult.get_metrics(dataset, 'train', is_random=True) | ||||
|       valid_info = archresult.get_metrics(dataset, 'x-valid', is_random=True) | ||||
|       test__info = archresult.get_metrics(dataset, 'ori-test', is_random=True) | ||||
|       train_info = archresult.get_metrics(dataset, 'train'   , iepoch=iepoch, is_random=True) | ||||
|       valid_info = archresult.get_metrics(dataset, 'x-valid' , iepoch=iepoch, is_random=True) | ||||
|       test__info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=True) | ||||
|       total      = train_info['iepoch'] + 1 | ||||
|       return {'train-loss'    : train_info['loss'], | ||||
|               'train-accuracy': train_info['accuracy'], | ||||
|   | ||||
		Reference in New Issue
	
	Block a user