update vis
This commit is contained in:
		| @@ -6,6 +6,7 @@ | ||||
| import os, sys, time, glob, random, argparse | ||||
| import numpy as np | ||||
| from copy import deepcopy | ||||
| from tqdm import tqdm | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from pathlib import Path | ||||
| @@ -142,15 +143,17 @@ def check_unique_arch(meta_file): | ||||
|   print ('{:} There are {:} unique architectures (considering zero).'.format(time_string(), unique_num)) | ||||
|  | ||||
|  | ||||
| def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True): | ||||
| def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True, need_print=False): | ||||
|   if isinstance(meta_file, API): | ||||
|     api = meta_file | ||||
|   else: | ||||
|     api = API(str(meta_file)) | ||||
|   cifar10_valid     = [] | ||||
|   cifar10_test      = [] | ||||
|   cifar100_valid    = [] | ||||
|   cifar100_test     = [] | ||||
|   imagenet_test     = [] | ||||
|   imagenet_valid    = [] | ||||
|   for idx, arch in enumerate(api): | ||||
|     results = api.get_more_info(idx, 'cifar10-valid' , test_epoch-1, use_less_or_not, is_rand) | ||||
|     cifar10_valid.append( results['valid-accuracy'] ) | ||||
| @@ -158,14 +161,16 @@ def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True): | ||||
|     cifar10_test.append( results['test-accuracy'] ) | ||||
|     results = api.get_more_info(idx, 'cifar100'      , None, False, is_rand) | ||||
|     cifar100_test.append( results['test-accuracy'] ) | ||||
|     cifar100_valid.append( results['valid-accuracy'] ) | ||||
|     results = api.get_more_info(idx, 'ImageNet16-120', None, False, is_rand) | ||||
|     imagenet_test.append( results['test-accuracy'] ) | ||||
|     imagenet_valid.append( results['valid-accuracy'] ) | ||||
|   def get_cor(A, B): | ||||
|     return float(np.corrcoef(A, B)[0,1]) | ||||
|   cors = [] | ||||
|   for basestr, xlist in zip(['CIFAR-010', 'CIFAR-100', 'ImageNet16'], [cifar10_test,cifar100_test, imagenet_test]): | ||||
|   for basestr, xlist in zip(['CIFAR-010', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T'], [cifar10_test, cifar100_valid, cifar100_test, imagenet_valid, imagenet_test]): | ||||
|     correlation = get_cor(cifar10_valid, xlist) | ||||
|     print ('With {:3d}/{:}-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(less_epoch, '012' if use_less_or_not else '200', basestr, correlation)) | ||||
|     if need_print: print ('With {:3d}/{:}-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, '012' if use_less_or_not else '200', basestr, correlation)) | ||||
|     cors.append( correlation ) | ||||
|     #print ('With {:3d}/200-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, basestr, get_cor(cifar10_valid_200, xlist))) | ||||
|     #print('-'*200) | ||||
| @@ -173,6 +178,19 @@ def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True): | ||||
|   return cors | ||||
|  | ||||
|  | ||||
| def check_cor_for_bandit_v2(meta_file, test_epoch, use_less_or_not, is_rand): | ||||
|   corrs = [] | ||||
|   for i in tqdm(range(100)): | ||||
|     x = check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand, False) | ||||
|     corrs.append( x ) | ||||
|   xstrs = ['CIFAR-010', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T'] | ||||
|   correlations = np.array(corrs) | ||||
|   print('------>>>>>>>> {:03d}/{:} >>>>>>>> ------'.format(test_epoch, '012' if use_less_or_not else '200')) | ||||
|   for idx, xstr in enumerate(xstrs): | ||||
|     print ('{:8s} ::: mean={:.4f}, std={:.4f} :: {:.4f}\\pm{:.4f}'.format(xstr, correlations[:,idx].mean(), correlations[:,idx].std(), correlations[:,idx].mean(), correlations[:,idx].std())) | ||||
|   print('') | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|   parser = argparse.ArgumentParser("Analysis of NAS-Bench-102") | ||||
|   parser.add_argument('--save_dir',  type=str, default='./output/search-cell-nas-bench-102/visuals', help='The base-name of folder to save checkpoints and log.') | ||||
| @@ -189,5 +207,11 @@ if __name__ == '__main__': | ||||
|   #for iepoch in [11, 25, 50, 100, 150, 175, 200]: | ||||
|   #  check_cor_for_bandit(api,  6, iepoch) | ||||
|   #  check_cor_for_bandit(api, 12, iepoch) | ||||
|   correlations = check_cor_for_bandit(api, 6, True, True) | ||||
|   import pdb; pdb.set_trace() | ||||
|   check_cor_for_bandit_v2(api,   6,  True, True) | ||||
|   check_cor_for_bandit_v2(api,  12,  True, True) | ||||
|   check_cor_for_bandit_v2(api,  12, False, True) | ||||
|   check_cor_for_bandit_v2(api,  24, False, True) | ||||
|   check_cor_for_bandit_v2(api, 100, False, True) | ||||
|   check_cor_for_bandit_v2(api, 150, False, True) | ||||
|   check_cor_for_bandit_v2(api, 200, False, True) | ||||
|   print('----') | ||||
|   | ||||
| @@ -383,4 +383,4 @@ if __name__ == '__main__': | ||||
|   #visualize_info(str(meta_file), 'cifar10' , vis_save_dir) | ||||
|   #visualize_info(str(meta_file), 'cifar100', vis_save_dir) | ||||
|   #visualize_info(str(meta_file), 'ImageNet16-120', vis_save_dir) | ||||
|   visualize_relative_ranking(vis_save_dir) | ||||
|   #visualize_relative_ranking(vis_save_dir) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user