Update for Rebuttal

This commit is contained in:
D-X-Y 2020-12-01 12:34:00 +08:00
parent 29428bf5a3
commit 8afb62ad2e
5 changed files with 96 additions and 6 deletions

View File

@ -26,6 +26,27 @@ from nats_bench import create
from log_utils import time_string
def fetch_data(root_dir='./output/search', search_space='tss', dataset=None):
ss_dir = '{:}-{:}'.format(root_dir, search_space)
alg2name, alg2path = OrderedDict(), OrderedDict()
alg2name['REA'] = 'R-EA-SS3'
alg2name['REINFORCE'] = 'REINFORCE-0.01'
alg2name['RANDOM'] = 'RANDOM'
alg2name['BOHB'] = 'BOHB'
for alg, name in alg2name.items():
alg2path[alg] = os.path.join(ss_dir, dataset, name, 'results.pth')
assert os.path.isfile(alg2path[alg]), 'invalid path : {:}'.format(alg2path[alg])
alg2data = OrderedDict()
for alg, path in alg2path.items():
data = torch.load(path)
for index, info in data.items():
info['time_w_arch'] = [(x, y) for x, y in zip(info['all_total_times'], info['all_archs'])]
for j, arch in enumerate(info['all_archs']):
assert arch != -1, 'invalid arch from {:} {:} {:} ({:}, {:})'.format(alg, search_space, dataset, index, j)
alg2data[alg] = data
return alg2data
def get_valid_test_acc(api, arch, dataset):
is_size_space = api.search_space_name == 'size'
if dataset == 'cifar10':
@ -52,7 +73,6 @@ def show_valid_test(api, arch):
def find_best_valid(api, dataset):
all_valid_accs, all_test_accs = [], []
for index, arch in enumerate(api):
# import pdb; pdb.set_trace()
valid_acc, test_acc, perf_str = get_valid_test_acc(api, index, dataset)
all_valid_accs.append((index, valid_acc))
all_test_accs.append((index, test_acc))
@ -68,8 +88,62 @@ def find_best_valid(api, dataset):
print('using test ::: {:}'.format(perf_str))
def interplate_fn(xpair1, xpair2, x):
(x1, y1) = xpair1
(x2, y2) = xpair2
return (x2 - x) / (x2 - x1) * y1 + \
(x - x1) / (x2 - x1) * y2
def query_performance(api, info, dataset, ticket):
info = deepcopy(info)
results, is_size_space = [], api.search_space_name == 'size'
time_w_arch = sorted(info['time_w_arch'], key=lambda x: abs(x[0]-ticket))
time_a, arch_a = time_w_arch[0]
time_b, arch_b = time_w_arch[1]
v_acc_a, t_acc_a, _ = get_valid_test_acc(api, arch_a, dataset)
v_acc_b, t_acc_b, _ = get_valid_test_acc(api, arch_b, dataset)
v_acc = interplate_fn((time_a, v_acc_a), (time_b, v_acc_b), ticket)
t_acc = interplate_fn((time_a, t_acc_a), (time_b, t_acc_b), ticket)
# if True:
# interplate = (time_b-ticket) / (time_b-time_a) * accuracy_a + (ticket-time_a) / (time_b-time_a) * accuracy_b
# results.append(interplate)
# return sum(results) / len(results)
return v_acc, t_acc
def show_multi_trial(search_space):
api = create(None, search_space, fast_mode=True, verbose=False)
def show(dataset):
print('show {:} on {:} done.'.format(dataset, search_space))
xdataset, max_time = dataset.split('-T')
alg2data = fetch_data(search_space=search_space, dataset=dataset)
for idx, (alg, data) in enumerate(alg2data.items()):
valid_accs, test_accs = [], []
for _, x in data.items():
v_acc, t_acc = query_performance(api, x, xdataset, float(max_time))
valid_accs.append(v_acc)
test_accs.append(t_acc)
valid_str = '{:.2f}$\pm${:.2f}'.format(np.mean(valid_accs), np.std(valid_accs))
test_str = '{:.2f}$\pm${:.2f}'.format(np.mean(test_accs), np.std(test_accs))
print('{:} plot alg : {:10s} | validation = {:} | test = {:}'.format(time_string(), alg, valid_str, test_str))
if search_space == 'tss':
datasets = ['cifar10-T20000', 'cifar100-T40000', 'ImageNet16-120-T120000']
elif search_space == 'sss':
datasets = ['cifar10-T20000', 'cifar100-T40000', 'ImageNet16-120-T60000']
else:
raise ValueError('Unknown search space: {:}'.format(search_space))
for dataset in datasets:
show(dataset)
print('{:} complete show multi-trial results.\n'.format(time_string()))
if __name__ == '__main__':
show_multi_trial('tss')
show_multi_trial('sss')
api_tss = create(None, 'tss', fast_mode=False, verbose=False)
resnet = '|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|'
resnet_index = api_tss.query_index_by_arch(resnet)

View File

@ -95,7 +95,7 @@ def mutate_size_func(info):
return mutate_size_func
def regularized_evolution(cycles, population_size, sample_size, time_budget, random_arch, mutate_arch, api, dataset):
def regularized_evolution(cycles, population_size, sample_size, time_budget, random_arch, mutate_arch, api, use_proxy, dataset):
"""Algorithm for regularized evolution (i.e. aging evolution).
Follows "Algorithm 1" in Real et al. "Regularized Evolution for Image
@ -119,7 +119,10 @@ def regularized_evolution(cycles, population_size, sample_size, time_budget, ran
while len(population) < population_size:
model = Model()
model.arch = random_arch()
model.accuracy, _, _, total_cost = api.simulate_train_eval(model.arch, dataset, hp='12')
if use_proxy:
model.accuracy, _, _, total_cost = api.simulate_train_eval(model.arch, dataset, hp='12')
else:
model.accuracy, _, _, total_cost = api.simulate_train_eval(model.arch, dataset, hp=api.full_train_epochs)
# Append the info
population.append(model)
history.append((model.accuracy, model.arch))
@ -171,7 +174,11 @@ def main(xargs, api):
x_start_time = time.time()
logger.log('{:} use api : {:}'.format(time_string(), api))
logger.log('-'*30 + ' start searching with the time budget of {:} s'.format(xargs.time_budget))
history, current_best_index, total_times = regularized_evolution(xargs.ea_cycles, xargs.ea_population, xargs.ea_sample_size, xargs.time_budget, random_arch, mutate_arch, api, xargs.dataset)
history, current_best_index, total_times = regularized_evolution(xargs.ea_cycles,
xargs.ea_population,
xargs.ea_sample_size,
xargs.time_budget,
random_arch, mutate_arch, api, xargs.use_proxy > 0, xargs.dataset)
logger.log('{:} regularized_evolution finish with history of {:} arch with {:.1f} s (real-cost={:.2f} s).'.format(time_string(), len(history), total_times[-1], time.time()-x_start_time))
best_arch = max(history, key=lambda x: x[0])[1]
logger.log('{:} best arch is {:}'.format(time_string(), best_arch))
@ -187,11 +194,13 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser("Regularized Evolution Algorithm")
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
parser.add_argument('--search_space', type=str, choices=['tss', 'sss'], help='Choose the search space.')
# channels and number-of-cells
# hyperparameters for REA
parser.add_argument('--ea_cycles', type=int, help='The number of cycles in EA.')
parser.add_argument('--ea_population', type=int, help='The population size in EA.')
parser.add_argument('--ea_sample_size', type=int, help='The sample size in EA.')
parser.add_argument('--time_budget', type=int, default=20000, help='The total time cost budge for searching (in seconds).')
parser.add_argument('--use_proxy', type=int, default=1, help='Whether to use the proxy (H0) task or not.')
#
parser.add_argument('--loops_if_rand', type=int, default=500, help='The total runs for evaluation.')
# log
parser.add_argument('--save_dir', type=str, default='./output/search', help='Folder to save checkpoints and log.')
@ -201,7 +210,8 @@ if __name__ == '__main__':
api = create(None, args.search_space, fast_mode=True, verbose=False)
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space),
'{:}-T{:}'.format(args.dataset, args.time_budget), 'R-EA-SS{:}'.format(args.ea_sample_size))
'{:}-T{:}{:}'.format(args.dataset, args.time_budget, '' if args.use_proxy > 0 else '-FULL'),
'R-EA-SS{:}'.format(args.ea_sample_size))
print('save-dir : {:}'.format(args.save_dir))
print('xargs : {:}'.format(args))

View File

@ -83,6 +83,7 @@ class NATSsize(NASBenchMetaAPI):
self._search_space_name = 'size'
self._fast_mode = fast_mode
self._archive_dir = None
self._full_train_epochs = 90
self.reset_time()
if file_path_or_dict is None:
if self._fast_mode:

View File

@ -83,6 +83,7 @@ class NATStopology(NASBenchMetaAPI):
self._search_space_name = 'topology'
self._fast_mode = fast_mode
self._archive_dir = None
self._full_train_epochs = 200
self.reset_time()
if file_path_or_dict is None:
if self._fast_mode:

View File

@ -190,6 +190,10 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
def archive_dir(self):
return self._archive_dir
@property
def full_train_epochs(self):
return self._full_train_epochs
def reset_archive_dir(self, archive_dir):
self._archive_dir = archive_dir