diff --git a/exps/NAS-Bench-201/test-nas-api.py b/exps/NAS-Bench-201/test-nas-api.py index 9a79f28..62d2bc3 100644 --- a/exps/NAS-Bench-201/test-nas-api.py +++ b/exps/NAS-Bench-201/test-nas-api.py @@ -72,6 +72,14 @@ def test_api(api, is_301=True): print('{:}\n'.format(info)) print('{:} finish testing the api : {:}'.format(time_string(), api)) + if not is_301: + arch_str = '|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|' + matrix = api.str2matrix(arch_str) + print('Compute the adjacency matrix of {:}'.format(arch_str)) + print(matrix) + info = api.simulate_train_eval(123, 'cifar10') + print('simulate_train_eval : {:}'.format(info)) + def test_issue_81_82(api): results = api.query_by_index(0, 'cifar10-valid', hp='12') diff --git a/exps/algos-v2/random_wo_share.py b/exps/algos-v2/random_wo_share.py new file mode 100644 index 0000000..774dfd4 --- /dev/null +++ b/exps/algos-v2/random_wo_share.py @@ -0,0 +1,91 @@ +################################################## +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # +############################################################################## +# Random Search for Hyper-Parameter Optimization, JMLR 2012 ################## +############################################################################## +# python ./exps/algos-v2/random_wo_share.py --dataset cifar10 --search_space tss +############################################################################## +import os, sys, time, glob, random, argparse +import numpy as np, collections +from copy import deepcopy +import torch +import torch.nn as nn +from pathlib import Path +lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() +if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) +from config_utils import load_config, dict2config, configure2str +from datasets import get_datasets, SearchDataset +from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler +from utils import get_model_infos, obtain_accuracy +from log_utils import AverageMeter, time_string, convert_secs2time +from models import get_search_spaces +from nas_201_api import NASBench201API, NASBench301API +from .regularized_ea import random_topology_func, random_size_func + + +def main(xargs, api): + torch.set_num_threads(4) + prepare_seed(xargs.rand_seed) + logger = prepare_logger(args) + + search_space = get_search_spaces(xargs.search_space, 'nas-bench-301') + if xargs.search_space == 'tss': + random_arch = random_topology_func(search_space) + else: + random_arch = random_size_func(search_space) + + x_start_time = time.time() + logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench)) + best_arch, best_acc, total_time_cost, history = None, -1, [], [] + while total_time_cost[-1] < xargs.time_budget: + arch = random_arch() + accuracy, _, _, total_cost = api.simulate_train_eval(arch, xargs.dataset, '12') + total_time_cost.append(total_cost) + history.append(arch) + if best_arch is None or best_acc < accuracy: + best_acc, best_arch = accuracy, arch + logger.log('[{:03d}] : {:} : accuracy = {:.2f}%'.format(len(history), arch, accuracy)) + logger.log('{:} best arch is {:}, accuracy = {:.2f}%, visit {:} archs with {:.1f} s (real-cost = {:.3f} s).'.format(time_string(), best_arch, best_acc, len(history), total_time_cost, time.time()-x_start_time)) + + info = api.query_info_str_by_arch(best_arch, '200' if xargs.search_space == 'tss' else '90') + logger.log('{:}'.format(info)) + logger.log('-'*100) + logger.close() + return logger.log_dir, total_time_cost, history + + +if __name__ == '__main__': + parser = argparse.ArgumentParser("Random NAS") + parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') + parser.add_argument('--search_space', type=str, choices=['tss', 'sss'], help='Choose the search space.') + + parser.add_argument('--time_budget', type=int, default=20000, help='The total time cost budge for searching (in seconds).') + parser.add_argument('--loops_if_rand', type=int, default=500, help='The total runs for evaluation.') + # log + parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.') + parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed') + args = parser.parse_args() + + if args.search_space == 'tss': + api = NASBench201API(verbose=False) + elif args.search_space == 'sss': + api = NASBench301API(verbose=False) + else: + raise ValueError('Invalid search space : {:}'.format(args.search_space)) + + args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, 'RANDOM') + print('save-dir : {:}'.format(args.save_dir)) + + if args.rand_seed < 0: + save_dir, all_info = None, {} + for i in range(args.loops_if_rand): + print ('{:} : {:03d}/{:03d}'.format(time_string(), i, args.loops_if_rand)) + args.rand_seed = random.randint(1, 100000) + save_dir, all_archs, all_total_times = main(args, api) + all_info[i] = {'all_archs': all_archs, + 'all_total_times': all_total_times} + save_path = save_dir / 'results.pth' + print('save into {:}'.format(save_path)) + torch.save(all_info, save_path) + else: + main(args, api) diff --git a/exps/algos-v2/REA.py b/exps/algos-v2/regularized_ea.py similarity index 87% rename from exps/algos-v2/REA.py rename to exps/algos-v2/regularized_ea.py index 7410fa3..4e0a3bd 100644 --- a/exps/algos-v2/REA.py +++ b/exps/algos-v2/regularized_ea.py @@ -3,12 +3,12 @@ ################################################################## # Regularized Evolution for Image Classifier Architecture Search # ################################################################## -# python ./exps/algos-v2/REA.py --dataset cifar10 --search_space tss --time_budget 12000 --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1 -# python ./exps/algos-v2/REA.py --dataset cifar100 --search_space tss --time_budget 12000 --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1 -# python ./exps/algos-v2/REA.py --dataset ImageNet16-120 --search_space tss --time_budget 12000 --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1 -# python ./exps/algos-v2/REA.py --dataset cifar10 --search_space sss --time_budget 12000 --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1 -# python ./exps/algos-v2/REA.py --dataset cifar100 --search_space sss --time_budget 12000 --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1 -# python ./exps/algos-v2/REA.py --dataset ImageNet16-120 --search_space sss --time_budget 12000 --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1 +# python ./exps/algos-v2/regularized_ea.py --dataset cifar10 --search_space tss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1 +# python ./exps/algos-v2/regularized_ea.py --dataset cifar100 --search_space tss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1 +# python ./exps/algos-v2/regularized_ea.py --dataset ImageNet16-120 --search_space tss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1 +# python ./exps/algos-v2/regularized_ea.py --dataset cifar10 --search_space sss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1 +# python ./exps/algos-v2/regularized_ea.py --dataset cifar100 --search_space sss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1 +# python ./exps/algos-v2/regularized_ea.py --dataset ImageNet16-120 --search_space sss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1 ################################################################## import os, sys, time, glob, random, argparse import numpy as np, collections @@ -160,7 +160,7 @@ def regularized_evolution(cycles, population_size, sample_size, time_budget, ran while len(population) < population_size: model = Model() model.arch = random_arch() - model.accuracy, time_cost, total_cost = api.simulate_train_eval(model.arch, dataset, '12') + model.accuracy, _, _, total_cost = api.simulate_train_eval(model.arch, dataset, '12') # Append the info population.append(model) history.append(model) @@ -183,7 +183,7 @@ def regularized_evolution(cycles, population_size, sample_size, time_budget, ran # Create the child model and store it. child = Model() child.arch = mutate_arch(parent.arch) - child.accuracy, time_cost, total_cost = api.simulate_train_eval(model.arch, dataset, '12') + child.accuracy, _, _, total_cost = api.simulate_train_eval(model.arch, dataset, '12') # Append the info population.append(child) history.append(child) @@ -195,11 +195,7 @@ def regularized_evolution(cycles, population_size, sample_size, time_budget, ran def main(xargs, api): - assert torch.cuda.is_available(), 'CUDA is not available.' - torch.backends.cudnn.enabled = True - torch.backends.cudnn.benchmark = False - torch.backends.cudnn.deterministic = True - torch.set_num_threads(xargs.workers) + torch.set_num_threads(4) prepare_seed(xargs.rand_seed) logger = prepare_logger(args) @@ -235,12 +231,11 @@ if __name__ == '__main__': parser.add_argument('--ea_cycles', type=int, help='The number of cycles in EA.') parser.add_argument('--ea_population', type=int, help='The population size in EA.') parser.add_argument('--ea_sample_size', type=int, help='The sample size in EA.') - parser.add_argument('--time_budget', type=int, help='The total time cost budge for searching (in seconds).') - parser.add_argument('--loops_if_rand', type=int, default=500, help='The total runs for evaluation.') + parser.add_argument('--time_budget', type=int, default=20000, help='The total time cost budge for searching (in seconds).') + parser.add_argument('--loops_if_rand', type=int, default=500, help='The total runs for evaluation.') # log - parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)') parser.add_argument('--save_dir', type=str, default='./output/search', help='Folder to save checkpoints and log.') - parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed') + parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed') args = parser.parse_args() if args.search_space == 'tss': diff --git a/exps/algos-v2/reinforce.py b/exps/algos-v2/reinforce.py index c81708c..400f1ef 100644 --- a/exps/algos-v2/reinforce.py +++ b/exps/algos-v2/reinforce.py @@ -3,12 +3,12 @@ ##################################################################################################### # modified from https://github.com/pytorch/examples/blob/master/reinforcement_learning/reinforce.py # ##################################################################################################### -# python ./exps/algos-v2/reinforce.py --dataset cifar10 --search_space tss --time_budget 12000 --learning_rate 0.001 -# python ./exps/algos-v2/reinforce.py --dataset cifar100 --search_space tss --time_budget 12000 --learning_rate 0.001 -# python ./exps/algos-v2/reinforce.py --dataset ImageNet16-120 --search_space tss --time_budget 12000 --learning_rate 0.001 -# python ./exps/algos-v2/reinforce.py --dataset cifar10 --search_space sss --time_budget 12000 --learning_rate 0.001 -# python ./exps/algos-v2/reinforce.py --dataset cifar100 --search_space sss --time_budget 12000 --learning_rate 0.001 -# python ./exps/algos-v2/reinforce.py --dataset ImageNet16-120 --search_space sss --time_budget 12000 --learning_rate 0.001 +# python ./exps/algos-v2/reinforce.py --dataset cifar10 --search_space tss --learning_rate 0.001 +# python ./exps/algos-v2/reinforce.py --dataset cifar100 --search_space tss --learning_rate 0.001 +# python ./exps/algos-v2/reinforce.py --dataset ImageNet16-120 --search_space tss --learning_rate 0.001 +# python ./exps/algos-v2/reinforce.py --dataset cifar10 --search_space sss --learning_rate 0.001 +# python ./exps/algos-v2/reinforce.py --dataset cifar100 --search_space sss --learning_rate 0.001 +# python ./exps/algos-v2/reinforce.py --dataset ImageNet16-120 --search_space sss --learning_rate 0.001 ##################################################################################################### import os, sys, time, glob, random, argparse import numpy as np, collections @@ -120,15 +120,10 @@ def select_action(policy): def main(xargs, api): - assert torch.cuda.is_available(), 'CUDA is not available.' - torch.backends.cudnn.enabled = True - torch.backends.cudnn.benchmark = False - torch.backends.cudnn.deterministic = True - torch.set_num_threads(xargs.workers) + torch.set_num_threads(4) prepare_seed(xargs.rand_seed) logger = prepare_logger(args) - search_space = get_search_spaces(xargs.search_space, 'nas-bench-301') if xargs.search_space == 'tss': policy = PolicyTopology(search_space) @@ -144,6 +139,7 @@ def main(xargs, api): # nas dataset load logger.log('{:} use api : {:}'.format(time_string(), api)) + api.reset_time() # REINFORCE x_start_time = time.time() @@ -153,7 +149,7 @@ def main(xargs, api): start_time = time.time() log_prob, action = select_action( policy ) arch = policy.generate_arch( action ) - reward, _, current_total_cost = api.simulate_train_eval(arch, xargs.dataset, '12') + reward, _, _, current_total_cost = api.simulate_train_eval(arch, xargs.dataset, '12') trace.append((reward, arch)) total_costs.append(current_total_cost) @@ -177,7 +173,7 @@ def main(xargs, api): logger.log('-'*100) logger.close() - return logger.log_dir, [api.query_index_by_arch(x[0]) for x in trace], total_costs + return logger.log_dir, [api.query_index_by_arch(x[1]) for x in trace], total_costs if __name__ == '__main__': @@ -186,15 +182,14 @@ if __name__ == '__main__': parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') parser.add_argument('--search_space', type=str, choices=['tss', 'sss'], help='Choose the search space.') parser.add_argument('--learning_rate', type=float, help='The learning rate for REINFORCE.') - parser.add_argument('--EMA_momentum', type=float, default=0.9, help='The momentum value for EMA.') - parser.add_argument('--time_budget', type=int, help='The total time cost budge for searching (in seconds).') - parser.add_argument('--loops_if_rand', type=int, default=500, help='The total runs for evaluation.') + parser.add_argument('--EMA_momentum', type=float, default=0.9, help='The momentum value for EMA.') + parser.add_argument('--time_budget', type=int, default=20000, help='The total time cost budge for searching (in seconds).') + parser.add_argument('--loops_if_rand', type=int, default=500, help='The total runs for evaluation.') # log - parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)') parser.add_argument('--save_dir', type=str, default='./output/search', help='Folder to save checkpoints and log.') parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (tiny-nas-benchmark).') parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)') - parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed') + parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed') args = parser.parse_args() if args.search_space == 'tss': diff --git a/exps/algos-v2/run-all.sh b/exps/algos-v2/run-all.sh new file mode 100644 index 0000000..3f2f01d --- /dev/null +++ b/exps/algos-v2/run-all.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# bash ./exps/algos-v2/run-all.sh +echo script name: $0 +echo $# arguments + +datasets="cifar10 cifar100 ImageNet16-120" +search_spaces="tss sss" + + +for dataset in ${datasets} +do + for search_space in ${search_spaces} + do + python ./exps/algos-v2/reinforce.py --dataset ${dataset} --search_space ${search_space} --learning_rate 0.001 + python ./exps/algos-v2/regularized_ea.py --dataset ${dataset} --search_space ${search_space} --ea_cycles 200 --ea_population 10 --ea_sample_size 3 + done +done diff --git a/exps/algos/RANDOM.py b/exps/algos/RANDOM.py index 58af886..e38bf60 100644 --- a/exps/algos/RANDOM.py +++ b/exps/algos/RANDOM.py @@ -84,7 +84,7 @@ def main(xargs, nas_bench): if __name__ == '__main__': - parser = argparse.ArgumentParser("Regularized Evolution Algorithm") + parser = argparse.ArgumentParser("Random NAS") parser.add_argument('--data_path', type=str, help='Path to dataset') parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') # channels and number-of-cells diff --git a/exps/experimental/vis-bench-algos.py b/exps/experimental/vis-bench-algos.py new file mode 100644 index 0000000..f0a4b1b --- /dev/null +++ b/exps/experimental/vis-bench-algos.py @@ -0,0 +1,107 @@ +############################################################### +# NAS-Bench-201, ICLR 2020 (https://arxiv.org/abs/2001.00326) # +############################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 # +############################################################### +# Usage: python exps/experimental/vis-bench-algos.py +############################################################### +import os, sys, time, torch, argparse +import numpy as np +from typing import List, Text, Dict, Any +from shutil import copyfile +from collections import defaultdict, OrderedDict +from copy import deepcopy +from pathlib import Path +import matplotlib +import seaborn as sns +matplotlib.use('agg') +import matplotlib.pyplot as plt +import matplotlib.ticker as ticker + +lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() +if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) +from config_utils import dict2config, load_config +from nas_201_api import NASBench201API, NASBench301API +from log_utils import time_string + + +def fetch_data(root_dir='./output/search', search_space='tss', dataset=None): + ss_dir = '{:}-{:}'.format(root_dir, search_space) + alg2name, alg2path = OrderedDict(), OrderedDict() + alg2name['REA'] = 'R-EA-SS3' + alg2name['REINFORCE'] = 'REINFORCE-0.001' + for alg, name in alg2name.items(): + alg2path[alg] = os.path.join(ss_dir, dataset, name, 'results.pth') + assert os.path.isfile(alg2path[alg]) + alg2data = OrderedDict() + for alg, path in alg2path.items(): + data = torch.load(path) + for index, info in data.items(): + info['time_w_arch'] = [(x, y) for x, y in zip(info['all_total_times'], info['all_archs'])] + for j, arch in enumerate(info['all_archs']): + assert arch != -1, 'invalid arch from {:} {:} {:} ({:}, {:})'.format(alg, search_space, dataset, index, j) + alg2data[alg] = data + return alg2data + + +def query_performance(api, data, dataset, ticket): + results, is_301 = [], isinstance(api, NASBench301API) + for i, info in data.items(): + time_w_arch = sorted(info['time_w_arch'], key=lambda x: abs(x[0]-ticket)) + time_a, arch_a = time_w_arch[0] + time_b, arch_b = time_w_arch[1] + info_a = api.get_more_info(arch_a, dataset=dataset, hp=90 if is_301 else 200, is_random=False) + info_b = api.get_more_info(arch_b, dataset=dataset, hp=90 if is_301 else 200, is_random=False) + accuracy_a, accuracy_b = info_a['test-accuracy'], info_b['test-accuracy'] + interplate = (time_b-ticket) / (time_b-time_a) * accuracy_a + (ticket-time_a) / (time_b-time_a) * accuracy_b + results.append(interplate) + return sum(results) / len(results) + + +def visualize_curve(api, vis_save_dir, search_space, max_time): + vis_save_dir = vis_save_dir.resolve() + vis_save_dir.mkdir(parents=True, exist_ok=True) + + dpi, width, height = 250, 4700, 1500 + figsize = width / float(dpi), height / float(dpi) + LabelSize, LegendFontsize = 14, 14 + + def sub_plot_fn(ax, dataset): + alg2data = fetch_data(search_space=search_space, dataset=dataset) + alg2accuracies = OrderedDict() + time_tickets = [float(i) / 100 * max_time for i in range(100)] + colors = ['b', 'g', 'c', 'm', 'y'] + for idx, (alg, data) in enumerate(alg2data.items()): + print('plot alg : {:}'.format(alg)) + accuracies = [] + for ticket in time_tickets: + accuracy = query_performance(api, data, dataset, ticket) + accuracies.append(accuracy) + alg2accuracies[alg] = accuracies + ax.plot(time_tickets, accuracies, c=colors[idx], label='{:}'.format(alg)) + ax.legend(loc=4, fontsize=LegendFontsize) + + fig, axs = plt.subplots(1, 3, figsize=figsize) + datasets = ['cifar10', 'cifar100', 'ImageNet16-120'] + for dataset, ax in zip(datasets, axs): + sub_plot_fn(ax, dataset) + print('sub-plot {:} on {:} done.'.format(dataset, search_space)) + save_path = (vis_save_dir / '{:}-curve.png'.format(search_space)).resolve() + fig.savefig(save_path, dpi=dpi, bbox_inches='tight', format='png') + print ('{:} save into {:}'.format(time_string(), save_path)) + plt.close('all') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='NAS-Bench-X', formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('--save_dir', type=str, default='output/vis-nas-bench/nas-algos', help='Folder to save checkpoints and log.') + parser.add_argument('--max_time', type=float, default=20000, help='The maximum time budget.') + args = parser.parse_args() + + save_dir = Path(args.save_dir) + + api201 = NASBench201API(verbose=False) + visualize_curve(api201, save_dir, 'tss', args.max_time) + api301 = NASBench301API(verbose=False) + visualize_curve(api301, save_dir, 'sss', args.max_time) + diff --git a/exps/NAS-Bench-201/test-nas-api-vis.py b/exps/experimental/visualize-nas-bench-x.py similarity index 96% rename from exps/NAS-Bench-201/test-nas-api-vis.py rename to exps/experimental/visualize-nas-bench-x.py index 02668fe..e3714a7 100644 --- a/exps/NAS-Bench-201/test-nas-api-vis.py +++ b/exps/experimental/visualize-nas-bench-x.py @@ -3,7 +3,7 @@ ############################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 # ############################################################### -# Usage: python exps/NAS-Bench-201/test-nas-api-vis.py +# Usage: python exps/experimental/visualize-nas-bench-x.py ############################################################### import os, sys, time, torch, argparse import numpy as np @@ -384,24 +384,25 @@ def visualize_all_rank_info(api, vis_save_dir, indicator): if __name__ == '__main__': parser = argparse.ArgumentParser(description='NAS-Bench-X', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--save_dir', type=str, default='output/NAS-BENCH-202', help='Folder to save checkpoints and log.') - parser.add_argument('--check_N', type=int, default=32768, help='For safety.') + parser.add_argument('--save_dir', type=str, default='output/vis-nas-bench', help='Folder to save checkpoints and log.') # use for train the model args = parser.parse_args() + to_save_dir = Path(args.save_dir) + datasets = ['cifar10', 'cifar100', 'ImageNet16-120'] api201 = NASBench201API(None, verbose=True) for xdata in datasets: - visualize_tss_info(api201, xdata, Path('output/vis-nas-bench')) + visualize_tss_info(api201, xdata, to_save_dir) api301 = NASBench301API(None, verbose=True) for xdata in datasets: - visualize_sss_info(api301, xdata, Path('output/vis-nas-bench')) + visualize_sss_info(api301, xdata, to_save_dir) - visualize_info(None, Path('output/vis-nas-bench/'), 'tss') - visualize_info(None, Path('output/vis-nas-bench/'), 'sss') - visualize_rank_info(None, Path('output/vis-nas-bench/'), 'tss') - visualize_rank_info(None, Path('output/vis-nas-bench/'), 'sss') + visualize_info(None, to_save_dir, 'tss') + visualize_info(None, to_save_dir, 'sss') + visualize_rank_info(None, to_save_dir, 'tss') + visualize_rank_info(None, to_save_dir, 'sss') - visualize_all_rank_info(None, Path('output/vis-nas-bench/'), 'tss') - visualize_all_rank_info(None, Path('output/vis-nas-bench/'), 'sss') + visualize_all_rank_info(None, to_save_dir, 'tss') + visualize_all_rank_info(None, to_save_dir, 'sss') diff --git a/lib/nas_201_api/api_201.py b/lib/nas_201_api/api_201.py index 454c49a..49d9a68 100644 --- a/lib/nas_201_api/api_201.py +++ b/lib/nas_201_api/api_201.py @@ -141,9 +141,12 @@ class NASBench201API(NASBenchMetaAPI): # `is_random` # When is_random=True, the performance of a random architecture will be returned # When is_random=False, the performanceo of all trials will be averaged. - def get_more_info(self, index: int, dataset, iepoch=None, hp='12', is_random=True): + def get_more_info(self, index, dataset, iepoch=None, hp='12', is_random=True): if self.verbose: print('Call the get_more_info function with index={:}, dataset={:}, iepoch={:}, hp={:}, and is_random={:}.'.format(index, dataset, iepoch, hp, is_random)) + index = self.query_index_by_arch(index) # To avoid the input is a string or an instance of a arch object + if index not in self.arch2infos_dict: + raise ValueError('Did not find {:} from arch2infos_dict.'.format(index)) archresult = self.arch2infos_dict[index][str(hp)] # if randomly select one trial, select the seed at first if isinstance(is_random, bool) and is_random: diff --git a/lib/nas_201_api/api_301.py b/lib/nas_201_api/api_301.py index a349056..8ac77f8 100644 --- a/lib/nas_201_api/api_301.py +++ b/lib/nas_201_api/api_301.py @@ -131,7 +131,7 @@ class NASBench301API(NASBenchMetaAPI): print('Call query_info_str_by_arch with arch={:} and hp={:}'.format(arch, hp)) return self._query_info_str_by_arch(arch, hp, print_information) - def get_more_info(self, index: int, dataset: Text, iepoch=None, hp='12', is_random=True): + def get_more_info(self, index, dataset: Text, iepoch=None, hp='12', is_random=True): """This function will return the metric for the `index`-th architecture `dataset` indicates the dataset: 'cifar10-valid' : using the proposed train set of CIFAR-10 as the training set @@ -151,6 +151,9 @@ class NASBench301API(NASBenchMetaAPI): """ if self.verbose: print('Call the get_more_info function with index={:}, dataset={:}, iepoch={:}, hp={:}, and is_random={:}.'.format(index, dataset, iepoch, hp, is_random)) + index = self.query_index_by_arch(index) # To avoid the input is a string or an instance of a arch object + if index not in self.arch2infos_dict: + raise ValueError('Did not find {:} from arch2infos_dict.'.format(index)) archresult = self.arch2infos_dict[index][str(hp)] # if randomly select one trial, select the seed at first if isinstance(is_random, bool) and is_random: diff --git a/lib/nas_201_api/api_utils.py b/lib/nas_201_api/api_utils.py index 53199ae..a8383d2 100644 --- a/lib/nas_201_api/api_utils.py +++ b/lib/nas_201_api/api_utils.py @@ -68,7 +68,7 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta): def reset_time(self): self._used_time = 0 - def simulate_train_eval(self, arch, dataset, hp='12'): + def simulate_train_eval(self, arch, dataset, hp='12', account_time=True): index = self.query_index_by_arch(arch) all_names = ('cifar10', 'cifar100', 'ImageNet16-120') assert dataset in all_names, 'Invalid dataset name : {:} vs {:}'.format(dataset, all_names) @@ -77,8 +77,10 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta): else: info = self.get_more_info(index, dataset, iepoch=None, hp=hp, is_random=True) valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time'] - self._used_time += time_cost - return valid_acc, time_cost, self._used_time + latency = self.get_latency(index, dataset) + if account_time: + self._used_time += time_cost + return valid_acc, latency, time_cost, self._used_time def random(self): """Return a random index of all architectures.""" diff --git a/lib/utils/nas_utils.py b/lib/utils/nas_utils.py index c701935..1b1a44d 100644 --- a/lib/utils/nas_utils.py +++ b/lib/utils/nas_utils.py @@ -8,7 +8,9 @@ import torch.nn as nn from models import CellStructure from log_utils import time_string + def evaluate_one_shot(model, xloader, api, cal_mode, seed=111): + print ('This is an old version of codes to use NAS-Bench-API, and should be modified to align with the new version. Please contact me for more details if you use this function.') weights = deepcopy(model.state_dict()) model.train(cal_mode) with torch.no_grad():