From 8afb62ad2ea1102db10cd3ca9d5e67069770709b Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Tue, 1 Dec 2020 12:34:00 +0800 Subject: [PATCH] Update for Rebuttal --- exps/NATS-Bench/draw-table.py | 76 ++++++++++++++++++++++++++++++- exps/NATS-algos/regularized_ea.py | 20 ++++++-- lib/nats_bench/api_size.py | 1 + lib/nats_bench/api_topology.py | 1 + lib/nats_bench/api_utils.py | 4 ++ 5 files changed, 96 insertions(+), 6 deletions(-) diff --git a/exps/NATS-Bench/draw-table.py b/exps/NATS-Bench/draw-table.py index 8105b6e..b379dc6 100644 --- a/exps/NATS-Bench/draw-table.py +++ b/exps/NATS-Bench/draw-table.py @@ -26,6 +26,27 @@ from nats_bench import create from log_utils import time_string +def fetch_data(root_dir='./output/search', search_space='tss', dataset=None): + ss_dir = '{:}-{:}'.format(root_dir, search_space) + alg2name, alg2path = OrderedDict(), OrderedDict() + alg2name['REA'] = 'R-EA-SS3' + alg2name['REINFORCE'] = 'REINFORCE-0.01' + alg2name['RANDOM'] = 'RANDOM' + alg2name['BOHB'] = 'BOHB' + for alg, name in alg2name.items(): + alg2path[alg] = os.path.join(ss_dir, dataset, name, 'results.pth') + assert os.path.isfile(alg2path[alg]), 'invalid path : {:}'.format(alg2path[alg]) + alg2data = OrderedDict() + for alg, path in alg2path.items(): + data = torch.load(path) + for index, info in data.items(): + info['time_w_arch'] = [(x, y) for x, y in zip(info['all_total_times'], info['all_archs'])] + for j, arch in enumerate(info['all_archs']): + assert arch != -1, 'invalid arch from {:} {:} {:} ({:}, {:})'.format(alg, search_space, dataset, index, j) + alg2data[alg] = data + return alg2data + + def get_valid_test_acc(api, arch, dataset): is_size_space = api.search_space_name == 'size' if dataset == 'cifar10': @@ -52,7 +73,6 @@ def show_valid_test(api, arch): def find_best_valid(api, dataset): all_valid_accs, all_test_accs = [], [] for index, arch in enumerate(api): - # import pdb; pdb.set_trace() valid_acc, test_acc, perf_str = get_valid_test_acc(api, index, dataset) all_valid_accs.append((index, valid_acc)) all_test_accs.append((index, test_acc)) @@ -68,8 +88,62 @@ def find_best_valid(api, dataset): print('using test ::: {:}'.format(perf_str)) +def interplate_fn(xpair1, xpair2, x): + (x1, y1) = xpair1 + (x2, y2) = xpair2 + return (x2 - x) / (x2 - x1) * y1 + \ + (x - x1) / (x2 - x1) * y2 + +def query_performance(api, info, dataset, ticket): + info = deepcopy(info) + results, is_size_space = [], api.search_space_name == 'size' + time_w_arch = sorted(info['time_w_arch'], key=lambda x: abs(x[0]-ticket)) + time_a, arch_a = time_w_arch[0] + time_b, arch_b = time_w_arch[1] + + v_acc_a, t_acc_a, _ = get_valid_test_acc(api, arch_a, dataset) + v_acc_b, t_acc_b, _ = get_valid_test_acc(api, arch_b, dataset) + v_acc = interplate_fn((time_a, v_acc_a), (time_b, v_acc_b), ticket) + t_acc = interplate_fn((time_a, t_acc_a), (time_b, t_acc_b), ticket) + # if True: + # interplate = (time_b-ticket) / (time_b-time_a) * accuracy_a + (ticket-time_a) / (time_b-time_a) * accuracy_b + # results.append(interplate) + # return sum(results) / len(results) + return v_acc, t_acc + + +def show_multi_trial(search_space): + api = create(None, search_space, fast_mode=True, verbose=False) + def show(dataset): + print('show {:} on {:} done.'.format(dataset, search_space)) + xdataset, max_time = dataset.split('-T') + alg2data = fetch_data(search_space=search_space, dataset=dataset) + for idx, (alg, data) in enumerate(alg2data.items()): + + valid_accs, test_accs = [], [] + for _, x in data.items(): + v_acc, t_acc = query_performance(api, x, xdataset, float(max_time)) + valid_accs.append(v_acc) + test_accs.append(t_acc) + valid_str = '{:.2f}$\pm${:.2f}'.format(np.mean(valid_accs), np.std(valid_accs)) + test_str = '{:.2f}$\pm${:.2f}'.format(np.mean(test_accs), np.std(test_accs)) + print('{:} plot alg : {:10s} | validation = {:} | test = {:}'.format(time_string(), alg, valid_str, test_str)) + if search_space == 'tss': + datasets = ['cifar10-T20000', 'cifar100-T40000', 'ImageNet16-120-T120000'] + elif search_space == 'sss': + datasets = ['cifar10-T20000', 'cifar100-T40000', 'ImageNet16-120-T60000'] + else: + raise ValueError('Unknown search space: {:}'.format(search_space)) + for dataset in datasets: + show(dataset) + print('{:} complete show multi-trial results.\n'.format(time_string())) + + if __name__ == '__main__': + show_multi_trial('tss') + show_multi_trial('sss') + api_tss = create(None, 'tss', fast_mode=False, verbose=False) resnet = '|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|' resnet_index = api_tss.query_index_by_arch(resnet) diff --git a/exps/NATS-algos/regularized_ea.py b/exps/NATS-algos/regularized_ea.py index 2af14f2..14ba2fd 100644 --- a/exps/NATS-algos/regularized_ea.py +++ b/exps/NATS-algos/regularized_ea.py @@ -95,7 +95,7 @@ def mutate_size_func(info): return mutate_size_func -def regularized_evolution(cycles, population_size, sample_size, time_budget, random_arch, mutate_arch, api, dataset): +def regularized_evolution(cycles, population_size, sample_size, time_budget, random_arch, mutate_arch, api, use_proxy, dataset): """Algorithm for regularized evolution (i.e. aging evolution). Follows "Algorithm 1" in Real et al. "Regularized Evolution for Image @@ -119,7 +119,10 @@ def regularized_evolution(cycles, population_size, sample_size, time_budget, ran while len(population) < population_size: model = Model() model.arch = random_arch() - model.accuracy, _, _, total_cost = api.simulate_train_eval(model.arch, dataset, hp='12') + if use_proxy: + model.accuracy, _, _, total_cost = api.simulate_train_eval(model.arch, dataset, hp='12') + else: + model.accuracy, _, _, total_cost = api.simulate_train_eval(model.arch, dataset, hp=api.full_train_epochs) # Append the info population.append(model) history.append((model.accuracy, model.arch)) @@ -171,7 +174,11 @@ def main(xargs, api): x_start_time = time.time() logger.log('{:} use api : {:}'.format(time_string(), api)) logger.log('-'*30 + ' start searching with the time budget of {:} s'.format(xargs.time_budget)) - history, current_best_index, total_times = regularized_evolution(xargs.ea_cycles, xargs.ea_population, xargs.ea_sample_size, xargs.time_budget, random_arch, mutate_arch, api, xargs.dataset) + history, current_best_index, total_times = regularized_evolution(xargs.ea_cycles, + xargs.ea_population, + xargs.ea_sample_size, + xargs.time_budget, + random_arch, mutate_arch, api, xargs.use_proxy > 0, xargs.dataset) logger.log('{:} regularized_evolution finish with history of {:} arch with {:.1f} s (real-cost={:.2f} s).'.format(time_string(), len(history), total_times[-1], time.time()-x_start_time)) best_arch = max(history, key=lambda x: x[0])[1] logger.log('{:} best arch is {:}'.format(time_string(), best_arch)) @@ -187,11 +194,13 @@ if __name__ == '__main__': parser = argparse.ArgumentParser("Regularized Evolution Algorithm") parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') parser.add_argument('--search_space', type=str, choices=['tss', 'sss'], help='Choose the search space.') - # channels and number-of-cells + # hyperparameters for REA parser.add_argument('--ea_cycles', type=int, help='The number of cycles in EA.') parser.add_argument('--ea_population', type=int, help='The population size in EA.') parser.add_argument('--ea_sample_size', type=int, help='The sample size in EA.') parser.add_argument('--time_budget', type=int, default=20000, help='The total time cost budge for searching (in seconds).') + parser.add_argument('--use_proxy', type=int, default=1, help='Whether to use the proxy (H0) task or not.') + # parser.add_argument('--loops_if_rand', type=int, default=500, help='The total runs for evaluation.') # log parser.add_argument('--save_dir', type=str, default='./output/search', help='Folder to save checkpoints and log.') @@ -201,7 +210,8 @@ if __name__ == '__main__': api = create(None, args.search_space, fast_mode=True, verbose=False) args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), - '{:}-T{:}'.format(args.dataset, args.time_budget), 'R-EA-SS{:}'.format(args.ea_sample_size)) + '{:}-T{:}{:}'.format(args.dataset, args.time_budget, '' if args.use_proxy > 0 else '-FULL'), + 'R-EA-SS{:}'.format(args.ea_sample_size)) print('save-dir : {:}'.format(args.save_dir)) print('xargs : {:}'.format(args)) diff --git a/lib/nats_bench/api_size.py b/lib/nats_bench/api_size.py index e7400fa..6eab753 100644 --- a/lib/nats_bench/api_size.py +++ b/lib/nats_bench/api_size.py @@ -83,6 +83,7 @@ class NATSsize(NASBenchMetaAPI): self._search_space_name = 'size' self._fast_mode = fast_mode self._archive_dir = None + self._full_train_epochs = 90 self.reset_time() if file_path_or_dict is None: if self._fast_mode: diff --git a/lib/nats_bench/api_topology.py b/lib/nats_bench/api_topology.py index 205c44a..f4211c4 100644 --- a/lib/nats_bench/api_topology.py +++ b/lib/nats_bench/api_topology.py @@ -83,6 +83,7 @@ class NATStopology(NASBenchMetaAPI): self._search_space_name = 'topology' self._fast_mode = fast_mode self._archive_dir = None + self._full_train_epochs = 200 self.reset_time() if file_path_or_dict is None: if self._fast_mode: diff --git a/lib/nats_bench/api_utils.py b/lib/nats_bench/api_utils.py index 8e63446..86a7a02 100644 --- a/lib/nats_bench/api_utils.py +++ b/lib/nats_bench/api_utils.py @@ -190,6 +190,10 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta): def archive_dir(self): return self._archive_dir + @property + def full_train_epochs(self): + return self._full_train_epochs + def reset_archive_dir(self, archive_dir): self._archive_dir = archive_dir