Update README
This commit is contained in:
		| @@ -70,17 +70,18 @@ api.show(2) | |||||||
| # show the mean loss and accuracy of an architecture | # show the mean loss and accuracy of an architecture | ||||||
| info = api.query_meta_info_by_index(1)  # This is an instance of `ArchResults` | info = api.query_meta_info_by_index(1)  # This is an instance of `ArchResults` | ||||||
| res_metrics = info.get_metrics('cifar10', 'train') # This is a dict with metric names as keys | res_metrics = info.get_metrics('cifar10', 'train') # This is a dict with metric names as keys | ||||||
| cost_metrics = info.get_comput_costs('cifar100') # This is a dict with metric names as keys, e.g., flops, params, latency | cost_metrics = info.get_compute_costs('cifar100') # This is a dict with metric names as keys, e.g., flops, params, latency | ||||||
|  |  | ||||||
| # get the detailed information | # get the detailed information | ||||||
| results = api.query_by_index(1, 'cifar100') # a dict of all trials for 1st net on cifar100, where the key is the seed | results = api.query_by_index(1, 'cifar100') # a dict of all trials for 1st net on cifar100, where the key is the seed | ||||||
| print ('There are {:} trials for this architecture [{:}] on cifar100'.format(len(results), api[1])) | print ('There are {:} trials for this architecture [{:}] on cifar100'.format(len(results), api[1])) | ||||||
| print ('Latency : {:}'.format(results[0].get_latency())) | for seed, result in results.items(): | ||||||
| print ('Train Info : {:}'.format(results[0].get_train())) |   print ('Latency : {:}'.format(result.get_latency())) | ||||||
| print ('Valid Info : {:}'.format(results[0].get_eval('x-valid'))) |   print ('Train Info : {:}'.format(result.get_train())) | ||||||
| print ('Test  Info : {:}'.format(results[0].get_eval('x-test'))) |   print ('Valid Info : {:}'.format(result.get_eval('x-valid'))) | ||||||
| # for the metric after a specific epoch |   print ('Test  Info : {:}'.format(result.get_eval('x-test'))) | ||||||
| print ('Train Info [10-th epoch] : {:}'.format(results[0].get_train(10))) |   # for the metric after a specific epoch | ||||||
|  |   print ('Train Info [10-th epoch] : {:}'.format(result.get_train(10))) | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
| 4. Query the index of an architecture by string | 4. Query the index of an architecture by string | ||||||
|   | |||||||
| @@ -68,17 +68,18 @@ api.show(2) | |||||||
| # show the mean loss and accuracy of an architecture | # show the mean loss and accuracy of an architecture | ||||||
| info = api.query_meta_info_by_index(1)  # This is an instance of `ArchResults` | info = api.query_meta_info_by_index(1)  # This is an instance of `ArchResults` | ||||||
| res_metrics = info.get_metrics('cifar10', 'train') # This is a dict with metric names as keys | res_metrics = info.get_metrics('cifar10', 'train') # This is a dict with metric names as keys | ||||||
| cost_metrics = info.get_comput_costs('cifar100') # This is a dict with metric names as keys, e.g., flops, params, latency | cost_metrics = info.get_compute_costs('cifar100') # This is a dict with metric names as keys, e.g., flops, params, latency | ||||||
|  |  | ||||||
| # get the detailed information | # get the detailed information | ||||||
| results = api.query_by_index(1, 'cifar100') # a dict of all trials for 1st net on cifar100, where the key is the seed | results = api.query_by_index(1, 'cifar100') # a dict of all trials for 1st net on cifar100, where the key is the seed | ||||||
| print ('There are {:} trials for this architecture [{:}] on cifar100'.format(len(results), api[1])) | print ('There are {:} trials for this architecture [{:}] on cifar100'.format(len(results), api[1])) | ||||||
| print ('Latency : {:}'.format(results[0].get_latency())) | for seed, result in results.items(): | ||||||
| print ('Train Info : {:}'.format(results[0].get_train())) |   print ('Latency : {:}'.format(result.get_latency())) | ||||||
| print ('Valid Info : {:}'.format(results[0].get_eval('x-valid'))) |   print ('Train Info : {:}'.format(result.get_train())) | ||||||
| print ('Test  Info : {:}'.format(results[0].get_eval('x-test'))) |   print ('Valid Info : {:}'.format(result.get_eval('x-valid'))) | ||||||
| # for the metric after a specific epoch |   print ('Test  Info : {:}'.format(result.get_eval('x-test'))) | ||||||
| print ('Train Info [10-th epoch] : {:}'.format(results[0].get_train(10))) |   # for the metric after a specific epoch | ||||||
|  |   print ('Train Info [10-th epoch] : {:}'.format(result.get_train(10))) | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
| 4. Query the index of an architecture by string | 4. Query the index of an architecture by string | ||||||
|   | |||||||
| @@ -11,7 +11,7 @@ This facilitates a much larger community of researchers to focus on developing b | |||||||
| The structure of this Markdown file: | The structure of this Markdown file: | ||||||
| - [How to use NATS-Bench?](#How-to-Use-NATS-Bench) | - [How to use NATS-Bench?](#How-to-Use-NATS-Bench) | ||||||
| - [How to re-create NATS-Bench from scratch?](#how-to-re-create-nats-bench-from-scratch) | - [How to re-create NATS-Bench from scratch?](#how-to-re-create-nats-bench-from-scratch) | ||||||
| - [How to reproduce benchmarked results?](#to-reproduce-13-baseline-nas-algorithms-in-nas-bench-201) | - [How to reproduce benchmarked results?](#to-reproduce-13-baseline-nas-algorithms-in-nats-bench) | ||||||
|  |  | ||||||
|  |  | ||||||
| ## How to Use [NATS-Bench](https://arxiv.org/pdf/2009.00437.pdf) | ## How to Use [NATS-Bench](https://arxiv.org/pdf/2009.00437.pdf) | ||||||
| @@ -77,8 +77,12 @@ params = api.get_net_param(12, 'cifar10', None) | |||||||
| network.load_state_dict(next(iter(params.values()))) | network.load_state_dict(next(iter(params.values()))) | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| ## How to Re-create NATS-Bench from Scratch | ## How to Re-create NATS-Bench from Scratch | ||||||
|  |  | ||||||
|  | You need to use the [AutoDL-Projects](https://github.com/D-X-Y/AutoDL-Projects) repo to re-create NATS-Bench from scratch. | ||||||
|  |  | ||||||
| ### The Size Search Space | ### The Size Search Space | ||||||
|  |  | ||||||
| The following command will train all architecture candidate in the size search space with 90 epochs and use the random seed of `777`. | The following command will train all architecture candidate in the size search space with 90 epochs and use the random seed of `777`. | ||||||
| @@ -108,7 +112,9 @@ python exps/NATS-Bench/tss-collect.py | |||||||
| ``` | ``` | ||||||
|  |  | ||||||
|  |  | ||||||
| ## To Reproduce 13 Baseline NAS Algorithms in NAS-Bench-201 | ## To Reproduce 13 Baseline NAS Algorithms in NATS-Bench | ||||||
|  |  | ||||||
|  | You need to use the [AutoDL-Projects](https://github.com/D-X-Y/AutoDL-Projects) repo to run 13 baseline NAS methods. | ||||||
|  |  | ||||||
| ### Reproduce NAS methods on the topology search space | ### Reproduce NAS methods on the topology search space | ||||||
|  |  | ||||||
| @@ -169,14 +175,14 @@ python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HO | |||||||
| python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --rand_seed 777 | python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --rand_seed 777 | ||||||
|  |  | ||||||
|  |  | ||||||
| Run the search strategy in FBNet-V2 | Run the channel search strategy in FBNet-V2 | ||||||
|  |  | ||||||
| python ./exps/NATS-algos/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed 777 | python ./exps/NATS-algos/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed 777 | ||||||
| python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed 777 | python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed 777 | ||||||
| python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo fbv2 --rand_seed 777 | python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo fbv2 --rand_seed 777 | ||||||
|  |  | ||||||
|  |  | ||||||
| Run the search strategy in TuNAS: | Run the channel search strategy in TuNAS: | ||||||
|  |  | ||||||
| python ./exps/NATS-algos/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed 777 --use_api 0 | python ./exps/NATS-algos/search-size.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed 777 --use_api 0 | ||||||
| python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed 777 | python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed 777 | ||||||
|   | |||||||
| @@ -43,7 +43,7 @@ def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suf | |||||||
|     # alg2name['TAS'] = 'tas-affine0_BN0{:}'.format(suffix) |     # alg2name['TAS'] = 'tas-affine0_BN0{:}'.format(suffix) | ||||||
|     # alg2name['FBNetV2'] = 'fbv2-affine0_BN0{:}'.format(suffix) |     # alg2name['FBNetV2'] = 'fbv2-affine0_BN0{:}'.format(suffix) | ||||||
|     # alg2name['TuNAS'] = 'tunas-affine0_BN0{:}'.format(suffix) |     # alg2name['TuNAS'] = 'tunas-affine0_BN0{:}'.format(suffix) | ||||||
|     alg2name['channel-wise interpaltion'] = 'tas-affine0_BN0-AWD0.001{:}'.format(suffix) |     alg2name['channel-wise interpolation'] = 'tas-affine0_BN0-AWD0.001{:}'.format(suffix) | ||||||
|     alg2name['masking + Gumbel-Softmax'] = 'fbv2-affine0_BN0-AWD0.001{:}'.format(suffix) |     alg2name['masking + Gumbel-Softmax'] = 'fbv2-affine0_BN0-AWD0.001{:}'.format(suffix) | ||||||
|     alg2name['masking + sampling'] = 'tunas-affine0_BN0-AWD0.0{:}'.format(suffix) |     alg2name['masking + sampling'] = 'tunas-affine0_BN0-AWD0.0{:}'.format(suffix) | ||||||
|   for alg, name in alg2name.items(): |   for alg, name in alg2name.items(): | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user