diff --git a/exps/algos-v2/R_EA.py b/exps/algos-v2/REA.py similarity index 90% rename from exps/algos-v2/R_EA.py rename to exps/algos-v2/REA.py index f2eb7bb..7410fa3 100644 --- a/exps/algos-v2/R_EA.py +++ b/exps/algos-v2/REA.py @@ -3,13 +3,13 @@ ################################################################## # Regularized Evolution for Image Classifier Architecture Search # ################################################################## -# python ./exps/algos-v2/R_EA.py --dataset cifar10 --search_space tss --time_budget 12000 --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1 -# python ./exps/algos-v2/R_EA.py --dataset cifar100 --search_space tss --time_budget 12000 --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1 -# python ./exps/algos-v2/R_EA.py --dataset ImageNet16-120 --search_space tss --time_budget 12000 --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1 -# python ./exps/algos-v2/R_EA.py --dataset cifar10 --search_space sss --time_budget 12000 --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1 -# -# -# +# python ./exps/algos-v2/REA.py --dataset cifar10 --search_space tss --time_budget 12000 --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1 +# python ./exps/algos-v2/REA.py --dataset cifar100 --search_space tss --time_budget 12000 --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1 +# python ./exps/algos-v2/REA.py --dataset ImageNet16-120 --search_space tss --time_budget 12000 --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1 +# python ./exps/algos-v2/REA.py --dataset cifar10 --search_space sss --time_budget 12000 --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1 +# python ./exps/algos-v2/REA.py --dataset cifar100 --search_space sss --time_budget 12000 --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1 +# python ./exps/algos-v2/REA.py --dataset ImageNet16-120 --search_space sss --time_budget 12000 --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1 +################################################################## import os, sys, time, glob, random, argparse import numpy as np, collections from copy import deepcopy @@ -236,12 +236,12 @@ if __name__ == '__main__': parser.add_argument('--ea_population', type=int, help='The population size in EA.') parser.add_argument('--ea_sample_size', type=int, help='The sample size in EA.') parser.add_argument('--time_budget', type=int, help='The total time cost budge for searching (in seconds).') + parser.add_argument('--loops_if_rand', type=int, default=500, help='The total runs for evaluation.') # log parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)') parser.add_argument('--save_dir', type=str, default='./output/search', help='Folder to save checkpoints and log.') parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed') args = parser.parse_args() - #if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) if args.search_space == 'tss': api = NASBench201API(verbose=False) @@ -250,17 +250,19 @@ if __name__ == '__main__': else: raise ValueError('Invalid search space : {:}'.format(args.search_space)) - args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), 'R-EA-SS{:}'.format(args.ea_sample_size)) + args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, 'R-EA-SS{:}'.format(args.ea_sample_size)) print('save-dir : {:}'.format(args.save_dir)) if args.rand_seed < 0: - save_dir, all_info, num = None, {}, 500 - for i in range(num): - print ('{:} : {:03d}/{:03d}'.format(time_string(), i, num)) + save_dir, all_info = None, {} + for i in range(args.loops_if_rand): + print ('{:} : {:03d}/{:03d}'.format(time_string(), i, args.loops_if_rand)) args.rand_seed = random.randint(1, 100000) save_dir, all_archs, all_total_times = main(args, api) all_info[i] = {'all_archs': all_archs, 'all_total_times': all_total_times} - torch.save(all_info, save_dir / 'results.pth') + save_path = save_dir / 'results.pth' + print('save into {:}'.format(save_path)) + torch.save(all_info, save_path) else: main(args, api) diff --git a/exps/algos-v2/README.md b/exps/algos-v2/README.md new file mode 100644 index 0000000..f1de907 --- /dev/null +++ b/exps/algos-v2/README.md @@ -0,0 +1 @@ +# Benchmarking NAS Algorithms diff --git a/exps/algos-v2/reinforce.py b/exps/algos-v2/reinforce.py new file mode 100644 index 0000000..c81708c --- /dev/null +++ b/exps/algos-v2/reinforce.py @@ -0,0 +1,222 @@ +################################################## +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # +##################################################################################################### +# modified from https://github.com/pytorch/examples/blob/master/reinforcement_learning/reinforce.py # +##################################################################################################### +# python ./exps/algos-v2/reinforce.py --dataset cifar10 --search_space tss --time_budget 12000 --learning_rate 0.001 +# python ./exps/algos-v2/reinforce.py --dataset cifar100 --search_space tss --time_budget 12000 --learning_rate 0.001 +# python ./exps/algos-v2/reinforce.py --dataset ImageNet16-120 --search_space tss --time_budget 12000 --learning_rate 0.001 +# python ./exps/algos-v2/reinforce.py --dataset cifar10 --search_space sss --time_budget 12000 --learning_rate 0.001 +# python ./exps/algos-v2/reinforce.py --dataset cifar100 --search_space sss --time_budget 12000 --learning_rate 0.001 +# python ./exps/algos-v2/reinforce.py --dataset ImageNet16-120 --search_space sss --time_budget 12000 --learning_rate 0.001 +##################################################################################################### +import os, sys, time, glob, random, argparse +import numpy as np, collections +from copy import deepcopy +from pathlib import Path +import torch +import torch.nn as nn +from torch.distributions import Categorical +lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() +if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) +from config_utils import load_config, dict2config, configure2str +from datasets import get_datasets, SearchDataset +from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler +from utils import get_model_infos, obtain_accuracy +from log_utils import AverageMeter, time_string, convert_secs2time +from nas_201_api import NASBench201API, NASBench301API +from models import CellStructure, get_search_spaces + + +class PolicyTopology(nn.Module): + + def __init__(self, search_space, max_nodes=4): + super(PolicyTopology, self).__init__() + self.max_nodes = max_nodes + self.search_space = deepcopy(search_space) + self.edge2index = {} + for i in range(1, max_nodes): + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + self.edge2index[ node_str ] = len(self.edge2index) + self.arch_parameters = nn.Parameter(1e-3*torch.randn(len(self.edge2index), len(search_space))) + + def generate_arch(self, actions): + genotypes = [] + for i in range(1, self.max_nodes): + xlist = [] + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + op_name = self.search_space[ actions[ self.edge2index[ node_str ] ] ] + xlist.append((op_name, j)) + genotypes.append( tuple(xlist) ) + return CellStructure( genotypes ) + + def genotype(self): + genotypes = [] + for i in range(1, self.max_nodes): + xlist = [] + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + with torch.no_grad(): + weights = self.arch_parameters[ self.edge2index[node_str] ] + op_name = self.search_space[ weights.argmax().item() ] + xlist.append((op_name, j)) + genotypes.append( tuple(xlist) ) + return CellStructure( genotypes ) + + def forward(self): + alphas = nn.functional.softmax(self.arch_parameters, dim=-1) + return alphas + + +class PolicySize(nn.Module): + + def __init__(self, search_space): + super(PolicySize, self).__init__() + self.candidates = search_space['candidates'] + self.numbers = search_space['numbers'] + self.arch_parameters = nn.Parameter(1e-3*torch.randn(self.numbers, len(self.candidates))) + + def generate_arch(self, actions): + channels = [str(self.candidates[i]) for i in actions] + return ':'.join(channels) + + def genotype(self): + channels = [] + for i in range(self.numbers): + index = self.arch_parameters[i].argmax().item() + channels.append(str(self.candidates[index])) + return ':'.join(channels) + + def forward(self): + alphas = nn.functional.softmax(self.arch_parameters, dim=-1) + return alphas + + +class ExponentialMovingAverage(object): + """Class that maintains an exponential moving average.""" + + def __init__(self, momentum): + self._numerator = 0 + self._denominator = 0 + self._momentum = momentum + + def update(self, value): + self._numerator = self._momentum * self._numerator + (1 - self._momentum) * value + self._denominator = self._momentum * self._denominator + (1 - self._momentum) + + def value(self): + """Return the current value of the moving average""" + return self._numerator / self._denominator + + +def select_action(policy): + probs = policy() + m = Categorical(probs) + action = m.sample() + # policy.saved_log_probs.append(m.log_prob(action)) + return m.log_prob(action), action.cpu().tolist() + + +def main(xargs, api): + assert torch.cuda.is_available(), 'CUDA is not available.' + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + torch.set_num_threads(xargs.workers) + prepare_seed(xargs.rand_seed) + logger = prepare_logger(args) + + + search_space = get_search_spaces(xargs.search_space, 'nas-bench-301') + if xargs.search_space == 'tss': + policy = PolicyTopology(search_space) + else: + policy = PolicySize(search_space) + optimizer = torch.optim.Adam(policy.parameters(), lr=xargs.learning_rate) + #optimizer = torch.optim.SGD(policy.parameters(), lr=xargs.learning_rate) + eps = np.finfo(np.float32).eps.item() + baseline = ExponentialMovingAverage(xargs.EMA_momentum) + logger.log('policy : {:}'.format(policy)) + logger.log('optimizer : {:}'.format(optimizer)) + logger.log('eps : {:}'.format(eps)) + + # nas dataset load + logger.log('{:} use api : {:}'.format(time_string(), api)) + + # REINFORCE + x_start_time = time.time() + logger.log('Will start searching with time budget of {:} s.'.format(xargs.time_budget)) + total_steps, total_costs, trace = 0, [], [] + while len(total_costs) == 0 or total_costs[-1] < xargs.time_budget: + start_time = time.time() + log_prob, action = select_action( policy ) + arch = policy.generate_arch( action ) + reward, _, current_total_cost = api.simulate_train_eval(arch, xargs.dataset, '12') + trace.append((reward, arch)) + total_costs.append(current_total_cost) + + baseline.update(reward) + # calculate loss + policy_loss = ( -log_prob * (reward - baseline.value()) ).sum() + optimizer.zero_grad() + policy_loss.backward() + optimizer.step() + # accumulate time + total_steps += 1 + logger.log('step [{:3d}] : average-reward={:.3f} : policy_loss={:.4f} : {:}'.format(total_steps, baseline.value(), policy_loss.item(), policy.genotype())) + #logger.log('----> {:}'.format(policy.arch_parameters)) + #logger.log('') + + # best_arch = policy.genotype() # first version + best_arch = max(trace, key=lambda x: x[0])[1] + logger.log('REINFORCE finish with {:} steps and {:.1f} s (real cost={:.3f}).'.format(total_steps, total_costs[-1], time.time()-x_start_time)) + info = api.query_info_str_by_arch(best_arch, '200' if xargs.search_space == 'tss' else '90') + logger.log('{:}'.format(info)) + logger.log('-'*100) + logger.close() + + return logger.log_dir, [api.query_index_by_arch(x[0]) for x in trace], total_costs + + +if __name__ == '__main__': + parser = argparse.ArgumentParser("The REINFORCE Algorithm") + parser.add_argument('--data_path', type=str, help='Path to dataset') + parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') + parser.add_argument('--search_space', type=str, choices=['tss', 'sss'], help='Choose the search space.') + parser.add_argument('--learning_rate', type=float, help='The learning rate for REINFORCE.') + parser.add_argument('--EMA_momentum', type=float, default=0.9, help='The momentum value for EMA.') + parser.add_argument('--time_budget', type=int, help='The total time cost budge for searching (in seconds).') + parser.add_argument('--loops_if_rand', type=int, default=500, help='The total runs for evaluation.') + # log + parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)') + parser.add_argument('--save_dir', type=str, default='./output/search', help='Folder to save checkpoints and log.') + parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (tiny-nas-benchmark).') + parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)') + parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed') + args = parser.parse_args() + + if args.search_space == 'tss': + api = NASBench201API(verbose=False) + elif args.search_space == 'sss': + api = NASBench301API(verbose=False) + else: + raise ValueError('Invalid search space : {:}'.format(args.search_space)) + + args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, 'REINFORCE-{:}'.format(args.learning_rate)) + print('save-dir : {:}'.format(args.save_dir)) + + if args.rand_seed < 0: + save_dir, all_info = None, {} + for i in range(args.loops_if_rand): + print ('{:} : {:03d}/{:03d}'.format(time_string(), i, args.loops_if_rand)) + args.rand_seed = random.randint(1, 100000) + save_dir, all_archs, all_total_times = main(args, api) + all_info[i] = {'all_archs': all_archs, + 'all_total_times': all_total_times} + save_path = save_dir / 'results.pth' + print('save into {:}'.format(save_path)) + torch.save(all_info, save_path) + else: + main(args, api) diff --git a/exps/algos/reinforce.py b/exps/algos/reinforce.py index ddb1d57..8dbed40 100644 --- a/exps/algos/reinforce.py +++ b/exps/algos/reinforce.py @@ -184,7 +184,7 @@ def main(xargs, nas_bench): if __name__ == '__main__': - parser = argparse.ArgumentParser("Regularized Evolution Algorithm") + parser = argparse.ArgumentParser("The REINFORCE Algorithm") parser.add_argument('--data_path', type=str, help='Path to dataset') parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') # channels and number-of-cells