Update the query_by_arch function in API to be compatiable with the submission version of NAS-Bench-201

This commit is contained in:
D-X-Y 2020-07-08 04:46:25 +00:00
parent 4892692622
commit 233a829bd7
11 changed files with 23 additions and 16 deletions

View File

@ -3,6 +3,8 @@
###################################################################
# BOHB: Robust and Efficient Hyperparameter Optimization at Scale #
# required to install hpbandster ##################################
# pip install hpbandster ##################################
###################################################################
# bash ./scripts-search/algos/BOHB.sh -1 ##################
###################################################################
import os, sys, time, random, argparse
@ -178,7 +180,7 @@ def main(xargs, nas_bench):
logger.log('Best found configuration: {:} within {:.3f} s'.format(id2config[incumbent]['config'], real_cost_time))
best_arch = config2structure( id2config[incumbent]['config'] )
info = nas_bench.query_by_arch( best_arch )
info = nas_bench.query_by_arch(best_arch, '200')
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
else : logger.log('{:}'.format(info))
logger.log('-'*100)

View File

@ -199,14 +199,14 @@ def main(xargs):
with torch.no_grad():
#logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
logger.log('{:}'.format(search_model.show_alphas()))
if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200')))
# measure elapsed time
epoch_time.update(time.time() - start_time)
start_time = time.time()
logger.log('\n' + '-'*100)
logger.log('DARTS-V1 : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotypes[total_epoch-1]))
if api is not None: logger.log('{:}'.format( api.query_by_arch(genotypes[total_epoch-1]) ))
if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[total_epoch-1], '200')))
logger.close()

View File

@ -260,7 +260,7 @@ def main(xargs):
copy_checkpoint(model_base_path, model_best_path, logger)
with torch.no_grad():
logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200')))
# measure elapsed time
epoch_time.update(time.time() - start_time)
start_time = time.time()
@ -268,7 +268,7 @@ def main(xargs):
logger.log('\n' + '-'*100)
# check the performance from the architecture dataset
logger.log('DARTS-V2 : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotypes[total_epoch-1]))
if api is not None: logger.log('{:}'.format( api.query_by_arch(genotypes[total_epoch-1]) ))
if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[total_epoch-1]), '200'))
logger.close()

View File

@ -295,7 +295,7 @@ def main(xargs):
if find_best:
logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, best_valid_acc))
copy_checkpoint(model_base_path, model_best_path, logger)
if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200')))
# measure elapsed time
epoch_time.update(time.time() - start_time)
start_time = time.time()

View File

@ -176,7 +176,7 @@ def main(xargs):
copy_checkpoint(model_base_path, model_best_path, logger)
with torch.no_grad():
logger.log('{:}'.format(search_model.show_alphas()))
if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200')))
# measure elapsed time
epoch_time.update(time.time() - start_time)
start_time = time.time()
@ -184,7 +184,7 @@ def main(xargs):
logger.log('\n' + '-'*100)
# check the performance from the architecture dataset
logger.log('GDAS : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotypes[total_epoch-1]))
if api is not None: logger.log('{:}'.format( api.query_by_arch(genotypes[total_epoch-1]) ))
if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[total_epoch-1], '200')))
logger.close()

View File

@ -199,7 +199,7 @@ def main(xargs):
if find_best:
logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1))
copy_checkpoint(model_base_path, model_best_path, logger)
if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200')))
# measure elapsed time
epoch_time.update(time.time() - start_time)
start_time = time.time()
@ -210,7 +210,7 @@ def main(xargs):
best_arch, best_acc = search_find_best(valid_loader, network, xargs.select_num)
search_time.update(time.time() - start_time)
logger.log('RANDOM-NAS finds the best one : {:} with accuracy={:.2f}%, with {:.1f} s.'.format(best_arch, best_acc, search_time.sum))
if api is not None: logger.log('{:}'.format( api.query_by_arch(best_arch) ))
if api is not None: logger.log('{:}'.format(api.query_by_arch(best_arch, '200')))
logger.close()

View File

@ -74,7 +74,7 @@ def main(xargs, nas_bench):
logger.log('[{:03d}] : {:} : accuracy = {:.2f}%'.format(len(history), arch, accuracy))
logger.log('{:} best arch is {:}, accuracy = {:.2f}%, visit {:} archs with {:.1f} s (real-cost = {:.3f} s).'.format(time_string(), best_arch, best_acc, len(history), total_time_cost, time.time()-x_start_time))
info = nas_bench.query_by_arch( best_arch )
info = nas_bench.query_by_arch(best_arch, '200')
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
else : logger.log('{:}'.format(info))
logger.log('-'*100)

5
exps/algos/README.md Normal file
View File

@ -0,0 +1,5 @@
# NAS Algorithms evaluated in NAS-Bench-201
The Python files in this folder are used to re-produce the results in our NAS-Bench-201 paper.
We will upgrade the codes to be more general and extendable. The new codes are at [coming soon].

View File

@ -53,7 +53,7 @@ def train_and_eval(arch, nas_bench, extra_info, dataname='cifar10-valid', use_01
assert arch_index >= 0, 'can not find this arch : {:}'.format(arch)
xoinfo = nas_bench.get_more_info(arch_index, 'cifar10-valid', iepoch=None, hp='12')
xocost = nas_bench.get_cost_info(arch_index, 'cifar10-valid', hp='200')
info = nas_bench.get_more_info(arch_index, dataname, nepoch, hp='200', True) # use the validation accuracy after 25 training epochs, which is used in our ICLR submission (not the camera ready).
info = nas_bench.get_more_info(arch_index, dataname, nepoch, hp='200', is_random=True) # use the validation accuracy after 25 training epochs, which is used in our ICLR submission (not the camera ready).
cost = nas_bench.get_cost_info(arch_index, dataname, hp='200')
# The following codes are used to estimate the time cost.
# When we build NAS-Bench-201, architectures are trained on different machines and we can not use that time record.
@ -218,7 +218,7 @@ def main(xargs, nas_bench):
best_arch = best_arch.arch
logger.log('{:} best arch is {:}'.format(time_string(), best_arch))
info = nas_bench.query_by_arch( best_arch )
info = nas_bench.query_by_arch(best_arch, '200')
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
else : logger.log('{:}'.format(info))
logger.log('-'*100)

View File

@ -235,7 +235,7 @@ def main(xargs):
}, logger.path('info'), logger)
with torch.no_grad():
logger.log('{:}'.format(search_model.show_alphas()))
if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200')))
# measure elapsed time
epoch_time.update(time.time() - start_time)
start_time = time.time()
@ -251,7 +251,7 @@ def main(xargs):
logger.log('\n' + '-'*100)
# check the performance from the architecture dataset
logger.log('SETN : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotype))
if api is not None: logger.log('{:}'.format( api.query_by_arch(genotype) ))
if api is not None: logger.log('{:}'.format(api.query_by_arch(genotype, '200') ))
logger.close()

View File

@ -174,7 +174,7 @@ def main(xargs, nas_bench):
# best_arch = policy.genotype() # first version
best_arch = max(trace, key=lambda x: x[0])[1]
logger.log('REINFORCE finish with {:} steps and {:.1f} s (real cost={:.3f}).'.format(total_steps, total_costs, time.time()-x_start_time))
info = nas_bench.query_by_arch( best_arch )
info = nas_bench.query_by_arch(best_arch, '200')
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
else : logger.log('{:}'.format(info))
logger.log('-'*100)