Update for Rebuttal

This commit is contained in:
D-X-Y 2020-12-01 12:34:00 +08:00
parent 29428bf5a3
commit 8afb62ad2e
5 changed files with 96 additions and 6 deletions

View File

@ -26,6 +26,27 @@ from nats_bench import create
from log_utils import time_string from log_utils import time_string
def fetch_data(root_dir='./output/search', search_space='tss', dataset=None):
ss_dir = '{:}-{:}'.format(root_dir, search_space)
alg2name, alg2path = OrderedDict(), OrderedDict()
alg2name['REA'] = 'R-EA-SS3'
alg2name['REINFORCE'] = 'REINFORCE-0.01'
alg2name['RANDOM'] = 'RANDOM'
alg2name['BOHB'] = 'BOHB'
for alg, name in alg2name.items():
alg2path[alg] = os.path.join(ss_dir, dataset, name, 'results.pth')
assert os.path.isfile(alg2path[alg]), 'invalid path : {:}'.format(alg2path[alg])
alg2data = OrderedDict()
for alg, path in alg2path.items():
data = torch.load(path)
for index, info in data.items():
info['time_w_arch'] = [(x, y) for x, y in zip(info['all_total_times'], info['all_archs'])]
for j, arch in enumerate(info['all_archs']):
assert arch != -1, 'invalid arch from {:} {:} {:} ({:}, {:})'.format(alg, search_space, dataset, index, j)
alg2data[alg] = data
return alg2data
def get_valid_test_acc(api, arch, dataset): def get_valid_test_acc(api, arch, dataset):
is_size_space = api.search_space_name == 'size' is_size_space = api.search_space_name == 'size'
if dataset == 'cifar10': if dataset == 'cifar10':
@ -52,7 +73,6 @@ def show_valid_test(api, arch):
def find_best_valid(api, dataset): def find_best_valid(api, dataset):
all_valid_accs, all_test_accs = [], [] all_valid_accs, all_test_accs = [], []
for index, arch in enumerate(api): for index, arch in enumerate(api):
# import pdb; pdb.set_trace()
valid_acc, test_acc, perf_str = get_valid_test_acc(api, index, dataset) valid_acc, test_acc, perf_str = get_valid_test_acc(api, index, dataset)
all_valid_accs.append((index, valid_acc)) all_valid_accs.append((index, valid_acc))
all_test_accs.append((index, test_acc)) all_test_accs.append((index, test_acc))
@ -68,8 +88,62 @@ def find_best_valid(api, dataset):
print('using test ::: {:}'.format(perf_str)) print('using test ::: {:}'.format(perf_str))
def interplate_fn(xpair1, xpair2, x):
(x1, y1) = xpair1
(x2, y2) = xpair2
return (x2 - x) / (x2 - x1) * y1 + \
(x - x1) / (x2 - x1) * y2
def query_performance(api, info, dataset, ticket):
info = deepcopy(info)
results, is_size_space = [], api.search_space_name == 'size'
time_w_arch = sorted(info['time_w_arch'], key=lambda x: abs(x[0]-ticket))
time_a, arch_a = time_w_arch[0]
time_b, arch_b = time_w_arch[1]
v_acc_a, t_acc_a, _ = get_valid_test_acc(api, arch_a, dataset)
v_acc_b, t_acc_b, _ = get_valid_test_acc(api, arch_b, dataset)
v_acc = interplate_fn((time_a, v_acc_a), (time_b, v_acc_b), ticket)
t_acc = interplate_fn((time_a, t_acc_a), (time_b, t_acc_b), ticket)
# if True:
# interplate = (time_b-ticket) / (time_b-time_a) * accuracy_a + (ticket-time_a) / (time_b-time_a) * accuracy_b
# results.append(interplate)
# return sum(results) / len(results)
return v_acc, t_acc
def show_multi_trial(search_space):
api = create(None, search_space, fast_mode=True, verbose=False)
def show(dataset):
print('show {:} on {:} done.'.format(dataset, search_space))
xdataset, max_time = dataset.split('-T')
alg2data = fetch_data(search_space=search_space, dataset=dataset)
for idx, (alg, data) in enumerate(alg2data.items()):
valid_accs, test_accs = [], []
for _, x in data.items():
v_acc, t_acc = query_performance(api, x, xdataset, float(max_time))
valid_accs.append(v_acc)
test_accs.append(t_acc)
valid_str = '{:.2f}$\pm${:.2f}'.format(np.mean(valid_accs), np.std(valid_accs))
test_str = '{:.2f}$\pm${:.2f}'.format(np.mean(test_accs), np.std(test_accs))
print('{:} plot alg : {:10s} | validation = {:} | test = {:}'.format(time_string(), alg, valid_str, test_str))
if search_space == 'tss':
datasets = ['cifar10-T20000', 'cifar100-T40000', 'ImageNet16-120-T120000']
elif search_space == 'sss':
datasets = ['cifar10-T20000', 'cifar100-T40000', 'ImageNet16-120-T60000']
else:
raise ValueError('Unknown search space: {:}'.format(search_space))
for dataset in datasets:
show(dataset)
print('{:} complete show multi-trial results.\n'.format(time_string()))
if __name__ == '__main__': if __name__ == '__main__':
show_multi_trial('tss')
show_multi_trial('sss')
api_tss = create(None, 'tss', fast_mode=False, verbose=False) api_tss = create(None, 'tss', fast_mode=False, verbose=False)
resnet = '|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|' resnet = '|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|'
resnet_index = api_tss.query_index_by_arch(resnet) resnet_index = api_tss.query_index_by_arch(resnet)

View File

@ -95,7 +95,7 @@ def mutate_size_func(info):
return mutate_size_func return mutate_size_func
def regularized_evolution(cycles, population_size, sample_size, time_budget, random_arch, mutate_arch, api, dataset): def regularized_evolution(cycles, population_size, sample_size, time_budget, random_arch, mutate_arch, api, use_proxy, dataset):
"""Algorithm for regularized evolution (i.e. aging evolution). """Algorithm for regularized evolution (i.e. aging evolution).
Follows "Algorithm 1" in Real et al. "Regularized Evolution for Image Follows "Algorithm 1" in Real et al. "Regularized Evolution for Image
@ -119,7 +119,10 @@ def regularized_evolution(cycles, population_size, sample_size, time_budget, ran
while len(population) < population_size: while len(population) < population_size:
model = Model() model = Model()
model.arch = random_arch() model.arch = random_arch()
if use_proxy:
model.accuracy, _, _, total_cost = api.simulate_train_eval(model.arch, dataset, hp='12') model.accuracy, _, _, total_cost = api.simulate_train_eval(model.arch, dataset, hp='12')
else:
model.accuracy, _, _, total_cost = api.simulate_train_eval(model.arch, dataset, hp=api.full_train_epochs)
# Append the info # Append the info
population.append(model) population.append(model)
history.append((model.accuracy, model.arch)) history.append((model.accuracy, model.arch))
@ -171,7 +174,11 @@ def main(xargs, api):
x_start_time = time.time() x_start_time = time.time()
logger.log('{:} use api : {:}'.format(time_string(), api)) logger.log('{:} use api : {:}'.format(time_string(), api))
logger.log('-'*30 + ' start searching with the time budget of {:} s'.format(xargs.time_budget)) logger.log('-'*30 + ' start searching with the time budget of {:} s'.format(xargs.time_budget))
history, current_best_index, total_times = regularized_evolution(xargs.ea_cycles, xargs.ea_population, xargs.ea_sample_size, xargs.time_budget, random_arch, mutate_arch, api, xargs.dataset) history, current_best_index, total_times = regularized_evolution(xargs.ea_cycles,
xargs.ea_population,
xargs.ea_sample_size,
xargs.time_budget,
random_arch, mutate_arch, api, xargs.use_proxy > 0, xargs.dataset)
logger.log('{:} regularized_evolution finish with history of {:} arch with {:.1f} s (real-cost={:.2f} s).'.format(time_string(), len(history), total_times[-1], time.time()-x_start_time)) logger.log('{:} regularized_evolution finish with history of {:} arch with {:.1f} s (real-cost={:.2f} s).'.format(time_string(), len(history), total_times[-1], time.time()-x_start_time))
best_arch = max(history, key=lambda x: x[0])[1] best_arch = max(history, key=lambda x: x[0])[1]
logger.log('{:} best arch is {:}'.format(time_string(), best_arch)) logger.log('{:} best arch is {:}'.format(time_string(), best_arch))
@ -187,11 +194,13 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser("Regularized Evolution Algorithm") parser = argparse.ArgumentParser("Regularized Evolution Algorithm")
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
parser.add_argument('--search_space', type=str, choices=['tss', 'sss'], help='Choose the search space.') parser.add_argument('--search_space', type=str, choices=['tss', 'sss'], help='Choose the search space.')
# channels and number-of-cells # hyperparameters for REA
parser.add_argument('--ea_cycles', type=int, help='The number of cycles in EA.') parser.add_argument('--ea_cycles', type=int, help='The number of cycles in EA.')
parser.add_argument('--ea_population', type=int, help='The population size in EA.') parser.add_argument('--ea_population', type=int, help='The population size in EA.')
parser.add_argument('--ea_sample_size', type=int, help='The sample size in EA.') parser.add_argument('--ea_sample_size', type=int, help='The sample size in EA.')
parser.add_argument('--time_budget', type=int, default=20000, help='The total time cost budge for searching (in seconds).') parser.add_argument('--time_budget', type=int, default=20000, help='The total time cost budge for searching (in seconds).')
parser.add_argument('--use_proxy', type=int, default=1, help='Whether to use the proxy (H0) task or not.')
#
parser.add_argument('--loops_if_rand', type=int, default=500, help='The total runs for evaluation.') parser.add_argument('--loops_if_rand', type=int, default=500, help='The total runs for evaluation.')
# log # log
parser.add_argument('--save_dir', type=str, default='./output/search', help='Folder to save checkpoints and log.') parser.add_argument('--save_dir', type=str, default='./output/search', help='Folder to save checkpoints and log.')
@ -201,7 +210,8 @@ if __name__ == '__main__':
api = create(None, args.search_space, fast_mode=True, verbose=False) api = create(None, args.search_space, fast_mode=True, verbose=False)
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space),
'{:}-T{:}'.format(args.dataset, args.time_budget), 'R-EA-SS{:}'.format(args.ea_sample_size)) '{:}-T{:}{:}'.format(args.dataset, args.time_budget, '' if args.use_proxy > 0 else '-FULL'),
'R-EA-SS{:}'.format(args.ea_sample_size))
print('save-dir : {:}'.format(args.save_dir)) print('save-dir : {:}'.format(args.save_dir))
print('xargs : {:}'.format(args)) print('xargs : {:}'.format(args))

View File

@ -83,6 +83,7 @@ class NATSsize(NASBenchMetaAPI):
self._search_space_name = 'size' self._search_space_name = 'size'
self._fast_mode = fast_mode self._fast_mode = fast_mode
self._archive_dir = None self._archive_dir = None
self._full_train_epochs = 90
self.reset_time() self.reset_time()
if file_path_or_dict is None: if file_path_or_dict is None:
if self._fast_mode: if self._fast_mode:

View File

@ -83,6 +83,7 @@ class NATStopology(NASBenchMetaAPI):
self._search_space_name = 'topology' self._search_space_name = 'topology'
self._fast_mode = fast_mode self._fast_mode = fast_mode
self._archive_dir = None self._archive_dir = None
self._full_train_epochs = 200
self.reset_time() self.reset_time()
if file_path_or_dict is None: if file_path_or_dict is None:
if self._fast_mode: if self._fast_mode:

View File

@ -190,6 +190,10 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
def archive_dir(self): def archive_dir(self):
return self._archive_dir return self._archive_dir
@property
def full_train_epochs(self):
return self._full_train_epochs
def reset_archive_dir(self, archive_dir): def reset_archive_dir(self, archive_dir):
self._archive_dir = archive_dir self._archive_dir = archive_dir