Update find_best API
This commit is contained in:
		
							
								
								
									
										47
									
								
								exps/show-dataset.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								exps/show-dataset.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,47 @@ | |||||||
|  | ############################################################################## | ||||||
|  | # NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size # | ||||||
|  | ############################################################################## | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.07                          # | ||||||
|  | ############################################################################## | ||||||
|  | # python ./exps/NATS-Bench/main-tss.py --mode meta                           # | ||||||
|  | ############################################################################## | ||||||
|  | import os, sys, time, torch, random, argparse | ||||||
|  | from typing import List, Text, Dict, Any | ||||||
|  | from PIL     import ImageFile | ||||||
|  | ImageFile.LOAD_TRUNCATED_IMAGES = True | ||||||
|  | from copy    import deepcopy | ||||||
|  | from pathlib import Path | ||||||
|  |  | ||||||
|  | lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() | ||||||
|  | if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||||
|  | from config_utils import dict2config, load_config | ||||||
|  | from datasets import get_datasets | ||||||
|  | from nats_bench import create | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def show_imagenet_16_120(dataset_dir=None): | ||||||
|  |   if dataset_dir is None: | ||||||
|  |     torch_home_dir = os.environ['TORCH_HOME'] if 'TORCH_HOME' in os.environ else os.path.join(os.environ['HOME'], '.torch') | ||||||
|  |     dataset_dir = os.path.join(torch_home_dir, 'cifar.python', 'ImageNet16') | ||||||
|  |   train_data, valid_data, xshape, class_num = get_datasets('ImageNet16-120', dataset_dir, -1) | ||||||
|  |   split_info  = load_config('configs/nas-benchmark/ImageNet16-120-split.txt', None, None) | ||||||
|  |   print('=' * 10 + ' ImageNet-16-120 ' + '=' * 10) | ||||||
|  |   print('Training Data: {:}'.format(train_data)) | ||||||
|  |   print('Evaluation Data: {:}'.format(valid_data)) | ||||||
|  |   print('Hold-out training: {:} images.'.format(len(split_info.train))) | ||||||
|  |   print('Hold-out valid   : {:} images.'.format(len(split_info.valid))) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == '__main__': | ||||||
|  |   # show_imagenet_16_120() | ||||||
|  |   api_nats_tss = create(None, 'tss', fast_mode=True, verbose=True) | ||||||
|  |  | ||||||
|  |   valid_acc_12e = [] | ||||||
|  |   test_acc_12e = [] | ||||||
|  |   test_acc_200e = [] | ||||||
|  |   for index in range(10000): | ||||||
|  |     info = api_nats_tss.get_more_info(index, 'ImageNet16-120', hp='12') | ||||||
|  |     valid_acc_12e.append(info['valid-accuracy']) | ||||||
|  |     test_acc_12e.append(info['test-accuracy']) | ||||||
|  |     info = api_nats_tss.get_more_info(index, 'ImageNet16-120', hp='200') | ||||||
|  |     test_acc_200e.append(info['test-accuracy'])  # which I reported. | ||||||
| @@ -92,6 +92,10 @@ class ImageNet16(data.Dataset): | |||||||
|     #std_data  = np.mean(np.mean(std_data, axis=0), axis=0) |     #std_data  = np.mean(np.mean(std_data, axis=0), axis=0) | ||||||
|     #print ('Std  : {:}'.format(std_data)) |     #print ('Std  : {:}'.format(std_data)) | ||||||
|  |  | ||||||
|  |   def __repr__(self): | ||||||
|  |     return ('{name}({num} images, {classes} classes)'.format(name=self.__class__.__name__, num=len(self.data), classes=len(set(self.targets)))) | ||||||
|  |  | ||||||
|  |  | ||||||
|   def __getitem__(self, index): |   def __getitem__(self, index): | ||||||
|     img, target = self.data[index], self.targets[index] - 1 |     img, target = self.data[index], self.targets[index] - 1 | ||||||
|  |  | ||||||
| @@ -114,16 +118,16 @@ class ImageNet16(data.Dataset): | |||||||
|         return False |         return False | ||||||
|     return True |     return True | ||||||
|  |  | ||||||
| # | """ | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|   train = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', True , None)  |   train = ImageNet16('~/.torch/cifar.python/ImageNet16', True , None)  | ||||||
|   valid = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', False, None)  |   valid = ImageNet16('~/.torch/cifar.python/ImageNet16', False, None)  | ||||||
|  |  | ||||||
|   print ( len(train) ) |   print ( len(train) ) | ||||||
|   print ( len(valid) ) |   print ( len(valid) ) | ||||||
|   image, label = train[111] |   image, label = train[111] | ||||||
|   trainX = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', True , None, 200) |   trainX = ImageNet16('~/.torch/cifar.python/ImageNet16', True , None, 200) | ||||||
|   validX = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', False , None, 200) |   validX = ImageNet16('~/.torch/cifar.python/ImageNet16', False , None, 200) | ||||||
|   print ( len(trainX) ) |   print ( len(trainX) ) | ||||||
|   print ( len(validX) ) |   print ( len(validX) ) | ||||||
|   #import pdb; pdb.set_trace() | """ | ||||||
|   | |||||||
| @@ -482,6 +482,7 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta): | |||||||
|     best_index, highest_accuracy = -1, None |     best_index, highest_accuracy = -1, None | ||||||
|     evaluated_indexes = sorted(list(self.evaluated_indexes)) |     evaluated_indexes = sorted(list(self.evaluated_indexes)) | ||||||
|     for arch_index in evaluated_indexes: |     for arch_index in evaluated_indexes: | ||||||
|  |       self._prepare_info(arch_index) | ||||||
|       arch_info = self.arch2infos_dict[arch_index][hp] |       arch_info = self.arch2infos_dict[arch_index][hp] | ||||||
|       info = arch_info.get_compute_costs(dataset)  # the information of costs |       info = arch_info.get_compute_costs(dataset)  # the information of costs | ||||||
|       flop, param, latency = info['flops'], info['params'], info['latency'] |       flop, param, latency = info['flops'], info['params'], info['latency'] | ||||||
| @@ -622,6 +623,8 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta): | |||||||
|         print('<' * 40 + '------------' + '<' * 40) |         print('<' * 40 + '------------' + '<' * 40) | ||||||
|     else: |     else: | ||||||
|       if 0 <= index < len(self.meta_archs): |       if 0 <= index < len(self.meta_archs): | ||||||
|  |         if index not in self.evaluated_indexes: | ||||||
|  |           self._prepare_info(index) | ||||||
|         if index not in self.evaluated_indexes: |         if index not in self.evaluated_indexes: | ||||||
|           print('The {:}-th architecture has not been evaluated ' |           print('The {:}-th architecture has not been evaluated ' | ||||||
|                 'or not saved.'.format(index)) |                 'or not saved.'.format(index)) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user