132 lines
		
	
	
		
			6.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			132 lines
		
	
	
		
			6.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #####################################################
 | |
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
 | |
| ########################################################
 | |
| # python exps/NAS-Bench-201/test-correlation.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth
 | |
| ########################################################
 | |
| import sys, argparse
 | |
| import numpy as np
 | |
| from copy import deepcopy
 | |
| from tqdm import tqdm
 | |
| import torch
 | |
| from pathlib import Path
 | |
| lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
 | |
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
 | |
| from log_utils    import time_string
 | |
| from models       import CellStructure
 | |
| from nas_201_api  import NASBench201API as API
 | |
| 
 | |
| 
 | |
| def check_unique_arch(meta_file):
 | |
|   api = API(str(meta_file))
 | |
|   arch_strs = deepcopy(api.meta_archs)
 | |
|   xarchs = [CellStructure.str2structure(x) for x in arch_strs]
 | |
|   def get_unique_matrix(archs, consider_zero):
 | |
|     UniquStrs = [arch.to_unique_str(consider_zero) for arch in archs]
 | |
|     print ('{:} create unique-string ({:}/{:}) done'.format(time_string(), len(set(UniquStrs)), len(UniquStrs)))
 | |
|     Unique2Index = dict()
 | |
|     for index, xstr in enumerate(UniquStrs):
 | |
|       if xstr not in Unique2Index: Unique2Index[xstr] = list()
 | |
|       Unique2Index[xstr].append( index )
 | |
|     sm_matrix = torch.eye(len(archs)).bool()
 | |
|     for _, xlist in Unique2Index.items():
 | |
|       for i in xlist:
 | |
|         for j in xlist:
 | |
|           sm_matrix[i,j] = True
 | |
|     unique_ids, unique_num = [-1 for _ in archs], 0
 | |
|     for i in range(len(unique_ids)):
 | |
|       if unique_ids[i] > -1: continue
 | |
|       neighbours = sm_matrix[i].nonzero().view(-1).tolist()
 | |
|       for nghb in neighbours:
 | |
|         assert unique_ids[nghb] == -1, 'impossible'
 | |
|         unique_ids[nghb] = unique_num
 | |
|       unique_num += 1
 | |
|     return sm_matrix, unique_ids, unique_num
 | |
| 
 | |
|   print ('There are {:} valid-archs'.format( sum(arch.check_valid() for arch in xarchs) ))
 | |
|   sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, None)
 | |
|   print ('{:} There are {:} unique architectures (considering nothing).'.format(time_string(), unique_num))
 | |
|   sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, False)
 | |
|   print ('{:} There are {:} unique architectures (not considering zero).'.format(time_string(), unique_num))
 | |
|   sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs,  True)
 | |
|   print ('{:} There are {:} unique architectures (considering zero).'.format(time_string(), unique_num))
 | |
| 
 | |
| 
 | |
| def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True, need_print=False):
 | |
|   if isinstance(meta_file, API):
 | |
|     api = meta_file
 | |
|   else:
 | |
|     api = API(str(meta_file))
 | |
|   cifar10_currs     = []
 | |
|   cifar10_valid     = []
 | |
|   cifar10_test      = []
 | |
|   cifar100_valid    = []
 | |
|   cifar100_test     = []
 | |
|   imagenet_test     = []
 | |
|   imagenet_valid    = []
 | |
|   for idx, arch in enumerate(api):
 | |
|     results = api.get_more_info(idx, 'cifar10-valid' , test_epoch-1, use_less_or_not, is_rand)
 | |
|     cifar10_currs.append( results['valid-accuracy'] )
 | |
|     # --->>>>>
 | |
|     results = api.get_more_info(idx, 'cifar10-valid' , None, False, is_rand)
 | |
|     cifar10_valid.append( results['valid-accuracy'] )
 | |
|     results = api.get_more_info(idx, 'cifar10'       , None, False, is_rand)
 | |
|     cifar10_test.append( results['test-accuracy'] )
 | |
|     results = api.get_more_info(idx, 'cifar100'      , None, False, is_rand)
 | |
|     cifar100_test.append( results['test-accuracy'] )
 | |
|     cifar100_valid.append( results['valid-accuracy'] )
 | |
|     results = api.get_more_info(idx, 'ImageNet16-120', None, False, is_rand)
 | |
|     imagenet_test.append( results['test-accuracy'] )
 | |
|     imagenet_valid.append( results['valid-accuracy'] )
 | |
|   def get_cor(A, B):
 | |
|     return float(np.corrcoef(A, B)[0,1])
 | |
|   cors = []
 | |
|   for basestr, xlist in zip(['C-010-V', 'C-010-T', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T'], [cifar10_valid, cifar10_test, cifar100_valid, cifar100_test, imagenet_valid, imagenet_test]):
 | |
|     correlation = get_cor(cifar10_currs, xlist)
 | |
|     if need_print: print ('With {:3d}/{:}-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, '012' if use_less_or_not else '200', basestr, correlation))
 | |
|     cors.append( correlation )
 | |
|     #print ('With {:3d}/200-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, basestr, get_cor(cifar10_valid_200, xlist)))
 | |
|     #print('-'*200)
 | |
|   #print('*'*230)
 | |
|   return cors
 | |
| 
 | |
| 
 | |
| def check_cor_for_bandit_v2(meta_file, test_epoch, use_less_or_not, is_rand):
 | |
|   corrs = []
 | |
|   for i in tqdm(range(100)):
 | |
|     x = check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand, False)
 | |
|     corrs.append( x )
 | |
|   #xstrs = ['CIFAR-010', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T']
 | |
|   xstrs = ['C-010-V', 'C-010-T', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T']
 | |
|   correlations = np.array(corrs)
 | |
|   print('------>>>>>>>> {:03d}/{:} >>>>>>>> ------'.format(test_epoch, '012' if use_less_or_not else '200'))
 | |
|   for idx, xstr in enumerate(xstrs):
 | |
|     print ('{:8s} ::: mean={:.4f}, std={:.4f} :: {:.4f}\\pm{:.4f}'.format(xstr, correlations[:,idx].mean(), correlations[:,idx].std(), correlations[:,idx].mean(), correlations[:,idx].std()))
 | |
|   print('')
 | |
| 
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|   parser = argparse.ArgumentParser("Analysis of NAS-Bench-201")
 | |
|   parser.add_argument('--save_dir',  type=str, default='./output/search-cell-nas-bench-201/visuals', help='The base-name of folder to save checkpoints and log.')
 | |
|   parser.add_argument('--api_path',  type=str, default=None,                                         help='The path to the NAS-Bench-201 benchmark file.')
 | |
|   args = parser.parse_args()
 | |
| 
 | |
|   vis_save_dir = Path(args.save_dir)
 | |
|   vis_save_dir.mkdir(parents=True, exist_ok=True)
 | |
|   meta_file = Path(args.api_path)
 | |
|   assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file)
 | |
| 
 | |
|   #check_unique_arch(meta_file)
 | |
|   api = API(str(meta_file))
 | |
|   #for iepoch in [11, 25, 50, 100, 150, 175, 200]:
 | |
|   #  check_cor_for_bandit(api,  6, iepoch)
 | |
|   #  check_cor_for_bandit(api, 12, iepoch)
 | |
|   check_cor_for_bandit_v2(api,   6,  True, True)
 | |
|   check_cor_for_bandit_v2(api,  12,  True, True)
 | |
|   check_cor_for_bandit_v2(api,  12, False, True)
 | |
|   check_cor_for_bandit_v2(api,  24, False, True)
 | |
|   check_cor_for_bandit_v2(api, 100, False, True)
 | |
|   check_cor_for_bandit_v2(api, 150, False, True)
 | |
|   check_cor_for_bandit_v2(api, 175, False, True)
 | |
|   check_cor_for_bandit_v2(api, 200, False, True)
 | |
|   print('----')
 |