Update the query_by_arch function in API to be compatiable with the submission version of NAS-Bench-201
This commit is contained in:
parent
4892692622
commit
233a829bd7
@ -3,6 +3,8 @@
|
|||||||
###################################################################
|
###################################################################
|
||||||
# BOHB: Robust and Efficient Hyperparameter Optimization at Scale #
|
# BOHB: Robust and Efficient Hyperparameter Optimization at Scale #
|
||||||
# required to install hpbandster ##################################
|
# required to install hpbandster ##################################
|
||||||
|
# pip install hpbandster ##################################
|
||||||
|
###################################################################
|
||||||
# bash ./scripts-search/algos/BOHB.sh -1 ##################
|
# bash ./scripts-search/algos/BOHB.sh -1 ##################
|
||||||
###################################################################
|
###################################################################
|
||||||
import os, sys, time, random, argparse
|
import os, sys, time, random, argparse
|
||||||
@ -178,7 +180,7 @@ def main(xargs, nas_bench):
|
|||||||
logger.log('Best found configuration: {:} within {:.3f} s'.format(id2config[incumbent]['config'], real_cost_time))
|
logger.log('Best found configuration: {:} within {:.3f} s'.format(id2config[incumbent]['config'], real_cost_time))
|
||||||
best_arch = config2structure( id2config[incumbent]['config'] )
|
best_arch = config2structure( id2config[incumbent]['config'] )
|
||||||
|
|
||||||
info = nas_bench.query_by_arch( best_arch )
|
info = nas_bench.query_by_arch(best_arch, '200')
|
||||||
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
|
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
|
||||||
else : logger.log('{:}'.format(info))
|
else : logger.log('{:}'.format(info))
|
||||||
logger.log('-'*100)
|
logger.log('-'*100)
|
||||||
|
@ -199,14 +199,14 @@ def main(xargs):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
#logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
|
#logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
|
||||||
logger.log('{:}'.format(search_model.show_alphas()))
|
logger.log('{:}'.format(search_model.show_alphas()))
|
||||||
if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
|
if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200')))
|
||||||
# measure elapsed time
|
# measure elapsed time
|
||||||
epoch_time.update(time.time() - start_time)
|
epoch_time.update(time.time() - start_time)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
logger.log('\n' + '-'*100)
|
logger.log('\n' + '-'*100)
|
||||||
logger.log('DARTS-V1 : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotypes[total_epoch-1]))
|
logger.log('DARTS-V1 : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotypes[total_epoch-1]))
|
||||||
if api is not None: logger.log('{:}'.format( api.query_by_arch(genotypes[total_epoch-1]) ))
|
if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[total_epoch-1], '200')))
|
||||||
logger.close()
|
logger.close()
|
||||||
|
|
||||||
|
|
||||||
|
@ -260,7 +260,7 @@ def main(xargs):
|
|||||||
copy_checkpoint(model_base_path, model_best_path, logger)
|
copy_checkpoint(model_base_path, model_best_path, logger)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
|
logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
|
||||||
if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
|
if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200')))
|
||||||
# measure elapsed time
|
# measure elapsed time
|
||||||
epoch_time.update(time.time() - start_time)
|
epoch_time.update(time.time() - start_time)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@ -268,7 +268,7 @@ def main(xargs):
|
|||||||
logger.log('\n' + '-'*100)
|
logger.log('\n' + '-'*100)
|
||||||
# check the performance from the architecture dataset
|
# check the performance from the architecture dataset
|
||||||
logger.log('DARTS-V2 : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotypes[total_epoch-1]))
|
logger.log('DARTS-V2 : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotypes[total_epoch-1]))
|
||||||
if api is not None: logger.log('{:}'.format( api.query_by_arch(genotypes[total_epoch-1]) ))
|
if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[total_epoch-1]), '200'))
|
||||||
logger.close()
|
logger.close()
|
||||||
|
|
||||||
|
|
||||||
|
@ -295,7 +295,7 @@ def main(xargs):
|
|||||||
if find_best:
|
if find_best:
|
||||||
logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, best_valid_acc))
|
logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, best_valid_acc))
|
||||||
copy_checkpoint(model_base_path, model_best_path, logger)
|
copy_checkpoint(model_base_path, model_best_path, logger)
|
||||||
if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
|
if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200')))
|
||||||
# measure elapsed time
|
# measure elapsed time
|
||||||
epoch_time.update(time.time() - start_time)
|
epoch_time.update(time.time() - start_time)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
@ -176,7 +176,7 @@ def main(xargs):
|
|||||||
copy_checkpoint(model_base_path, model_best_path, logger)
|
copy_checkpoint(model_base_path, model_best_path, logger)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logger.log('{:}'.format(search_model.show_alphas()))
|
logger.log('{:}'.format(search_model.show_alphas()))
|
||||||
if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
|
if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200')))
|
||||||
# measure elapsed time
|
# measure elapsed time
|
||||||
epoch_time.update(time.time() - start_time)
|
epoch_time.update(time.time() - start_time)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@ -184,7 +184,7 @@ def main(xargs):
|
|||||||
logger.log('\n' + '-'*100)
|
logger.log('\n' + '-'*100)
|
||||||
# check the performance from the architecture dataset
|
# check the performance from the architecture dataset
|
||||||
logger.log('GDAS : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotypes[total_epoch-1]))
|
logger.log('GDAS : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotypes[total_epoch-1]))
|
||||||
if api is not None: logger.log('{:}'.format( api.query_by_arch(genotypes[total_epoch-1]) ))
|
if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[total_epoch-1], '200')))
|
||||||
logger.close()
|
logger.close()
|
||||||
|
|
||||||
|
|
||||||
|
@ -199,7 +199,7 @@ def main(xargs):
|
|||||||
if find_best:
|
if find_best:
|
||||||
logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1))
|
logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1))
|
||||||
copy_checkpoint(model_base_path, model_best_path, logger)
|
copy_checkpoint(model_base_path, model_best_path, logger)
|
||||||
if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
|
if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200')))
|
||||||
# measure elapsed time
|
# measure elapsed time
|
||||||
epoch_time.update(time.time() - start_time)
|
epoch_time.update(time.time() - start_time)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@ -210,7 +210,7 @@ def main(xargs):
|
|||||||
best_arch, best_acc = search_find_best(valid_loader, network, xargs.select_num)
|
best_arch, best_acc = search_find_best(valid_loader, network, xargs.select_num)
|
||||||
search_time.update(time.time() - start_time)
|
search_time.update(time.time() - start_time)
|
||||||
logger.log('RANDOM-NAS finds the best one : {:} with accuracy={:.2f}%, with {:.1f} s.'.format(best_arch, best_acc, search_time.sum))
|
logger.log('RANDOM-NAS finds the best one : {:} with accuracy={:.2f}%, with {:.1f} s.'.format(best_arch, best_acc, search_time.sum))
|
||||||
if api is not None: logger.log('{:}'.format( api.query_by_arch(best_arch) ))
|
if api is not None: logger.log('{:}'.format(api.query_by_arch(best_arch, '200')))
|
||||||
logger.close()
|
logger.close()
|
||||||
|
|
||||||
|
|
||||||
|
@ -74,7 +74,7 @@ def main(xargs, nas_bench):
|
|||||||
logger.log('[{:03d}] : {:} : accuracy = {:.2f}%'.format(len(history), arch, accuracy))
|
logger.log('[{:03d}] : {:} : accuracy = {:.2f}%'.format(len(history), arch, accuracy))
|
||||||
logger.log('{:} best arch is {:}, accuracy = {:.2f}%, visit {:} archs with {:.1f} s (real-cost = {:.3f} s).'.format(time_string(), best_arch, best_acc, len(history), total_time_cost, time.time()-x_start_time))
|
logger.log('{:} best arch is {:}, accuracy = {:.2f}%, visit {:} archs with {:.1f} s (real-cost = {:.3f} s).'.format(time_string(), best_arch, best_acc, len(history), total_time_cost, time.time()-x_start_time))
|
||||||
|
|
||||||
info = nas_bench.query_by_arch( best_arch )
|
info = nas_bench.query_by_arch(best_arch, '200')
|
||||||
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
|
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
|
||||||
else : logger.log('{:}'.format(info))
|
else : logger.log('{:}'.format(info))
|
||||||
logger.log('-'*100)
|
logger.log('-'*100)
|
||||||
|
5
exps/algos/README.md
Normal file
5
exps/algos/README.md
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
# NAS Algorithms evaluated in NAS-Bench-201
|
||||||
|
|
||||||
|
The Python files in this folder are used to re-produce the results in our NAS-Bench-201 paper.
|
||||||
|
|
||||||
|
We will upgrade the codes to be more general and extendable. The new codes are at [coming soon].
|
@ -53,7 +53,7 @@ def train_and_eval(arch, nas_bench, extra_info, dataname='cifar10-valid', use_01
|
|||||||
assert arch_index >= 0, 'can not find this arch : {:}'.format(arch)
|
assert arch_index >= 0, 'can not find this arch : {:}'.format(arch)
|
||||||
xoinfo = nas_bench.get_more_info(arch_index, 'cifar10-valid', iepoch=None, hp='12')
|
xoinfo = nas_bench.get_more_info(arch_index, 'cifar10-valid', iepoch=None, hp='12')
|
||||||
xocost = nas_bench.get_cost_info(arch_index, 'cifar10-valid', hp='200')
|
xocost = nas_bench.get_cost_info(arch_index, 'cifar10-valid', hp='200')
|
||||||
info = nas_bench.get_more_info(arch_index, dataname, nepoch, hp='200', True) # use the validation accuracy after 25 training epochs, which is used in our ICLR submission (not the camera ready).
|
info = nas_bench.get_more_info(arch_index, dataname, nepoch, hp='200', is_random=True) # use the validation accuracy after 25 training epochs, which is used in our ICLR submission (not the camera ready).
|
||||||
cost = nas_bench.get_cost_info(arch_index, dataname, hp='200')
|
cost = nas_bench.get_cost_info(arch_index, dataname, hp='200')
|
||||||
# The following codes are used to estimate the time cost.
|
# The following codes are used to estimate the time cost.
|
||||||
# When we build NAS-Bench-201, architectures are trained on different machines and we can not use that time record.
|
# When we build NAS-Bench-201, architectures are trained on different machines and we can not use that time record.
|
||||||
@ -218,7 +218,7 @@ def main(xargs, nas_bench):
|
|||||||
best_arch = best_arch.arch
|
best_arch = best_arch.arch
|
||||||
logger.log('{:} best arch is {:}'.format(time_string(), best_arch))
|
logger.log('{:} best arch is {:}'.format(time_string(), best_arch))
|
||||||
|
|
||||||
info = nas_bench.query_by_arch( best_arch )
|
info = nas_bench.query_by_arch(best_arch, '200')
|
||||||
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
|
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
|
||||||
else : logger.log('{:}'.format(info))
|
else : logger.log('{:}'.format(info))
|
||||||
logger.log('-'*100)
|
logger.log('-'*100)
|
||||||
|
@ -235,7 +235,7 @@ def main(xargs):
|
|||||||
}, logger.path('info'), logger)
|
}, logger.path('info'), logger)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logger.log('{:}'.format(search_model.show_alphas()))
|
logger.log('{:}'.format(search_model.show_alphas()))
|
||||||
if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
|
if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200')))
|
||||||
# measure elapsed time
|
# measure elapsed time
|
||||||
epoch_time.update(time.time() - start_time)
|
epoch_time.update(time.time() - start_time)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@ -251,7 +251,7 @@ def main(xargs):
|
|||||||
logger.log('\n' + '-'*100)
|
logger.log('\n' + '-'*100)
|
||||||
# check the performance from the architecture dataset
|
# check the performance from the architecture dataset
|
||||||
logger.log('SETN : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotype))
|
logger.log('SETN : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotype))
|
||||||
if api is not None: logger.log('{:}'.format( api.query_by_arch(genotype) ))
|
if api is not None: logger.log('{:}'.format(api.query_by_arch(genotype, '200') ))
|
||||||
logger.close()
|
logger.close()
|
||||||
|
|
||||||
|
|
||||||
|
@ -174,7 +174,7 @@ def main(xargs, nas_bench):
|
|||||||
# best_arch = policy.genotype() # first version
|
# best_arch = policy.genotype() # first version
|
||||||
best_arch = max(trace, key=lambda x: x[0])[1]
|
best_arch = max(trace, key=lambda x: x[0])[1]
|
||||||
logger.log('REINFORCE finish with {:} steps and {:.1f} s (real cost={:.3f}).'.format(total_steps, total_costs, time.time()-x_start_time))
|
logger.log('REINFORCE finish with {:} steps and {:.1f} s (real cost={:.3f}).'.format(total_steps, total_costs, time.time()-x_start_time))
|
||||||
info = nas_bench.query_by_arch( best_arch )
|
info = nas_bench.query_by_arch(best_arch, '200')
|
||||||
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
|
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
|
||||||
else : logger.log('{:}'.format(info))
|
else : logger.log('{:}'.format(info))
|
||||||
logger.log('-'*100)
|
logger.log('-'*100)
|
||||||
|
Loading…
Reference in New Issue
Block a user