update vis
This commit is contained in:
parent
9ec25663f1
commit
28e4b8406f
@ -1,3 +1,6 @@
|
|||||||
|
##################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||||
|
##################################################
|
||||||
import sys, time, torch, random, argparse
|
import sys, time, torch, random, argparse
|
||||||
from PIL import ImageFile
|
from PIL import ImageFile
|
||||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
import os, sys, time, glob, random, argparse
|
import os, sys, time, glob, random, argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -142,15 +143,17 @@ def check_unique_arch(meta_file):
|
|||||||
print ('{:} There are {:} unique architectures (considering zero).'.format(time_string(), unique_num))
|
print ('{:} There are {:} unique architectures (considering zero).'.format(time_string(), unique_num))
|
||||||
|
|
||||||
|
|
||||||
def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True):
|
def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True, need_print=False):
|
||||||
if isinstance(meta_file, API):
|
if isinstance(meta_file, API):
|
||||||
api = meta_file
|
api = meta_file
|
||||||
else:
|
else:
|
||||||
api = API(str(meta_file))
|
api = API(str(meta_file))
|
||||||
cifar10_valid = []
|
cifar10_valid = []
|
||||||
cifar10_test = []
|
cifar10_test = []
|
||||||
|
cifar100_valid = []
|
||||||
cifar100_test = []
|
cifar100_test = []
|
||||||
imagenet_test = []
|
imagenet_test = []
|
||||||
|
imagenet_valid = []
|
||||||
for idx, arch in enumerate(api):
|
for idx, arch in enumerate(api):
|
||||||
results = api.get_more_info(idx, 'cifar10-valid' , test_epoch-1, use_less_or_not, is_rand)
|
results = api.get_more_info(idx, 'cifar10-valid' , test_epoch-1, use_less_or_not, is_rand)
|
||||||
cifar10_valid.append( results['valid-accuracy'] )
|
cifar10_valid.append( results['valid-accuracy'] )
|
||||||
@ -158,14 +161,16 @@ def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True):
|
|||||||
cifar10_test.append( results['test-accuracy'] )
|
cifar10_test.append( results['test-accuracy'] )
|
||||||
results = api.get_more_info(idx, 'cifar100' , None, False, is_rand)
|
results = api.get_more_info(idx, 'cifar100' , None, False, is_rand)
|
||||||
cifar100_test.append( results['test-accuracy'] )
|
cifar100_test.append( results['test-accuracy'] )
|
||||||
|
cifar100_valid.append( results['valid-accuracy'] )
|
||||||
results = api.get_more_info(idx, 'ImageNet16-120', None, False, is_rand)
|
results = api.get_more_info(idx, 'ImageNet16-120', None, False, is_rand)
|
||||||
imagenet_test.append( results['test-accuracy'] )
|
imagenet_test.append( results['test-accuracy'] )
|
||||||
|
imagenet_valid.append( results['valid-accuracy'] )
|
||||||
def get_cor(A, B):
|
def get_cor(A, B):
|
||||||
return float(np.corrcoef(A, B)[0,1])
|
return float(np.corrcoef(A, B)[0,1])
|
||||||
cors = []
|
cors = []
|
||||||
for basestr, xlist in zip(['CIFAR-010', 'CIFAR-100', 'ImageNet16'], [cifar10_test,cifar100_test, imagenet_test]):
|
for basestr, xlist in zip(['CIFAR-010', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T'], [cifar10_test, cifar100_valid, cifar100_test, imagenet_valid, imagenet_test]):
|
||||||
correlation = get_cor(cifar10_valid, xlist)
|
correlation = get_cor(cifar10_valid, xlist)
|
||||||
print ('With {:3d}/{:}-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(less_epoch, '012' if use_less_or_not else '200', basestr, correlation))
|
if need_print: print ('With {:3d}/{:}-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, '012' if use_less_or_not else '200', basestr, correlation))
|
||||||
cors.append( correlation )
|
cors.append( correlation )
|
||||||
#print ('With {:3d}/200-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, basestr, get_cor(cifar10_valid_200, xlist)))
|
#print ('With {:3d}/200-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, basestr, get_cor(cifar10_valid_200, xlist)))
|
||||||
#print('-'*200)
|
#print('-'*200)
|
||||||
@ -173,6 +178,19 @@ def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True):
|
|||||||
return cors
|
return cors
|
||||||
|
|
||||||
|
|
||||||
|
def check_cor_for_bandit_v2(meta_file, test_epoch, use_less_or_not, is_rand):
|
||||||
|
corrs = []
|
||||||
|
for i in tqdm(range(100)):
|
||||||
|
x = check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand, False)
|
||||||
|
corrs.append( x )
|
||||||
|
xstrs = ['CIFAR-010', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T']
|
||||||
|
correlations = np.array(corrs)
|
||||||
|
print('------>>>>>>>> {:03d}/{:} >>>>>>>> ------'.format(test_epoch, '012' if use_less_or_not else '200'))
|
||||||
|
for idx, xstr in enumerate(xstrs):
|
||||||
|
print ('{:8s} ::: mean={:.4f}, std={:.4f} :: {:.4f}\\pm{:.4f}'.format(xstr, correlations[:,idx].mean(), correlations[:,idx].std(), correlations[:,idx].mean(), correlations[:,idx].std()))
|
||||||
|
print('')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser("Analysis of NAS-Bench-102")
|
parser = argparse.ArgumentParser("Analysis of NAS-Bench-102")
|
||||||
parser.add_argument('--save_dir', type=str, default='./output/search-cell-nas-bench-102/visuals', help='The base-name of folder to save checkpoints and log.')
|
parser.add_argument('--save_dir', type=str, default='./output/search-cell-nas-bench-102/visuals', help='The base-name of folder to save checkpoints and log.')
|
||||||
@ -189,5 +207,11 @@ if __name__ == '__main__':
|
|||||||
#for iepoch in [11, 25, 50, 100, 150, 175, 200]:
|
#for iepoch in [11, 25, 50, 100, 150, 175, 200]:
|
||||||
# check_cor_for_bandit(api, 6, iepoch)
|
# check_cor_for_bandit(api, 6, iepoch)
|
||||||
# check_cor_for_bandit(api, 12, iepoch)
|
# check_cor_for_bandit(api, 12, iepoch)
|
||||||
correlations = check_cor_for_bandit(api, 6, True, True)
|
check_cor_for_bandit_v2(api, 6, True, True)
|
||||||
import pdb; pdb.set_trace()
|
check_cor_for_bandit_v2(api, 12, True, True)
|
||||||
|
check_cor_for_bandit_v2(api, 12, False, True)
|
||||||
|
check_cor_for_bandit_v2(api, 24, False, True)
|
||||||
|
check_cor_for_bandit_v2(api, 100, False, True)
|
||||||
|
check_cor_for_bandit_v2(api, 150, False, True)
|
||||||
|
check_cor_for_bandit_v2(api, 200, False, True)
|
||||||
|
print('----')
|
||||||
|
@ -383,4 +383,4 @@ if __name__ == '__main__':
|
|||||||
#visualize_info(str(meta_file), 'cifar10' , vis_save_dir)
|
#visualize_info(str(meta_file), 'cifar10' , vis_save_dir)
|
||||||
#visualize_info(str(meta_file), 'cifar100', vis_save_dir)
|
#visualize_info(str(meta_file), 'cifar100', vis_save_dir)
|
||||||
#visualize_info(str(meta_file), 'ImageNet16-120', vis_save_dir)
|
#visualize_info(str(meta_file), 'ImageNet16-120', vis_save_dir)
|
||||||
visualize_relative_ranking(vis_save_dir)
|
#visualize_relative_ranking(vis_save_dir)
|
||||||
|
@ -174,7 +174,7 @@ def main(xargs, nas_bench):
|
|||||||
|
|
||||||
id2config = results.get_id2config_mapping()
|
id2config = results.get_id2config_mapping()
|
||||||
incumbent = results.get_incumbent_id()
|
incumbent = results.get_incumbent_id()
|
||||||
logger.log('Best found configuration: {:}'.format(id2config[incumbent]['config']))
|
logger.log('Best found configuration: {:} within {:.3f} s'.format(id2config[incumbent]['config'], real_cost_time))
|
||||||
best_arch = config2structure( id2config[incumbent]['config'] )
|
best_arch = config2structure( id2config[incumbent]['config'] )
|
||||||
|
|
||||||
info = nas_bench.query_by_arch( best_arch )
|
info = nas_bench.query_by_arch( best_arch )
|
||||||
|
@ -56,6 +56,7 @@ def main(xargs, nas_bench):
|
|||||||
search_space = get_search_spaces('cell', xargs.search_space_name)
|
search_space = get_search_spaces('cell', xargs.search_space_name)
|
||||||
random_arch = random_architecture_func(xargs.max_nodes, search_space)
|
random_arch = random_architecture_func(xargs.max_nodes, search_space)
|
||||||
#x =random_arch() ; y = mutate_arch(x)
|
#x =random_arch() ; y = mutate_arch(x)
|
||||||
|
x_start_time = time.time()
|
||||||
logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench))
|
logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench))
|
||||||
best_arch, best_acc, total_time_cost, history = None, -1, 0, []
|
best_arch, best_acc, total_time_cost, history = None, -1, 0, []
|
||||||
#for idx in range(xargs.random_num):
|
#for idx in range(xargs.random_num):
|
||||||
@ -68,7 +69,7 @@ def main(xargs, nas_bench):
|
|||||||
if best_arch is None or best_acc < accuracy:
|
if best_arch is None or best_acc < accuracy:
|
||||||
best_acc, best_arch = accuracy, arch
|
best_acc, best_arch = accuracy, arch
|
||||||
logger.log('[{:03d}] : {:} : accuracy = {:.2f}%'.format(len(history), arch, accuracy))
|
logger.log('[{:03d}] : {:} : accuracy = {:.2f}%'.format(len(history), arch, accuracy))
|
||||||
logger.log('{:} best arch is {:}, accuracy = {:.2f}%, visit {:} archs with {:.1f} s.'.format(time_string(), best_arch, best_acc, len(history), total_time_cost))
|
logger.log('{:} best arch is {:}, accuracy = {:.2f}%, visit {:} archs with {:.1f} s (real-cost = {:.3f} s).'.format(time_string(), best_arch, best_acc, len(history), total_time_cost, time.time()-x_start_time))
|
||||||
|
|
||||||
info = nas_bench.query_by_arch( best_arch )
|
info = nas_bench.query_by_arch( best_arch )
|
||||||
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
|
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
|
||||||
|
@ -201,10 +201,11 @@ def main(xargs, nas_bench):
|
|||||||
random_arch = random_architecture_func(xargs.max_nodes, search_space)
|
random_arch = random_architecture_func(xargs.max_nodes, search_space)
|
||||||
mutate_arch = mutate_arch_func(search_space)
|
mutate_arch = mutate_arch_func(search_space)
|
||||||
#x =random_arch() ; y = mutate_arch(x)
|
#x =random_arch() ; y = mutate_arch(x)
|
||||||
|
x_start_time = time.time()
|
||||||
logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench))
|
logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench))
|
||||||
logger.log('-'*30 + ' start searching with the time budget of {:} s'.format(xargs.time_budget))
|
logger.log('-'*30 + ' start searching with the time budget of {:} s'.format(xargs.time_budget))
|
||||||
history, total_cost = regularized_evolution(xargs.ea_cycles, xargs.ea_population, xargs.ea_sample_size, xargs.time_budget, random_arch, mutate_arch, nas_bench if args.ea_fast_by_api else None, extra_info)
|
history, total_cost = regularized_evolution(xargs.ea_cycles, xargs.ea_population, xargs.ea_sample_size, xargs.time_budget, random_arch, mutate_arch, nas_bench if args.ea_fast_by_api else None, extra_info)
|
||||||
logger.log('{:} regularized_evolution finish with history of {:} arch with {:.1f} s.'.format(time_string(), len(history), total_cost))
|
logger.log('{:} regularized_evolution finish with history of {:} arch with {:.1f} s (real-cost={:.2f} s).'.format(time_string(), len(history), total_cost, time.time()-x_start_time))
|
||||||
best_arch = max(history, key=lambda i: i.accuracy)
|
best_arch = max(history, key=lambda i: i.accuracy)
|
||||||
best_arch = best_arch.arch
|
best_arch = best_arch.arch
|
||||||
logger.log('{:} best arch is {:}'.format(time_string(), best_arch))
|
logger.log('{:} best arch is {:}'.format(time_string(), best_arch))
|
||||||
|
@ -139,6 +139,7 @@ def main(xargs, nas_bench):
|
|||||||
|
|
||||||
# REINFORCE
|
# REINFORCE
|
||||||
# attempts = 0
|
# attempts = 0
|
||||||
|
x_start_time = time.time()
|
||||||
logger.log('Will start searching with time budget of {:} s.'.format(xargs.time_budget))
|
logger.log('Will start searching with time budget of {:} s.'.format(xargs.time_budget))
|
||||||
total_steps, total_costs = 0, 0
|
total_steps, total_costs = 0, 0
|
||||||
#for istep in range(xargs.RL_steps):
|
#for istep in range(xargs.RL_steps):
|
||||||
@ -166,7 +167,7 @@ def main(xargs, nas_bench):
|
|||||||
#logger.log('')
|
#logger.log('')
|
||||||
|
|
||||||
best_arch = policy.genotype()
|
best_arch = policy.genotype()
|
||||||
logger.log('REINFORCE finish with {:} steps and {:.1f} s.'.format(total_steps, total_costs))
|
logger.log('REINFORCE finish with {:} steps and {:.1f} s (real cost={:.3f}).'.format(total_steps, total_costs, time.time()-x_start_time))
|
||||||
info = nas_bench.query_by_arch( best_arch )
|
info = nas_bench.query_by_arch( best_arch )
|
||||||
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
|
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
|
||||||
else : logger.log('{:}'.format(info))
|
else : logger.log('{:}'.format(info))
|
||||||
|
@ -1,3 +1,6 @@
|
|||||||
|
##################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||||
|
##################################################
|
||||||
import os, sys, time, torch, random, argparse
|
import os, sys, time, torch, random, argparse
|
||||||
from PIL import ImageFile
|
from PIL import ImageFile
|
||||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||||
|
@ -1,3 +1,6 @@
|
|||||||
|
##################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||||
|
##################################################
|
||||||
import sys, time, torch, random, argparse
|
import sys, time, torch, random, argparse
|
||||||
from PIL import ImageFile
|
from PIL import ImageFile
|
||||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||||
|
@ -1 +0,0 @@
|
|||||||
from graphviz import Digraph
|
|
@ -1,27 +1,28 @@
|
|||||||
# python ./vis-exps/show-results.py
|
##################################################
|
||||||
import os, sys
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||||
|
##################################################
|
||||||
|
# python ./exps/vis/show-results.py --api_path ${HOME}/.torch/NAS-Bench-102-v1_0-e61699.pth
|
||||||
|
##################################################
|
||||||
|
import os, sys, argparse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
|
||||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
|
||||||
|
|
||||||
from aa_nas_api import AANASBenchAPI
|
|
||||||
|
|
||||||
api = AANASBenchAPI('./output/AA-NAS-BENCH-4/simplifies/C16-N5-final-infos.pth')
|
|
||||||
|
|
||||||
def plot_results_nas(dataset, xset, file_name, y_lims):
|
|
||||||
import matplotlib
|
import matplotlib
|
||||||
matplotlib.use('agg')
|
matplotlib.use('agg')
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
root = Path('./output/cell-search-tiny-vis').resolve()
|
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||||
print ('root path : {:}'.format( root ))
|
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||||
root.mkdir(parents=True, exist_ok=True)
|
|
||||||
checkpoints = ['./output/cell-search-tiny/R-EA-cifar10/results.pth',
|
from nas_102_api import NASBench102API as API
|
||||||
'./output/cell-search-tiny/REINFORCE-cifar10/results.pth',
|
|
||||||
'./output/cell-search-tiny/RAND-cifar10/results.pth',
|
|
||||||
'./output/cell-search-tiny/BOHB-cifar10/results.pth'
|
def plot_results_nas(api, dataset, xset, root, file_name, y_lims):
|
||||||
|
print ('root-path={:}, dataset={:}, xset={:}'.format(root, dataset, xset))
|
||||||
|
checkpoints = ['./output/search-cell-nas-bench-102/R-EA-cifar10/results.pth',
|
||||||
|
'./output/search-cell-nas-bench-102/REINFORCE-cifar10/results.pth',
|
||||||
|
'./output/search-cell-nas-bench-102/RAND-cifar10/results.pth',
|
||||||
|
'./output/search-cell-nas-bench-102/BOHB-cifar10/results.pth'
|
||||||
]
|
]
|
||||||
legends, indexes = ['REA', 'REINFORCE', 'RANDOM', 'BOHB'], None
|
legends, indexes = ['REA', 'REINFORCE', 'RANDOM', 'BOHB'], None
|
||||||
All_Accs = OrderedDict()
|
All_Accs = OrderedDict()
|
||||||
@ -29,15 +30,15 @@ def plot_results_nas(dataset, xset, file_name, y_lims):
|
|||||||
all_indexes = torch.load(checkpoint, map_location='cpu')
|
all_indexes = torch.load(checkpoint, map_location='cpu')
|
||||||
accuracies = []
|
accuracies = []
|
||||||
for x in all_indexes:
|
for x in all_indexes:
|
||||||
info = api.arch2infos[ x ]
|
info = api.arch2infos_full[ x ]
|
||||||
_, accy = info.get_metrics(dataset, xset, None, False)
|
metrics = info.get_metrics(dataset, xset, None, False)
|
||||||
accuracies.append( accy )
|
accuracies.append( metrics['accuracy'] )
|
||||||
if indexes is None: indexes = list(range(len(all_indexes)))
|
if indexes is None: indexes = list(range(len(all_indexes)))
|
||||||
All_Accs[legend] = sorted(accuracies)
|
All_Accs[legend] = sorted(accuracies)
|
||||||
|
|
||||||
color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k']
|
color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k']
|
||||||
dpi, width, height = 300, 3400, 2600
|
dpi, width, height = 300, 3400, 2600
|
||||||
LabelSize, LegendFontsize = 26, 26
|
LabelSize, LegendFontsize = 28, 28
|
||||||
figsize = width / float(dpi), height / float(dpi)
|
figsize = width / float(dpi), height / float(dpi)
|
||||||
fig = plt.figure(figsize=figsize)
|
fig = plt.figure(figsize=figsize)
|
||||||
x_axis = np.arange(0, 600)
|
x_axis = np.arange(0, 600)
|
||||||
@ -52,16 +53,83 @@ def plot_results_nas(dataset, xset, file_name, y_lims):
|
|||||||
|
|
||||||
for idx, legend in enumerate(legends):
|
for idx, legend in enumerate(legends):
|
||||||
plt.plot(indexes, All_Accs[legend], color=color_set[idx], linestyle='-', label='{:}'.format(legend), lw=2)
|
plt.plot(indexes, All_Accs[legend], color=color_set[idx], linestyle='-', label='{:}'.format(legend), lw=2)
|
||||||
print ('{:} : mean = {:}, std = {:}'.format(legend, np.mean(All_Accs[legend]), np.std(All_Accs[legend])))
|
print ('{:} : mean = {:}, std = {:} :: {:.2f}$\\pm${:.2f}'.format(legend, np.mean(All_Accs[legend]), np.std(All_Accs[legend]), np.mean(All_Accs[legend]), np.std(All_Accs[legend])))
|
||||||
plt.legend(loc=4, fontsize=LegendFontsize)
|
plt.legend(loc=4, fontsize=LegendFontsize)
|
||||||
save_path = root / '{:}-{:}-{:}'.format(dataset, xset, file_name)
|
save_path = root / '{:}-{:}-{:}'.format(dataset, xset, file_name)
|
||||||
print('save figure into {:}\n'.format(save_path))
|
print('save figure into {:}\n'.format(save_path))
|
||||||
fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf')
|
fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf')
|
||||||
|
|
||||||
|
|
||||||
|
def just_show(api):
|
||||||
|
xtimes = {'RSPS': [8082.5, 7794.2, 8144.7],
|
||||||
|
'DARTS-V1': [11582.1, 11347.0, 11948.2],
|
||||||
|
'DARTS-V2': [35694.7, 36132.7, 35518.0],
|
||||||
|
'GDAS' : [31334.1, 31478.6, 32016.7],
|
||||||
|
'SETN' : [33528.8, 33831.5, 35058.3],
|
||||||
|
'ENAS' : [14340.2, 13817.3, 14018.9]}
|
||||||
|
for xkey, xlist in xtimes.items():
|
||||||
|
xlist = np.array(xlist)
|
||||||
|
print ('{:4s} : mean-time={:.1f} s'.format(xkey, xlist.mean()))
|
||||||
|
|
||||||
|
xpaths = {'RSPS' : 'output/search-cell-nas-bench-102/RANDOM-NAS-cifar10/checkpoint/',
|
||||||
|
'DARTS-V1': 'output/search-cell-nas-bench-102/DARTS-V1-cifar10/checkpoint/',
|
||||||
|
'DARTS-V2': 'output/search-cell-nas-bench-102/DARTS-V2-cifar10/checkpoint/',
|
||||||
|
'GDAS' : 'output/search-cell-nas-bench-102/GDAS-cifar10/checkpoint/',
|
||||||
|
'SETN' : 'output/search-cell-nas-bench-102/SETN-cifar10/checkpoint/',
|
||||||
|
'ENAS' : 'output/search-cell-nas-bench-102/ENAS-cifar10/checkpoint/',
|
||||||
|
}
|
||||||
|
xseeds = {'RSPS' : [5349, 59613, 5983],
|
||||||
|
'DARTS-V1': [11416, 72873, 81184],
|
||||||
|
'DARTS-V2': [43330, 79405, 79423],
|
||||||
|
'GDAS' : [19677, 884, 95950],
|
||||||
|
'SETN' : [20518, 61817, 89144],
|
||||||
|
'ENAS' : [30801, 75610, 97745],
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_accs(xdata, index=-1):
|
||||||
|
if index == -1:
|
||||||
|
epochs = xdata['epoch']
|
||||||
|
genotype = xdata['genotypes'][epochs-1]
|
||||||
|
index = api.query_index_by_arch(genotype)
|
||||||
|
pairs = [('cifar10-valid', 'x-valid'), ('cifar10', 'ori-test'), ('cifar100', 'x-valid'), ('cifar100', 'x-test'), ('ImageNet16-120', 'x-valid'), ('ImageNet16-120', 'x-test')]
|
||||||
|
xresults = []
|
||||||
|
for dataset, xset in pairs:
|
||||||
|
metrics = api.arch2infos_full[index].get_metrics(dataset, xset, None, False)
|
||||||
|
xresults.append( metrics['accuracy'] )
|
||||||
|
return xresults
|
||||||
|
|
||||||
|
for xkey in xpaths.keys():
|
||||||
|
all_paths = [ '{:}/seed-{:}-basic.pth'.format(xpaths[xkey], seed) for seed in xseeds[xkey] ]
|
||||||
|
all_datas = [torch.load(xpath) for xpath in all_paths]
|
||||||
|
accyss = [get_accs(xdatas) for xdatas in all_datas]
|
||||||
|
accyss = np.array( accyss )
|
||||||
|
print('\nxkey = {:}'.format(xkey))
|
||||||
|
for i in range(accyss.shape[1]): print('---->>>> {:.2f}$\\pm${:.2f}'.format(accyss[:,i].mean(), accyss[:,i].std()))
|
||||||
|
|
||||||
|
print('\n{:}'.format(get_accs(None, 11472))) # resnet
|
||||||
|
pairs = [('cifar10-valid', 'x-valid'), ('cifar10', 'ori-test'), ('cifar100', 'x-valid'), ('cifar100', 'x-test'), ('ImageNet16-120', 'x-valid'), ('ImageNet16-120', 'x-test')]
|
||||||
|
for dataset, metric_on_set in pairs:
|
||||||
|
arch_index, highest_acc = api.find_best(dataset, metric_on_set)
|
||||||
|
print ('[{:10s}-{:10s} ::: index={:5d}, accuracy={:.2f}'.format(dataset, metric_on_set, arch_index, highest_acc))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
plot_results_nas('cifar10', 'ori-test', 'nas-com.pdf', (85,95, 1))
|
parser = argparse.ArgumentParser(description='NAS-Bench-102', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
plot_results_nas('cifar100', 'x-valid', 'nas-com.pdf', (55,75, 3))
|
parser.add_argument('--save_dir', type=str, default='./output/search-cell-nas-bench-102/visuals', help='The base-name of folder to save checkpoints and log.')
|
||||||
plot_results_nas('cifar100', 'x-test' , 'nas-com.pdf', (55,75, 3))
|
parser.add_argument('--api_path', type=str, default=None, help='The path to the NAS-Bench-102 benchmark file.')
|
||||||
plot_results_nas('ImageNet16-120', 'x-valid', 'nas-com.pdf', (35,50, 3))
|
args = parser.parse_args()
|
||||||
plot_results_nas('ImageNet16-120', 'x-test' , 'nas-com.pdf', (35,50, 3))
|
|
||||||
|
api = API(args.api_path)
|
||||||
|
|
||||||
|
root = Path(args.save_dir).resolve()
|
||||||
|
root.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
just_show(api)
|
||||||
|
"""
|
||||||
|
plot_results_nas(api, 'cifar10-valid' , 'x-valid' , root, 'nas-com.pdf', (85,95, 1))
|
||||||
|
plot_results_nas(api, 'cifar10' , 'ori-test', root, 'nas-com.pdf', (85,95, 1))
|
||||||
|
plot_results_nas(api, 'cifar100' , 'x-valid' , root, 'nas-com.pdf', (55,75, 3))
|
||||||
|
plot_results_nas(api, 'cifar100' , 'x-test' , root, 'nas-com.pdf', (55,75, 3))
|
||||||
|
plot_results_nas(api, 'ImageNet16-120', 'x-valid' , root, 'nas-com.pdf', (35,50, 3))
|
||||||
|
plot_results_nas(api, 'ImageNet16-120', 'x-test' , root, 'nas-com.pdf', (35,50, 3))
|
||||||
|
"""
|
||||||
|
@ -131,15 +131,17 @@ class NASBench102API(object):
|
|||||||
else : basestr, arch2infos = '200epochs', self.arch2infos_full
|
else : basestr, arch2infos = '200epochs', self.arch2infos_full
|
||||||
best_index, highest_accuracy = -1, None
|
best_index, highest_accuracy = -1, None
|
||||||
for i, idx in enumerate(self.evaluated_indexes):
|
for i, idx in enumerate(self.evaluated_indexes):
|
||||||
flop, param, latency = arch2infos[idx].get_comput_costs(dataset)
|
info = arch2infos[idx].get_comput_costs(dataset)
|
||||||
|
flop, param, latency = info['flops'], info['params'], info['latency']
|
||||||
if FLOP_max is not None and flop > FLOP_max : continue
|
if FLOP_max is not None and flop > FLOP_max : continue
|
||||||
if Param_max is not None and param > Param_max: continue
|
if Param_max is not None and param > Param_max: continue
|
||||||
loss, accuracy = arch2infos[idx].get_metrics(dataset, metric_on_set)
|
xinfo = arch2infos[idx].get_metrics(dataset, metric_on_set)
|
||||||
|
loss, accuracy = xinfo['loss'], xinfo['accuracy']
|
||||||
if best_index == -1:
|
if best_index == -1:
|
||||||
best_index, highest_accuracy = idx, accuracy
|
best_index, highest_accuracy = idx, accuracy
|
||||||
elif highest_accuracy < accuracy:
|
elif highest_accuracy < accuracy:
|
||||||
best_index, highest_accuracy = idx, accuracy
|
best_index, highest_accuracy = idx, accuracy
|
||||||
return best_index
|
return best_index, highest_accuracy
|
||||||
|
|
||||||
# return the topology structure of the `index`-th architecture
|
# return the topology structure of the `index`-th architecture
|
||||||
def arch(self, index):
|
def arch(self, index):
|
||||||
@ -183,10 +185,18 @@ class NASBench102API(object):
|
|||||||
test__info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
|
test__info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
|
||||||
else:
|
else:
|
||||||
test__info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random)
|
test__info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random)
|
||||||
return {'train-loss' : train_info['loss'],
|
try:
|
||||||
|
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
|
||||||
|
except:
|
||||||
|
valid_info = None
|
||||||
|
xifo = {'train-loss' : train_info['loss'],
|
||||||
'train-accuracy': train_info['accuracy'],
|
'train-accuracy': train_info['accuracy'],
|
||||||
'test-loss' : test__info['loss'],
|
'test-loss' : test__info['loss'],
|
||||||
'test-accuracy' : test__info['accuracy']}
|
'test-accuracy' : test__info['accuracy']}
|
||||||
|
if valid_info is not None:
|
||||||
|
xifo['valid-loss'] = valid_info['loss']
|
||||||
|
xifo['valid-accuracy'] = valid_info['accuracy']
|
||||||
|
return xifo
|
||||||
|
|
||||||
def show(self, index=-1):
|
def show(self, index=-1):
|
||||||
if index < 0: # show all architectures
|
if index < 0: # show all architectures
|
||||||
|
Loading…
Reference in New Issue
Block a user