diff --git a/exps/algos/BOHB.py b/exps/algos/BOHB.py index 18e9c5d..f4c6e50 100644 --- a/exps/algos/BOHB.py +++ b/exps/algos/BOHB.py @@ -3,6 +3,8 @@ ################################################################### # BOHB: Robust and Efficient Hyperparameter Optimization at Scale # # required to install hpbandster ################################## +# pip install hpbandster ################################## +################################################################### # bash ./scripts-search/algos/BOHB.sh -1 ################## ################################################################### import os, sys, time, random, argparse @@ -178,7 +180,7 @@ def main(xargs, nas_bench): logger.log('Best found configuration: {:} within {:.3f} s'.format(id2config[incumbent]['config'], real_cost_time)) best_arch = config2structure( id2config[incumbent]['config'] ) - info = nas_bench.query_by_arch( best_arch ) + info = nas_bench.query_by_arch(best_arch, '200') if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch)) else : logger.log('{:}'.format(info)) logger.log('-'*100) diff --git a/exps/algos/DARTS-V1.py b/exps/algos/DARTS-V1.py index 32f4b6b..705bc0d 100644 --- a/exps/algos/DARTS-V1.py +++ b/exps/algos/DARTS-V1.py @@ -199,14 +199,14 @@ def main(xargs): with torch.no_grad(): #logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() )) logger.log('{:}'.format(search_model.show_alphas())) - if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] ))) + if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200'))) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() logger.log('\n' + '-'*100) logger.log('DARTS-V1 : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotypes[total_epoch-1])) - if api is not None: logger.log('{:}'.format( api.query_by_arch(genotypes[total_epoch-1]) )) + if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[total_epoch-1], '200'))) logger.close() diff --git a/exps/algos/DARTS-V2.py b/exps/algos/DARTS-V2.py index beec424..798ad85 100644 --- a/exps/algos/DARTS-V2.py +++ b/exps/algos/DARTS-V2.py @@ -260,7 +260,7 @@ def main(xargs): copy_checkpoint(model_base_path, model_best_path, logger) with torch.no_grad(): logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() )) - if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] ))) + if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200'))) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() @@ -268,7 +268,7 @@ def main(xargs): logger.log('\n' + '-'*100) # check the performance from the architecture dataset logger.log('DARTS-V2 : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotypes[total_epoch-1])) - if api is not None: logger.log('{:}'.format( api.query_by_arch(genotypes[total_epoch-1]) )) + if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[total_epoch-1]), '200')) logger.close() diff --git a/exps/algos/ENAS.py b/exps/algos/ENAS.py index 8487dfa..0af87ce 100644 --- a/exps/algos/ENAS.py +++ b/exps/algos/ENAS.py @@ -295,7 +295,7 @@ def main(xargs): if find_best: logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, best_valid_acc)) copy_checkpoint(model_base_path, model_best_path, logger) - if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] ))) + if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200'))) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() diff --git a/exps/algos/GDAS.py b/exps/algos/GDAS.py index bd82c34..329073c 100644 --- a/exps/algos/GDAS.py +++ b/exps/algos/GDAS.py @@ -176,7 +176,7 @@ def main(xargs): copy_checkpoint(model_base_path, model_best_path, logger) with torch.no_grad(): logger.log('{:}'.format(search_model.show_alphas())) - if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] ))) + if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200'))) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() @@ -184,7 +184,7 @@ def main(xargs): logger.log('\n' + '-'*100) # check the performance from the architecture dataset logger.log('GDAS : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotypes[total_epoch-1])) - if api is not None: logger.log('{:}'.format( api.query_by_arch(genotypes[total_epoch-1]) )) + if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[total_epoch-1], '200'))) logger.close() diff --git a/exps/algos/RANDOM-NAS.py b/exps/algos/RANDOM-NAS.py index cd865a6..78eddda 100644 --- a/exps/algos/RANDOM-NAS.py +++ b/exps/algos/RANDOM-NAS.py @@ -199,7 +199,7 @@ def main(xargs): if find_best: logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1)) copy_checkpoint(model_base_path, model_best_path, logger) - if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] ))) + if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200'))) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() @@ -210,7 +210,7 @@ def main(xargs): best_arch, best_acc = search_find_best(valid_loader, network, xargs.select_num) search_time.update(time.time() - start_time) logger.log('RANDOM-NAS finds the best one : {:} with accuracy={:.2f}%, with {:.1f} s.'.format(best_arch, best_acc, search_time.sum)) - if api is not None: logger.log('{:}'.format( api.query_by_arch(best_arch) )) + if api is not None: logger.log('{:}'.format(api.query_by_arch(best_arch, '200'))) logger.close() diff --git a/exps/algos/RANDOM.py b/exps/algos/RANDOM.py index e486911..58af886 100644 --- a/exps/algos/RANDOM.py +++ b/exps/algos/RANDOM.py @@ -74,7 +74,7 @@ def main(xargs, nas_bench): logger.log('[{:03d}] : {:} : accuracy = {:.2f}%'.format(len(history), arch, accuracy)) logger.log('{:} best arch is {:}, accuracy = {:.2f}%, visit {:} archs with {:.1f} s (real-cost = {:.3f} s).'.format(time_string(), best_arch, best_acc, len(history), total_time_cost, time.time()-x_start_time)) - info = nas_bench.query_by_arch( best_arch ) + info = nas_bench.query_by_arch(best_arch, '200') if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch)) else : logger.log('{:}'.format(info)) logger.log('-'*100) diff --git a/exps/algos/README.md b/exps/algos/README.md new file mode 100644 index 0000000..e929b83 --- /dev/null +++ b/exps/algos/README.md @@ -0,0 +1,5 @@ +# NAS Algorithms evaluated in NAS-Bench-201 + +The Python files in this folder are used to re-produce the results in our NAS-Bench-201 paper. + +We will upgrade the codes to be more general and extendable. The new codes are at [coming soon]. diff --git a/exps/algos/R_EA.py b/exps/algos/R_EA.py index b507bba..ddfcde8 100644 --- a/exps/algos/R_EA.py +++ b/exps/algos/R_EA.py @@ -53,7 +53,7 @@ def train_and_eval(arch, nas_bench, extra_info, dataname='cifar10-valid', use_01 assert arch_index >= 0, 'can not find this arch : {:}'.format(arch) xoinfo = nas_bench.get_more_info(arch_index, 'cifar10-valid', iepoch=None, hp='12') xocost = nas_bench.get_cost_info(arch_index, 'cifar10-valid', hp='200') - info = nas_bench.get_more_info(arch_index, dataname, nepoch, hp='200', True) # use the validation accuracy after 25 training epochs, which is used in our ICLR submission (not the camera ready). + info = nas_bench.get_more_info(arch_index, dataname, nepoch, hp='200', is_random=True) # use the validation accuracy after 25 training epochs, which is used in our ICLR submission (not the camera ready). cost = nas_bench.get_cost_info(arch_index, dataname, hp='200') # The following codes are used to estimate the time cost. # When we build NAS-Bench-201, architectures are trained on different machines and we can not use that time record. @@ -218,7 +218,7 @@ def main(xargs, nas_bench): best_arch = best_arch.arch logger.log('{:} best arch is {:}'.format(time_string(), best_arch)) - info = nas_bench.query_by_arch( best_arch ) + info = nas_bench.query_by_arch(best_arch, '200') if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch)) else : logger.log('{:}'.format(info)) logger.log('-'*100) diff --git a/exps/algos/SETN.py b/exps/algos/SETN.py index 0766cb7..038a65d 100644 --- a/exps/algos/SETN.py +++ b/exps/algos/SETN.py @@ -235,7 +235,7 @@ def main(xargs): }, logger.path('info'), logger) with torch.no_grad(): logger.log('{:}'.format(search_model.show_alphas())) - if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] ))) + if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200'))) # measure elapsed time epoch_time.update(time.time() - start_time) start_time = time.time() @@ -251,7 +251,7 @@ def main(xargs): logger.log('\n' + '-'*100) # check the performance from the architecture dataset logger.log('SETN : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotype)) - if api is not None: logger.log('{:}'.format( api.query_by_arch(genotype) )) + if api is not None: logger.log('{:}'.format(api.query_by_arch(genotype, '200') )) logger.close() diff --git a/exps/algos/reinforce.py b/exps/algos/reinforce.py index e836aea..ddb1d57 100644 --- a/exps/algos/reinforce.py +++ b/exps/algos/reinforce.py @@ -174,7 +174,7 @@ def main(xargs, nas_bench): # best_arch = policy.genotype() # first version best_arch = max(trace, key=lambda x: x[0])[1] logger.log('REINFORCE finish with {:} steps and {:.1f} s (real cost={:.3f}).'.format(total_steps, total_costs, time.time()-x_start_time)) - info = nas_bench.query_by_arch( best_arch ) + info = nas_bench.query_by_arch(best_arch, '200') if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch)) else : logger.log('{:}'.format(info)) logger.log('-'*100)