Update REA and REINFORCE

This commit is contained in:
D-X-Y 2020-07-13 03:43:10 +00:00
parent 88a5be1368
commit 041a9aa4a3
4 changed files with 239 additions and 14 deletions

View File

@ -3,13 +3,13 @@
################################################################## ##################################################################
# Regularized Evolution for Image Classifier Architecture Search # # Regularized Evolution for Image Classifier Architecture Search #
################################################################## ##################################################################
# python ./exps/algos-v2/R_EA.py --dataset cifar10 --search_space tss --time_budget 12000 --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1 # python ./exps/algos-v2/REA.py --dataset cifar10 --search_space tss --time_budget 12000 --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
# python ./exps/algos-v2/R_EA.py --dataset cifar100 --search_space tss --time_budget 12000 --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1 # python ./exps/algos-v2/REA.py --dataset cifar100 --search_space tss --time_budget 12000 --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
# python ./exps/algos-v2/R_EA.py --dataset ImageNet16-120 --search_space tss --time_budget 12000 --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1 # python ./exps/algos-v2/REA.py --dataset ImageNet16-120 --search_space tss --time_budget 12000 --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
# python ./exps/algos-v2/R_EA.py --dataset cifar10 --search_space sss --time_budget 12000 --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1 # python ./exps/algos-v2/REA.py --dataset cifar10 --search_space sss --time_budget 12000 --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
# # python ./exps/algos-v2/REA.py --dataset cifar100 --search_space sss --time_budget 12000 --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
# # python ./exps/algos-v2/REA.py --dataset ImageNet16-120 --search_space sss --time_budget 12000 --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
# ##################################################################
import os, sys, time, glob, random, argparse import os, sys, time, glob, random, argparse
import numpy as np, collections import numpy as np, collections
from copy import deepcopy from copy import deepcopy
@ -236,12 +236,12 @@ if __name__ == '__main__':
parser.add_argument('--ea_population', type=int, help='The population size in EA.') parser.add_argument('--ea_population', type=int, help='The population size in EA.')
parser.add_argument('--ea_sample_size', type=int, help='The sample size in EA.') parser.add_argument('--ea_sample_size', type=int, help='The sample size in EA.')
parser.add_argument('--time_budget', type=int, help='The total time cost budge for searching (in seconds).') parser.add_argument('--time_budget', type=int, help='The total time cost budge for searching (in seconds).')
parser.add_argument('--loops_if_rand', type=int, default=500, help='The total runs for evaluation.')
# log # log
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)') parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
parser.add_argument('--save_dir', type=str, default='./output/search', help='Folder to save checkpoints and log.') parser.add_argument('--save_dir', type=str, default='./output/search', help='Folder to save checkpoints and log.')
parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed') parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed')
args = parser.parse_args() args = parser.parse_args()
#if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
if args.search_space == 'tss': if args.search_space == 'tss':
api = NASBench201API(verbose=False) api = NASBench201API(verbose=False)
@ -250,17 +250,19 @@ if __name__ == '__main__':
else: else:
raise ValueError('Invalid search space : {:}'.format(args.search_space)) raise ValueError('Invalid search space : {:}'.format(args.search_space))
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), 'R-EA-SS{:}'.format(args.ea_sample_size)) args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, 'R-EA-SS{:}'.format(args.ea_sample_size))
print('save-dir : {:}'.format(args.save_dir)) print('save-dir : {:}'.format(args.save_dir))
if args.rand_seed < 0: if args.rand_seed < 0:
save_dir, all_info, num = None, {}, 500 save_dir, all_info = None, {}
for i in range(num): for i in range(args.loops_if_rand):
print ('{:} : {:03d}/{:03d}'.format(time_string(), i, num)) print ('{:} : {:03d}/{:03d}'.format(time_string(), i, args.loops_if_rand))
args.rand_seed = random.randint(1, 100000) args.rand_seed = random.randint(1, 100000)
save_dir, all_archs, all_total_times = main(args, api) save_dir, all_archs, all_total_times = main(args, api)
all_info[i] = {'all_archs': all_archs, all_info[i] = {'all_archs': all_archs,
'all_total_times': all_total_times} 'all_total_times': all_total_times}
torch.save(all_info, save_dir / 'results.pth') save_path = save_dir / 'results.pth'
print('save into {:}'.format(save_path))
torch.save(all_info, save_path)
else: else:
main(args, api) main(args, api)

1
exps/algos-v2/README.md Normal file
View File

@ -0,0 +1 @@
# Benchmarking NAS Algorithms

222
exps/algos-v2/reinforce.py Normal file
View File

@ -0,0 +1,222 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
#####################################################################################################
# modified from https://github.com/pytorch/examples/blob/master/reinforcement_learning/reinforce.py #
#####################################################################################################
# python ./exps/algos-v2/reinforce.py --dataset cifar10 --search_space tss --time_budget 12000 --learning_rate 0.001
# python ./exps/algos-v2/reinforce.py --dataset cifar100 --search_space tss --time_budget 12000 --learning_rate 0.001
# python ./exps/algos-v2/reinforce.py --dataset ImageNet16-120 --search_space tss --time_budget 12000 --learning_rate 0.001
# python ./exps/algos-v2/reinforce.py --dataset cifar10 --search_space sss --time_budget 12000 --learning_rate 0.001
# python ./exps/algos-v2/reinforce.py --dataset cifar100 --search_space sss --time_budget 12000 --learning_rate 0.001
# python ./exps/algos-v2/reinforce.py --dataset ImageNet16-120 --search_space sss --time_budget 12000 --learning_rate 0.001
#####################################################################################################
import os, sys, time, glob, random, argparse
import numpy as np, collections
from copy import deepcopy
from pathlib import Path
import torch
import torch.nn as nn
from torch.distributions import Categorical
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from config_utils import load_config, dict2config, configure2str
from datasets import get_datasets, SearchDataset
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
from nas_201_api import NASBench201API, NASBench301API
from models import CellStructure, get_search_spaces
class PolicyTopology(nn.Module):
def __init__(self, search_space, max_nodes=4):
super(PolicyTopology, self).__init__()
self.max_nodes = max_nodes
self.search_space = deepcopy(search_space)
self.edge2index = {}
for i in range(1, max_nodes):
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
self.edge2index[ node_str ] = len(self.edge2index)
self.arch_parameters = nn.Parameter(1e-3*torch.randn(len(self.edge2index), len(search_space)))
def generate_arch(self, actions):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
op_name = self.search_space[ actions[ self.edge2index[ node_str ] ] ]
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
return CellStructure( genotypes )
def genotype(self):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
with torch.no_grad():
weights = self.arch_parameters[ self.edge2index[node_str] ]
op_name = self.search_space[ weights.argmax().item() ]
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
return CellStructure( genotypes )
def forward(self):
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
return alphas
class PolicySize(nn.Module):
def __init__(self, search_space):
super(PolicySize, self).__init__()
self.candidates = search_space['candidates']
self.numbers = search_space['numbers']
self.arch_parameters = nn.Parameter(1e-3*torch.randn(self.numbers, len(self.candidates)))
def generate_arch(self, actions):
channels = [str(self.candidates[i]) for i in actions]
return ':'.join(channels)
def genotype(self):
channels = []
for i in range(self.numbers):
index = self.arch_parameters[i].argmax().item()
channels.append(str(self.candidates[index]))
return ':'.join(channels)
def forward(self):
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
return alphas
class ExponentialMovingAverage(object):
"""Class that maintains an exponential moving average."""
def __init__(self, momentum):
self._numerator = 0
self._denominator = 0
self._momentum = momentum
def update(self, value):
self._numerator = self._momentum * self._numerator + (1 - self._momentum) * value
self._denominator = self._momentum * self._denominator + (1 - self._momentum)
def value(self):
"""Return the current value of the moving average"""
return self._numerator / self._denominator
def select_action(policy):
probs = policy()
m = Categorical(probs)
action = m.sample()
# policy.saved_log_probs.append(m.log_prob(action))
return m.log_prob(action), action.cpu().tolist()
def main(xargs, api):
assert torch.cuda.is_available(), 'CUDA is not available.'
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads(xargs.workers)
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
search_space = get_search_spaces(xargs.search_space, 'nas-bench-301')
if xargs.search_space == 'tss':
policy = PolicyTopology(search_space)
else:
policy = PolicySize(search_space)
optimizer = torch.optim.Adam(policy.parameters(), lr=xargs.learning_rate)
#optimizer = torch.optim.SGD(policy.parameters(), lr=xargs.learning_rate)
eps = np.finfo(np.float32).eps.item()
baseline = ExponentialMovingAverage(xargs.EMA_momentum)
logger.log('policy : {:}'.format(policy))
logger.log('optimizer : {:}'.format(optimizer))
logger.log('eps : {:}'.format(eps))
# nas dataset load
logger.log('{:} use api : {:}'.format(time_string(), api))
# REINFORCE
x_start_time = time.time()
logger.log('Will start searching with time budget of {:} s.'.format(xargs.time_budget))
total_steps, total_costs, trace = 0, [], []
while len(total_costs) == 0 or total_costs[-1] < xargs.time_budget:
start_time = time.time()
log_prob, action = select_action( policy )
arch = policy.generate_arch( action )
reward, _, current_total_cost = api.simulate_train_eval(arch, xargs.dataset, '12')
trace.append((reward, arch))
total_costs.append(current_total_cost)
baseline.update(reward)
# calculate loss
policy_loss = ( -log_prob * (reward - baseline.value()) ).sum()
optimizer.zero_grad()
policy_loss.backward()
optimizer.step()
# accumulate time
total_steps += 1
logger.log('step [{:3d}] : average-reward={:.3f} : policy_loss={:.4f} : {:}'.format(total_steps, baseline.value(), policy_loss.item(), policy.genotype()))
#logger.log('----> {:}'.format(policy.arch_parameters))
#logger.log('')
# best_arch = policy.genotype() # first version
best_arch = max(trace, key=lambda x: x[0])[1]
logger.log('REINFORCE finish with {:} steps and {:.1f} s (real cost={:.3f}).'.format(total_steps, total_costs[-1], time.time()-x_start_time))
info = api.query_info_str_by_arch(best_arch, '200' if xargs.search_space == 'tss' else '90')
logger.log('{:}'.format(info))
logger.log('-'*100)
logger.close()
return logger.log_dir, [api.query_index_by_arch(x[0]) for x in trace], total_costs
if __name__ == '__main__':
parser = argparse.ArgumentParser("The REINFORCE Algorithm")
parser.add_argument('--data_path', type=str, help='Path to dataset')
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
parser.add_argument('--search_space', type=str, choices=['tss', 'sss'], help='Choose the search space.')
parser.add_argument('--learning_rate', type=float, help='The learning rate for REINFORCE.')
parser.add_argument('--EMA_momentum', type=float, default=0.9, help='The momentum value for EMA.')
parser.add_argument('--time_budget', type=int, help='The total time cost budge for searching (in seconds).')
parser.add_argument('--loops_if_rand', type=int, default=500, help='The total runs for evaluation.')
# log
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
parser.add_argument('--save_dir', type=str, default='./output/search', help='Folder to save checkpoints and log.')
parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (tiny-nas-benchmark).')
parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)')
parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed')
args = parser.parse_args()
if args.search_space == 'tss':
api = NASBench201API(verbose=False)
elif args.search_space == 'sss':
api = NASBench301API(verbose=False)
else:
raise ValueError('Invalid search space : {:}'.format(args.search_space))
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, 'REINFORCE-{:}'.format(args.learning_rate))
print('save-dir : {:}'.format(args.save_dir))
if args.rand_seed < 0:
save_dir, all_info = None, {}
for i in range(args.loops_if_rand):
print ('{:} : {:03d}/{:03d}'.format(time_string(), i, args.loops_if_rand))
args.rand_seed = random.randint(1, 100000)
save_dir, all_archs, all_total_times = main(args, api)
all_info[i] = {'all_archs': all_archs,
'all_total_times': all_total_times}
save_path = save_dir / 'results.pth'
print('save into {:}'.format(save_path))
torch.save(all_info, save_path)
else:
main(args, api)

View File

@ -184,7 +184,7 @@ def main(xargs, nas_bench):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser("Regularized Evolution Algorithm") parser = argparse.ArgumentParser("The REINFORCE Algorithm")
parser.add_argument('--data_path', type=str, help='Path to dataset') parser.add_argument('--data_path', type=str, help='Path to dataset')
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
# channels and number-of-cells # channels and number-of-cells