update vis
This commit is contained in:
		| @@ -1,3 +1,6 @@ | |||||||
|  | ################################################## | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||||
|  | ################################################## | ||||||
| import sys, time, torch, random, argparse | import sys, time, torch, random, argparse | ||||||
| from PIL     import ImageFile | from PIL     import ImageFile | ||||||
| ImageFile.LOAD_TRUNCATED_IMAGES = True | ImageFile.LOAD_TRUNCATED_IMAGES = True | ||||||
|   | |||||||
| @@ -6,6 +6,7 @@ | |||||||
| import os, sys, time, glob, random, argparse | import os, sys, time, glob, random, argparse | ||||||
| import numpy as np | import numpy as np | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
|  | from tqdm import tqdm | ||||||
| import torch | import torch | ||||||
| import torch.nn as nn | import torch.nn as nn | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| @@ -142,15 +143,17 @@ def check_unique_arch(meta_file): | |||||||
|   print ('{:} There are {:} unique architectures (considering zero).'.format(time_string(), unique_num)) |   print ('{:} There are {:} unique architectures (considering zero).'.format(time_string(), unique_num)) | ||||||
|  |  | ||||||
|  |  | ||||||
| def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True): | def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True, need_print=False): | ||||||
|   if isinstance(meta_file, API): |   if isinstance(meta_file, API): | ||||||
|     api = meta_file |     api = meta_file | ||||||
|   else: |   else: | ||||||
|     api = API(str(meta_file)) |     api = API(str(meta_file)) | ||||||
|   cifar10_valid     = [] |   cifar10_valid     = [] | ||||||
|   cifar10_test      = [] |   cifar10_test      = [] | ||||||
|  |   cifar100_valid    = [] | ||||||
|   cifar100_test     = [] |   cifar100_test     = [] | ||||||
|   imagenet_test     = [] |   imagenet_test     = [] | ||||||
|  |   imagenet_valid    = [] | ||||||
|   for idx, arch in enumerate(api): |   for idx, arch in enumerate(api): | ||||||
|     results = api.get_more_info(idx, 'cifar10-valid' , test_epoch-1, use_less_or_not, is_rand) |     results = api.get_more_info(idx, 'cifar10-valid' , test_epoch-1, use_less_or_not, is_rand) | ||||||
|     cifar10_valid.append( results['valid-accuracy'] ) |     cifar10_valid.append( results['valid-accuracy'] ) | ||||||
| @@ -158,14 +161,16 @@ def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True): | |||||||
|     cifar10_test.append( results['test-accuracy'] ) |     cifar10_test.append( results['test-accuracy'] ) | ||||||
|     results = api.get_more_info(idx, 'cifar100'      , None, False, is_rand) |     results = api.get_more_info(idx, 'cifar100'      , None, False, is_rand) | ||||||
|     cifar100_test.append( results['test-accuracy'] ) |     cifar100_test.append( results['test-accuracy'] ) | ||||||
|  |     cifar100_valid.append( results['valid-accuracy'] ) | ||||||
|     results = api.get_more_info(idx, 'ImageNet16-120', None, False, is_rand) |     results = api.get_more_info(idx, 'ImageNet16-120', None, False, is_rand) | ||||||
|     imagenet_test.append( results['test-accuracy'] ) |     imagenet_test.append( results['test-accuracy'] ) | ||||||
|  |     imagenet_valid.append( results['valid-accuracy'] ) | ||||||
|   def get_cor(A, B): |   def get_cor(A, B): | ||||||
|     return float(np.corrcoef(A, B)[0,1]) |     return float(np.corrcoef(A, B)[0,1]) | ||||||
|   cors = [] |   cors = [] | ||||||
|   for basestr, xlist in zip(['CIFAR-010', 'CIFAR-100', 'ImageNet16'], [cifar10_test,cifar100_test, imagenet_test]): |   for basestr, xlist in zip(['CIFAR-010', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T'], [cifar10_test, cifar100_valid, cifar100_test, imagenet_valid, imagenet_test]): | ||||||
|     correlation = get_cor(cifar10_valid, xlist) |     correlation = get_cor(cifar10_valid, xlist) | ||||||
|     print ('With {:3d}/{:}-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(less_epoch, '012' if use_less_or_not else '200', basestr, correlation)) |     if need_print: print ('With {:3d}/{:}-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, '012' if use_less_or_not else '200', basestr, correlation)) | ||||||
|     cors.append( correlation ) |     cors.append( correlation ) | ||||||
|     #print ('With {:3d}/200-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, basestr, get_cor(cifar10_valid_200, xlist))) |     #print ('With {:3d}/200-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, basestr, get_cor(cifar10_valid_200, xlist))) | ||||||
|     #print('-'*200) |     #print('-'*200) | ||||||
| @@ -173,6 +178,19 @@ def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True): | |||||||
|   return cors |   return cors | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def check_cor_for_bandit_v2(meta_file, test_epoch, use_less_or_not, is_rand): | ||||||
|  |   corrs = [] | ||||||
|  |   for i in tqdm(range(100)): | ||||||
|  |     x = check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand, False) | ||||||
|  |     corrs.append( x ) | ||||||
|  |   xstrs = ['CIFAR-010', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T'] | ||||||
|  |   correlations = np.array(corrs) | ||||||
|  |   print('------>>>>>>>> {:03d}/{:} >>>>>>>> ------'.format(test_epoch, '012' if use_less_or_not else '200')) | ||||||
|  |   for idx, xstr in enumerate(xstrs): | ||||||
|  |     print ('{:8s} ::: mean={:.4f}, std={:.4f} :: {:.4f}\\pm{:.4f}'.format(xstr, correlations[:,idx].mean(), correlations[:,idx].std(), correlations[:,idx].mean(), correlations[:,idx].std())) | ||||||
|  |   print('') | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|   parser = argparse.ArgumentParser("Analysis of NAS-Bench-102") |   parser = argparse.ArgumentParser("Analysis of NAS-Bench-102") | ||||||
|   parser.add_argument('--save_dir',  type=str, default='./output/search-cell-nas-bench-102/visuals', help='The base-name of folder to save checkpoints and log.') |   parser.add_argument('--save_dir',  type=str, default='./output/search-cell-nas-bench-102/visuals', help='The base-name of folder to save checkpoints and log.') | ||||||
| @@ -189,5 +207,11 @@ if __name__ == '__main__': | |||||||
|   #for iepoch in [11, 25, 50, 100, 150, 175, 200]: |   #for iepoch in [11, 25, 50, 100, 150, 175, 200]: | ||||||
|   #  check_cor_for_bandit(api,  6, iepoch) |   #  check_cor_for_bandit(api,  6, iepoch) | ||||||
|   #  check_cor_for_bandit(api, 12, iepoch) |   #  check_cor_for_bandit(api, 12, iepoch) | ||||||
|   correlations = check_cor_for_bandit(api, 6, True, True) |   check_cor_for_bandit_v2(api,   6,  True, True) | ||||||
|   import pdb; pdb.set_trace() |   check_cor_for_bandit_v2(api,  12,  True, True) | ||||||
|  |   check_cor_for_bandit_v2(api,  12, False, True) | ||||||
|  |   check_cor_for_bandit_v2(api,  24, False, True) | ||||||
|  |   check_cor_for_bandit_v2(api, 100, False, True) | ||||||
|  |   check_cor_for_bandit_v2(api, 150, False, True) | ||||||
|  |   check_cor_for_bandit_v2(api, 200, False, True) | ||||||
|  |   print('----') | ||||||
|   | |||||||
| @@ -383,4 +383,4 @@ if __name__ == '__main__': | |||||||
|   #visualize_info(str(meta_file), 'cifar10' , vis_save_dir) |   #visualize_info(str(meta_file), 'cifar10' , vis_save_dir) | ||||||
|   #visualize_info(str(meta_file), 'cifar100', vis_save_dir) |   #visualize_info(str(meta_file), 'cifar100', vis_save_dir) | ||||||
|   #visualize_info(str(meta_file), 'ImageNet16-120', vis_save_dir) |   #visualize_info(str(meta_file), 'ImageNet16-120', vis_save_dir) | ||||||
|   visualize_relative_ranking(vis_save_dir) |   #visualize_relative_ranking(vis_save_dir) | ||||||
|   | |||||||
| @@ -174,7 +174,7 @@ def main(xargs, nas_bench): | |||||||
|  |  | ||||||
|   id2config = results.get_id2config_mapping() |   id2config = results.get_id2config_mapping() | ||||||
|   incumbent = results.get_incumbent_id() |   incumbent = results.get_incumbent_id() | ||||||
|   logger.log('Best found configuration: {:}'.format(id2config[incumbent]['config'])) |   logger.log('Best found configuration: {:} within {:.3f} s'.format(id2config[incumbent]['config'], real_cost_time)) | ||||||
|   best_arch = config2structure( id2config[incumbent]['config'] ) |   best_arch = config2structure( id2config[incumbent]['config'] ) | ||||||
|  |  | ||||||
|   info = nas_bench.query_by_arch( best_arch ) |   info = nas_bench.query_by_arch( best_arch ) | ||||||
|   | |||||||
| @@ -56,6 +56,7 @@ def main(xargs, nas_bench): | |||||||
|   search_space = get_search_spaces('cell', xargs.search_space_name) |   search_space = get_search_spaces('cell', xargs.search_space_name) | ||||||
|   random_arch = random_architecture_func(xargs.max_nodes, search_space) |   random_arch = random_architecture_func(xargs.max_nodes, search_space) | ||||||
|   #x =random_arch() ; y = mutate_arch(x) |   #x =random_arch() ; y = mutate_arch(x) | ||||||
|  |   x_start_time = time.time() | ||||||
|   logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench)) |   logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench)) | ||||||
|   best_arch, best_acc, total_time_cost, history = None, -1, 0, [] |   best_arch, best_acc, total_time_cost, history = None, -1, 0, [] | ||||||
|   #for idx in range(xargs.random_num): |   #for idx in range(xargs.random_num): | ||||||
| @@ -68,7 +69,7 @@ def main(xargs, nas_bench): | |||||||
|     if best_arch is None or best_acc < accuracy: |     if best_arch is None or best_acc < accuracy: | ||||||
|       best_acc, best_arch = accuracy, arch |       best_acc, best_arch = accuracy, arch | ||||||
|     logger.log('[{:03d}] : {:} : accuracy = {:.2f}%'.format(len(history), arch, accuracy)) |     logger.log('[{:03d}] : {:} : accuracy = {:.2f}%'.format(len(history), arch, accuracy)) | ||||||
|   logger.log('{:} best arch is {:}, accuracy = {:.2f}%, visit {:} archs with {:.1f} s.'.format(time_string(), best_arch, best_acc, len(history), total_time_cost)) |   logger.log('{:} best arch is {:}, accuracy = {:.2f}%, visit {:} archs with {:.1f} s (real-cost = {:.3f} s).'.format(time_string(), best_arch, best_acc, len(history), total_time_cost, time.time()-x_start_time)) | ||||||
|    |    | ||||||
|   info = nas_bench.query_by_arch( best_arch ) |   info = nas_bench.query_by_arch( best_arch ) | ||||||
|   if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch)) |   if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch)) | ||||||
|   | |||||||
| @@ -201,10 +201,11 @@ def main(xargs, nas_bench): | |||||||
|   random_arch = random_architecture_func(xargs.max_nodes, search_space) |   random_arch = random_architecture_func(xargs.max_nodes, search_space) | ||||||
|   mutate_arch = mutate_arch_func(search_space) |   mutate_arch = mutate_arch_func(search_space) | ||||||
|   #x =random_arch() ; y = mutate_arch(x) |   #x =random_arch() ; y = mutate_arch(x) | ||||||
|  |   x_start_time = time.time() | ||||||
|   logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench)) |   logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench)) | ||||||
|   logger.log('-'*30 + ' start searching with the time budget of {:} s'.format(xargs.time_budget)) |   logger.log('-'*30 + ' start searching with the time budget of {:} s'.format(xargs.time_budget)) | ||||||
|   history, total_cost = regularized_evolution(xargs.ea_cycles, xargs.ea_population, xargs.ea_sample_size, xargs.time_budget, random_arch, mutate_arch, nas_bench if args.ea_fast_by_api else None, extra_info) |   history, total_cost = regularized_evolution(xargs.ea_cycles, xargs.ea_population, xargs.ea_sample_size, xargs.time_budget, random_arch, mutate_arch, nas_bench if args.ea_fast_by_api else None, extra_info) | ||||||
|   logger.log('{:} regularized_evolution finish with history of {:} arch with {:.1f} s.'.format(time_string(), len(history), total_cost)) |   logger.log('{:} regularized_evolution finish with history of {:} arch with {:.1f} s (real-cost={:.2f} s).'.format(time_string(), len(history), total_cost, time.time()-x_start_time)) | ||||||
|   best_arch = max(history, key=lambda i: i.accuracy) |   best_arch = max(history, key=lambda i: i.accuracy) | ||||||
|   best_arch = best_arch.arch |   best_arch = best_arch.arch | ||||||
|   logger.log('{:} best arch is {:}'.format(time_string(), best_arch)) |   logger.log('{:} best arch is {:}'.format(time_string(), best_arch)) | ||||||
|   | |||||||
| @@ -139,6 +139,7 @@ def main(xargs, nas_bench): | |||||||
|  |  | ||||||
|   # REINFORCE |   # REINFORCE | ||||||
|   # attempts = 0 |   # attempts = 0 | ||||||
|  |   x_start_time = time.time() | ||||||
|   logger.log('Will start searching with time budget of {:} s.'.format(xargs.time_budget)) |   logger.log('Will start searching with time budget of {:} s.'.format(xargs.time_budget)) | ||||||
|   total_steps, total_costs = 0, 0 |   total_steps, total_costs = 0, 0 | ||||||
|   #for istep in range(xargs.RL_steps): |   #for istep in range(xargs.RL_steps): | ||||||
| @@ -166,7 +167,7 @@ def main(xargs, nas_bench): | |||||||
|     #logger.log('') |     #logger.log('') | ||||||
|  |  | ||||||
|   best_arch = policy.genotype() |   best_arch = policy.genotype() | ||||||
|   logger.log('REINFORCE finish with {:} steps and {:.1f} s.'.format(total_steps, total_costs)) |   logger.log('REINFORCE finish with {:} steps and {:.1f} s (real cost={:.3f}).'.format(total_steps, total_costs, time.time()-x_start_time)) | ||||||
|   info = nas_bench.query_by_arch( best_arch ) |   info = nas_bench.query_by_arch( best_arch ) | ||||||
|   if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch)) |   if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch)) | ||||||
|   else           : logger.log('{:}'.format(info)) |   else           : logger.log('{:}'.format(info)) | ||||||
|   | |||||||
| @@ -1,3 +1,6 @@ | |||||||
|  | ################################################## | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||||
|  | ################################################## | ||||||
| import os, sys, time, torch, random, argparse | import os, sys, time, torch, random, argparse | ||||||
| from PIL     import ImageFile | from PIL     import ImageFile | ||||||
| ImageFile.LOAD_TRUNCATED_IMAGES = True | ImageFile.LOAD_TRUNCATED_IMAGES = True | ||||||
|   | |||||||
| @@ -1,3 +1,6 @@ | |||||||
|  | ################################################## | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||||
|  | ################################################## | ||||||
| import sys, time, torch, random, argparse | import sys, time, torch, random, argparse | ||||||
| from PIL     import ImageFile | from PIL     import ImageFile | ||||||
| ImageFile.LOAD_TRUNCATED_IMAGES = True | ImageFile.LOAD_TRUNCATED_IMAGES = True | ||||||
|   | |||||||
| @@ -1 +0,0 @@ | |||||||
| from graphviz import Digraph |  | ||||||
| @@ -1,27 +1,28 @@ | |||||||
| # python ./vis-exps/show-results.py | ################################################## | ||||||
| import os, sys | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||||
|  | ################################################## | ||||||
|  | # python ./exps/vis/show-results.py --api_path ${HOME}/.torch/NAS-Bench-102-v1_0-e61699.pth | ||||||
|  | ################################################## | ||||||
|  | import os, sys, argparse | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| import torch | import torch | ||||||
| import numpy as np | import numpy as np | ||||||
| from collections import OrderedDict | from collections import OrderedDict | ||||||
| lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() |  | ||||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) |  | ||||||
|  |  | ||||||
| from aa_nas_api   import AANASBenchAPI |  | ||||||
|  |  | ||||||
| api = AANASBenchAPI('./output/AA-NAS-BENCH-4/simplifies/C16-N5-final-infos.pth') |  | ||||||
|  |  | ||||||
| def plot_results_nas(dataset, xset, file_name, y_lims): |  | ||||||
| import matplotlib | import matplotlib | ||||||
| matplotlib.use('agg') | matplotlib.use('agg') | ||||||
| import matplotlib.pyplot as plt | import matplotlib.pyplot as plt | ||||||
|   root = Path('./output/cell-search-tiny-vis').resolve() | lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() | ||||||
|   print ('root path : {:}'.format( root )) | if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) | ||||||
|   root.mkdir(parents=True, exist_ok=True) |  | ||||||
|   checkpoints = ['./output/cell-search-tiny/R-EA-cifar10/results.pth', | from nas_102_api import NASBench102API as API | ||||||
|                  './output/cell-search-tiny/REINFORCE-cifar10/results.pth', |  | ||||||
|                  './output/cell-search-tiny/RAND-cifar10/results.pth', |  | ||||||
|                  './output/cell-search-tiny/BOHB-cifar10/results.pth' | def plot_results_nas(api, dataset, xset, root, file_name, y_lims): | ||||||
|  |   print ('root-path={:}, dataset={:}, xset={:}'.format(root, dataset, xset)) | ||||||
|  |   checkpoints = ['./output/search-cell-nas-bench-102/R-EA-cifar10/results.pth', | ||||||
|  |                  './output/search-cell-nas-bench-102/REINFORCE-cifar10/results.pth', | ||||||
|  |                  './output/search-cell-nas-bench-102/RAND-cifar10/results.pth', | ||||||
|  |                  './output/search-cell-nas-bench-102/BOHB-cifar10/results.pth' | ||||||
|                 ] |                 ] | ||||||
|   legends, indexes = ['REA', 'REINFORCE', 'RANDOM', 'BOHB'], None |   legends, indexes = ['REA', 'REINFORCE', 'RANDOM', 'BOHB'], None | ||||||
|   All_Accs = OrderedDict() |   All_Accs = OrderedDict() | ||||||
| @@ -29,15 +30,15 @@ def plot_results_nas(dataset, xset, file_name, y_lims): | |||||||
|     all_indexes = torch.load(checkpoint, map_location='cpu') |     all_indexes = torch.load(checkpoint, map_location='cpu') | ||||||
|     accuracies  = [] |     accuracies  = [] | ||||||
|     for x in all_indexes: |     for x in all_indexes: | ||||||
|       info = api.arch2infos[ x ] |       info = api.arch2infos_full[ x ] | ||||||
|       _, accy = info.get_metrics(dataset, xset, None, False) |       metrics = info.get_metrics(dataset, xset, None, False) | ||||||
|       accuracies.append( accy ) |       accuracies.append( metrics['accuracy'] ) | ||||||
|     if indexes is None: indexes = list(range(len(all_indexes))) |     if indexes is None: indexes = list(range(len(all_indexes))) | ||||||
|     All_Accs[legend] = sorted(accuracies) |     All_Accs[legend] = sorted(accuracies) | ||||||
|    |    | ||||||
|   color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k'] |   color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k'] | ||||||
|   dpi, width, height = 300, 3400, 2600 |   dpi, width, height = 300, 3400, 2600 | ||||||
|   LabelSize, LegendFontsize = 26, 26 |   LabelSize, LegendFontsize = 28, 28 | ||||||
|   figsize = width / float(dpi), height / float(dpi) |   figsize = width / float(dpi), height / float(dpi) | ||||||
|   fig = plt.figure(figsize=figsize) |   fig = plt.figure(figsize=figsize) | ||||||
|   x_axis = np.arange(0, 600) |   x_axis = np.arange(0, 600) | ||||||
| @@ -52,16 +53,83 @@ def plot_results_nas(dataset, xset, file_name, y_lims): | |||||||
|  |  | ||||||
|   for idx, legend in enumerate(legends): |   for idx, legend in enumerate(legends): | ||||||
|     plt.plot(indexes, All_Accs[legend], color=color_set[idx], linestyle='-', label='{:}'.format(legend), lw=2) |     plt.plot(indexes, All_Accs[legend], color=color_set[idx], linestyle='-', label='{:}'.format(legend), lw=2) | ||||||
|     print ('{:} : mean = {:}, std = {:}'.format(legend, np.mean(All_Accs[legend]), np.std(All_Accs[legend]))) |     print ('{:} : mean = {:}, std = {:} :: {:.2f}$\\pm${:.2f}'.format(legend, np.mean(All_Accs[legend]), np.std(All_Accs[legend]), np.mean(All_Accs[legend]), np.std(All_Accs[legend]))) | ||||||
|   plt.legend(loc=4, fontsize=LegendFontsize) |   plt.legend(loc=4, fontsize=LegendFontsize) | ||||||
|   save_path = root / '{:}-{:}-{:}'.format(dataset, xset, file_name) |   save_path = root / '{:}-{:}-{:}'.format(dataset, xset, file_name) | ||||||
|   print('save figure into {:}\n'.format(save_path)) |   print('save figure into {:}\n'.format(save_path)) | ||||||
|   fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf') |   fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf') | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def just_show(api): | ||||||
|  |   xtimes = {'RSPS': [8082.5, 7794.2, 8144.7], | ||||||
|  |             'DARTS-V1': [11582.1, 11347.0, 11948.2], | ||||||
|  |             'DARTS-V2': [35694.7, 36132.7, 35518.0], | ||||||
|  |             'GDAS'    : [31334.1, 31478.6, 32016.7], | ||||||
|  |             'SETN'    : [33528.8, 33831.5, 35058.3], | ||||||
|  |             'ENAS'    : [14340.2, 13817.3, 14018.9]} | ||||||
|  |   for xkey, xlist in xtimes.items(): | ||||||
|  |     xlist = np.array(xlist) | ||||||
|  |     print ('{:4s} : mean-time={:.1f} s'.format(xkey, xlist.mean())) | ||||||
|  |  | ||||||
|  |   xpaths = {'RSPS'    : 'output/search-cell-nas-bench-102/RANDOM-NAS-cifar10/checkpoint/', | ||||||
|  |             'DARTS-V1': 'output/search-cell-nas-bench-102/DARTS-V1-cifar10/checkpoint/', | ||||||
|  |             'DARTS-V2': 'output/search-cell-nas-bench-102/DARTS-V2-cifar10/checkpoint/', | ||||||
|  |             'GDAS'    : 'output/search-cell-nas-bench-102/GDAS-cifar10/checkpoint/', | ||||||
|  |             'SETN'    : 'output/search-cell-nas-bench-102/SETN-cifar10/checkpoint/', | ||||||
|  |             'ENAS'    : 'output/search-cell-nas-bench-102/ENAS-cifar10/checkpoint/', | ||||||
|  |            } | ||||||
|  |   xseeds = {'RSPS'    : [5349, 59613, 5983], | ||||||
|  |             'DARTS-V1': [11416, 72873, 81184], | ||||||
|  |             'DARTS-V2': [43330, 79405, 79423], | ||||||
|  |             'GDAS'    : [19677, 884, 95950], | ||||||
|  |             'SETN'    : [20518, 61817, 89144], | ||||||
|  |             'ENAS'    : [30801, 75610, 97745], | ||||||
|  |            } | ||||||
|  |  | ||||||
|  |   def get_accs(xdata, index=-1): | ||||||
|  |     if index == -1: | ||||||
|  |       epochs = xdata['epoch'] | ||||||
|  |       genotype = xdata['genotypes'][epochs-1] | ||||||
|  |       index = api.query_index_by_arch(genotype) | ||||||
|  |     pairs = [('cifar10-valid', 'x-valid'), ('cifar10', 'ori-test'), ('cifar100', 'x-valid'), ('cifar100', 'x-test'), ('ImageNet16-120', 'x-valid'), ('ImageNet16-120', 'x-test')] | ||||||
|  |     xresults = [] | ||||||
|  |     for dataset, xset in pairs: | ||||||
|  |       metrics = api.arch2infos_full[index].get_metrics(dataset, xset, None, False) | ||||||
|  |       xresults.append( metrics['accuracy'] ) | ||||||
|  |     return xresults | ||||||
|  |  | ||||||
|  |   for xkey in xpaths.keys(): | ||||||
|  |     all_paths = [ '{:}/seed-{:}-basic.pth'.format(xpaths[xkey], seed) for seed in xseeds[xkey] ] | ||||||
|  |     all_datas = [torch.load(xpath) for xpath in all_paths] | ||||||
|  |     accyss = [get_accs(xdatas) for xdatas in all_datas] | ||||||
|  |     accyss = np.array( accyss ) | ||||||
|  |     print('\nxkey = {:}'.format(xkey)) | ||||||
|  |     for i in range(accyss.shape[1]): print('---->>>> {:.2f}$\\pm${:.2f}'.format(accyss[:,i].mean(), accyss[:,i].std())) | ||||||
|  |  | ||||||
|  |   print('\n{:}'.format(get_accs(None, 11472))) # resnet | ||||||
|  |   pairs = [('cifar10-valid', 'x-valid'), ('cifar10', 'ori-test'), ('cifar100', 'x-valid'), ('cifar100', 'x-test'), ('ImageNet16-120', 'x-valid'), ('ImageNet16-120', 'x-test')] | ||||||
|  |   for dataset, metric_on_set in pairs: | ||||||
|  |     arch_index, highest_acc = api.find_best(dataset, metric_on_set) | ||||||
|  |     print ('[{:10s}-{:10s} ::: index={:5d}, accuracy={:.2f}'.format(dataset, metric_on_set, arch_index, highest_acc)) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|   plot_results_nas('cifar10', 'ori-test', 'nas-com.pdf', (85,95, 1)) |   parser = argparse.ArgumentParser(description='NAS-Bench-102', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||||
|   plot_results_nas('cifar100', 'x-valid', 'nas-com.pdf', (55,75, 3)) |   parser.add_argument('--save_dir',  type=str, default='./output/search-cell-nas-bench-102/visuals', help='The base-name of folder to save checkpoints and log.') | ||||||
|   plot_results_nas('cifar100', 'x-test' , 'nas-com.pdf', (55,75, 3)) |   parser.add_argument('--api_path',  type=str, default=None,                                         help='The path to the NAS-Bench-102 benchmark file.') | ||||||
|   plot_results_nas('ImageNet16-120', 'x-valid', 'nas-com.pdf', (35,50, 3)) |   args = parser.parse_args() | ||||||
|   plot_results_nas('ImageNet16-120', 'x-test' , 'nas-com.pdf', (35,50, 3)) |  | ||||||
|  |   api  = API(args.api_path) | ||||||
|  |  | ||||||
|  |   root = Path(args.save_dir).resolve() | ||||||
|  |   root.mkdir(parents=True, exist_ok=True) | ||||||
|  |  | ||||||
|  |   just_show(api) | ||||||
|  |   """ | ||||||
|  |   plot_results_nas(api, 'cifar10-valid' , 'x-valid' , root, 'nas-com.pdf', (85,95, 1)) | ||||||
|  |   plot_results_nas(api, 'cifar10'       , 'ori-test', root, 'nas-com.pdf', (85,95, 1)) | ||||||
|  |   plot_results_nas(api, 'cifar100'      , 'x-valid' , root, 'nas-com.pdf', (55,75, 3)) | ||||||
|  |   plot_results_nas(api, 'cifar100'      , 'x-test'  , root, 'nas-com.pdf', (55,75, 3)) | ||||||
|  |   plot_results_nas(api, 'ImageNet16-120', 'x-valid' , root, 'nas-com.pdf', (35,50, 3)) | ||||||
|  |   plot_results_nas(api, 'ImageNet16-120', 'x-test'  , root, 'nas-com.pdf', (35,50, 3)) | ||||||
|  |   """ | ||||||
|   | |||||||
| @@ -131,15 +131,17 @@ class NASBench102API(object): | |||||||
|     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full |     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full | ||||||
|     best_index, highest_accuracy = -1, None |     best_index, highest_accuracy = -1, None | ||||||
|     for i, idx in enumerate(self.evaluated_indexes): |     for i, idx in enumerate(self.evaluated_indexes): | ||||||
|       flop, param, latency = arch2infos[idx].get_comput_costs(dataset) |       info = arch2infos[idx].get_comput_costs(dataset) | ||||||
|  |       flop, param, latency = info['flops'], info['params'], info['latency'] | ||||||
|       if FLOP_max  is not None and flop  > FLOP_max : continue |       if FLOP_max  is not None and flop  > FLOP_max : continue | ||||||
|       if Param_max is not None and param > Param_max: continue |       if Param_max is not None and param > Param_max: continue | ||||||
|       loss, accuracy = arch2infos[idx].get_metrics(dataset, metric_on_set) |       xinfo = arch2infos[idx].get_metrics(dataset, metric_on_set) | ||||||
|  |       loss, accuracy = xinfo['loss'], xinfo['accuracy'] | ||||||
|       if best_index == -1: |       if best_index == -1: | ||||||
|         best_index, highest_accuracy = idx, accuracy |         best_index, highest_accuracy = idx, accuracy | ||||||
|       elif highest_accuracy < accuracy: |       elif highest_accuracy < accuracy: | ||||||
|         best_index, highest_accuracy = idx, accuracy |         best_index, highest_accuracy = idx, accuracy | ||||||
|     return best_index |     return best_index, highest_accuracy | ||||||
|  |  | ||||||
|   # return the topology structure of the `index`-th architecture |   # return the topology structure of the `index`-th architecture | ||||||
|   def arch(self, index): |   def arch(self, index): | ||||||
| @@ -183,10 +185,18 @@ class NASBench102API(object): | |||||||
|         test__info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random) |         test__info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random) | ||||||
|       else: |       else: | ||||||
|         test__info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random) |         test__info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random) | ||||||
|       return {'train-loss'    : train_info['loss'], |       try: | ||||||
|  |         valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random) | ||||||
|  |       except: | ||||||
|  |         valid_info = None | ||||||
|  |       xifo = {'train-loss'    : train_info['loss'], | ||||||
|               'train-accuracy': train_info['accuracy'], |               'train-accuracy': train_info['accuracy'], | ||||||
|               'test-loss'     : test__info['loss'], |               'test-loss'     : test__info['loss'], | ||||||
|               'test-accuracy' : test__info['accuracy']} |               'test-accuracy' : test__info['accuracy']} | ||||||
|  |       if valid_info is not None: | ||||||
|  |         xifo['valid-loss'] = valid_info['loss'] | ||||||
|  |         xifo['valid-accuracy'] = valid_info['accuracy'] | ||||||
|  |       return xifo | ||||||
|  |  | ||||||
|   def show(self, index=-1): |   def show(self, index=-1): | ||||||
|     if index < 0: # show all architectures |     if index < 0: # show all architectures | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user