diff --git a/exps/NAS-Bench-102/visualize.py b/exps/NAS-Bench-102/visualize.py index f73adc0..992bb61 100644 --- a/exps/NAS-Bench-102/visualize.py +++ b/exps/NAS-Bench-102/visualize.py @@ -367,6 +367,160 @@ def write_video(save_dir): + +def plot_results_nas(api, dataset, xset, root, file_name, y_lims): + print ('root-path={:}, dataset={:}, xset={:}'.format(root, dataset, xset)) + checkpoints = ['./output/search-cell-nas-bench-102/R-EA-cifar10/results.pth', + './output/search-cell-nas-bench-102/REINFORCE-cifar10/results.pth', + './output/search-cell-nas-bench-102/RAND-cifar10/results.pth', + './output/search-cell-nas-bench-102/BOHB-cifar10/results.pth' + ] + legends, indexes = ['REA', 'REINFORCE', 'RANDOM', 'BOHB'], None + All_Accs = OrderedDict() + for legend, checkpoint in zip(legends, checkpoints): + all_indexes = torch.load(checkpoint, map_location='cpu') + accuracies = [] + for x in all_indexes: + info = api.arch2infos_full[ x ] + metrics = info.get_metrics(dataset, xset, None, False) + accuracies.append( metrics['accuracy'] ) + if indexes is None: indexes = list(range(len(all_indexes))) + All_Accs[legend] = sorted(accuracies) + + color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k'] + dpi, width, height = 300, 3400, 2600 + LabelSize, LegendFontsize = 28, 28 + figsize = width / float(dpi), height / float(dpi) + fig = plt.figure(figsize=figsize) + x_axis = np.arange(0, 600) + plt.xlim(0, max(indexes)) + plt.ylim(y_lims[0], y_lims[1]) + interval_x, interval_y = 100, y_lims[2] + plt.xticks(np.arange(0, max(indexes), interval_x), fontsize=LegendFontsize) + plt.yticks(np.arange(y_lims[0],y_lims[1], interval_y), fontsize=LegendFontsize) + plt.grid() + plt.xlabel('The index of runs', fontsize=LabelSize) + plt.ylabel('The accuracy (%)', fontsize=LabelSize) + + for idx, legend in enumerate(legends): + plt.plot(indexes, All_Accs[legend], color=color_set[idx], linestyle='-', label='{:}'.format(legend), lw=2) + print ('{:} : mean = {:}, std = {:} :: {:.2f}$\\pm${:.2f}'.format(legend, np.mean(All_Accs[legend]), np.std(All_Accs[legend]), np.mean(All_Accs[legend]), np.std(All_Accs[legend]))) + plt.legend(loc=4, fontsize=LegendFontsize) + save_path = root / '{:}-{:}-{:}'.format(dataset, xset, file_name) + print('save figure into {:}\n'.format(save_path)) + fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf') + + +def just_show(api): + xtimes = {'RSPS': [8082.5, 7794.2, 8144.7], + 'DARTS-V1': [11582.1, 11347.0, 11948.2], + 'DARTS-V2': [35694.7, 36132.7, 35518.0], + 'GDAS' : [31334.1, 31478.6, 32016.7], + 'SETN' : [33528.8, 33831.5, 35058.3], + 'ENAS' : [14340.2, 13817.3, 14018.9]} + for xkey, xlist in xtimes.items(): + xlist = np.array(xlist) + print ('{:4s} : mean-time={:.1f} s'.format(xkey, xlist.mean())) + + xpaths = {'RSPS' : 'output/search-cell-nas-bench-102/RANDOM-NAS-cifar10/checkpoint/', + 'DARTS-V1': 'output/search-cell-nas-bench-102/DARTS-V1-cifar10/checkpoint/', + 'DARTS-V2': 'output/search-cell-nas-bench-102/DARTS-V2-cifar10/checkpoint/', + 'GDAS' : 'output/search-cell-nas-bench-102/GDAS-cifar10/checkpoint/', + 'SETN' : 'output/search-cell-nas-bench-102/SETN-cifar10/checkpoint/', + 'ENAS' : 'output/search-cell-nas-bench-102/ENAS-cifar10/checkpoint/', + } + xseeds = {'RSPS' : [5349, 59613, 5983], + 'DARTS-V1': [11416, 72873, 81184], + 'DARTS-V2': [43330, 79405, 79423], + 'GDAS' : [19677, 884, 95950], + 'SETN' : [20518, 61817, 89144], + 'ENAS' : [30801, 75610, 97745], + } + + def get_accs(xdata, index=-1): + if index == -1: + epochs = xdata['epoch'] + genotype = xdata['genotypes'][epochs-1] + index = api.query_index_by_arch(genotype) + pairs = [('cifar10-valid', 'x-valid'), ('cifar10', 'ori-test'), ('cifar100', 'x-valid'), ('cifar100', 'x-test'), ('ImageNet16-120', 'x-valid'), ('ImageNet16-120', 'x-test')] + xresults = [] + for dataset, xset in pairs: + metrics = api.arch2infos_full[index].get_metrics(dataset, xset, None, False) + xresults.append( metrics['accuracy'] ) + return xresults + + for xkey in xpaths.keys(): + all_paths = [ '{:}/seed-{:}-basic.pth'.format(xpaths[xkey], seed) for seed in xseeds[xkey] ] + all_datas = [torch.load(xpath) for xpath in all_paths] + accyss = [get_accs(xdatas) for xdatas in all_datas] + accyss = np.array( accyss ) + print('\nxkey = {:}'.format(xkey)) + for i in range(accyss.shape[1]): print('---->>>> {:.2f}$\\pm${:.2f}'.format(accyss[:,i].mean(), accyss[:,i].std())) + + print('\n{:}'.format(get_accs(None, 11472))) # resnet + pairs = [('cifar10-valid', 'x-valid'), ('cifar10', 'ori-test'), ('cifar100', 'x-valid'), ('cifar100', 'x-test'), ('ImageNet16-120', 'x-valid'), ('ImageNet16-120', 'x-test')] + for dataset, metric_on_set in pairs: + arch_index, highest_acc = api.find_best(dataset, metric_on_set) + print ('[{:10s}-{:10s} ::: index={:5d}, accuracy={:.2f}'.format(dataset, metric_on_set, arch_index, highest_acc)) + + +def show_nas_sharing_w(api, dataset, subset, vis_save_dir, file_name, y_lims): + color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k'] + dpi, width, height = 300, 3400, 2600 + LabelSize, LegendFontsize = 28, 28 + figsize = width / float(dpi), height / float(dpi) + fig = plt.figure(figsize=figsize) + x_maxs = 250 + x_axis = np.arange(0, x_maxs) + plt.xlim(0, x_maxs) + plt.ylim(y_lims[0], y_lims[1]) + interval_x, interval_y = x_maxs // 5, y_lims[2] + plt.xticks(np.arange(0, x_maxs, interval_x), fontsize=LegendFontsize) + plt.yticks(np.arange(y_lims[0],y_lims[1], interval_y), fontsize=LegendFontsize) + plt.grid() + plt.xlabel('The searching epoch', fontsize=LabelSize) + plt.ylabel('The accuracy (%)', fontsize=LabelSize) + + xpaths = {'RSPS' : 'output/search-cell-nas-bench-102/RANDOM-NAS-cifar10/checkpoint/', + 'DARTS-V1': 'output/search-cell-nas-bench-102/DARTS-V1-cifar10/checkpoint/', + 'DARTS-V2': 'output/search-cell-nas-bench-102/DARTS-V2-cifar10/checkpoint/', + 'GDAS' : 'output/search-cell-nas-bench-102/GDAS-cifar10/checkpoint/', + 'SETN' : 'output/search-cell-nas-bench-102/SETN-cifar10/checkpoint/', + 'ENAS' : 'output/search-cell-nas-bench-102/ENAS-cifar10/checkpoint/', + } + xseeds = {'RSPS' : [5349, 59613, 5983], + 'DARTS-V1': [11416, 72873, 81184], + 'DARTS-V2': [43330, 79405, 79423], + 'GDAS' : [19677, 884, 95950], + 'SETN' : [20518, 61817, 89144], + 'ENAS' : [30801, 75610, 97745], + } + + def get_accs(xdata): + epochs, xresults = xdata['epoch'], [] + for iepoch in range(epochs): + genotype = xdata['genotypes'][iepoch] + index = api.query_index_by_arch(genotype) + metrics = api.arch2infos_full[index].get_metrics(dataset, subset, None, False) + xresults.append( metrics['accuracy'] ) + return xresults + + for idx, method in enumerate(['RSPS', 'GDAS', 'SETN', 'ENAS']): + xkey = method + all_paths = [ '{:}/seed-{:}-basic.pth'.format(xpaths[xkey], seed) for seed in xseeds[xkey] ] + all_datas = [torch.load(xpath) for xpath in all_paths] + accyss = [get_accs(xdatas) for xdatas in all_datas] + accyss = np.array( accyss ) + epochs = list(range(accyss.shape[1])) + plt.plot(epochs, [accyss[:,i].mean() for i in epochs], color=color_set[idx], linestyle='-', label='{:}'.format(method), lw=2) + plt.fill_between(epochs, [accyss[:,i].mean()-accyss[:,i].std() for i in epochs], [accyss[:,i].mean()+accyss[:,i].std() for i in epochs], alpha=0.2, color=color_set[idx]) + plt.legend(loc=4, fontsize=LegendFontsize) + save_path = vis_save_dir / '{:}-{:}-{:}'.format(dataset, subset, file_name) + print('save figure into {:}\n'.format(save_path)) + fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf') + + + if __name__ == '__main__': parser = argparse.ArgumentParser(description='NAS-Bench-102', formatter_class=argparse.ArgumentDefaultsHelpFormatter) @@ -384,3 +538,15 @@ if __name__ == '__main__': #visualize_info(str(meta_file), 'cifar100', vis_save_dir) #visualize_info(str(meta_file), 'ImageNet16-120', vis_save_dir) #visualize_relative_ranking(vis_save_dir) + + api = API(args.api_path) + show_nas_sharing_w(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-plot.pdf', (5,95,10)) + """ + just_show(api) + plot_results_nas(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (85,95, 1)) + plot_results_nas(api, 'cifar10' , 'ori-test', vis_save_dir, 'nas-com.pdf', (85,95, 1)) + plot_results_nas(api, 'cifar100' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (55,75, 3)) + plot_results_nas(api, 'cifar100' , 'x-test' , vis_save_dir, 'nas-com.pdf', (55,75, 3)) + plot_results_nas(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-com.pdf', (35,50, 3)) + plot_results_nas(api, 'ImageNet16-120', 'x-test' , vis_save_dir, 'nas-com.pdf', (35,50, 3)) + """ diff --git a/exps/vis/show-results.py b/exps/vis/show-results.py deleted file mode 100644 index 55545c9..0000000 --- a/exps/vis/show-results.py +++ /dev/null @@ -1,135 +0,0 @@ -################################################## -# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # -################################################## -# python ./exps/vis/show-results.py --api_path ${HOME}/.torch/NAS-Bench-102-v1_0-e61699.pth -################################################## -import os, sys, argparse -from pathlib import Path -import torch -import numpy as np -from collections import OrderedDict -import matplotlib -matplotlib.use('agg') -import matplotlib.pyplot as plt -lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() -if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) - -from nas_102_api import NASBench102API as API - - -def plot_results_nas(api, dataset, xset, root, file_name, y_lims): - print ('root-path={:}, dataset={:}, xset={:}'.format(root, dataset, xset)) - checkpoints = ['./output/search-cell-nas-bench-102/R-EA-cifar10/results.pth', - './output/search-cell-nas-bench-102/REINFORCE-cifar10/results.pth', - './output/search-cell-nas-bench-102/RAND-cifar10/results.pth', - './output/search-cell-nas-bench-102/BOHB-cifar10/results.pth' - ] - legends, indexes = ['REA', 'REINFORCE', 'RANDOM', 'BOHB'], None - All_Accs = OrderedDict() - for legend, checkpoint in zip(legends, checkpoints): - all_indexes = torch.load(checkpoint, map_location='cpu') - accuracies = [] - for x in all_indexes: - info = api.arch2infos_full[ x ] - metrics = info.get_metrics(dataset, xset, None, False) - accuracies.append( metrics['accuracy'] ) - if indexes is None: indexes = list(range(len(all_indexes))) - All_Accs[legend] = sorted(accuracies) - - color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k'] - dpi, width, height = 300, 3400, 2600 - LabelSize, LegendFontsize = 28, 28 - figsize = width / float(dpi), height / float(dpi) - fig = plt.figure(figsize=figsize) - x_axis = np.arange(0, 600) - plt.xlim(0, max(indexes)) - plt.ylim(y_lims[0], y_lims[1]) - interval_x, interval_y = 100, y_lims[2] - plt.xticks(np.arange(0, max(indexes), interval_x), fontsize=LegendFontsize) - plt.yticks(np.arange(y_lims[0],y_lims[1], interval_y), fontsize=LegendFontsize) - plt.grid() - plt.xlabel('The index of runs', fontsize=LabelSize) - plt.ylabel('The accuracy (%)', fontsize=LabelSize) - - for idx, legend in enumerate(legends): - plt.plot(indexes, All_Accs[legend], color=color_set[idx], linestyle='-', label='{:}'.format(legend), lw=2) - print ('{:} : mean = {:}, std = {:} :: {:.2f}$\\pm${:.2f}'.format(legend, np.mean(All_Accs[legend]), np.std(All_Accs[legend]), np.mean(All_Accs[legend]), np.std(All_Accs[legend]))) - plt.legend(loc=4, fontsize=LegendFontsize) - save_path = root / '{:}-{:}-{:}'.format(dataset, xset, file_name) - print('save figure into {:}\n'.format(save_path)) - fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf') - - -def just_show(api): - xtimes = {'RSPS': [8082.5, 7794.2, 8144.7], - 'DARTS-V1': [11582.1, 11347.0, 11948.2], - 'DARTS-V2': [35694.7, 36132.7, 35518.0], - 'GDAS' : [31334.1, 31478.6, 32016.7], - 'SETN' : [33528.8, 33831.5, 35058.3], - 'ENAS' : [14340.2, 13817.3, 14018.9]} - for xkey, xlist in xtimes.items(): - xlist = np.array(xlist) - print ('{:4s} : mean-time={:.1f} s'.format(xkey, xlist.mean())) - - xpaths = {'RSPS' : 'output/search-cell-nas-bench-102/RANDOM-NAS-cifar10/checkpoint/', - 'DARTS-V1': 'output/search-cell-nas-bench-102/DARTS-V1-cifar10/checkpoint/', - 'DARTS-V2': 'output/search-cell-nas-bench-102/DARTS-V2-cifar10/checkpoint/', - 'GDAS' : 'output/search-cell-nas-bench-102/GDAS-cifar10/checkpoint/', - 'SETN' : 'output/search-cell-nas-bench-102/SETN-cifar10/checkpoint/', - 'ENAS' : 'output/search-cell-nas-bench-102/ENAS-cifar10/checkpoint/', - } - xseeds = {'RSPS' : [5349, 59613, 5983], - 'DARTS-V1': [11416, 72873, 81184], - 'DARTS-V2': [43330, 79405, 79423], - 'GDAS' : [19677, 884, 95950], - 'SETN' : [20518, 61817, 89144], - 'ENAS' : [30801, 75610, 97745], - } - - def get_accs(xdata, index=-1): - if index == -1: - epochs = xdata['epoch'] - genotype = xdata['genotypes'][epochs-1] - index = api.query_index_by_arch(genotype) - pairs = [('cifar10-valid', 'x-valid'), ('cifar10', 'ori-test'), ('cifar100', 'x-valid'), ('cifar100', 'x-test'), ('ImageNet16-120', 'x-valid'), ('ImageNet16-120', 'x-test')] - xresults = [] - for dataset, xset in pairs: - metrics = api.arch2infos_full[index].get_metrics(dataset, xset, None, False) - xresults.append( metrics['accuracy'] ) - return xresults - - for xkey in xpaths.keys(): - all_paths = [ '{:}/seed-{:}-basic.pth'.format(xpaths[xkey], seed) for seed in xseeds[xkey] ] - all_datas = [torch.load(xpath) for xpath in all_paths] - accyss = [get_accs(xdatas) for xdatas in all_datas] - accyss = np.array( accyss ) - print('\nxkey = {:}'.format(xkey)) - for i in range(accyss.shape[1]): print('---->>>> {:.2f}$\\pm${:.2f}'.format(accyss[:,i].mean(), accyss[:,i].std())) - - print('\n{:}'.format(get_accs(None, 11472))) # resnet - pairs = [('cifar10-valid', 'x-valid'), ('cifar10', 'ori-test'), ('cifar100', 'x-valid'), ('cifar100', 'x-test'), ('ImageNet16-120', 'x-valid'), ('ImageNet16-120', 'x-test')] - for dataset, metric_on_set in pairs: - arch_index, highest_acc = api.find_best(dataset, metric_on_set) - print ('[{:10s}-{:10s} ::: index={:5d}, accuracy={:.2f}'.format(dataset, metric_on_set, arch_index, highest_acc)) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='NAS-Bench-102', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--save_dir', type=str, default='./output/search-cell-nas-bench-102/visuals', help='The base-name of folder to save checkpoints and log.') - parser.add_argument('--api_path', type=str, default=None, help='The path to the NAS-Bench-102 benchmark file.') - args = parser.parse_args() - - api = API(args.api_path) - - root = Path(args.save_dir).resolve() - root.mkdir(parents=True, exist_ok=True) - - just_show(api) - """ - plot_results_nas(api, 'cifar10-valid' , 'x-valid' , root, 'nas-com.pdf', (85,95, 1)) - plot_results_nas(api, 'cifar10' , 'ori-test', root, 'nas-com.pdf', (85,95, 1)) - plot_results_nas(api, 'cifar100' , 'x-valid' , root, 'nas-com.pdf', (55,75, 3)) - plot_results_nas(api, 'cifar100' , 'x-test' , root, 'nas-com.pdf', (55,75, 3)) - plot_results_nas(api, 'ImageNet16-120', 'x-valid' , root, 'nas-com.pdf', (35,50, 3)) - plot_results_nas(api, 'ImageNet16-120', 'x-test' , root, 'nas-com.pdf', (35,50, 3)) - """