update vis

This commit is contained in:
D-X-Y 2020-01-01 22:18:42 +11:00
parent 9ec25663f1
commit 28e4b8406f
12 changed files with 153 additions and 40 deletions

View File

@ -1,3 +1,6 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import sys, time, torch, random, argparse import sys, time, torch, random, argparse
from PIL import ImageFile from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True ImageFile.LOAD_TRUNCATED_IMAGES = True

View File

@ -6,6 +6,7 @@
import os, sys, time, glob, random, argparse import os, sys, time, glob, random, argparse
import numpy as np import numpy as np
from copy import deepcopy from copy import deepcopy
from tqdm import tqdm
import torch import torch
import torch.nn as nn import torch.nn as nn
from pathlib import Path from pathlib import Path
@ -142,15 +143,17 @@ def check_unique_arch(meta_file):
print ('{:} There are {:} unique architectures (considering zero).'.format(time_string(), unique_num)) print ('{:} There are {:} unique architectures (considering zero).'.format(time_string(), unique_num))
def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True): def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True, need_print=False):
if isinstance(meta_file, API): if isinstance(meta_file, API):
api = meta_file api = meta_file
else: else:
api = API(str(meta_file)) api = API(str(meta_file))
cifar10_valid = [] cifar10_valid = []
cifar10_test = [] cifar10_test = []
cifar100_valid = []
cifar100_test = [] cifar100_test = []
imagenet_test = [] imagenet_test = []
imagenet_valid = []
for idx, arch in enumerate(api): for idx, arch in enumerate(api):
results = api.get_more_info(idx, 'cifar10-valid' , test_epoch-1, use_less_or_not, is_rand) results = api.get_more_info(idx, 'cifar10-valid' , test_epoch-1, use_less_or_not, is_rand)
cifar10_valid.append( results['valid-accuracy'] ) cifar10_valid.append( results['valid-accuracy'] )
@ -158,14 +161,16 @@ def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True):
cifar10_test.append( results['test-accuracy'] ) cifar10_test.append( results['test-accuracy'] )
results = api.get_more_info(idx, 'cifar100' , None, False, is_rand) results = api.get_more_info(idx, 'cifar100' , None, False, is_rand)
cifar100_test.append( results['test-accuracy'] ) cifar100_test.append( results['test-accuracy'] )
cifar100_valid.append( results['valid-accuracy'] )
results = api.get_more_info(idx, 'ImageNet16-120', None, False, is_rand) results = api.get_more_info(idx, 'ImageNet16-120', None, False, is_rand)
imagenet_test.append( results['test-accuracy'] ) imagenet_test.append( results['test-accuracy'] )
imagenet_valid.append( results['valid-accuracy'] )
def get_cor(A, B): def get_cor(A, B):
return float(np.corrcoef(A, B)[0,1]) return float(np.corrcoef(A, B)[0,1])
cors = [] cors = []
for basestr, xlist in zip(['CIFAR-010', 'CIFAR-100', 'ImageNet16'], [cifar10_test,cifar100_test, imagenet_test]): for basestr, xlist in zip(['CIFAR-010', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T'], [cifar10_test, cifar100_valid, cifar100_test, imagenet_valid, imagenet_test]):
correlation = get_cor(cifar10_valid, xlist) correlation = get_cor(cifar10_valid, xlist)
print ('With {:3d}/{:}-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(less_epoch, '012' if use_less_or_not else '200', basestr, correlation)) if need_print: print ('With {:3d}/{:}-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, '012' if use_less_or_not else '200', basestr, correlation))
cors.append( correlation ) cors.append( correlation )
#print ('With {:3d}/200-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, basestr, get_cor(cifar10_valid_200, xlist))) #print ('With {:3d}/200-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, basestr, get_cor(cifar10_valid_200, xlist)))
#print('-'*200) #print('-'*200)
@ -173,6 +178,19 @@ def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True):
return cors return cors
def check_cor_for_bandit_v2(meta_file, test_epoch, use_less_or_not, is_rand):
corrs = []
for i in tqdm(range(100)):
x = check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand, False)
corrs.append( x )
xstrs = ['CIFAR-010', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T']
correlations = np.array(corrs)
print('------>>>>>>>> {:03d}/{:} >>>>>>>> ------'.format(test_epoch, '012' if use_less_or_not else '200'))
for idx, xstr in enumerate(xstrs):
print ('{:8s} ::: mean={:.4f}, std={:.4f} :: {:.4f}\\pm{:.4f}'.format(xstr, correlations[:,idx].mean(), correlations[:,idx].std(), correlations[:,idx].mean(), correlations[:,idx].std()))
print('')
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser("Analysis of NAS-Bench-102") parser = argparse.ArgumentParser("Analysis of NAS-Bench-102")
parser.add_argument('--save_dir', type=str, default='./output/search-cell-nas-bench-102/visuals', help='The base-name of folder to save checkpoints and log.') parser.add_argument('--save_dir', type=str, default='./output/search-cell-nas-bench-102/visuals', help='The base-name of folder to save checkpoints and log.')
@ -189,5 +207,11 @@ if __name__ == '__main__':
#for iepoch in [11, 25, 50, 100, 150, 175, 200]: #for iepoch in [11, 25, 50, 100, 150, 175, 200]:
# check_cor_for_bandit(api, 6, iepoch) # check_cor_for_bandit(api, 6, iepoch)
# check_cor_for_bandit(api, 12, iepoch) # check_cor_for_bandit(api, 12, iepoch)
correlations = check_cor_for_bandit(api, 6, True, True) check_cor_for_bandit_v2(api, 6, True, True)
import pdb; pdb.set_trace() check_cor_for_bandit_v2(api, 12, True, True)
check_cor_for_bandit_v2(api, 12, False, True)
check_cor_for_bandit_v2(api, 24, False, True)
check_cor_for_bandit_v2(api, 100, False, True)
check_cor_for_bandit_v2(api, 150, False, True)
check_cor_for_bandit_v2(api, 200, False, True)
print('----')

View File

@ -383,4 +383,4 @@ if __name__ == '__main__':
#visualize_info(str(meta_file), 'cifar10' , vis_save_dir) #visualize_info(str(meta_file), 'cifar10' , vis_save_dir)
#visualize_info(str(meta_file), 'cifar100', vis_save_dir) #visualize_info(str(meta_file), 'cifar100', vis_save_dir)
#visualize_info(str(meta_file), 'ImageNet16-120', vis_save_dir) #visualize_info(str(meta_file), 'ImageNet16-120', vis_save_dir)
visualize_relative_ranking(vis_save_dir) #visualize_relative_ranking(vis_save_dir)

View File

@ -174,7 +174,7 @@ def main(xargs, nas_bench):
id2config = results.get_id2config_mapping() id2config = results.get_id2config_mapping()
incumbent = results.get_incumbent_id() incumbent = results.get_incumbent_id()
logger.log('Best found configuration: {:}'.format(id2config[incumbent]['config'])) logger.log('Best found configuration: {:} within {:.3f} s'.format(id2config[incumbent]['config'], real_cost_time))
best_arch = config2structure( id2config[incumbent]['config'] ) best_arch = config2structure( id2config[incumbent]['config'] )
info = nas_bench.query_by_arch( best_arch ) info = nas_bench.query_by_arch( best_arch )

View File

@ -56,6 +56,7 @@ def main(xargs, nas_bench):
search_space = get_search_spaces('cell', xargs.search_space_name) search_space = get_search_spaces('cell', xargs.search_space_name)
random_arch = random_architecture_func(xargs.max_nodes, search_space) random_arch = random_architecture_func(xargs.max_nodes, search_space)
#x =random_arch() ; y = mutate_arch(x) #x =random_arch() ; y = mutate_arch(x)
x_start_time = time.time()
logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench)) logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench))
best_arch, best_acc, total_time_cost, history = None, -1, 0, [] best_arch, best_acc, total_time_cost, history = None, -1, 0, []
#for idx in range(xargs.random_num): #for idx in range(xargs.random_num):
@ -68,7 +69,7 @@ def main(xargs, nas_bench):
if best_arch is None or best_acc < accuracy: if best_arch is None or best_acc < accuracy:
best_acc, best_arch = accuracy, arch best_acc, best_arch = accuracy, arch
logger.log('[{:03d}] : {:} : accuracy = {:.2f}%'.format(len(history), arch, accuracy)) logger.log('[{:03d}] : {:} : accuracy = {:.2f}%'.format(len(history), arch, accuracy))
logger.log('{:} best arch is {:}, accuracy = {:.2f}%, visit {:} archs with {:.1f} s.'.format(time_string(), best_arch, best_acc, len(history), total_time_cost)) logger.log('{:} best arch is {:}, accuracy = {:.2f}%, visit {:} archs with {:.1f} s (real-cost = {:.3f} s).'.format(time_string(), best_arch, best_acc, len(history), total_time_cost, time.time()-x_start_time))
info = nas_bench.query_by_arch( best_arch ) info = nas_bench.query_by_arch( best_arch )
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch)) if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))

View File

@ -201,10 +201,11 @@ def main(xargs, nas_bench):
random_arch = random_architecture_func(xargs.max_nodes, search_space) random_arch = random_architecture_func(xargs.max_nodes, search_space)
mutate_arch = mutate_arch_func(search_space) mutate_arch = mutate_arch_func(search_space)
#x =random_arch() ; y = mutate_arch(x) #x =random_arch() ; y = mutate_arch(x)
x_start_time = time.time()
logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench)) logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench))
logger.log('-'*30 + ' start searching with the time budget of {:} s'.format(xargs.time_budget)) logger.log('-'*30 + ' start searching with the time budget of {:} s'.format(xargs.time_budget))
history, total_cost = regularized_evolution(xargs.ea_cycles, xargs.ea_population, xargs.ea_sample_size, xargs.time_budget, random_arch, mutate_arch, nas_bench if args.ea_fast_by_api else None, extra_info) history, total_cost = regularized_evolution(xargs.ea_cycles, xargs.ea_population, xargs.ea_sample_size, xargs.time_budget, random_arch, mutate_arch, nas_bench if args.ea_fast_by_api else None, extra_info)
logger.log('{:} regularized_evolution finish with history of {:} arch with {:.1f} s.'.format(time_string(), len(history), total_cost)) logger.log('{:} regularized_evolution finish with history of {:} arch with {:.1f} s (real-cost={:.2f} s).'.format(time_string(), len(history), total_cost, time.time()-x_start_time))
best_arch = max(history, key=lambda i: i.accuracy) best_arch = max(history, key=lambda i: i.accuracy)
best_arch = best_arch.arch best_arch = best_arch.arch
logger.log('{:} best arch is {:}'.format(time_string(), best_arch)) logger.log('{:} best arch is {:}'.format(time_string(), best_arch))

View File

@ -139,6 +139,7 @@ def main(xargs, nas_bench):
# REINFORCE # REINFORCE
# attempts = 0 # attempts = 0
x_start_time = time.time()
logger.log('Will start searching with time budget of {:} s.'.format(xargs.time_budget)) logger.log('Will start searching with time budget of {:} s.'.format(xargs.time_budget))
total_steps, total_costs = 0, 0 total_steps, total_costs = 0, 0
#for istep in range(xargs.RL_steps): #for istep in range(xargs.RL_steps):
@ -166,7 +167,7 @@ def main(xargs, nas_bench):
#logger.log('') #logger.log('')
best_arch = policy.genotype() best_arch = policy.genotype()
logger.log('REINFORCE finish with {:} steps and {:.1f} s.'.format(total_steps, total_costs)) logger.log('REINFORCE finish with {:} steps and {:.1f} s (real cost={:.3f}).'.format(total_steps, total_costs, time.time()-x_start_time))
info = nas_bench.query_by_arch( best_arch ) info = nas_bench.query_by_arch( best_arch )
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch)) if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
else : logger.log('{:}'.format(info)) else : logger.log('{:}'.format(info))

View File

@ -1,3 +1,6 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import os, sys, time, torch, random, argparse import os, sys, time, torch, random, argparse
from PIL import ImageFile from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True ImageFile.LOAD_TRUNCATED_IMAGES = True

View File

@ -1,3 +1,6 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import sys, time, torch, random, argparse import sys, time, torch, random, argparse
from PIL import ImageFile from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True ImageFile.LOAD_TRUNCATED_IMAGES = True

View File

@ -1 +0,0 @@
from graphviz import Digraph

View File

@ -1,27 +1,28 @@
# python ./vis-exps/show-results.py ##################################################
import os, sys # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
# python ./exps/vis/show-results.py --api_path ${HOME}/.torch/NAS-Bench-102-v1_0-e61699.pth
##################################################
import os, sys, argparse
from pathlib import Path from pathlib import Path
import torch import torch
import numpy as np import numpy as np
from collections import OrderedDict from collections import OrderedDict
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from aa_nas_api import AANASBenchAPI
api = AANASBenchAPI('./output/AA-NAS-BENCH-4/simplifies/C16-N5-final-infos.pth')
def plot_results_nas(dataset, xset, file_name, y_lims):
import matplotlib import matplotlib
matplotlib.use('agg') matplotlib.use('agg')
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
root = Path('./output/cell-search-tiny-vis').resolve() lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
print ('root path : {:}'.format( root )) if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
root.mkdir(parents=True, exist_ok=True)
checkpoints = ['./output/cell-search-tiny/R-EA-cifar10/results.pth', from nas_102_api import NASBench102API as API
'./output/cell-search-tiny/REINFORCE-cifar10/results.pth',
'./output/cell-search-tiny/RAND-cifar10/results.pth',
'./output/cell-search-tiny/BOHB-cifar10/results.pth' def plot_results_nas(api, dataset, xset, root, file_name, y_lims):
print ('root-path={:}, dataset={:}, xset={:}'.format(root, dataset, xset))
checkpoints = ['./output/search-cell-nas-bench-102/R-EA-cifar10/results.pth',
'./output/search-cell-nas-bench-102/REINFORCE-cifar10/results.pth',
'./output/search-cell-nas-bench-102/RAND-cifar10/results.pth',
'./output/search-cell-nas-bench-102/BOHB-cifar10/results.pth'
] ]
legends, indexes = ['REA', 'REINFORCE', 'RANDOM', 'BOHB'], None legends, indexes = ['REA', 'REINFORCE', 'RANDOM', 'BOHB'], None
All_Accs = OrderedDict() All_Accs = OrderedDict()
@ -29,15 +30,15 @@ def plot_results_nas(dataset, xset, file_name, y_lims):
all_indexes = torch.load(checkpoint, map_location='cpu') all_indexes = torch.load(checkpoint, map_location='cpu')
accuracies = [] accuracies = []
for x in all_indexes: for x in all_indexes:
info = api.arch2infos[ x ] info = api.arch2infos_full[ x ]
_, accy = info.get_metrics(dataset, xset, None, False) metrics = info.get_metrics(dataset, xset, None, False)
accuracies.append( accy ) accuracies.append( metrics['accuracy'] )
if indexes is None: indexes = list(range(len(all_indexes))) if indexes is None: indexes = list(range(len(all_indexes)))
All_Accs[legend] = sorted(accuracies) All_Accs[legend] = sorted(accuracies)
color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k'] color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k']
dpi, width, height = 300, 3400, 2600 dpi, width, height = 300, 3400, 2600
LabelSize, LegendFontsize = 26, 26 LabelSize, LegendFontsize = 28, 28
figsize = width / float(dpi), height / float(dpi) figsize = width / float(dpi), height / float(dpi)
fig = plt.figure(figsize=figsize) fig = plt.figure(figsize=figsize)
x_axis = np.arange(0, 600) x_axis = np.arange(0, 600)
@ -52,16 +53,83 @@ def plot_results_nas(dataset, xset, file_name, y_lims):
for idx, legend in enumerate(legends): for idx, legend in enumerate(legends):
plt.plot(indexes, All_Accs[legend], color=color_set[idx], linestyle='-', label='{:}'.format(legend), lw=2) plt.plot(indexes, All_Accs[legend], color=color_set[idx], linestyle='-', label='{:}'.format(legend), lw=2)
print ('{:} : mean = {:}, std = {:}'.format(legend, np.mean(All_Accs[legend]), np.std(All_Accs[legend]))) print ('{:} : mean = {:}, std = {:} :: {:.2f}$\\pm${:.2f}'.format(legend, np.mean(All_Accs[legend]), np.std(All_Accs[legend]), np.mean(All_Accs[legend]), np.std(All_Accs[legend])))
plt.legend(loc=4, fontsize=LegendFontsize) plt.legend(loc=4, fontsize=LegendFontsize)
save_path = root / '{:}-{:}-{:}'.format(dataset, xset, file_name) save_path = root / '{:}-{:}-{:}'.format(dataset, xset, file_name)
print('save figure into {:}\n'.format(save_path)) print('save figure into {:}\n'.format(save_path))
fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf') fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf')
def just_show(api):
xtimes = {'RSPS': [8082.5, 7794.2, 8144.7],
'DARTS-V1': [11582.1, 11347.0, 11948.2],
'DARTS-V2': [35694.7, 36132.7, 35518.0],
'GDAS' : [31334.1, 31478.6, 32016.7],
'SETN' : [33528.8, 33831.5, 35058.3],
'ENAS' : [14340.2, 13817.3, 14018.9]}
for xkey, xlist in xtimes.items():
xlist = np.array(xlist)
print ('{:4s} : mean-time={:.1f} s'.format(xkey, xlist.mean()))
xpaths = {'RSPS' : 'output/search-cell-nas-bench-102/RANDOM-NAS-cifar10/checkpoint/',
'DARTS-V1': 'output/search-cell-nas-bench-102/DARTS-V1-cifar10/checkpoint/',
'DARTS-V2': 'output/search-cell-nas-bench-102/DARTS-V2-cifar10/checkpoint/',
'GDAS' : 'output/search-cell-nas-bench-102/GDAS-cifar10/checkpoint/',
'SETN' : 'output/search-cell-nas-bench-102/SETN-cifar10/checkpoint/',
'ENAS' : 'output/search-cell-nas-bench-102/ENAS-cifar10/checkpoint/',
}
xseeds = {'RSPS' : [5349, 59613, 5983],
'DARTS-V1': [11416, 72873, 81184],
'DARTS-V2': [43330, 79405, 79423],
'GDAS' : [19677, 884, 95950],
'SETN' : [20518, 61817, 89144],
'ENAS' : [30801, 75610, 97745],
}
def get_accs(xdata, index=-1):
if index == -1:
epochs = xdata['epoch']
genotype = xdata['genotypes'][epochs-1]
index = api.query_index_by_arch(genotype)
pairs = [('cifar10-valid', 'x-valid'), ('cifar10', 'ori-test'), ('cifar100', 'x-valid'), ('cifar100', 'x-test'), ('ImageNet16-120', 'x-valid'), ('ImageNet16-120', 'x-test')]
xresults = []
for dataset, xset in pairs:
metrics = api.arch2infos_full[index].get_metrics(dataset, xset, None, False)
xresults.append( metrics['accuracy'] )
return xresults
for xkey in xpaths.keys():
all_paths = [ '{:}/seed-{:}-basic.pth'.format(xpaths[xkey], seed) for seed in xseeds[xkey] ]
all_datas = [torch.load(xpath) for xpath in all_paths]
accyss = [get_accs(xdatas) for xdatas in all_datas]
accyss = np.array( accyss )
print('\nxkey = {:}'.format(xkey))
for i in range(accyss.shape[1]): print('---->>>> {:.2f}$\\pm${:.2f}'.format(accyss[:,i].mean(), accyss[:,i].std()))
print('\n{:}'.format(get_accs(None, 11472))) # resnet
pairs = [('cifar10-valid', 'x-valid'), ('cifar10', 'ori-test'), ('cifar100', 'x-valid'), ('cifar100', 'x-test'), ('ImageNet16-120', 'x-valid'), ('ImageNet16-120', 'x-test')]
for dataset, metric_on_set in pairs:
arch_index, highest_acc = api.find_best(dataset, metric_on_set)
print ('[{:10s}-{:10s} ::: index={:5d}, accuracy={:.2f}'.format(dataset, metric_on_set, arch_index, highest_acc))
if __name__ == '__main__': if __name__ == '__main__':
plot_results_nas('cifar10', 'ori-test', 'nas-com.pdf', (85,95, 1)) parser = argparse.ArgumentParser(description='NAS-Bench-102', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
plot_results_nas('cifar100', 'x-valid', 'nas-com.pdf', (55,75, 3)) parser.add_argument('--save_dir', type=str, default='./output/search-cell-nas-bench-102/visuals', help='The base-name of folder to save checkpoints and log.')
plot_results_nas('cifar100', 'x-test' , 'nas-com.pdf', (55,75, 3)) parser.add_argument('--api_path', type=str, default=None, help='The path to the NAS-Bench-102 benchmark file.')
plot_results_nas('ImageNet16-120', 'x-valid', 'nas-com.pdf', (35,50, 3)) args = parser.parse_args()
plot_results_nas('ImageNet16-120', 'x-test' , 'nas-com.pdf', (35,50, 3))
api = API(args.api_path)
root = Path(args.save_dir).resolve()
root.mkdir(parents=True, exist_ok=True)
just_show(api)
"""
plot_results_nas(api, 'cifar10-valid' , 'x-valid' , root, 'nas-com.pdf', (85,95, 1))
plot_results_nas(api, 'cifar10' , 'ori-test', root, 'nas-com.pdf', (85,95, 1))
plot_results_nas(api, 'cifar100' , 'x-valid' , root, 'nas-com.pdf', (55,75, 3))
plot_results_nas(api, 'cifar100' , 'x-test' , root, 'nas-com.pdf', (55,75, 3))
plot_results_nas(api, 'ImageNet16-120', 'x-valid' , root, 'nas-com.pdf', (35,50, 3))
plot_results_nas(api, 'ImageNet16-120', 'x-test' , root, 'nas-com.pdf', (35,50, 3))
"""

View File

@ -131,15 +131,17 @@ class NASBench102API(object):
else : basestr, arch2infos = '200epochs', self.arch2infos_full else : basestr, arch2infos = '200epochs', self.arch2infos_full
best_index, highest_accuracy = -1, None best_index, highest_accuracy = -1, None
for i, idx in enumerate(self.evaluated_indexes): for i, idx in enumerate(self.evaluated_indexes):
flop, param, latency = arch2infos[idx].get_comput_costs(dataset) info = arch2infos[idx].get_comput_costs(dataset)
flop, param, latency = info['flops'], info['params'], info['latency']
if FLOP_max is not None and flop > FLOP_max : continue if FLOP_max is not None and flop > FLOP_max : continue
if Param_max is not None and param > Param_max: continue if Param_max is not None and param > Param_max: continue
loss, accuracy = arch2infos[idx].get_metrics(dataset, metric_on_set) xinfo = arch2infos[idx].get_metrics(dataset, metric_on_set)
loss, accuracy = xinfo['loss'], xinfo['accuracy']
if best_index == -1: if best_index == -1:
best_index, highest_accuracy = idx, accuracy best_index, highest_accuracy = idx, accuracy
elif highest_accuracy < accuracy: elif highest_accuracy < accuracy:
best_index, highest_accuracy = idx, accuracy best_index, highest_accuracy = idx, accuracy
return best_index return best_index, highest_accuracy
# return the topology structure of the `index`-th architecture # return the topology structure of the `index`-th architecture
def arch(self, index): def arch(self, index):
@ -183,10 +185,18 @@ class NASBench102API(object):
test__info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random) test__info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
else: else:
test__info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random) test__info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random)
return {'train-loss' : train_info['loss'], try:
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
except:
valid_info = None
xifo = {'train-loss' : train_info['loss'],
'train-accuracy': train_info['accuracy'], 'train-accuracy': train_info['accuracy'],
'test-loss' : test__info['loss'], 'test-loss' : test__info['loss'],
'test-accuracy' : test__info['accuracy']} 'test-accuracy' : test__info['accuracy']}
if valid_info is not None:
xifo['valid-loss'] = valid_info['loss']
xifo['valid-accuracy'] = valid_info['accuracy']
return xifo
def show(self, index=-1): def show(self, index=-1):
if index < 0: # show all architectures if index < 0: # show all architectures