102->201 / NAS->autoDL / more configs of TAS / reorganize docs / fix bugs in NAS baselines
This commit is contained in:
parent
33384a78af
commit
bb2f405961
229
README.md
229
README.md
@ -1,159 +1,92 @@
|
|||||||
# Neural Architecture Search (NAS)
|
# Auto Deep Learning (AutoDL)
|
||||||
|
|
||||||
This project contains the following neural architecture search (NAS) algorithms, implemented in [PyTorch](http://pytorch.org).
|
---------
|
||||||
More NAS resources can be found in [Awesome-NAS](https://github.com/D-X-Y/Awesome-NAS).
|
[](LICENSE.md)
|
||||||
|
|
||||||
- NAS-Bench-102: Extending the Scope of Reproducible Neural Architecture Search, ICLR 2020
|
Auto Deep Learning by DXY (AutoDL-Projects) is an open source, lightweight, but useful project for researchers.
|
||||||
- Network Pruning via Transformable Architecture Search, NeurIPS 2019
|
In this project, Xuanyi Dong implemented several neural architecture search (NAS) and hyper-parameter optimization (HPO) algorithms.
|
||||||
- One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019
|
He hopes to build it as an easy-to-use AutoDL toolkit in future.
|
||||||
- Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019
|
|
||||||
- 10 NAS algorithms for the neural topology in `exps/algos` (see [NAS-Bench-102.md](https://github.com/D-X-Y/NAS-Projects/blob/master/NAS-Bench-102.md) for more details)
|
## **Who should consider using AutoDL-Projects**
|
||||||
- Several typical classification models, e.g., ResNet and DenseNet (see [BASELINE.md](https://github.com/D-X-Y/NAS-Projects/blob/master/BASELINE.md))
|
|
||||||
|
- Beginner who want to **try different AutoDL algorithms** for study
|
||||||
|
- Engineer who want to **try AutoDL** to investigate whether AutoDL works on your projects
|
||||||
|
- Researchers who want to **easily** implement and experiement **new** AutoDL algorithms.
|
||||||
|
|
||||||
|
## **Why should we use AutoDL-Projects**
|
||||||
|
- Simplest library dependencies: each examlpe is purely relied on PyTorch or Tensorflow (except for some basic libraries in Anaconda)
|
||||||
|
- All algorithms are in the same codebase. If you implement new algorithms, it is easy to fairly compare with many other baselines.
|
||||||
|
- I will actively support this project, because all my furture AutoDL research will be built upon this project.
|
||||||
|
|
||||||
|
|
||||||
|
## AutoDL-Projects Capabilities
|
||||||
|
|
||||||
|
At the moment, this project provides the following algorithms and scripts to run them. Please see the details in the link provided in the description column.
|
||||||
|
|
||||||
|
|
||||||
|
<table>
|
||||||
|
<tbody>
|
||||||
|
<tr align="center" valign="bottom">
|
||||||
|
<th>Type</th>
|
||||||
|
<th>Algorithms</th>
|
||||||
|
<th>Description</th>
|
||||||
|
</tr>
|
||||||
|
<tr> <!-- (1-st row) -->
|
||||||
|
<td rowspan="5" align="center" valign="middle" halign="middle"> NAS </td>
|
||||||
|
<td align="center" valign="middle"> Network Pruning via Transformable Architecture Search </td>
|
||||||
|
<td align="center" valign="middle"> <a href="https://github.com/D-X-Y/AutoDL-Projects/tree/master/docs/NIPS-2019-TAS.md">NIPS-2019-TAS.md</a> </td>
|
||||||
|
</tr>
|
||||||
|
<tr> <!-- (2-nd row) -->
|
||||||
|
<td align="center" valign="middle"> Searching for A Robust Neural Architecture in Four GPU Hours </td>
|
||||||
|
<td align="center" valign="middle"> <a href="https://github.com/D-X-Y/AutoDL-Projects/tree/master/docs/CVPR-2019-GDAS.md">CVPR-2019-GDAS.md</a> </td>
|
||||||
|
</tr>
|
||||||
|
<tr> <!-- (3-rd row) -->
|
||||||
|
<td align="center" valign="middle"> One-Shot Neural Architecture Search via Self-Evaluated Template Network </td>
|
||||||
|
<td align="center" valign="middle"> <a href="https://github.com/D-X-Y/AutoDL-Projects/tree/master/docs/ICCV-2019-SETN.py">ICCV-2019-SETN.py</a> </td>
|
||||||
|
</tr>
|
||||||
|
<tr> <!-- (4-th row) -->
|
||||||
|
<td align="center" valign="middle"> NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search </td>
|
||||||
|
<td align="center" valign="middle"> <a href="https://github.com/D-X-Y/AutoDL-Projects/tree/master/docs/NAS-Bench-201.md">NAS-Bench-201.md</a> </td>
|
||||||
|
</tr>
|
||||||
|
<tr> <!-- (5-th row) -->
|
||||||
|
<td align="center" valign="middle"> ENAS / DARTS / REA / REINFORCE / BOHB </td>
|
||||||
|
<td align="center" valign="middle"> <a href="https://github.com/D-X-Y/AutoDL-Projects/tree/master/docs/NAS-Bench-201.md">NAS-Bench-201.md</a> </td>
|
||||||
|
</tr>
|
||||||
|
<tr> <!-- (start second block) -->
|
||||||
|
<td rowspan="1" align="center" valign="middle" halign="middle"> HPO </td>
|
||||||
|
<td align="center" valign="middle"> coming soon </td>
|
||||||
|
<td align="center" valign="middle"> coming soon </a> </td>
|
||||||
|
</tr>
|
||||||
|
<tr> <!-- (start third block) -->
|
||||||
|
<td rowspan="1" align="center" valign="middle" halign="middle"> Basic </td>
|
||||||
|
<td align="center" valign="middle"> Deep Learning-based Image Classification </td>
|
||||||
|
<td align="center" valign="middle"> <a href="https://github.com/D-X-Y/AutoDL-Projects/tree/master/docs/BASELINE.md">BASELINE.md</a> </a> </td>
|
||||||
|
</tr>
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
|
||||||
|
|
||||||
|
## History of this repo
|
||||||
|
|
||||||
|
At first, this repo is `GDAS`, which is used to reproduce results in Searching for A Robust Neural Architecture in Four GPU Hours.
|
||||||
|
After that, more functions and more NAS algorithms are continuely added in this repo. After it supports more than five algorithms, it is upgraded from `GDAS` to `NAS-Project`.
|
||||||
|
Now, since both HPO and NAS are supported in this repo, it is upgraded from `NAS-Project` to `AutoDL-Projects`.
|
||||||
|
|
||||||
|
|
||||||
## Requirements and Preparation
|
## Requirements and Preparation
|
||||||
|
|
||||||
Please install `PyTorch>=1.2.0`, `Python>=3.6`, and `opencv`.
|
Please install `Python>=3.6` and `PyTorch>=1.3.0`. (You could also run this project in lower versions of Python and PyTorch, but may have bugs).
|
||||||
|
Some visualization codes may require `opencv`.
|
||||||
|
|
||||||
CIFAR and ImageNet should be downloaded and extracted into `$TORCH_HOME`.
|
CIFAR and ImageNet should be downloaded and extracted into `$TORCH_HOME`.
|
||||||
Some methods use knowledge distillation (KD), which require pre-trained models. Please download these models from [Google Driver](https://drive.google.com/open?id=1ANmiYEGX-IQZTfH8w0aSpj-Wypg-0DR-) (or train by yourself) and save into `.latent-data`.
|
Some methods use knowledge distillation (KD), which require pre-trained models. Please download these models from [Google Drive](https://drive.google.com/open?id=1ANmiYEGX-IQZTfH8w0aSpj-Wypg-0DR-) (or train by yourself) and save into `.latent-data`.
|
||||||
|
|
||||||
### Usefull tools
|
## Citation
|
||||||
1. Compute the number of parameters and FLOPs of a model:
|
|
||||||
```
|
|
||||||
from utils import get_model_infos
|
|
||||||
flop, param = get_model_infos(net, (1,3,32,32))
|
|
||||||
```
|
|
||||||
|
|
||||||
2. Different NAS-searched architectures are defined [here](https://github.com/D-X-Y/NAS-Projects/blob/master/lib/nas_infer_model/DXYs/genotypes.py).
|
|
||||||
|
|
||||||
|
|
||||||
## [NAS-Bench-102: Extending the Scope of Reproducible Neural Architecture Search](https://openreview.net/forum?id=HJxyZkBKDr)
|
|
||||||
|
|
||||||
We build a new benchmark for neural architecture search, please see more details in [NAS-Bench-102.md](https://github.com/D-X-Y/NAS-Projects/blob/master/NAS-Bench-102.md).
|
|
||||||
|
|
||||||
The benchmark data file (v1.0) is `NAS-Bench-102-v1_0-e61699.pth`, which can be downloaded from [Google Drive](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs).
|
|
||||||
|
|
||||||
Now you can simply use our API by `pip install nas-bench-102`.
|
|
||||||
|
|
||||||
## [Network Pruning via Transformable Architecture Search](https://arxiv.org/abs/1905.09717)
|
|
||||||
[](https://paperswithcode.com/sota/network-pruning-on-cifar-100?p=network-pruning-via-transformable)
|
|
||||||
|
|
||||||
In this paper, we proposed a differentiable searching strategy for transformable architectures, i.e., searching for the depth and width of a deep neural network.
|
|
||||||
You could see the highlight of our Transformable Architecture Search (TAS) at our [project page](https://xuanyidong.com/assets/projects/NeurIPS-2019-TAS.html).
|
|
||||||
|
|
||||||
<p float="left">
|
|
||||||
<img src="https://d-x-y.github.com/resources/paper-icon/NIPS-2019-TAS.png" width="680px"/>
|
|
||||||
<img src="https://d-x-y.github.com/resources/videos/NeurIPS-2019-TAS/TAS-arch.gif?raw=true" width="180px"/>
|
|
||||||
</p>
|
|
||||||
|
|
||||||
|
|
||||||
### Usage
|
|
||||||
|
|
||||||
Use `bash ./scripts/prepare.sh` to prepare data splits for `CIFAR-10`, `CIFARR-100`, and `ILSVRC2012`.
|
|
||||||
If you do not have `ILSVRC2012` data, pleasee comment L12 in `./scripts/prepare.sh`.
|
|
||||||
|
|
||||||
args: `cifar10` indicates the dataset name, `ResNet56` indicates the basemodel name, `CIFARX` indicates the searching hyper-parameters, `0.47/0.57` indicates the expected FLOP ratio, `-1` indicates the random seed.
|
|
||||||
|
|
||||||
#### Search for the depth configuration of ResNet:
|
|
||||||
```
|
|
||||||
CUDA_VISIBLE_DEVICES=0,1 bash ./scripts-search/search-depth-gumbel.sh cifar10 ResNet110 CIFARX 0.57 -1
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Search for the width configuration of ResNet:
|
|
||||||
```
|
|
||||||
CUDA_VISIBLE_DEVICES=0,1 bash ./scripts-search/search-width-gumbel.sh cifar10 ResNet110 CIFARX 0.57 -1
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Search for both depth and width configuration of ResNet:
|
|
||||||
```
|
|
||||||
CUDA_VISIBLE_DEVICES=0,1 bash ./scripts-search/search-shape-cifar.sh cifar10 ResNet56 CIFARX 0.47 -1
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Training the searched shape config from TAS
|
|
||||||
If you want to directly train a model with searched configuration of TAS, try these:
|
|
||||||
```
|
|
||||||
CUDA_VISIBLE_DEVICES=0,1 bash ./scripts/tas-infer-train.sh cifar10 C010-ResNet32 -1
|
|
||||||
CUDA_VISIBLE_DEVICES=0,1 bash ./scripts/tas-infer-train.sh cifar100 C100-ResNet32 -1
|
|
||||||
```
|
|
||||||
|
|
||||||
### Model Configuration
|
|
||||||
The searched shapes for ResNet-20/32/56/110/164 in Table 3 in the original paper are listed in [`configs/NeurIPS-2019`](https://github.com/D-X-Y/NAS-Projects/tree/master/configs/NeurIPS-2019).
|
|
||||||
|
|
||||||
|
|
||||||
## [One-Shot Neural Architecture Search via Self-Evaluated Template Network](https://arxiv.org/abs/1910.05733)
|
|
||||||
|
|
||||||
<img align="right" src="https://d-x-y.github.com/resources/paper-icon/ICCV-2019-SETN.png" width="450">
|
|
||||||
|
|
||||||
<strong>Highlight</strong>: we equip one-shot NAS with an architecture sampler and train network weights using uniformly sampling.
|
|
||||||
|
|
||||||
|
|
||||||
### Usage
|
|
||||||
|
|
||||||
Please use the following scripts to train the searched SETN-searched CNN on CIFAR-10, CIFAR-100, and ImageNet.
|
|
||||||
```
|
|
||||||
CUDA_VISIBLE_DEVICES=0 bash ./scripts/nas-infer-train.sh cifar10 SETN 96 -1
|
|
||||||
CUDA_VISIBLE_DEVICES=0 bash ./scripts/nas-infer-train.sh cifar100 SETN 96 -1
|
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./scripts/nas-infer-train.sh imagenet-1k SETN 256 -1
|
|
||||||
```
|
|
||||||
|
|
||||||
The searching codes of SETN on a small search space:
|
|
||||||
```
|
|
||||||
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/SETN.sh cifar10 -1
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
## [Searching for A Robust Neural Architecture in Four GPU Hours](https://arxiv.org/abs/1910.04465)
|
|
||||||
|
|
||||||
|
|
||||||
<img align="right" src="https://d-x-y.github.com/resources/paper-icon/CVPR-2019-GDAS.png" width="300">
|
|
||||||
|
|
||||||
We proposed a Gradient-based searching algorithm using Differentiable Architecture Sampling (GDAS). GDAS is baseed on DARTS and improves it with Gumbel-softmax sampling.
|
|
||||||
Experiments on CIFAR-10, CIFAR-100, ImageNet, PTB, and WT2 are reported.
|
|
||||||
|
|
||||||
|
|
||||||
### Usage
|
|
||||||
|
|
||||||
#### Reproducing the results of our searched architecture in GDAS
|
|
||||||
Please use the following scripts to train the searched GDAS-searched CNN on CIFAR-10, CIFAR-100, and ImageNet.
|
|
||||||
```
|
|
||||||
CUDA_VISIBLE_DEVICES=0 bash ./scripts/nas-infer-train.sh cifar10 GDAS_V1 96 -1
|
|
||||||
CUDA_VISIBLE_DEVICES=0 bash ./scripts/nas-infer-train.sh cifar100 GDAS_V1 96 -1
|
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./scripts/nas-infer-train.sh imagenet-1k GDAS_V1 256 -1
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Searching on the NASNet search space
|
|
||||||
Please use the following scripts to use GDAS to search as in the original paper:
|
|
||||||
```
|
|
||||||
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/GDAS-search-NASNet-space.sh cifar10 1 -1
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Searching on a small search space (NAS-Bench-102)
|
|
||||||
The GDAS searching codes on a small search space:
|
|
||||||
```
|
|
||||||
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/GDAS.sh cifar10 -1
|
|
||||||
```
|
|
||||||
|
|
||||||
The baseline searching codes are DARTS:
|
|
||||||
```
|
|
||||||
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V1.sh cifar10 -1
|
|
||||||
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V2.sh cifar10 -1
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Training the searched architecture
|
|
||||||
To train the searched architecture found by the above scripts, please use the following codes:
|
|
||||||
```
|
|
||||||
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/NAS-Bench-102/train-a-net.sh '|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|skip_connect~1|skip_connect~2|' 16 5
|
|
||||||
```
|
|
||||||
`|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|skip_connect~1|skip_connect~2|` represents the structure of a searched architecture. My codes will automatically print it during the searching procedure.
|
|
||||||
|
|
||||||
|
|
||||||
# Citation
|
|
||||||
|
|
||||||
If you find that this project helps your research, please consider citing some of the following papers:
|
If you find that this project helps your research, please consider citing some of the following papers:
|
||||||
```
|
```
|
||||||
@inproceedings{dong2020nasbench102,
|
@inproceedings{dong2020nasbench201,
|
||||||
title = {NAS-Bench-102: Extending the Scope of Reproducible Neural Architecture Search},
|
title = {NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search},
|
||||||
author = {Dong, Xuanyi and Yang, Yi},
|
author = {Dong, Xuanyi and Yang, Yi},
|
||||||
booktitle = {International Conference on Learning Representations (ICLR)},
|
booktitle = {International Conference on Learning Representations (ICLR)},
|
||||||
url = {https://openreview.net/forum?id=HJxyZkBKDr},
|
url = {https://openreview.net/forum?id=HJxyZkBKDr},
|
||||||
@ -180,3 +113,11 @@ If you find that this project helps your research, please consider citing some o
|
|||||||
year = {2019}
|
year = {2019}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Related Projects
|
||||||
|
|
||||||
|
- [Awesome-NAS](https://github.com/D-X-Y/Awesome-NAS) : A curated list of neural architecture search and related resources.
|
||||||
|
- [AutoML Freiburg-Hannover](https://www.automl.org/) : A website maintained by Frank Hutter's team, containing many AutoML resources.
|
||||||
|
|
||||||
|
# License
|
||||||
|
The entire codebase is under [MIT license](LICENSE.md)
|
||||||
|
14
configs/NeurIPS-2019/ImageNet-ResNet18V1.config
Normal file
14
configs/NeurIPS-2019/ImageNet-ResNet18V1.config
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
{
|
||||||
|
"dataset" : ["str" , "imagenet"],
|
||||||
|
"arch" : ["str" , "resnet"],
|
||||||
|
"block_name" : ["str" , "BasicBlock"],
|
||||||
|
"layers" : ["int" , ["2", "2", "2", "2"]],
|
||||||
|
"deep_stem" : ["bool" , "0"],
|
||||||
|
"zero_init_residual" : ["bool" , "1"],
|
||||||
|
"class_num" : ["int" , "1000"],
|
||||||
|
"search_mode" : ["str" , "shape"],
|
||||||
|
"xchannels" : ["int" , ["3", "64", "25", "64", "38", "19", "128", "128", "38", "38", "256", "256", "256", "256", "512", "512", "512", "512"]],
|
||||||
|
"xblocks" : ["int" , ["1", "1", "2", "2"]],
|
||||||
|
"super_type" : ["str" , "infer-shape"],
|
||||||
|
"estimated_FLOP" : ["float" , "1120.44032"]
|
||||||
|
}
|
14
configs/NeurIPS-2019/ImageNet-ResNet50V1.config
Normal file
14
configs/NeurIPS-2019/ImageNet-ResNet50V1.config
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
{
|
||||||
|
"dataset" : ["str" , "imagenet"],
|
||||||
|
"arch" : ["str" , "resnet"],
|
||||||
|
"block_name" : ["str" , "Bottleneck"],
|
||||||
|
"layers" : ["int" , ["3", "4", "6", "3"]],
|
||||||
|
"deep_stem" : ["bool" , "0"],
|
||||||
|
"zero_init_residual" : ["bool" , "1"],
|
||||||
|
"class_num" : ["int" , "1000"],
|
||||||
|
"search_mode" : ["str" , "shape"],
|
||||||
|
"xchannels" : ["int" , ["3", "45", "45", "30", "102", "33", "60", "154", "68", "70", "180", "38", "38", "307", "38", "38", "410", "64", "128", "358", "38", "51", "256", "76", "76", "512", "76", "76", "512", "179", "256", "614", "100", "102", "307", "179", "230", "614", "204", "102", "307", "153", "153", "1228", "512", "512", "1434", "512", "512", "1844"]],
|
||||||
|
"xblocks" : ["int" , ["3", "4", "5", "3"]],
|
||||||
|
"super_type" : ["str" , "infer-shape"],
|
||||||
|
"estimated_FLOP" : ["float" , "2291.316289"]
|
||||||
|
}
|
76
docs/CVPR-2019-GDAS.md
Normal file
76
docs/CVPR-2019-GDAS.md
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
# [Searching for A Robust Neural Architecture in Four GPU Hours](https://arxiv.org/abs/1910.04465)
|
||||||
|
|
||||||
|
<img align="right" src="https://d-x-y.github.com/resources/paper-icon/CVPR-2019-GDAS.png" width="300">
|
||||||
|
|
||||||
|
Searching for A Robust Neural Architecture in Four GPU Hours is accepted at CVPR 2019.
|
||||||
|
In this paper, we proposed a Gradient-based searching algorithm using Differentiable Architecture Sampling (GDAS).
|
||||||
|
GDAS is baseed on DARTS and improves it with Gumbel-softmax sampling.
|
||||||
|
Concurrently at the submission period, several NAS papers (SNAS and FBNet) also utilized Gumbel-softmax sampling. We are different at how to forward and backward, see more details in our paper and codes.
|
||||||
|
Experiments on CIFAR-10, CIFAR-100, ImageNet, PTB, and WT2 are reported.
|
||||||
|
|
||||||
|
|
||||||
|
## Requirements and Preparation
|
||||||
|
|
||||||
|
Please install `Python>=3.6` and `PyTorch>=1.2.0`.
|
||||||
|
|
||||||
|
CIFAR and ImageNet should be downloaded and extracted into `$TORCH_HOME`.
|
||||||
|
|
||||||
|
### Usefull tools
|
||||||
|
1. Compute the number of parameters and FLOPs of a model:
|
||||||
|
```
|
||||||
|
from utils import get_model_infos
|
||||||
|
flop, param = get_model_infos(net, (1,3,32,32))
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Different NAS-searched architectures are defined [here](https://github.com/D-X-Y/AutoDL-Projects/blob/master/lib/nas_infer_model/DXYs/genotypes.py).
|
||||||
|
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
### Reproducing the results of our searched architecture in GDAS
|
||||||
|
Please use the following scripts to train the searched GDAS-searched CNN on CIFAR-10, CIFAR-100, and ImageNet.
|
||||||
|
```
|
||||||
|
CUDA_VISIBLE_DEVICES=0 bash ./scripts/nas-infer-train.sh cifar10 GDAS_V1 96 -1
|
||||||
|
CUDA_VISIBLE_DEVICES=0 bash ./scripts/nas-infer-train.sh cifar100 GDAS_V1 96 -1
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./scripts/nas-infer-train.sh imagenet-1k GDAS_V1 256 -1
|
||||||
|
```
|
||||||
|
If you are interested in the configs of each NAS-searched architecture, they are defined at [genotypes.py](https://github.com/D-X-Y/AutoDL-Projects/blob/master/lib/nas_infer_model/DXYs/genotypes.py).
|
||||||
|
|
||||||
|
### Searching on the NASNet search space
|
||||||
|
Please use the following scripts to use GDAS to search as in the original paper:
|
||||||
|
```
|
||||||
|
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/GDAS-search-NASNet-space.sh cifar10 1 -1
|
||||||
|
```
|
||||||
|
If you want to train the searched architecture found by the above scripts, you need to add the config of that architecture (will be printed in log) in [genotypes.py](https://github.com/D-X-Y/AutoDL-Projects/blob/master/lib/nas_infer_model/DXYs/genotypes.py).
|
||||||
|
|
||||||
|
### Searching on a small search space (NAS-Bench-201)
|
||||||
|
The GDAS searching codes on a small search space:
|
||||||
|
```
|
||||||
|
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/GDAS.sh cifar10 -1
|
||||||
|
```
|
||||||
|
|
||||||
|
The baseline searching codes are DARTS:
|
||||||
|
```
|
||||||
|
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V1.sh cifar10 -1
|
||||||
|
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V2.sh cifar10 -1
|
||||||
|
```
|
||||||
|
|
||||||
|
**After searching**, if you want to train the searched architecture found by the above scripts, please use the following codes:
|
||||||
|
```
|
||||||
|
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/NAS-Bench-201/train-a-net.sh '|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|skip_connect~1|skip_connect~2|' 16 5
|
||||||
|
```
|
||||||
|
`|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|skip_connect~1|skip_connect~2|` represents the structure of a searched architecture. My codes will automatically print it during the searching procedure.
|
||||||
|
|
||||||
|
|
||||||
|
# Citation
|
||||||
|
|
||||||
|
If you find that this project helps your research, please consider citing the following paper:
|
||||||
|
```
|
||||||
|
@inproceedings{dong2019search,
|
||||||
|
title = {Searching for A Robust Neural Architecture in Four GPU Hours},
|
||||||
|
author = {Dong, Xuanyi and Yang, Yi},
|
||||||
|
booktitle = {Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
|
||||||
|
pages = {1761--1770},
|
||||||
|
year = {2019}
|
||||||
|
}
|
||||||
|
```
|
50
docs/ICCV-2019-SETN.py
Normal file
50
docs/ICCV-2019-SETN.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
# [One-Shot Neural Architecture Search via Self-Evaluated Template Network](https://arxiv.org/abs/1910.05733)
|
||||||
|
|
||||||
|
<img align="right" src="https://d-x-y.github.com/resources/paper-icon/ICCV-2019-SETN.png" width="450">
|
||||||
|
|
||||||
|
<strong>Highlight</strong>: we equip one-shot NAS with an architecture sampler and train network weights using uniformly sampling.
|
||||||
|
|
||||||
|
One-Shot Neural Architecture Search via Self-Evaluated Template Network is accepted by ICCV 2019.
|
||||||
|
|
||||||
|
|
||||||
|
## Requirements and Preparation
|
||||||
|
|
||||||
|
Please install `Python>=3.6` and `PyTorch>=1.2.0`.
|
||||||
|
|
||||||
|
### Usefull tools
|
||||||
|
1. Compute the number of parameters and FLOPs of a model:
|
||||||
|
```
|
||||||
|
from utils import get_model_infos
|
||||||
|
flop, param = get_model_infos(net, (1,3,32,32))
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Different NAS-searched architectures are defined [here](https://github.com/D-X-Y/AutoDL-Projects/blob/master/lib/nas_infer_model/DXYs/genotypes.py).
|
||||||
|
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
Please use the following scripts to train the searched SETN-searched CNN on CIFAR-10, CIFAR-100, and ImageNet.
|
||||||
|
```
|
||||||
|
CUDA_VISIBLE_DEVICES=0 bash ./scripts/nas-infer-train.sh cifar10 SETN 96 -1
|
||||||
|
CUDA_VISIBLE_DEVICES=0 bash ./scripts/nas-infer-train.sh cifar100 SETN 96 -1
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./scripts/nas-infer-train.sh imagenet-1k SETN 256 -1
|
||||||
|
```
|
||||||
|
|
||||||
|
The searching codes of SETN on a small search space (NAS-Bench-201).
|
||||||
|
```
|
||||||
|
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/SETN.sh cifar10 -1
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
# Citation
|
||||||
|
|
||||||
|
If you find that this project helps your research, please consider citing the following paper:
|
||||||
|
```
|
||||||
|
@inproceedings{dong2019one,
|
||||||
|
title = {One-Shot Neural Architecture Search via Self-Evaluated Template Network},
|
||||||
|
author = {Dong, Xuanyi and Yang, Yi},
|
||||||
|
booktitle = {Proceedings of the IEEE International Conference on Computer Vision (ICCV)},
|
||||||
|
pages = {3681--3690},
|
||||||
|
year = {2019}
|
||||||
|
}
|
||||||
|
```
|
@ -1,40 +1,40 @@
|
|||||||
# [NAS-BENCH-102: Extending the Scope of Reproducible Neural Architecture Search](https://openreview.net/forum?id=HJxyZkBKDr)
|
# [NAS-BENCH-201: Extending the Scope of Reproducible Neural Architecture Search](https://openreview.net/forum?id=HJxyZkBKDr)
|
||||||
|
|
||||||
We propose an algorithm-agnostic NAS benchmark (NAS-Bench-102) with a fixed search space, which provides a unified benchmark for almost any up-to-date NAS algorithms.
|
We propose an algorithm-agnostic NAS benchmark (NAS-Bench-201) with a fixed search space, which provides a unified benchmark for almost any up-to-date NAS algorithms.
|
||||||
The design of our search space is inspired by that used in the most popular cell-based searching algorithms, where a cell is represented as a directed acyclic graph.
|
The design of our search space is inspired by that used in the most popular cell-based searching algorithms, where a cell is represented as a directed acyclic graph.
|
||||||
Each edge here is associated with an operation selected from a predefined operation set.
|
Each edge here is associated with an operation selected from a predefined operation set.
|
||||||
For it to be applicable for all NAS algorithms, the search space defined in NAS-Bench-102 includes 4 nodes and 5 associated operation options, which generates 15,625 neural cell candidates in total.
|
For it to be applicable for all NAS algorithms, the search space defined in NAS-Bench-201 includes 4 nodes and 5 associated operation options, which generates 15,625 neural cell candidates in total.
|
||||||
|
|
||||||
In this Markdown file, we provide:
|
In this Markdown file, we provide:
|
||||||
- [How to Use NAS-Bench-102](#how-to-use-nas-bench-102)
|
- [How to Use NAS-Bench-201](#how-to-use-nas-bench-201)
|
||||||
- [Instruction to re-generate NAS-Bench-102](#instruction-to-re-generate-nas-bench-102)
|
- [Instruction to re-generate NAS-Bench-201](#instruction-to-re-generate-nas-bench-201)
|
||||||
- [10 NAS algorithms evaluated in our paper](#to-reproduce-10-baseline-nas-algorithms-in-nas-bench-102)
|
- [10 NAS algorithms evaluated in our paper](#to-reproduce-10-baseline-nas-algorithms-in-nas-bench-201)
|
||||||
|
|
||||||
Note: please use `PyTorch >= 1.2.0` and `Python >= 3.6.0`.
|
Note: please use `PyTorch >= 1.2.0` and `Python >= 3.6.0`.
|
||||||
|
|
||||||
Simply type `pip install nas-bench-102` to install our api.
|
Simply type `pip install nas-bench-201` to install our api.
|
||||||
|
|
||||||
If you have any questions or issues, please post it at [here](https://github.com/D-X-Y/NAS-Projects/issues) or email me.
|
If you have any questions or issues, please post it at [here](https://github.com/D-X-Y/AutoDL-Projects/issues) or email me.
|
||||||
|
|
||||||
### Preparation and Download
|
### Preparation and Download
|
||||||
|
|
||||||
The benchmark file of NAS-Bench-102 can be downloaded from [Google Drive](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs) or [Baidu-Wangpan (code:6u5d)](https://pan.baidu.com/s/1CiaNH6C12zuZf7q-Ilm09w).
|
The benchmark file of NAS-Bench-201 can be downloaded from [Google Drive](https://drive.google.com/open?id=1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs) or [Baidu-Wangpan (code:6u5d)](https://pan.baidu.com/s/1CiaNH6C12zuZf7q-Ilm09w).
|
||||||
You can move it to anywhere you want and send its path to our API for initialization.
|
You can move it to anywhere you want and send its path to our API for initialization.
|
||||||
- v1.0: `NAS-Bench-102-v1_0-e61699.pth`, where `e61699` is the last six digits for this file. It contains all information except for the trained weights of each trial.
|
- v1.0: `NAS-Bench-201-v1_0-e61699.pth`, where `e61699` is the last six digits for this file. It contains all information except for the trained weights of each trial.
|
||||||
- v1.0: The full data of each architecture can be download from [Google Drive](https://drive.google.com/open?id=1X2i-JXaElsnVLuGgM4tP-yNwtsspXgdQ) (about 226GB). This compressed folder has 15625 files containing the the trained weights.
|
- v1.0: The full data of each architecture can be download from [Google Drive](https://drive.google.com/open?id=1X2i-JXaElsnVLuGgM4tP-yNwtsspXgdQ) (about 226GB). This compressed folder has 15625 files containing the the trained weights.
|
||||||
- v1.0: Checkpoints for 3 runs of each baseline NAS algorithm are provided in [Google Drive](https://drive.google.com/open?id=1eAgLZQAViP3r6dA0_ZOOGG9zPLXhGwXi).
|
- v1.0: Checkpoints for 3 runs of each baseline NAS algorithm are provided in [Google Drive](https://drive.google.com/open?id=1eAgLZQAViP3r6dA0_ZOOGG9zPLXhGwXi).
|
||||||
|
|
||||||
The training and evaluation data used in NAS-Bench-102 can be downloaded from [Google Drive](https://drive.google.com/open?id=1L0Lzq8rWpZLPfiQGd6QR8q5xLV88emU7) or [Baidu-Wangpan (code:4fg7)](https://pan.baidu.com/s/1XAzavPKq3zcat1yBA1L2tQ).
|
The training and evaluation data used in NAS-Bench-201 can be downloaded from [Google Drive](https://drive.google.com/open?id=1L0Lzq8rWpZLPfiQGd6QR8q5xLV88emU7) or [Baidu-Wangpan (code:4fg7)](https://pan.baidu.com/s/1XAzavPKq3zcat1yBA1L2tQ).
|
||||||
It is recommended to put these data into `$TORCH_HOME` (`~/.torch/` by default). If you want to generate NAS-Bench-102 or similar NAS datasets or training models by yourself, you need these data.
|
It is recommended to put these data into `$TORCH_HOME` (`~/.torch/` by default). If you want to generate NAS-Bench-201 or similar NAS datasets or training models by yourself, you need these data.
|
||||||
|
|
||||||
## How to Use NAS-Bench-102
|
## How to Use NAS-Bench-201
|
||||||
|
|
||||||
1. Creating an API instance from a file:
|
1. Creating an API instance from a file:
|
||||||
```
|
```
|
||||||
from nas_102_api import NASBench102API as API
|
from nas_201_api import NASBench201API as API
|
||||||
api = API('$path_to_meta_nas_bench_file')
|
api = API('$path_to_meta_nas_bench_file')
|
||||||
api = API('NAS-Bench-102-v1_0-e61699.pth')
|
api = API('NAS-Bench-201-v1_0-e61699.pth')
|
||||||
api = API('{:}/{:}'.format(os.environ['TORCH_HOME'], 'NAS-Bench-102-v1_0-e61699.pth'))
|
api = API('{:}/{:}'.format(os.environ['TORCH_HOME'], 'NAS-Bench-201-v1_0-e61699.pth'))
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Show the number of architectures `len(api)` and each architecture `api[i]`:
|
2. Show the number of architectures `len(api)` and each architecture `api[i]`:
|
||||||
@ -72,16 +72,16 @@ index = api.query_index_by_arch('|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1
|
|||||||
api.show(index)
|
api.show(index)
|
||||||
```
|
```
|
||||||
|
|
||||||
5. For other usages, please see `lib/nas_102_api/api.py`
|
5. For other usages, please see `lib/nas_201_api/api.py`
|
||||||
|
|
||||||
|
|
||||||
### Detailed Instruction
|
### Detailed Instruction
|
||||||
|
|
||||||
In `nas_102_api`, we define three classes: `NASBench102API`, `ArchResults`, `ResultsCount`.
|
In `nas_201_api`, we define three classes: `NASBench201API`, `ArchResults`, `ResultsCount`.
|
||||||
|
|
||||||
`ResultsCount` maintains all information of a specific trial. One can instantiate ResultsCount and get the info via the following codes (`000157-FULL.pth` saves all information of all trials of 157-th architecture):
|
`ResultsCount` maintains all information of a specific trial. One can instantiate ResultsCount and get the info via the following codes (`000157-FULL.pth` saves all information of all trials of 157-th architecture):
|
||||||
```
|
```
|
||||||
from nas_102_api import ResultsCount
|
from nas_201_api import ResultsCount
|
||||||
xdata = torch.load('000157-FULL.pth')
|
xdata = torch.load('000157-FULL.pth')
|
||||||
odata = xdata['full']['all_results'][('cifar10-valid', 777)]
|
odata = xdata['full']['all_results'][('cifar10-valid', 777)]
|
||||||
result = ResultsCount.create_from_state_dict( odata )
|
result = ResultsCount.create_from_state_dict( odata )
|
||||||
@ -100,7 +100,7 @@ network.load_state_dict(result.get_net_param())
|
|||||||
|
|
||||||
`ArchResults` maintains all information of all trials of an architecture. Please see the following usages:
|
`ArchResults` maintains all information of all trials of an architecture. Please see the following usages:
|
||||||
```
|
```
|
||||||
from nas_102_api import ArchResults
|
from nas_201_api import ArchResults
|
||||||
xdata = torch.load('000157-FULL.pth')
|
xdata = torch.load('000157-FULL.pth')
|
||||||
archRes = ArchResults.create_from_state_dict(xdata['less']) # load trials trained with 12 epochs
|
archRes = ArchResults.create_from_state_dict(xdata['less']) # load trials trained with 12 epochs
|
||||||
archRes = ArchResults.create_from_state_dict(xdata['full']) # load trials trained with 200 epochs
|
archRes = ArchResults.create_from_state_dict(xdata['full']) # load trials trained with 200 epochs
|
||||||
@ -112,28 +112,30 @@ print(archRes.get_metrics('cifar10-valid', 'x-valid', None, False)) # print the
|
|||||||
print(archRes.get_metrics('cifar10-valid', 'x-valid', None, True)) # print loss/accuracy/time of a randomly selected trial
|
print(archRes.get_metrics('cifar10-valid', 'x-valid', None, True)) # print loss/accuracy/time of a randomly selected trial
|
||||||
```
|
```
|
||||||
|
|
||||||
`NASBench102API` is the topest level api. Please see the following usages:
|
`NASBench201API` is the topest level api. Please see the following usages:
|
||||||
```
|
```
|
||||||
from nas_102_api import NASBench102API as API
|
from nas_201_api import NASBench201API as API
|
||||||
api = API('NAS-Bench-102-v1_0-e61699.pth') # This will load all the information of NAS-Bench-102 except the trained weights
|
api = API('NAS-Bench-201-v1_0-e61699.pth') # This will load all the information of NAS-Bench-201 except the trained weights
|
||||||
api = API('{:}/{:}'.format(os.environ['TORCH_HOME'], 'NAS-Bench-102-v1_0-e61699.pth')) # The same as the above line while I usually save NAS-Bench-102-v1_0-e61699.pth in ~/.torch/.
|
api = API('{:}/{:}'.format(os.environ['TORCH_HOME'], 'NAS-Bench-201-v1_0-e61699.pth')) # The same as the above line while I usually save NAS-Bench-201-v1_0-e61699.pth in ~/.torch/.
|
||||||
api.show(-1) # show info of all architectures
|
api.show(-1) # show info of all architectures
|
||||||
api.reload('{:}/{:}'.format(os.environ['TORCH_HOME'], 'NAS-BENCH-102-4-v1.0-archive'), 3) # This code will reload the information 3-th architecture with the trained weights
|
api.reload('{:}/{:}'.format(os.environ['TORCH_HOME'], 'NAS-BENCH-201-4-v1.0-archive'), 3) # This code will reload the information 3-th architecture with the trained weights
|
||||||
|
|
||||||
weights = api.get_net_param(3, 'cifar10', None) # Obtaining the weights of all trials for the 3-th architecture on cifar10. It will returns a dict, where the key is the seed and the value is the trained weights.
|
weights = api.get_net_param(3, 'cifar10', None) # Obtaining the weights of all trials for the 3-th architecture on cifar10. It will returns a dict, where the key is the seed and the value is the trained weights.
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
## Instruction to Re-Generate NAS-Bench-102
|
## Instruction to Re-Generate NAS-Bench-201
|
||||||
|
|
||||||
1. generate the meta file for NAS-Bench-102 using the following script, where `NAS-BENCH-102` indicates the name and `4` indicates the maximum number of nodes in a cell.
|
There are four steps to build NAS-Bench-201.
|
||||||
|
|
||||||
|
1. generate the meta file for NAS-Bench-201 using the following script, where `NAS-BENCH-201` indicates the name and `4` indicates the maximum number of nodes in a cell.
|
||||||
```
|
```
|
||||||
bash scripts-search/NAS-Bench-102/meta-gen.sh NAS-BENCH-102 4
|
bash scripts-search/NAS-Bench-201/meta-gen.sh NAS-BENCH-201 4
|
||||||
```
|
```
|
||||||
|
|
||||||
2. train earch architecture on a single GPU (see commands in `output/NAS-BENCH-102-4/BENCH-102-N4.opt-full.script`, which is automatically generated by step-1).
|
2. train earch architecture on a single GPU (see commands in `output/NAS-BENCH-201-4/BENCH-201-N4.opt-full.script`, which is automatically generated by step-1).
|
||||||
```
|
```
|
||||||
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/NAS-Bench-102/train-models.sh 0 0 389 -1 '777 888 999'
|
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/NAS-Bench-201/train-models.sh 0 0 389 -1 '777 888 999'
|
||||||
```
|
```
|
||||||
This command will train 390 architectures (id from 0 to 389) using the following four kinds of splits with three random seeds (777, 888, 999).
|
This command will train 390 architectures (id from 0 to 389) using the following four kinds of splits with three random seeds (777, 888, 999).
|
||||||
|
|
||||||
@ -144,54 +146,55 @@ This command will train 390 architectures (id from 0 to 389) using the following
|
|||||||
| CIFAR-100 | train | valid / test |
|
| CIFAR-100 | train | valid / test |
|
||||||
| ImageNet-16-120 | train | valid / test |
|
| ImageNet-16-120 | train | valid / test |
|
||||||
|
|
||||||
Note that the above `train`, `valid`, and `test` indicate the proposed splits in our NAS-Bench-102, and they might be different with the original splits.
|
Note that the above `train`, `valid`, and `test` indicate the proposed splits in our NAS-Bench-201, and they might be different with the original splits.
|
||||||
|
|
||||||
3. calculate the latency, merge the results of all architectures, and simplify the results.
|
3. calculate the latency, merge the results of all architectures, and simplify the results.
|
||||||
(see commands in `output/NAS-BENCH-102-4/meta-node-4.cal-script.txt` which is automatically generated by step-1).
|
(see commands in `output/NAS-BENCH-201-4/meta-node-4.cal-script.txt` which is automatically generated by step-1).
|
||||||
```
|
```
|
||||||
OMP_NUM_THREADS=6 CUDA_VISIBLE_DEVICES=0 python exps/NAS-Bench-102/statistics.py --mode cal --target_dir 000000-000389-C16-N5
|
OMP_NUM_THREADS=6 CUDA_VISIBLE_DEVICES=0 python exps/NAS-Bench-201/statistics.py --mode cal --target_dir 000000-000389-C16-N5
|
||||||
```
|
```
|
||||||
|
|
||||||
4. merge all results into a single file for NAS-Bench-102-API.
|
4. merge all results into a single file for NAS-Bench-201-API.
|
||||||
```
|
```
|
||||||
OMP_NUM_THREADS=4 python exps/NAS-Bench-102/statistics.py --mode merge
|
OMP_NUM_THREADS=4 python exps/NAS-Bench-201/statistics.py --mode merge
|
||||||
```
|
```
|
||||||
This command will generate a single file `output/NAS-BENCH-102-4/simplifies/C16-N5-final-infos.pth` contains all the data for NAS-Bench-102.
|
This command will generate a single file `output/NAS-BENCH-201-4/simplifies/C16-N5-final-infos.pth` contains all the data for NAS-Bench-201.
|
||||||
This generated file will serve as the input for our NAS-Bench-102 API.
|
This generated file will serve as the input for our NAS-Bench-201 API.
|
||||||
|
|
||||||
[option] train a single architecture on a single GPU.
|
[option] train a single architecture on a single GPU.
|
||||||
```
|
```
|
||||||
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/NAS-Bench-102/train-a-net.sh resnet 16 5
|
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/NAS-Bench-201/train-a-net.sh resnet 16 5
|
||||||
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/NAS-Bench-102/train-a-net.sh '|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|skip_connect~1|skip_connect~2|' 16 5
|
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/NAS-Bench-201/train-a-net.sh '|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|skip_connect~1|skip_connect~2|' 16 5
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
## To Reproduce 10 Baseline NAS Algorithms in NAS-Bench-102
|
## To Reproduce 10 Baseline NAS Algorithms in NAS-Bench-201
|
||||||
|
|
||||||
We have tried our best to implement each method. However, still, some algorithms might obtain non-optimal results since their hyper-parameters might not fit our NAS-Bench-102.
|
We have tried our best to implement each method. However, still, some algorithms might obtain non-optimal results since their hyper-parameters might not fit our NAS-Bench-201.
|
||||||
If researchers can provide better results with different hyper-parameters, we are happy to update results according to the new experimental results. We also welcome more NAS algorithms to test on our dataset and would include them accordingly.
|
If researchers can provide better results with different hyper-parameters, we are happy to update results according to the new experimental results. We also welcome more NAS algorithms to test on our dataset and would include them accordingly.
|
||||||
|
|
||||||
**Note that** you need to prepare the training and test data as described in [Preparation and Download](#preparation-and-download)
|
**Note that** you need to prepare the training and test data as described in [Preparation and Download](#preparation-and-download)
|
||||||
|
|
||||||
- [1] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V1.sh cifar10 -1`, where `cifar10` can be replaced with `cifar100` or `ImageNet16-120`.
|
- [1] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V1.sh cifar10 1 -1`, where `cifar10` can be replaced with `cifar100` or `ImageNet16-120`.
|
||||||
- [2] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V2.sh cifar10 -1`
|
- [2] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V2.sh cifar10 1 -1`
|
||||||
- [3] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/GDAS.sh cifar10 -1`
|
- [3] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/GDAS.sh cifar10 1 -1`
|
||||||
- [4] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/SETN.sh cifar10 -1`
|
- [4] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/SETN.sh cifar10 1 -1`
|
||||||
- [5] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/ENAS.sh cifar10 -1`
|
- [5] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/ENAS.sh cifar10 1 -1`
|
||||||
- [6] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/RANDOM-NAS.sh cifar10 -1`
|
- [6] `CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/RANDOM-NAS.sh cifar10 1 -1`
|
||||||
- [7] `bash ./scripts-search/algos/R-EA.sh -1`
|
- [7] `bash ./scripts-search/algos/R-EA.sh -1`
|
||||||
- [8] `bash ./scripts-search/algos/Random.sh -1`
|
- [8] `bash ./scripts-search/algos/Random.sh -1`
|
||||||
- [9] `bash ./scripts-search/algos/REINFORCE.sh -1`
|
- [9] `bash ./scripts-search/algos/REINFORCE.sh 0.5 -1`
|
||||||
- [10] `bash ./scripts-search/algos/BOHB.sh -1`
|
- [10] `bash ./scripts-search/algos/BOHB.sh -1`
|
||||||
|
|
||||||
|
In commands [1-6], the first args `cifar10` indicates the dataset name, the second args `1` indicates the behavior of BN, and the first args `-1` indicates the random seed.
|
||||||
|
|
||||||
|
|
||||||
# Citation
|
# Citation
|
||||||
|
|
||||||
If you find that NAS-Bench-102 helps your research, please consider citing it:
|
If you find that NAS-Bench-201 helps your research, please consider citing it:
|
||||||
```
|
```
|
||||||
@inproceedings{dong2020nasbench102,
|
@inproceedings{dong2020nasbench201,
|
||||||
title = {NAS-Bench-102: Extending the Scope of Reproducible Neural Architecture Search},
|
title = {NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search},
|
||||||
author = {Dong, Xuanyi and Yang, Yi},
|
author = {Dong, Xuanyi and Yang, Yi},
|
||||||
booktitle = {International Conference on Learning Representations (ICLR)},
|
booktitle = {International Conference on Learning Representations (ICLR)},
|
||||||
url = {https://openreview.net/forum?id=HJxyZkBKDr},
|
url = {https://openreview.net/forum?id=HJxyZkBKDr},
|
71
docs/NIPS-2019-TAS.md
Normal file
71
docs/NIPS-2019-TAS.md
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
# [Network Pruning via Transformable Architecture Search](https://arxiv.org/abs/1905.09717)
|
||||||
|
|
||||||
|
[](https://paperswithcode.com/sota/network-pruning-on-cifar-100?p=network-pruning-via-transformable)
|
||||||
|
|
||||||
|
Network Pruning via Transformable Architecture Search is accepted by NeurIPS 2019.
|
||||||
|
In this paper, we proposed a differentiable searching strategy for transformable architectures, i.e., searching for the depth and width of a deep neural network.
|
||||||
|
You could see the highlight of our Transformable Architecture Search (TAS) at our [project page](https://xuanyidong.com/assets/projects/NeurIPS-2019-TAS.html).
|
||||||
|
|
||||||
|
<p float="left">
|
||||||
|
<img src="https://d-x-y.github.com/resources/paper-icon/NIPS-2019-TAS.png" width="680px"/>
|
||||||
|
<img src="https://d-x-y.github.com/resources/videos/NeurIPS-2019-TAS/TAS-arch.gif?raw=true" width="180px"/>
|
||||||
|
</p>
|
||||||
|
|
||||||
|
|
||||||
|
## Requirements and Preparation
|
||||||
|
|
||||||
|
Please install `Python>=3.6` and `PyTorch>=1.2.0`.
|
||||||
|
|
||||||
|
CIFAR and ImageNet should be downloaded and extracted into `$TORCH_HOME`.
|
||||||
|
The proposed method utilized knowledge distillation (KD), which require pre-trained models. Please download these models from [Google Drive](https://drive.google.com/open?id=1ANmiYEGX-IQZTfH8w0aSpj-Wypg-0DR-) (or train by yourself) and save into `.latent-data`.
|
||||||
|
|
||||||
|
**LOGS**:
|
||||||
|
We provide some logs at [Google Drive](https://drive.google.com/open?id=1_qUY4DTtuW_l6ZonynQAC9ttqy35fxZ-). It includes (1) logs of training searched shape of ResNet-18 and ResNet-50 on ImageNet, (2) logs of searching and training for ResNet-164 on CIFAR, (3) logs of searching and training for ResNet56 on CIFAR-10, (4) logs of searching and training for ResNet110 on CIFAR-100.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
Use `bash ./scripts/prepare.sh` to prepare data splits for `CIFAR-10`, `CIFARR-100`, and `ILSVRC2012`.
|
||||||
|
If you do not have `ILSVRC2012` data, pleasee comment L12 in `./scripts/prepare.sh`.
|
||||||
|
|
||||||
|
args: `cifar10` indicates the dataset name, `ResNet56` indicates the basemodel name, `CIFARX` indicates the searching hyper-parameters, `0.47/0.57` indicates the expected FLOP ratio, `-1` indicates the random seed.
|
||||||
|
|
||||||
|
**Model Configuration**
|
||||||
|
|
||||||
|
The searched shapes for ResNet-20/32/56/110/164 and ResNet-18/50 in Table 3/4 in the original paper are listed in [`configs/NeurIPS-2019`](https://github.com/D-X-Y/AutoDL-Projects/tree/master/configs/NeurIPS-2019).
|
||||||
|
|
||||||
|
**Search for the depth configuration of ResNet**
|
||||||
|
```
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1 bash ./scripts-search/search-depth-gumbel.sh cifar10 ResNet110 CIFARX 0.57 -1
|
||||||
|
```
|
||||||
|
|
||||||
|
**Search for the width configuration of ResNet**
|
||||||
|
```
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1 bash ./scripts-search/search-width-gumbel.sh cifar10 ResNet110 CIFARX 0.57 -1
|
||||||
|
```
|
||||||
|
|
||||||
|
**Search for both depth and width configuration of ResNet**
|
||||||
|
```
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1 bash ./scripts-search/search-shape-cifar.sh cifar10 ResNet56 CIFARX 0.47 -1
|
||||||
|
```
|
||||||
|
|
||||||
|
**Training the searched shape config from TAS:**
|
||||||
|
If you want to directly train a model with searched configuration of TAS, try these:
|
||||||
|
```
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1 bash ./scripts/tas-infer-train.sh cifar10 C010-ResNet32 -1
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1 bash ./scripts/tas-infer-train.sh cifar100 C100-ResNet32 -1
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./scripts/tas-infer-train.sh imagenet-1k ImageNet-ResNet18V1 -1
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./scripts/tas-infer-train.sh imagenet-1k ImageNet-ResNet50V1 -1
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
# Citation
|
||||||
|
|
||||||
|
If you find that this project helps your research, please consider citing the following paper:
|
||||||
|
```
|
||||||
|
@inproceedings{dong2019tas,
|
||||||
|
title = {Network Pruning via Transformable Architecture Search},
|
||||||
|
author = {Dong, Xuanyi and Yang, Yi},
|
||||||
|
booktitle = {Neural Information Processing Systems (NeurIPS)},
|
||||||
|
year = {2019}
|
||||||
|
}
|
||||||
|
```
|
@ -48,7 +48,7 @@ def main(xargs):
|
|||||||
# Create an instance of the model
|
# Create an instance of the model
|
||||||
config = dict2config({'name': 'GDAS',
|
config = dict2config({'name': 'GDAS',
|
||||||
'C' : xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes,
|
'C' : xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes,
|
||||||
'num_classes': 10, 'space': 'nas-bench-102', 'affine': True}, None)
|
'num_classes': 10, 'space': 'nas-bench-201', 'affine': True}, None)
|
||||||
model = get_cell_based_tiny_net(config)
|
model = get_cell_based_tiny_net(config)
|
||||||
#import pdb; pdb.set_trace()
|
#import pdb; pdb.set_trace()
|
||||||
#model.build(((64, 32, 32, 3), (1,)))
|
#model.build(((64, 32, 32, 3), (1,)))
|
||||||
@ -126,7 +126,7 @@ def main(xargs):
|
|||||||
print('{:} genotype : {:}\n{:}\n'.format(time_string(), genotype, model.get_np_alphas()))
|
print('{:} genotype : {:}\n{:}\n'.format(time_string(), genotype, model.get_np_alphas()))
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser(description='NAS-Bench-102', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
parser = argparse.ArgumentParser(description='NAS-Bench-201', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
# training details
|
# training details
|
||||||
parser.add_argument('--epochs' , type=int , default= 250 , help='')
|
parser.add_argument('--epochs' , type=int , default= 250 , help='')
|
||||||
parser.add_argument('--tau_max' , type=float, default= 10 , help='')
|
parser.add_argument('--tau_max' , type=float, default= 10 , help='')
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
##################################################
|
##################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||||
##################################################
|
##################################################
|
||||||
# python exps/NAS-Bench-102/check.py --base_save_dir
|
# python exps/NAS-Bench-201/check.py --base_save_dir
|
||||||
##################################################
|
##################################################
|
||||||
import os, sys, time, argparse, collections
|
import os, sys, time, argparse, collections
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
@ -67,8 +67,8 @@ def check_files(save_dir, meta_file, basestr):
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='NAS Benchmark 102', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
parser = argparse.ArgumentParser(description='NAS Benchmark 201', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
parser.add_argument('--base_save_dir', type=str, default='./output/NAS-BENCH-102-4', help='The base-name of folder to save checkpoints and log.')
|
parser.add_argument('--base_save_dir', type=str, default='./output/NAS-BENCH-201-4', help='The base-name of folder to save checkpoints and log.')
|
||||||
parser.add_argument('--max_node', type=int, default=4, help='The maximum node in a cell.')
|
parser.add_argument('--max_node', type=int, default=4, help='The maximum node in a cell.')
|
||||||
parser.add_argument('--channel', type=int, default=16, help='The number of channels.')
|
parser.add_argument('--channel', type=int, default=16, help='The number of channels.')
|
||||||
parser.add_argument('--num_cells', type=int, default=5, help='The number of cells in one stage.')
|
parser.add_argument('--num_cells', type=int, default=5, help='The number of cells in one stage.')
|
||||||
@ -78,7 +78,7 @@ if __name__ == '__main__':
|
|||||||
meta_path = save_dir / 'meta-node-{:}.pth'.format(args.max_node)
|
meta_path = save_dir / 'meta-node-{:}.pth'.format(args.max_node)
|
||||||
assert save_dir.exists(), 'invalid save dir path : {:}'.format(save_dir)
|
assert save_dir.exists(), 'invalid save dir path : {:}'.format(save_dir)
|
||||||
assert meta_path.exists(), 'invalid saved meta path : {:}'.format(meta_path)
|
assert meta_path.exists(), 'invalid saved meta path : {:}'.format(meta_path)
|
||||||
print ('check NAS-Bench-102 in {:}'.format(save_dir))
|
print ('check NAS-Bench-201 in {:}'.format(save_dir))
|
||||||
|
|
||||||
basestr = 'C{:}-N{:}'.format(args.channel, args.num_cells)
|
basestr = 'C{:}-N{:}'.format(args.channel, args.num_cells)
|
||||||
check_files(save_dir, meta_path, basestr)
|
check_files(save_dir, meta_path, basestr)
|
@ -8,15 +8,15 @@ def read(fname='README.md'):
|
|||||||
|
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name = "nas_bench_102",
|
name = "nas_bench_201",
|
||||||
version = "1.0",
|
version = "1.0",
|
||||||
author = "Xuanyi Dong",
|
author = "Xuanyi Dong",
|
||||||
author_email = "dongxuanyi888@gmail.com",
|
author_email = "dongxuanyi888@gmail.com",
|
||||||
description = "API for NAS-Bench-102 (a benchmark for neural architecture search).",
|
description = "API for NAS-Bench-201 (a benchmark for neural architecture search).",
|
||||||
license = "MIT",
|
license = "MIT",
|
||||||
keywords = "NAS Dataset API DeepLearning",
|
keywords = "NAS Dataset API DeepLearning",
|
||||||
url = "https://github.com/D-X-Y/NAS-Projects",
|
url = "https://github.com/D-X-Y/NAS-Bench-201",
|
||||||
packages=['nas_102_api'],
|
packages=['nas_201_api'],
|
||||||
long_description=read('README.md'),
|
long_description=read('README.md'),
|
||||||
long_description_content_type='text/markdown',
|
long_description_content_type='text/markdown',
|
||||||
classifiers=[
|
classifiers=[
|
@ -1,5 +1,5 @@
|
|||||||
###############################################################
|
###############################################################
|
||||||
# NAS-Bench-102, ICLR 2020 (https://arxiv.org/abs/2001.00326) #
|
# NAS-Bench-201, ICLR 2020 (https://arxiv.org/abs/2001.00326) #
|
||||||
###############################################################
|
###############################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019-2020 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019-2020 #
|
||||||
###############################################################
|
###############################################################
|
||||||
@ -213,7 +213,7 @@ def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less, se
|
|||||||
|
|
||||||
|
|
||||||
def generate_meta_info(save_dir, max_node, divide=40):
|
def generate_meta_info(save_dir, max_node, divide=40):
|
||||||
aa_nas_bench_ss = get_search_spaces('cell', 'nas-bench-102')
|
aa_nas_bench_ss = get_search_spaces('cell', 'nas-bench-201')
|
||||||
archs = CellStructure.gen_all(aa_nas_bench_ss, max_node, False)
|
archs = CellStructure.gen_all(aa_nas_bench_ss, max_node, False)
|
||||||
print ('There are {:} archs vs {:}.'.format(len(archs), len(aa_nas_bench_ss) ** ((max_node-1)*max_node/2)))
|
print ('There are {:} archs vs {:}.'.format(len(archs), len(aa_nas_bench_ss) ** ((max_node-1)*max_node/2)))
|
||||||
|
|
||||||
@ -249,15 +249,15 @@ def generate_meta_info(save_dir, max_node, divide=40):
|
|||||||
torch.save(info, save_name)
|
torch.save(info, save_name)
|
||||||
print ('save the meta file into {:}'.format(save_name))
|
print ('save the meta file into {:}'.format(save_name))
|
||||||
|
|
||||||
script_name_full = save_dir / 'BENCH-102-N{:}.opt-full.script'.format(max_node)
|
script_name_full = save_dir / 'BENCH-201-N{:}.opt-full.script'.format(max_node)
|
||||||
script_name_less = save_dir / 'BENCH-102-N{:}.opt-less.script'.format(max_node)
|
script_name_less = save_dir / 'BENCH-201-N{:}.opt-less.script'.format(max_node)
|
||||||
full_file = open(str(script_name_full), 'w')
|
full_file = open(str(script_name_full), 'w')
|
||||||
less_file = open(str(script_name_less), 'w')
|
less_file = open(str(script_name_less), 'w')
|
||||||
gaps = total_arch // divide
|
gaps = total_arch // divide
|
||||||
for start in range(0, total_arch, gaps):
|
for start in range(0, total_arch, gaps):
|
||||||
xend = min(start+gaps, total_arch)
|
xend = min(start+gaps, total_arch)
|
||||||
full_file.write('bash ./scripts-search/NAS-Bench-102/train-models.sh 0 {:5d} {:5d} -1 \'777 888 999\'\n'.format(start, xend-1))
|
full_file.write('bash ./scripts-search/NAS-Bench-201/train-models.sh 0 {:5d} {:5d} -1 \'777 888 999\'\n'.format(start, xend-1))
|
||||||
less_file.write('bash ./scripts-search/NAS-Bench-102/train-models.sh 1 {:5d} {:5d} -1 \'777 888 999\'\n'.format(start, xend-1))
|
less_file.write('bash ./scripts-search/NAS-Bench-201/train-models.sh 1 {:5d} {:5d} -1 \'777 888 999\'\n'.format(start, xend-1))
|
||||||
print ('save the training script into {:} and {:}'.format(script_name_full, script_name_less))
|
print ('save the training script into {:} and {:}'.format(script_name_full, script_name_less))
|
||||||
full_file.close()
|
full_file.close()
|
||||||
less_file.close()
|
less_file.close()
|
||||||
@ -267,14 +267,14 @@ def generate_meta_info(save_dir, max_node, divide=40):
|
|||||||
with open(str(script_name), 'w') as cfile:
|
with open(str(script_name), 'w') as cfile:
|
||||||
for start in range(0, total_arch, gaps):
|
for start in range(0, total_arch, gaps):
|
||||||
xend = min(start+gaps, total_arch)
|
xend = min(start+gaps, total_arch)
|
||||||
cfile.write('{:} python exps/NAS-Bench-102/statistics.py --mode cal --target_dir {:06d}-{:06d}-C16-N5\n'.format(macro, start, xend-1))
|
cfile.write('{:} python exps/NAS-Bench-201/statistics.py --mode cal --target_dir {:06d}-{:06d}-C16-N5\n'.format(macro, start, xend-1))
|
||||||
print ('save the post-processing script into {:}'.format(script_name))
|
print ('save the post-processing script into {:}'.format(script_name))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
#mode_choices = ['meta', 'new', 'cover'] + ['specific-{:}'.format(_) for _ in CellArchitectures.keys()]
|
#mode_choices = ['meta', 'new', 'cover'] + ['specific-{:}'.format(_) for _ in CellArchitectures.keys()]
|
||||||
#parser = argparse.ArgumentParser(description='Algorithm-Agnostic NAS Benchmark', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
#parser = argparse.ArgumentParser(description='Algorithm-Agnostic NAS Benchmark', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
parser = argparse.ArgumentParser(description='NAS-Bench-102', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
parser = argparse.ArgumentParser(description='NAS-Bench-201', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
parser.add_argument('--mode' , type=str, required=True, help='The script mode.')
|
parser.add_argument('--mode' , type=str, required=True, help='The script mode.')
|
||||||
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')
|
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')
|
||||||
parser.add_argument('--max_node', type=int, help='The maximum node in a cell.')
|
parser.add_argument('--max_node', type=int, help='The maximum node in a cell.')
|
@ -12,9 +12,9 @@ if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
|||||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||||
from config_utils import load_config, dict2config
|
from config_utils import load_config, dict2config
|
||||||
from datasets import get_datasets
|
from datasets import get_datasets
|
||||||
# NAS-Bench-102 related module or function
|
# NAS-Bench-201 related module or function
|
||||||
from models import CellStructure, get_cell_based_tiny_net
|
from models import CellStructure, get_cell_based_tiny_net
|
||||||
from nas_102_api import ArchResults, ResultsCount
|
from nas_201_api import ArchResults, ResultsCount
|
||||||
from functions import pure_evaluate
|
from functions import pure_evaluate
|
||||||
|
|
||||||
|
|
||||||
@ -271,9 +271,9 @@ def merge_all(save_dir, meta_file, basestr):
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='NAS-BENCH-102', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
parser = argparse.ArgumentParser(description='NAS-BENCH-201', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
parser.add_argument('--mode' , type=str, choices=['cal', 'merge'], help='The running mode for this script.')
|
parser.add_argument('--mode' , type=str, choices=['cal', 'merge'], help='The running mode for this script.')
|
||||||
parser.add_argument('--base_save_dir', type=str, default='./output/NAS-BENCH-102-4', help='The base-name of folder to save checkpoints and log.')
|
parser.add_argument('--base_save_dir', type=str, default='./output/NAS-BENCH-201-4', help='The base-name of folder to save checkpoints and log.')
|
||||||
parser.add_argument('--target_dir' , type=str, help='The target directory.')
|
parser.add_argument('--target_dir' , type=str, help='The target directory.')
|
||||||
parser.add_argument('--max_node' , type=int, default=4, help='The maximum node in a cell.')
|
parser.add_argument('--max_node' , type=int, default=4, help='The maximum node in a cell.')
|
||||||
parser.add_argument('--channel' , type=int, default=16, help='The number of channels.')
|
parser.add_argument('--channel' , type=int, default=16, help='The number of channels.')
|
@ -1,7 +1,7 @@
|
|||||||
##################################################
|
##################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||||
########################################################
|
########################################################
|
||||||
# python exps/NAS-Bench-102/test-correlation.py --api_path $HOME/.torch/NAS-Bench-102-v1_0-e61699.pth
|
# python exps/NAS-Bench-201/test-correlation.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth
|
||||||
########################################################
|
########################################################
|
||||||
import os, sys, time, glob, random, argparse
|
import os, sys, time, glob, random, argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -18,7 +18,7 @@ from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_che
|
|||||||
from utils import get_model_infos, obtain_accuracy
|
from utils import get_model_infos, obtain_accuracy
|
||||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||||
from models import get_cell_based_tiny_net, get_search_spaces, CellStructure
|
from models import get_cell_based_tiny_net, get_search_spaces, CellStructure
|
||||||
from nas_102_api import NASBench102API as API
|
from nas_201_api import NASBench201API as API
|
||||||
|
|
||||||
|
|
||||||
def valid_func(xloader, network, criterion):
|
def valid_func(xloader, network, criterion):
|
||||||
@ -197,9 +197,9 @@ def check_cor_for_bandit_v2(meta_file, test_epoch, use_less_or_not, is_rand):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser("Analysis of NAS-Bench-102")
|
parser = argparse.ArgumentParser("Analysis of NAS-Bench-201")
|
||||||
parser.add_argument('--save_dir', type=str, default='./output/search-cell-nas-bench-102/visuals', help='The base-name of folder to save checkpoints and log.')
|
parser.add_argument('--save_dir', type=str, default='./output/search-cell-nas-bench-201/visuals', help='The base-name of folder to save checkpoints and log.')
|
||||||
parser.add_argument('--api_path', type=str, default=None, help='The path to the NAS-Bench-102 benchmark file.')
|
parser.add_argument('--api_path', type=str, default=None, help='The path to the NAS-Bench-201 benchmark file.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
vis_save_dir = Path(args.save_dir)
|
vis_save_dir = Path(args.save_dir)
|
@ -1,7 +1,7 @@
|
|||||||
##################################################
|
##################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||||
##################################################
|
##################################################
|
||||||
# python exps/NAS-Bench-102/visualize.py --api_path $HOME/.torch/NAS-Bench-102-v1_0-e61699.pth
|
# python exps/NAS-Bench-201/visualize.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth
|
||||||
##################################################
|
##################################################
|
||||||
import os, sys, time, argparse, collections
|
import os, sys, time, argparse, collections
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@ -19,7 +19,7 @@ import matplotlib.pyplot as plt
|
|||||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||||
from log_utils import time_string
|
from log_utils import time_string
|
||||||
from nas_102_api import NASBench102API as API
|
from nas_201_api import NASBench201API as API
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -367,13 +367,66 @@ def write_video(save_dir):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def plot_results_nas_v2(api, dataset_xset_a, dataset_xset_b, root, file_name, y_lims):
|
||||||
|
#print ('root-path={:}, dataset={:}, xset={:}'.format(root, dataset, xset))
|
||||||
|
print ('root-path : {:} and {:}'.format(dataset_xset_a, dataset_xset_b))
|
||||||
|
checkpoints = ['./output/search-cell-nas-bench-201/R-EA-cifar10/results.pth',
|
||||||
|
'./output/search-cell-nas-bench-201/REINFORCE-cifar10/results.pth',
|
||||||
|
'./output/search-cell-nas-bench-201/RAND-cifar10/results.pth',
|
||||||
|
'./output/search-cell-nas-bench-201/BOHB-cifar10/results.pth'
|
||||||
|
]
|
||||||
|
legends, indexes = ['REA', 'REINFORCE', 'RANDOM', 'BOHB'], None
|
||||||
|
All_Accs_A, All_Accs_B = OrderedDict(), OrderedDict()
|
||||||
|
for legend, checkpoint in zip(legends, checkpoints):
|
||||||
|
all_indexes = torch.load(checkpoint, map_location='cpu')
|
||||||
|
accuracies_A, accuracies_B = [], []
|
||||||
|
accuracies = []
|
||||||
|
for x in all_indexes:
|
||||||
|
info = api.arch2infos_full[ x ]
|
||||||
|
metrics = info.get_metrics(dataset_xset_a[0], dataset_xset_a[1], None, False)
|
||||||
|
accuracies_A.append( metrics['accuracy'] )
|
||||||
|
metrics = info.get_metrics(dataset_xset_b[0], dataset_xset_b[1], None, False)
|
||||||
|
accuracies_B.append( metrics['accuracy'] )
|
||||||
|
accuracies.append( (accuracies_A[-1], accuracies_B[-1]) )
|
||||||
|
if indexes is None: indexes = list(range(len(all_indexes)))
|
||||||
|
accuracies = sorted(accuracies)
|
||||||
|
All_Accs_A[legend] = [x[0] for x in accuracies]
|
||||||
|
All_Accs_B[legend] = [x[1] for x in accuracies]
|
||||||
|
|
||||||
|
color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k']
|
||||||
|
dpi, width, height = 300, 3400, 2600
|
||||||
|
LabelSize, LegendFontsize = 28, 28
|
||||||
|
figsize = width / float(dpi), height / float(dpi)
|
||||||
|
fig = plt.figure(figsize=figsize)
|
||||||
|
x_axis = np.arange(0, 600)
|
||||||
|
plt.xlim(0, max(indexes))
|
||||||
|
plt.ylim(y_lims[0], y_lims[1])
|
||||||
|
interval_x, interval_y = 100, y_lims[2]
|
||||||
|
plt.xticks(np.arange(0, max(indexes), interval_x), fontsize=LegendFontsize)
|
||||||
|
plt.yticks(np.arange(y_lims[0],y_lims[1], interval_y), fontsize=LegendFontsize)
|
||||||
|
plt.grid()
|
||||||
|
plt.xlabel('The index of runs', fontsize=LabelSize)
|
||||||
|
plt.ylabel('The accuracy (%)', fontsize=LabelSize)
|
||||||
|
|
||||||
|
for idx, legend in enumerate(legends):
|
||||||
|
plt.plot(indexes, All_Accs_B[legend], color=color_set[idx], linestyle='--', label='{:}'.format(legend), lw=1, alpha=0.5)
|
||||||
|
plt.plot(indexes, All_Accs_A[legend], color=color_set[idx], linestyle='-', lw=1)
|
||||||
|
for All_Accs in [All_Accs_A, All_Accs_B]:
|
||||||
|
print ('{:} : mean = {:}, std = {:} :: {:.2f}$\\pm${:.2f}'.format(legend, np.mean(All_Accs[legend]), np.std(All_Accs[legend]), np.mean(All_Accs[legend]), np.std(All_Accs[legend])))
|
||||||
|
plt.legend(loc=4, fontsize=LegendFontsize)
|
||||||
|
save_path = root / '{:}'.format(file_name)
|
||||||
|
print('save figure into {:}\n'.format(save_path))
|
||||||
|
fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def plot_results_nas(api, dataset, xset, root, file_name, y_lims):
|
def plot_results_nas(api, dataset, xset, root, file_name, y_lims):
|
||||||
print ('root-path={:}, dataset={:}, xset={:}'.format(root, dataset, xset))
|
print ('root-path={:}, dataset={:}, xset={:}'.format(root, dataset, xset))
|
||||||
checkpoints = ['./output/search-cell-nas-bench-102/R-EA-cifar10/results.pth',
|
checkpoints = ['./output/search-cell-nas-bench-201/R-EA-cifar10/results.pth',
|
||||||
'./output/search-cell-nas-bench-102/REINFORCE-cifar10/results.pth',
|
'./output/search-cell-nas-bench-201/REINFORCE-cifar10/results.pth',
|
||||||
'./output/search-cell-nas-bench-102/RAND-cifar10/results.pth',
|
'./output/search-cell-nas-bench-201/RAND-cifar10/results.pth',
|
||||||
'./output/search-cell-nas-bench-102/BOHB-cifar10/results.pth'
|
'./output/search-cell-nas-bench-201/BOHB-cifar10/results.pth'
|
||||||
]
|
]
|
||||||
legends, indexes = ['REA', 'REINFORCE', 'RANDOM', 'BOHB'], None
|
legends, indexes = ['REA', 'REINFORCE', 'RANDOM', 'BOHB'], None
|
||||||
All_Accs = OrderedDict()
|
All_Accs = OrderedDict()
|
||||||
@ -422,19 +475,19 @@ def just_show(api):
|
|||||||
xlist = np.array(xlist)
|
xlist = np.array(xlist)
|
||||||
print ('{:4s} : mean-time={:.2f} s'.format(xkey, xlist.mean()))
|
print ('{:4s} : mean-time={:.2f} s'.format(xkey, xlist.mean()))
|
||||||
|
|
||||||
xpaths = {'RSPS' : 'output/search-cell-nas-bench-102/RANDOM-NAS-cifar10/checkpoint/',
|
xpaths = {'RSPS' : 'output/search-cell-nas-bench-201/RANDOM-NAS-cifar10/checkpoint/',
|
||||||
'DARTS-V1': 'output/search-cell-nas-bench-102/DARTS-V1-cifar10/checkpoint/',
|
'DARTS-V1': 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/',
|
||||||
'DARTS-V2': 'output/search-cell-nas-bench-102/DARTS-V2-cifar10/checkpoint/',
|
'DARTS-V2': 'output/search-cell-nas-bench-201/DARTS-V2-cifar10/checkpoint/',
|
||||||
'GDAS' : 'output/search-cell-nas-bench-102/GDAS-cifar10/checkpoint/',
|
'GDAS' : 'output/search-cell-nas-bench-201/GDAS-cifar10/checkpoint/',
|
||||||
'SETN' : 'output/search-cell-nas-bench-102/SETN-cifar10/checkpoint/',
|
'SETN' : 'output/search-cell-nas-bench-201/SETN-cifar10/checkpoint/',
|
||||||
'ENAS' : 'output/search-cell-nas-bench-102/ENAS-cifar10/checkpoint/',
|
'ENAS' : 'output/search-cell-nas-bench-201/ENAS-cifar10/checkpoint/',
|
||||||
}
|
}
|
||||||
xseeds = {'RSPS' : [5349, 59613, 5983],
|
xseeds = {'RSPS' : [5349, 59613, 5983],
|
||||||
'DARTS-V1': [11416, 72873, 81184],
|
'DARTS-V1': [11416, 72873, 81184],
|
||||||
'DARTS-V2': [43330, 79405, 79423],
|
'DARTS-V2': [43330, 79405, 79423],
|
||||||
'GDAS' : [19677, 884, 95950],
|
'GDAS' : [19677, 884, 95950],
|
||||||
'SETN' : [20518, 61817, 89144],
|
'SETN' : [20518, 61817, 89144],
|
||||||
'ENAS' : [30801, 75610, 97745],
|
'ENAS' : [3231, 34238, 96929],
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_accs(xdata, index=-1):
|
def get_accs(xdata, index=-1):
|
||||||
@ -480,23 +533,26 @@ def show_nas_sharing_w(api, dataset, subset, vis_save_dir, file_name, y_lims, x_
|
|||||||
plt.xlabel('The searching epoch', fontsize=LabelSize)
|
plt.xlabel('The searching epoch', fontsize=LabelSize)
|
||||||
plt.ylabel('The accuracy (%)', fontsize=LabelSize)
|
plt.ylabel('The accuracy (%)', fontsize=LabelSize)
|
||||||
|
|
||||||
xpaths = {'RSPS' : 'output/search-cell-nas-bench-102/RANDOM-NAS-cifar10/checkpoint/',
|
xpaths = {'RSPS' : 'output/search-cell-nas-bench-201/RANDOM-NAS-cifar10/checkpoint/',
|
||||||
'DARTS-V1': 'output/search-cell-nas-bench-102/DARTS-V1-cifar10/checkpoint/',
|
'DARTS-V1': 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/',
|
||||||
'DARTS-V2': 'output/search-cell-nas-bench-102/DARTS-V2-cifar10/checkpoint/',
|
'DARTS-V2': 'output/search-cell-nas-bench-201/DARTS-V2-cifar10/checkpoint/',
|
||||||
'GDAS' : 'output/search-cell-nas-bench-102/GDAS-cifar10/checkpoint/',
|
'GDAS' : 'output/search-cell-nas-bench-201/GDAS-cifar10/checkpoint/',
|
||||||
'SETN' : 'output/search-cell-nas-bench-102/SETN-cifar10/checkpoint/',
|
'SETN' : 'output/search-cell-nas-bench-201/SETN-cifar10/checkpoint/',
|
||||||
'ENAS' : 'output/search-cell-nas-bench-102/ENAS-cifar10/checkpoint/',
|
'ENAS' : 'output/search-cell-nas-bench-201/ENAS-cifar10/checkpoint/',
|
||||||
}
|
}
|
||||||
xseeds = {'RSPS' : [5349, 59613, 5983],
|
xseeds = {'RSPS' : [5349, 59613, 5983],
|
||||||
'DARTS-V1': [11416, 72873, 81184],
|
'DARTS-V1': [11416, 72873, 81184, 28640],
|
||||||
'DARTS-V2': [43330, 79405, 79423],
|
'DARTS-V2': [43330, 79405, 79423],
|
||||||
'GDAS' : [19677, 884, 95950],
|
'GDAS' : [19677, 884, 95950],
|
||||||
'SETN' : [20518, 61817, 89144],
|
'SETN' : [20518, 61817, 89144],
|
||||||
'ENAS' : [30801, 75610, 97745],
|
'ENAS' : [3231, 34238, 96929],
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_accs(xdata):
|
def get_accs(xdata):
|
||||||
epochs, xresults = xdata['epoch'], []
|
epochs, xresults = xdata['epoch'], []
|
||||||
|
if -1 in xdata['genotypes']:
|
||||||
|
metrics = api.arch2infos_full[ api.query_index_by_arch(xdata['genotypes'][-1]) ].get_metrics(dataset, subset, None, False)
|
||||||
|
else:
|
||||||
metrics = api.arch2infos_full[ api.random() ].get_metrics(dataset, subset, None, False)
|
metrics = api.arch2infos_full[ api.random() ].get_metrics(dataset, subset, None, False)
|
||||||
xresults.append( metrics['accuracy'] )
|
xresults.append( metrics['accuracy'] )
|
||||||
for iepoch in range(epochs):
|
for iepoch in range(epochs):
|
||||||
@ -528,12 +584,120 @@ def show_nas_sharing_w(api, dataset, subset, vis_save_dir, file_name, y_lims, x_
|
|||||||
fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf')
|
fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf')
|
||||||
|
|
||||||
|
|
||||||
|
def show_nas_sharing_w_v2(api, data_sub_a, data_sub_b, vis_save_dir, file_name, y_lims, x_maxs):
|
||||||
|
color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k']
|
||||||
|
dpi, width, height = 300, 3400, 2600
|
||||||
|
LabelSize, LegendFontsize = 28, 28
|
||||||
|
figsize = width / float(dpi), height / float(dpi)
|
||||||
|
fig = plt.figure(figsize=figsize)
|
||||||
|
#x_maxs = 250
|
||||||
|
plt.xlim(0, x_maxs+1)
|
||||||
|
plt.ylim(y_lims[0], y_lims[1])
|
||||||
|
interval_x, interval_y = x_maxs // 5, y_lims[2]
|
||||||
|
plt.xticks(np.arange(0, x_maxs+1, interval_x), fontsize=LegendFontsize)
|
||||||
|
plt.yticks(np.arange(y_lims[0],y_lims[1], interval_y), fontsize=LegendFontsize)
|
||||||
|
plt.grid()
|
||||||
|
plt.xlabel('The searching epoch', fontsize=LabelSize)
|
||||||
|
plt.ylabel('The accuracy (%)', fontsize=LabelSize)
|
||||||
|
|
||||||
|
xpaths = {'RSPS' : 'output/search-cell-nas-bench-201/RANDOM-NAS-cifar10/checkpoint/',
|
||||||
|
'DARTS-V1': 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/',
|
||||||
|
'DARTS-V2': 'output/search-cell-nas-bench-201/DARTS-V2-cifar10/checkpoint/',
|
||||||
|
'GDAS' : 'output/search-cell-nas-bench-201/GDAS-cifar10/checkpoint/',
|
||||||
|
'SETN' : 'output/search-cell-nas-bench-201/SETN-cifar10/checkpoint/',
|
||||||
|
'ENAS' : 'output/search-cell-nas-bench-201/ENAS-cifar10/checkpoint/',
|
||||||
|
}
|
||||||
|
xseeds = {'RSPS' : [5349, 59613, 5983],
|
||||||
|
'DARTS-V1': [11416, 72873, 81184, 28640],
|
||||||
|
'DARTS-V2': [43330, 79405, 79423],
|
||||||
|
'GDAS' : [19677, 884, 95950],
|
||||||
|
'SETN' : [20518, 61817, 89144],
|
||||||
|
'ENAS' : [3231, 34238, 96929],
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_accs(xdata, dataset, subset):
|
||||||
|
epochs, xresults = xdata['epoch'], []
|
||||||
|
if -1 in xdata['genotypes']:
|
||||||
|
metrics = api.arch2infos_full[ api.query_index_by_arch(xdata['genotypes'][-1]) ].get_metrics(dataset, subset, None, False)
|
||||||
|
else:
|
||||||
|
metrics = api.arch2infos_full[ api.random() ].get_metrics(dataset, subset, None, False)
|
||||||
|
xresults.append( metrics['accuracy'] )
|
||||||
|
for iepoch in range(epochs):
|
||||||
|
genotype = xdata['genotypes'][iepoch]
|
||||||
|
index = api.query_index_by_arch(genotype)
|
||||||
|
metrics = api.arch2infos_full[index].get_metrics(dataset, subset, None, False)
|
||||||
|
xresults.append( metrics['accuracy'] )
|
||||||
|
return xresults
|
||||||
|
|
||||||
|
if x_maxs == 50:
|
||||||
|
xox, xxxstrs = 'v2', ['DARTS-V1', 'DARTS-V2']
|
||||||
|
elif x_maxs == 250:
|
||||||
|
xox, xxxstrs = 'v1', ['RSPS', 'GDAS', 'SETN', 'ENAS']
|
||||||
|
else: raise ValueError('invalid x_maxs={:}'.format(x_maxs))
|
||||||
|
|
||||||
|
for idx, method in enumerate(xxxstrs):
|
||||||
|
xkey = method
|
||||||
|
all_paths = [ '{:}/seed-{:}-basic.pth'.format(xpaths[xkey], seed) for seed in xseeds[xkey] ]
|
||||||
|
all_datas = [torch.load(xpath, map_location='cpu') for xpath in all_paths]
|
||||||
|
accyss_A = np.array( [get_accs(xdatas, data_sub_a[0], data_sub_a[1]) for xdatas in all_datas] )
|
||||||
|
accyss_B = np.array( [get_accs(xdatas, data_sub_b[0], data_sub_b[1]) for xdatas in all_datas] )
|
||||||
|
epochs = list(range(accyss_A.shape[1]))
|
||||||
|
for j, accyss in enumerate([accyss_A, accyss_B]):
|
||||||
|
plt.plot(epochs, [accyss[:,i].mean() for i in epochs], color=color_set[idx*2+j], linestyle='-' if j==0 else '--', label='{:} ({:})'.format(method, 'VALID' if j == 0 else 'TEST'), lw=2, alpha=0.9)
|
||||||
|
plt.fill_between(epochs, [accyss[:,i].mean()-accyss[:,i].std() for i in epochs], [accyss[:,i].mean()+accyss[:,i].std() for i in epochs], alpha=0.2, color=color_set[idx*2+j])
|
||||||
|
#plt.legend(loc=4, fontsize=LegendFontsize)
|
||||||
|
plt.legend(loc=0, fontsize=LegendFontsize)
|
||||||
|
save_path = vis_save_dir / '{:}-{:}'.format(xox, file_name)
|
||||||
|
print('save figure into {:}\n'.format(save_path))
|
||||||
|
fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf')
|
||||||
|
|
||||||
|
|
||||||
|
def show_reinforce(api, root, dataset, xset, file_name, y_lims):
|
||||||
|
print ('root-path={:}, dataset={:}, xset={:}'.format(root, dataset, xset))
|
||||||
|
LRs = ['0.01', '0.02', '0.1', '0.2', '0.5', '1.0', '1.5', '2.0', '2.5', '3.0']
|
||||||
|
checkpoints = ['./output/search-cell-nas-bench-201/REINFORCE-cifar10-{:}/results.pth'.format(x) for x in LRs]
|
||||||
|
acc_lr_dict, indexes = {}, None
|
||||||
|
for lr, checkpoint in zip(LRs, checkpoints):
|
||||||
|
all_indexes, accuracies = torch.load(checkpoint, map_location='cpu'), []
|
||||||
|
for x in all_indexes:
|
||||||
|
info = api.arch2infos_full[ x ]
|
||||||
|
metrics = info.get_metrics(dataset, xset, None, False)
|
||||||
|
accuracies.append( metrics['accuracy'] )
|
||||||
|
if indexes is None: indexes = list(range(len(accuracies)))
|
||||||
|
acc_lr_dict[lr] = np.array( sorted(accuracies) )
|
||||||
|
print ('LR={:.3f}, mean={:}, std={:}'.format(float(lr), acc_lr_dict[lr].mean(), acc_lr_dict[lr].std()))
|
||||||
|
|
||||||
|
color_set = ['r', 'b', 'g', 'c', 'm', 'y', 'k']
|
||||||
|
dpi, width, height = 300, 3400, 2600
|
||||||
|
LabelSize, LegendFontsize = 28, 22
|
||||||
|
figsize = width / float(dpi), height / float(dpi)
|
||||||
|
fig = plt.figure(figsize=figsize)
|
||||||
|
x_axis = np.arange(0, 600)
|
||||||
|
plt.xlim(0, max(indexes))
|
||||||
|
plt.ylim(y_lims[0], y_lims[1])
|
||||||
|
interval_x, interval_y = 100, y_lims[2]
|
||||||
|
plt.xticks(np.arange(0, max(indexes), interval_x), fontsize=LegendFontsize)
|
||||||
|
plt.yticks(np.arange(y_lims[0],y_lims[1], interval_y), fontsize=LegendFontsize)
|
||||||
|
plt.grid()
|
||||||
|
plt.xlabel('The index of runs', fontsize=LabelSize)
|
||||||
|
plt.ylabel('The accuracy (%)', fontsize=LabelSize)
|
||||||
|
|
||||||
|
for idx, LR in enumerate(LRs):
|
||||||
|
legend = 'LR={:.2f}'.format(float(LR))
|
||||||
|
color, linestyle = color_set[idx // 2], '-' if idx % 2 == 0 else '-.'
|
||||||
|
plt.plot(indexes, acc_lr_dict[LR], color=color, linestyle=linestyle, label=legend, lw=2, alpha=0.8)
|
||||||
|
print ('{:} : mean = {:}, std = {:} :: {:.2f}$\\pm${:.2f}'.format(legend, np.mean(acc_lr_dict[LR]), np.std(acc_lr_dict[LR]), np.mean(acc_lr_dict[LR]), np.std(acc_lr_dict[LR])))
|
||||||
|
plt.legend(loc=4, fontsize=LegendFontsize)
|
||||||
|
save_path = root / '{:}-{:}-{:}.pdf'.format(dataset, xset, file_name)
|
||||||
|
print('save figure into {:}\n'.format(save_path))
|
||||||
|
fig.savefig(str(save_path), dpi=dpi, bbox_inches='tight', format='pdf')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='NAS-Bench-102', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
parser = argparse.ArgumentParser(description='NAS-Bench-201', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
parser.add_argument('--save_dir', type=str, default='./output/search-cell-nas-bench-102/visuals', help='The base-name of folder to save checkpoints and log.')
|
parser.add_argument('--save_dir', type=str, default='./output/search-cell-nas-bench-201/visuals', help='The base-name of folder to save checkpoints and log.')
|
||||||
parser.add_argument('--api_path', type=str, default=None, help='The path to the NAS-Bench-102 benchmark file.')
|
parser.add_argument('--api_path', type=str, default=None, help='The path to the NAS-Bench-201 benchmark file.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
vis_save_dir = Path(args.save_dir)
|
vis_save_dir = Path(args.save_dir)
|
||||||
@ -548,6 +712,9 @@ if __name__ == '__main__':
|
|||||||
#visualize_relative_ranking(vis_save_dir)
|
#visualize_relative_ranking(vis_save_dir)
|
||||||
|
|
||||||
api = API(args.api_path)
|
api = API(args.api_path)
|
||||||
|
show_reinforce(api, vis_save_dir, 'cifar10-valid' , 'x-valid', 'REINFORCE-CIFAR-10', (75, 95, 5))
|
||||||
|
import pdb; pdb.set_trace()
|
||||||
|
|
||||||
for x_maxs in [50, 250]:
|
for x_maxs in [50, 250]:
|
||||||
show_nas_sharing_w(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
show_nas_sharing_w(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
||||||
show_nas_sharing_w(api, 'cifar10' , 'ori-test', vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
show_nas_sharing_w(api, 'cifar10' , 'ori-test', vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
||||||
@ -555,12 +722,19 @@ if __name__ == '__main__':
|
|||||||
show_nas_sharing_w(api, 'cifar100' , 'x-test' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
show_nas_sharing_w(api, 'cifar100' , 'x-test' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
||||||
show_nas_sharing_w(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
show_nas_sharing_w(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
||||||
show_nas_sharing_w(api, 'ImageNet16-120', 'x-test' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
show_nas_sharing_w(api, 'ImageNet16-120', 'x-test' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
||||||
|
|
||||||
|
show_nas_sharing_w_v2(api, ('cifar10-valid' , 'x-valid'), ('cifar10' , 'ori-test') , vis_save_dir, 'DARTS-CIFAR010.pdf', (0, 100,10), 50)
|
||||||
|
show_nas_sharing_w_v2(api, ('cifar100' , 'x-valid'), ('cifar100' , 'x-test' ) , vis_save_dir, 'DARTS-CIFAR100.pdf', (0, 100,10), 50)
|
||||||
|
show_nas_sharing_w_v2(api, ('ImageNet16-120', 'x-valid'), ('ImageNet16-120', 'x-test' ) , vis_save_dir, 'DARTS-ImageNet.pdf', (0, 100,10), 50)
|
||||||
|
#just_show(api)
|
||||||
"""
|
"""
|
||||||
just_show(api)
|
|
||||||
plot_results_nas(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (85,95, 1))
|
plot_results_nas(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (85,95, 1))
|
||||||
plot_results_nas(api, 'cifar10' , 'ori-test', vis_save_dir, 'nas-com.pdf', (85,95, 1))
|
plot_results_nas(api, 'cifar10' , 'ori-test', vis_save_dir, 'nas-com.pdf', (85,95, 1))
|
||||||
plot_results_nas(api, 'cifar100' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (55,75, 3))
|
plot_results_nas(api, 'cifar100' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (55,75, 3))
|
||||||
plot_results_nas(api, 'cifar100' , 'x-test' , vis_save_dir, 'nas-com.pdf', (55,75, 3))
|
plot_results_nas(api, 'cifar100' , 'x-test' , vis_save_dir, 'nas-com.pdf', (55,75, 3))
|
||||||
plot_results_nas(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-com.pdf', (35,50, 3))
|
plot_results_nas(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-com.pdf', (35,50, 3))
|
||||||
plot_results_nas(api, 'ImageNet16-120', 'x-test' , vis_save_dir, 'nas-com.pdf', (35,50, 3))
|
plot_results_nas(api, 'ImageNet16-120', 'x-test' , vis_save_dir, 'nas-com.pdf', (35,50, 3))
|
||||||
|
plot_results_nas_v2(api, ('cifar10-valid' , 'x-valid'), ('cifar10' , 'ori-test'), vis_save_dir, 'nas-com-v2-cifar010.pdf', (85,95, 1))
|
||||||
|
plot_results_nas_v2(api, ('cifar100' , 'x-valid'), ('cifar100' , 'x-test' ), vis_save_dir, 'nas-com-v2-cifar100.pdf', (60,75, 3))
|
||||||
|
plot_results_nas_v2(api, ('ImageNet16-120', 'x-valid'), ('ImageNet16-120', 'x-test' ), vis_save_dir, 'nas-com-v2-imagenet.pdf', (35,48, 2))
|
||||||
"""
|
"""
|
@ -1,9 +1,10 @@
|
|||||||
##################################################
|
##################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
|
||||||
##################################################
|
###################################################################
|
||||||
# required to install hpbandster #################
|
# BOHB: Robust and Efficient Hyperparameter Optimization at Scale #
|
||||||
# bash ./scripts-search/algos/BOHB.sh -1 #
|
# required to install hpbandster ##################################
|
||||||
##################################################
|
# bash ./scripts-search/algos/BOHB.sh -1 ##################
|
||||||
|
###################################################################
|
||||||
import os, sys, time, glob, random, argparse
|
import os, sys, time, glob, random, argparse
|
||||||
import numpy as np, collections
|
import numpy as np, collections
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
@ -17,7 +18,7 @@ from datasets import get_datasets, SearchDataset
|
|||||||
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
|
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
|
||||||
from utils import get_model_infos, obtain_accuracy
|
from utils import get_model_infos, obtain_accuracy
|
||||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||||
from nas_102_api import NASBench102API as API
|
from nas_201_api import NASBench201API as API
|
||||||
from models import CellStructure, get_search_spaces
|
from models import CellStructure, get_search_spaces
|
||||||
# BOHB: Robust and Efficient Hyperparameter Optimization at Scale, ICML 2018
|
# BOHB: Robust and Efficient Hyperparameter Optimization at Scale, ICML 2018
|
||||||
import ConfigSpace
|
import ConfigSpace
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
##################################################
|
##################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
|
||||||
########################################################
|
########################################################
|
||||||
# DARTS: Differentiable Architecture Search, ICLR 2019 #
|
# DARTS: Differentiable Architecture Search, ICLR 2019 #
|
||||||
########################################################
|
########################################################
|
||||||
@ -17,7 +17,7 @@ from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_che
|
|||||||
from utils import get_model_infos, obtain_accuracy
|
from utils import get_model_infos, obtain_accuracy
|
||||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||||
from models import get_cell_based_tiny_net, get_search_spaces
|
from models import get_cell_based_tiny_net, get_search_spaces
|
||||||
from nas_102_api import NASBench102API as API
|
from nas_201_api import NASBench201API as API
|
||||||
|
|
||||||
|
|
||||||
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger):
|
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger):
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
##################################################
|
##################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
|
||||||
########################################################
|
########################################################
|
||||||
# DARTS: Differentiable Architecture Search, ICLR 2019 #
|
# DARTS: Differentiable Architecture Search, ICLR 2019 #
|
||||||
########################################################
|
########################################################
|
||||||
@ -17,7 +17,7 @@ from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_che
|
|||||||
from utils import get_model_infos, obtain_accuracy
|
from utils import get_model_infos, obtain_accuracy
|
||||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||||
from models import get_cell_based_tiny_net, get_search_spaces
|
from models import get_cell_based_tiny_net, get_search_spaces
|
||||||
from nas_102_api import NASBench102API as API
|
from nas_201_api import NASBench201API as API
|
||||||
|
|
||||||
|
|
||||||
def _concat(xs):
|
def _concat(xs):
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
##################################################
|
##################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
|
||||||
##################################################
|
##########################################################################
|
||||||
|
# Efficient Neural Architecture Search via Parameters Sharing, ICML 2018 #
|
||||||
|
##########################################################################
|
||||||
import os, sys, time, glob, random, argparse
|
import os, sys, time, glob, random, argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
@ -15,7 +17,7 @@ from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_che
|
|||||||
from utils import get_model_infos, obtain_accuracy
|
from utils import get_model_infos, obtain_accuracy
|
||||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||||
from models import get_cell_based_tiny_net, get_search_spaces
|
from models import get_cell_based_tiny_net, get_search_spaces
|
||||||
from nas_102_api import NASBench102API as API
|
from nas_201_api import NASBench201API as API
|
||||||
|
|
||||||
|
|
||||||
def train_shared_cnn(xloader, shared_cnn, controller, criterion, scheduler, optimizer, epoch_str, print_freq, logger):
|
def train_shared_cnn(xloader, shared_cnn, controller, criterion, scheduler, optimizer, epoch_str, print_freq, logger):
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
##################################################
|
##################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
|
||||||
###########################################################################
|
###########################################################################
|
||||||
# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 #
|
# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 #
|
||||||
###########################################################################
|
###########################################################################
|
||||||
@ -17,7 +17,7 @@ from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_che
|
|||||||
from utils import get_model_infos, obtain_accuracy
|
from utils import get_model_infos, obtain_accuracy
|
||||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||||
from models import get_cell_based_tiny_net, get_search_spaces
|
from models import get_cell_based_tiny_net, get_search_spaces
|
||||||
from nas_102_api import NASBench102API as API
|
from nas_201_api import NASBench201API as API
|
||||||
|
|
||||||
|
|
||||||
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger):
|
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger):
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
##################################################
|
##################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
|
||||||
##################################################
|
##############################################################################
|
||||||
|
# Random Search and Reproducibility for Neural Architecture Search, UAI 2019 #
|
||||||
|
##############################################################################
|
||||||
import os, sys, time, glob, random, argparse
|
import os, sys, time, glob, random, argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
@ -15,7 +17,7 @@ from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_che
|
|||||||
from utils import get_model_infos, obtain_accuracy
|
from utils import get_model_infos, obtain_accuracy
|
||||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||||
from models import get_cell_based_tiny_net, get_search_spaces
|
from models import get_cell_based_tiny_net, get_search_spaces
|
||||||
from nas_102_api import NASBench102API as API
|
from nas_201_api import NASBench201API as API
|
||||||
|
|
||||||
|
|
||||||
def search_func(xloader, network, criterion, scheduler, w_optimizer, epoch_str, print_freq, logger):
|
def search_func(xloader, network, criterion, scheduler, w_optimizer, epoch_str, print_freq, logger):
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
##################################################
|
##################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
|
||||||
##################################################
|
##############################################################################
|
||||||
import os, sys, time, glob, random, argparse
|
import os, sys, time, glob, random, argparse
|
||||||
import numpy as np, collections
|
import numpy as np, collections
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
@ -15,7 +15,7 @@ from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_che
|
|||||||
from utils import get_model_infos, obtain_accuracy
|
from utils import get_model_infos, obtain_accuracy
|
||||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||||
from models import get_search_spaces
|
from models import get_search_spaces
|
||||||
from nas_102_api import NASBench102API as API
|
from nas_201_api import NASBench201API as API
|
||||||
from R_EA import train_and_eval, random_architecture_func
|
from R_EA import train_and_eval, random_architecture_func
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
##################################################
|
##################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
|
||||||
##################################################################
|
##################################################################
|
||||||
# Regularized Evolution for Image Classifier Architecture Search #
|
# Regularized Evolution for Image Classifier Architecture Search #
|
||||||
##################################################################
|
##################################################################
|
||||||
@ -16,7 +16,7 @@ from datasets import get_datasets, SearchDataset
|
|||||||
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
|
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
|
||||||
from utils import get_model_infos, obtain_accuracy
|
from utils import get_model_infos, obtain_accuracy
|
||||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||||
from nas_102_api import NASBench102API as API
|
from nas_201_api import NASBench201API as API
|
||||||
from models import CellStructure, get_search_spaces
|
from models import CellStructure, get_search_spaces
|
||||||
|
|
||||||
|
|
||||||
@ -31,30 +31,8 @@ class Model(object):
|
|||||||
return '{:}'.format(self.arch)
|
return '{:}'.format(self.arch)
|
||||||
|
|
||||||
|
|
||||||
def valid_func(xloader, network, criterion):
|
# This function is to mimic the training and evaluatinig procedure for a single architecture `arch`.
|
||||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
# The time_cost is calculated as the total training time for a few (e.g., 12 epochs) plus the evaluation time for one epoch.
|
||||||
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
|
||||||
network.train()
|
|
||||||
end = time.time()
|
|
||||||
with torch.no_grad():
|
|
||||||
for step, (arch_inputs, arch_targets) in enumerate(xloader):
|
|
||||||
arch_targets = arch_targets.cuda(non_blocking=True)
|
|
||||||
# measure data loading time
|
|
||||||
data_time.update(time.time() - end)
|
|
||||||
# prediction
|
|
||||||
_, logits = network(arch_inputs)
|
|
||||||
arch_loss = criterion(logits, arch_targets)
|
|
||||||
# record
|
|
||||||
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
|
|
||||||
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
|
|
||||||
arch_top1.update (arch_prec1.item(), arch_inputs.size(0))
|
|
||||||
arch_top5.update (arch_prec5.item(), arch_inputs.size(0))
|
|
||||||
# measure elapsed time
|
|
||||||
batch_time.update(time.time() - end)
|
|
||||||
end = time.time()
|
|
||||||
return arch_losses.avg, arch_top1.avg, arch_top5.avg
|
|
||||||
|
|
||||||
|
|
||||||
def train_and_eval(arch, nas_bench, extra_info):
|
def train_and_eval(arch, nas_bench, extra_info):
|
||||||
if nas_bench is not None:
|
if nas_bench is not None:
|
||||||
arch_index = nas_bench.query_index_by_arch( arch )
|
arch_index = nas_bench.query_index_by_arch( arch )
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
##################################################
|
##################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
|
||||||
######################################################################################
|
######################################################################################
|
||||||
# One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 #
|
# One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 #
|
||||||
######################################################################################
|
######################################################################################
|
||||||
@ -17,7 +17,7 @@ from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_che
|
|||||||
from utils import get_model_infos, obtain_accuracy
|
from utils import get_model_infos, obtain_accuracy
|
||||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||||
from models import get_cell_based_tiny_net, get_search_spaces
|
from models import get_cell_based_tiny_net, get_search_spaces
|
||||||
from nas_102_api import NASBench102API as API
|
from nas_201_api import NASBench201API as API
|
||||||
|
|
||||||
|
|
||||||
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger):
|
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger):
|
||||||
|
@ -17,7 +17,7 @@ from datasets import get_datasets, SearchDataset
|
|||||||
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
|
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
|
||||||
from utils import get_model_infos, obtain_accuracy
|
from utils import get_model_infos, obtain_accuracy
|
||||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||||
from nas_102_api import NASBench102API as API
|
from nas_201_api import NASBench201API as API
|
||||||
from models import CellStructure, get_search_spaces
|
from models import CellStructure, get_search_spaces
|
||||||
from R_EA import train_and_eval
|
from R_EA import train_and_eval
|
||||||
|
|
||||||
@ -128,6 +128,7 @@ def main(xargs, nas_bench):
|
|||||||
search_space = get_search_spaces('cell', xargs.search_space_name)
|
search_space = get_search_spaces('cell', xargs.search_space_name)
|
||||||
policy = Policy(xargs.max_nodes, search_space)
|
policy = Policy(xargs.max_nodes, search_space)
|
||||||
optimizer = torch.optim.Adam(policy.parameters(), lr=xargs.learning_rate)
|
optimizer = torch.optim.Adam(policy.parameters(), lr=xargs.learning_rate)
|
||||||
|
#optimizer = torch.optim.SGD(policy.parameters(), lr=xargs.learning_rate)
|
||||||
eps = np.finfo(np.float32).eps.item()
|
eps = np.finfo(np.float32).eps.item()
|
||||||
baseline = ExponentialMovingAverage(xargs.EMA_momentum)
|
baseline = ExponentialMovingAverage(xargs.EMA_momentum)
|
||||||
logger.log('policy : {:}'.format(policy))
|
logger.log('policy : {:}'.format(policy))
|
||||||
@ -141,13 +142,14 @@ def main(xargs, nas_bench):
|
|||||||
# attempts = 0
|
# attempts = 0
|
||||||
x_start_time = time.time()
|
x_start_time = time.time()
|
||||||
logger.log('Will start searching with time budget of {:} s.'.format(xargs.time_budget))
|
logger.log('Will start searching with time budget of {:} s.'.format(xargs.time_budget))
|
||||||
total_steps, total_costs = 0, 0
|
total_steps, total_costs, trace = 0, 0, []
|
||||||
#for istep in range(xargs.RL_steps):
|
#for istep in range(xargs.RL_steps):
|
||||||
while total_costs < xargs.time_budget:
|
while total_costs < xargs.time_budget:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
log_prob, action = select_action( policy )
|
log_prob, action = select_action( policy )
|
||||||
arch = policy.generate_arch( action )
|
arch = policy.generate_arch( action )
|
||||||
reward, cost_time = train_and_eval(arch, nas_bench, extra_info)
|
reward, cost_time = train_and_eval(arch, nas_bench, extra_info)
|
||||||
|
trace.append( (reward, arch) )
|
||||||
# accumulate time
|
# accumulate time
|
||||||
if total_costs + cost_time < xargs.time_budget:
|
if total_costs + cost_time < xargs.time_budget:
|
||||||
total_costs += cost_time
|
total_costs += cost_time
|
||||||
@ -166,7 +168,8 @@ def main(xargs, nas_bench):
|
|||||||
#logger.log('----> {:}'.format(policy.arch_parameters))
|
#logger.log('----> {:}'.format(policy.arch_parameters))
|
||||||
#logger.log('')
|
#logger.log('')
|
||||||
|
|
||||||
best_arch = policy.genotype()
|
# best_arch = policy.genotype() # first version
|
||||||
|
best_arch = max(trace, key=lambda x: x[0])[1]
|
||||||
logger.log('REINFORCE finish with {:} steps and {:.1f} s (real cost={:.3f}).'.format(total_steps, total_costs, time.time()-x_start_time))
|
logger.log('REINFORCE finish with {:} steps and {:.1f} s (real cost={:.3f}).'.format(total_steps, total_costs, time.time()-x_start_time))
|
||||||
info = nas_bench.query_by_arch( best_arch )
|
info = nas_bench.query_by_arch( best_arch )
|
||||||
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
|
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
|
||||||
|
@ -8,11 +8,11 @@ from collections import OrderedDict
|
|||||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||||
|
|
||||||
from nas_102_api import NASBench102API as API
|
from nas_201_api import NASBench201API as API
|
||||||
|
|
||||||
def test_nas_api():
|
def test_nas_api():
|
||||||
from nas_102_api import ArchResults
|
from nas_201_api import ArchResults
|
||||||
xdata = torch.load('/home/dxy/FOR-RELEASE/NAS-Projects/output/NAS-BENCH-102-4/simplifies/architectures/000157-FULL.pth')
|
xdata = torch.load('/home/dxy/FOR-RELEASE/NAS-Projects/output/NAS-BENCH-201-4/simplifies/architectures/000157-FULL.pth')
|
||||||
for key in ['full', 'less']:
|
for key in ['full', 'less']:
|
||||||
print ('\n------------------------- {:} -------------------------'.format(key))
|
print ('\n------------------------- {:} -------------------------'.format(key))
|
||||||
archRes = ArchResults.create_from_state_dict(xdata[key])
|
archRes = ArchResults.create_from_state_dict(xdata[key])
|
||||||
@ -81,8 +81,8 @@ def test_one_shot_model(ckpath, use_train):
|
|||||||
from config_utils import load_config, dict2config
|
from config_utils import load_config, dict2config
|
||||||
from utils.nas_utils import evaluate_one_shot
|
from utils.nas_utils import evaluate_one_shot
|
||||||
use_train = int(use_train) > 0
|
use_train = int(use_train) > 0
|
||||||
#ckpath = 'output/search-cell-nas-bench-102/DARTS-V1-cifar10/checkpoint/seed-11416-basic.pth'
|
#ckpath = 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/seed-11416-basic.pth'
|
||||||
#ckpath = 'output/search-cell-nas-bench-102/DARTS-V1-cifar10/checkpoint/seed-28640-basic.pth'
|
#ckpath = 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/seed-28640-basic.pth'
|
||||||
print ('ckpath : {:}'.format(ckpath))
|
print ('ckpath : {:}'.format(ckpath))
|
||||||
ckp = torch.load(ckpath)
|
ckp = torch.load(ckpath)
|
||||||
xargs = ckp['args']
|
xargs = ckp['args']
|
||||||
@ -103,7 +103,7 @@ def test_one_shot_model(ckpath, use_train):
|
|||||||
search_model = get_cell_based_tiny_net(model_config)
|
search_model = get_cell_based_tiny_net(model_config)
|
||||||
search_model.load_state_dict( ckp['search_model'] )
|
search_model.load_state_dict( ckp['search_model'] )
|
||||||
search_model = search_model.cuda()
|
search_model = search_model.cuda()
|
||||||
api = API('/home/dxy/.torch/NAS-Bench-102-v1_0-e61699.pth')
|
api = API('/home/dxy/.torch/NAS-Bench-201-v1_0-e61699.pth')
|
||||||
archs, probs, accuracies = evaluate_one_shot(search_model, valid_loader, api, use_train)
|
archs, probs, accuracies = evaluate_one_shot(search_model, valid_loader, api, use_train)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
##################################################
|
##################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
|
||||||
##################################################
|
##################################################
|
||||||
import os, sys, time, random, argparse
|
import os, sys, time, random, argparse
|
||||||
from .share_args import add_shared_args
|
from .share_args import add_shared_args
|
||||||
|
@ -19,7 +19,7 @@ def get_cell_based_tiny_net(config):
|
|||||||
super_type = getattr(config, 'super_type', 'basic')
|
super_type = getattr(config, 'super_type', 'basic')
|
||||||
group_names = ['DARTS-V1', 'DARTS-V2', 'GDAS', 'SETN', 'ENAS', 'RANDOM']
|
group_names = ['DARTS-V1', 'DARTS-V2', 'GDAS', 'SETN', 'ENAS', 'RANDOM']
|
||||||
if super_type == 'basic' and config.name in group_names:
|
if super_type == 'basic' and config.name in group_names:
|
||||||
from .cell_searchs import nas102_super_nets as nas_super_nets
|
from .cell_searchs import nas201_super_nets as nas_super_nets
|
||||||
try:
|
try:
|
||||||
return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space, config.affine, config.track_running_stats)
|
return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space, config.affine, config.track_running_stats)
|
||||||
except:
|
except:
|
||||||
|
@ -1,8 +1,13 @@
|
|||||||
|
##################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
|
||||||
|
##################################################
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from ..cell_operations import OPS
|
from ..cell_operations import OPS
|
||||||
|
|
||||||
|
|
||||||
|
# Cell for NAS-Bench-201
|
||||||
class InferCell(nn.Module):
|
class InferCell(nn.Module):
|
||||||
|
|
||||||
def __init__(self, genotype, C_in, C_out, stride):
|
def __init__(self, genotype, C_in, C_out, stride):
|
||||||
|
@ -1,9 +1,13 @@
|
|||||||
|
##################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
|
||||||
|
##################################################
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from ..cell_operations import ResNetBasicblock
|
from ..cell_operations import ResNetBasicblock
|
||||||
from .cells import InferCell
|
from .cells import InferCell
|
||||||
|
|
||||||
|
|
||||||
|
# The macro structure for architectures in NAS-Bench-201
|
||||||
class TinyNetwork(nn.Module):
|
class TinyNetwork(nn.Module):
|
||||||
|
|
||||||
def __init__(self, C, N, genotype, num_classes):
|
def __init__(self, C, N, genotype, num_classes):
|
||||||
|
@ -21,12 +21,11 @@ OPS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3']
|
CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3']
|
||||||
NAS_BENCH_102 = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
|
NAS_BENCH_201 = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
|
||||||
DARTS_SPACE = ['none', 'skip_connect', 'dua_sepc_3x3', 'dua_sepc_5x5', 'dil_sepc_3x3', 'dil_sepc_5x5', 'avg_pool_3x3', 'max_pool_3x3']
|
DARTS_SPACE = ['none', 'skip_connect', 'dua_sepc_3x3', 'dua_sepc_5x5', 'dil_sepc_3x3', 'dil_sepc_5x5', 'avg_pool_3x3', 'max_pool_3x3']
|
||||||
|
|
||||||
SearchSpaceNames = {'connect-nas' : CONNECT_NAS_BENCHMARK,
|
SearchSpaceNames = {'connect-nas' : CONNECT_NAS_BENCHMARK,
|
||||||
'aa-nas' : NAS_BENCH_102,
|
'nas-bench-201': NAS_BENCH_201,
|
||||||
'nas-bench-102': NAS_BENCH_102,
|
|
||||||
'darts' : DARTS_SPACE}
|
'darts' : DARTS_SPACE}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
##################################################
|
##################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||||
##################################################
|
##################################################
|
||||||
# The macro structure is defined in NAS-Bench-102
|
# The macro structure is defined in NAS-Bench-201
|
||||||
from .search_model_darts import TinyNetworkDarts
|
from .search_model_darts import TinyNetworkDarts
|
||||||
from .search_model_gdas import TinyNetworkGDAS
|
from .search_model_gdas import TinyNetworkGDAS
|
||||||
from .search_model_setn import TinyNetworkSETN
|
from .search_model_setn import TinyNetworkSETN
|
||||||
@ -12,7 +12,7 @@ from .genotypes import Structure as CellStructure, architectures as
|
|||||||
from .search_model_gdas_nasnet import NASNetworkGDAS
|
from .search_model_gdas_nasnet import NASNetworkGDAS
|
||||||
|
|
||||||
|
|
||||||
nas102_super_nets = {'DARTS-V1': TinyNetworkDarts,
|
nas201_super_nets = {'DARTS-V1': TinyNetworkDarts,
|
||||||
'DARTS-V2': TinyNetworkDarts,
|
'DARTS-V2': TinyNetworkDarts,
|
||||||
'GDAS' : TinyNetworkGDAS,
|
'GDAS' : TinyNetworkGDAS,
|
||||||
'SETN' : TinyNetworkSETN,
|
'SETN' : TinyNetworkSETN,
|
||||||
|
@ -9,11 +9,11 @@ from copy import deepcopy
|
|||||||
from ..cell_operations import OPS
|
from ..cell_operations import OPS
|
||||||
|
|
||||||
|
|
||||||
# This module is used for NAS-Bench-102, represents a small search space with a complete DAG
|
# This module is used for NAS-Bench-201, represents a small search space with a complete DAG
|
||||||
class NAS102SearchCell(nn.Module):
|
class NAS201SearchCell(nn.Module):
|
||||||
|
|
||||||
def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False, track_running_stats=True):
|
def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False, track_running_stats=True):
|
||||||
super(NAS102SearchCell, self).__init__()
|
super(NAS201SearchCell, self).__init__()
|
||||||
|
|
||||||
self.op_names = deepcopy(op_names)
|
self.op_names = deepcopy(op_names)
|
||||||
self.edges = nn.ModuleDict()
|
self.edges = nn.ModuleDict()
|
||||||
|
@ -7,7 +7,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from ..cell_operations import ResNetBasicblock
|
from ..cell_operations import ResNetBasicblock
|
||||||
from .search_cells import NAS102SearchCell as SearchCell
|
from .search_cells import NAS201SearchCell as SearchCell
|
||||||
from .genotypes import Structure
|
from .genotypes import Structure
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from ..cell_operations import ResNetBasicblock
|
from ..cell_operations import ResNetBasicblock
|
||||||
from .search_cells import NAS102SearchCell as SearchCell
|
from .search_cells import NAS201SearchCell as SearchCell
|
||||||
from .genotypes import Structure
|
from .genotypes import Structure
|
||||||
from .search_model_enas_utils import Controller
|
from .search_model_enas_utils import Controller
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from ..cell_operations import ResNetBasicblock
|
from ..cell_operations import ResNetBasicblock
|
||||||
from .search_cells import NAS102SearchCell as SearchCell
|
from .search_cells import NAS201SearchCell as SearchCell
|
||||||
from .genotypes import Structure
|
from .genotypes import Structure
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ import torch, random
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from ..cell_operations import ResNetBasicblock
|
from ..cell_operations import ResNetBasicblock
|
||||||
from .search_cells import NAS102SearchCell as SearchCell
|
from .search_cells import NAS201SearchCell as SearchCell
|
||||||
from .genotypes import Structure
|
from .genotypes import Structure
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ import torch, random
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from ..cell_operations import ResNetBasicblock
|
from ..cell_operations import ResNetBasicblock
|
||||||
from .search_cells import NAS102SearchCell as SearchCell
|
from .search_cells import NAS201SearchCell as SearchCell
|
||||||
from .genotypes import Structure
|
from .genotypes import Structure
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
##################################################
|
##################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||||
##################################################
|
##################################################
|
||||||
from .api import NASBench102API
|
from .api import NASBench201API
|
||||||
from .api import ArchResults, ResultsCount
|
from .api import ArchResults, ResultsCount
|
||||||
|
|
||||||
NAS_BENCH_102_API_VERSION="v1.0"
|
NAS_BENCH_201_API_VERSION="v1.0"
|
@ -1,9 +1,9 @@
|
|||||||
##################################################
|
##################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||||
############################################################################################
|
############################################################################################
|
||||||
# NAS-Bench-102: Extending the Scope of Reproducible Neural Architecture Search, ICLR 2020 #
|
# NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search, ICLR 2020 #
|
||||||
############################################################################################
|
############################################################################################
|
||||||
# NAS-Bench-102-v1_0-e61699.pth : 6219 architectures are trained once, 1621 architectures are trained twice, 7785 architectures are trained three times. `LESS` only supports CIFAR10-VALID.
|
# NAS-Bench-201-v1_0-e61699.pth : 6219 architectures are trained once, 1621 architectures are trained twice, 7785 architectures are trained three times. `LESS` only supports CIFAR10-VALID.
|
||||||
#
|
#
|
||||||
#
|
#
|
||||||
#
|
#
|
||||||
@ -38,11 +38,11 @@ def print_information(information, extra_info=None, show=False):
|
|||||||
return strings
|
return strings
|
||||||
|
|
||||||
|
|
||||||
class NASBench102API(object):
|
class NASBench201API(object):
|
||||||
|
|
||||||
def __init__(self, file_path_or_dict, verbose=True):
|
def __init__(self, file_path_or_dict, verbose=True):
|
||||||
if isinstance(file_path_or_dict, str):
|
if isinstance(file_path_or_dict, str):
|
||||||
if verbose: print('try to create the NAS-Bench-102 api from {:}'.format(file_path_or_dict))
|
if verbose: print('try to create the NAS-Bench-201 api from {:}'.format(file_path_or_dict))
|
||||||
assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict)
|
assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict)
|
||||||
file_path_or_dict = torch.load(file_path_or_dict)
|
file_path_or_dict = torch.load(file_path_or_dict)
|
||||||
elif isinstance(file_path_or_dict, dict):
|
elif isinstance(file_path_or_dict, dict):
|
@ -2,3 +2,4 @@
|
|||||||
from .CifarNet import NetworkCIFAR as CifarNet
|
from .CifarNet import NetworkCIFAR as CifarNet
|
||||||
from .ImageNet import NetworkImageNet as ImageNet
|
from .ImageNet import NetworkImageNet as ImageNet
|
||||||
from .genotypes import Networks
|
from .genotypes import Networks
|
||||||
|
from .genotypes import build_genotype_from_dict
|
||||||
|
@ -167,3 +167,6 @@ Networks = {'DARTS_V1': DARTS_V1,
|
|||||||
'PNASNet' : PNASNet,
|
'PNASNet' : PNASNet,
|
||||||
'SETN' : SETN,
|
'SETN' : SETN,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def build_genotype_from_dict(xdict):
|
||||||
|
import pdb; pdb.set_trace()
|
||||||
|
@ -1,6 +1,11 @@
|
|||||||
##################################################
|
##################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
|
||||||
##################################################
|
##################################################
|
||||||
|
# I write this package to make AutoDL-Projects to be compatible with the old GDAS projects.
|
||||||
|
# Ideally, this package will be merged into lib/models/cell_infers in future.
|
||||||
|
# Currently, this package is used to reproduce the results in GDAS (Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019).
|
||||||
|
##################################################
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
def obtain_nas_infer_model(config):
|
def obtain_nas_infer_model(config):
|
||||||
|
@ -14,10 +14,10 @@ OPS = {
|
|||||||
'skip_connect': lambda C_in, C_out, stride, affine: Identity(C_in, C_out, stride)
|
'skip_connect': lambda C_in, C_out, stride, affine: Identity(C_in, C_out, stride)
|
||||||
}
|
}
|
||||||
|
|
||||||
NAS_BENCH_102 = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
|
NAS_BENCH_201 = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
|
||||||
|
|
||||||
SearchSpaceNames = {
|
SearchSpaceNames = {
|
||||||
'nas-bench-102': NAS_BENCH_102,
|
'nas-bench-201': NAS_BENCH_201,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
# bash scripts-search/NAS-Bench-102/build.sh
|
# bash scripts-search/NAS-Bench-201/build.sh
|
||||||
echo script name: $0
|
echo script name: $0
|
||||||
echo $# arguments
|
echo $# arguments
|
||||||
if [ "$#" -ne 0 ] ;then
|
if [ "$#" -ne 0 ] ;then
|
||||||
@ -8,17 +8,17 @@ if [ "$#" -ne 0 ] ;then
|
|||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
save_dir=./output/nas_bench_102_package
|
save_dir=./output/nas_bench_201_package
|
||||||
echo "Prepare to build the package in ${save_dir}"
|
echo "Prepare to build the package in ${save_dir}"
|
||||||
rm -rf ${save_dir}
|
rm -rf ${save_dir}
|
||||||
mkdir -p ${save_dir}
|
mkdir -p ${save_dir}
|
||||||
|
|
||||||
#cp NAS-Bench-102.md ${save_dir}/README.md
|
#cp NAS-Bench-201.md ${save_dir}/README.md
|
||||||
sed '125,187d' NAS-Bench-102.md > ${save_dir}/README.md
|
sed '125,187d' NAS-Bench-201.md > ${save_dir}/README.md
|
||||||
cp LICENSE.md ${save_dir}/LICENSE.md
|
cp LICENSE.md ${save_dir}/LICENSE.md
|
||||||
cp -r lib/nas_102_api ${save_dir}/
|
cp -r lib/nas_201_api ${save_dir}/
|
||||||
rm -rf ${save_dir}/nas_102_api/__pycache__
|
rm -rf ${save_dir}/nas_201_api/__pycache__
|
||||||
cp exps/NAS-Bench-102/dist-setup.py ${save_dir}/setup.py
|
cp exps/NAS-Bench-201/dist-setup.py ${save_dir}/setup.py
|
||||||
|
|
||||||
cd ${save_dir}
|
cd ${save_dir}
|
||||||
# python setup.py sdist bdist_wheel
|
# python setup.py sdist bdist_wheel
|
@ -1,5 +1,5 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
# bash scripts-search/NAS-Bench-102/meta-gen.sh NAS-BENCH-102 4
|
# bash scripts-search/NAS-Bench-201/meta-gen.sh NAS-BENCH-201 4
|
||||||
echo script name: $0
|
echo script name: $0
|
||||||
echo $# arguments
|
echo $# arguments
|
||||||
if [ "$#" -ne 2 ] ;then
|
if [ "$#" -ne 2 ] ;then
|
||||||
@ -13,4 +13,4 @@ node=$2
|
|||||||
|
|
||||||
save_dir=./output/${name}-${node}
|
save_dir=./output/${name}-${node}
|
||||||
|
|
||||||
python ./exps/NAS-Bench-102/main.py --mode meta --save_dir ${save_dir} --max_node ${node}
|
python ./exps/NAS-Bench-201/main.py --mode meta --save_dir ${save_dir} --max_node ${node}
|
@ -1,5 +1,5 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
# bash ./scripts-search/NAS-Bench-102/train-a-net.sh resnet 16 5
|
# bash ./scripts-search/NAS-Bench-201/train-a-net.sh resnet 16 5
|
||||||
echo script name: $0
|
echo script name: $0
|
||||||
echo $# arguments
|
echo $# arguments
|
||||||
if [ "$#" -ne 3 ] ;then
|
if [ "$#" -ne 3 ] ;then
|
||||||
@ -18,9 +18,9 @@ model=$1
|
|||||||
channel=$2
|
channel=$2
|
||||||
num_cells=$3
|
num_cells=$3
|
||||||
|
|
||||||
save_dir=./output/NAS-BENCH-102-4/
|
save_dir=./output/NAS-BENCH-201-4/
|
||||||
|
|
||||||
OMP_NUM_THREADS=4 python ./exps/NAS-Bench-102/main.py \
|
OMP_NUM_THREADS=4 python ./exps/NAS-Bench-201/main.py \
|
||||||
--mode specific-${model} --save_dir ${save_dir} --max_node 4 \
|
--mode specific-${model} --save_dir ${save_dir} --max_node 4 \
|
||||||
--datasets cifar10 cifar10 cifar100 ImageNet16-120 \
|
--datasets cifar10 cifar10 cifar100 ImageNet16-120 \
|
||||||
--use_less 0 \
|
--use_less 0 \
|
@ -20,7 +20,7 @@ xend=$3
|
|||||||
arch_index=$4
|
arch_index=$4
|
||||||
all_seeds=$5
|
all_seeds=$5
|
||||||
|
|
||||||
save_dir=./output/NAS-BENCH-102-4/
|
save_dir=./output/NAS-BENCH-201-4/
|
||||||
|
|
||||||
if [ ${arch_index} == "-1" ]; then
|
if [ ${arch_index} == "-1" ]; then
|
||||||
mode=new
|
mode=new
|
||||||
@ -28,7 +28,7 @@ else
|
|||||||
mode=cover
|
mode=cover
|
||||||
fi
|
fi
|
||||||
|
|
||||||
OMP_NUM_THREADS=4 python ./exps/NAS-Bench-102/main.py \
|
OMP_NUM_THREADS=4 python ./exps/NAS-Bench-201/main.py \
|
||||||
--mode ${mode} --save_dir ${save_dir} --max_node 4 \
|
--mode ${mode} --save_dir ${save_dir} --max_node 4 \
|
||||||
--use_less ${use_less} \
|
--use_less ${use_less} \
|
||||||
--datasets cifar10 cifar10 cifar100 ImageNet16-120 \
|
--datasets cifar10 cifar10 cifar100 ImageNet16-120 \
|
@ -19,7 +19,7 @@ seed=$1
|
|||||||
channel=16
|
channel=16
|
||||||
num_cells=5
|
num_cells=5
|
||||||
max_nodes=4
|
max_nodes=4
|
||||||
space=nas-bench-102
|
space=nas-bench-201
|
||||||
|
|
||||||
save_dir=./output/search-cell-${space}/BOHB-${dataset}
|
save_dir=./output/search-cell-${space}/BOHB-${dataset}
|
||||||
|
|
||||||
@ -27,7 +27,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/BOHB.py \
|
|||||||
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
||||||
--dataset ${dataset} \
|
--dataset ${dataset} \
|
||||||
--search_space_name ${space} \
|
--search_space_name ${space} \
|
||||||
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \
|
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \
|
||||||
--time_budget 12000 \
|
--time_budget 12000 \
|
||||||
--n_iters 50 --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 \
|
--n_iters 50 --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 \
|
||||||
--workers 4 --print_freq 200 --rand_seed ${seed}
|
--workers 4 --print_freq 200 --rand_seed ${seed}
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
# bash ./scripts-search/algos/DARTS-V1.sh cifar10 -1
|
# bash ./scripts-search/algos/DARTS-V1.sh cifar10 0 -1
|
||||||
echo script name: $0
|
echo script name: $0
|
||||||
echo $# arguments
|
echo $# arguments
|
||||||
if [ "$#" -ne 2 ] ;then
|
if [ "$#" -ne 3 ] ;then
|
||||||
echo "Input illegal number of parameters " $#
|
echo "Input illegal number of parameters " $#
|
||||||
echo "Need 2 parameters for dataset and seed"
|
echo "Need 3 parameters for dataset, tracking_status, and seed"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
if [ "$TORCH_HOME" = "" ]; then
|
if [ "$TORCH_HOME" = "" ]; then
|
||||||
@ -15,11 +15,12 @@ else
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
dataset=$1
|
dataset=$1
|
||||||
seed=$2
|
BN=$2
|
||||||
|
seed=$3
|
||||||
channel=16
|
channel=16
|
||||||
num_cells=5
|
num_cells=5
|
||||||
max_nodes=4
|
max_nodes=4
|
||||||
space=nas-bench-102
|
space=nas-bench-201
|
||||||
|
|
||||||
if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
|
if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
|
||||||
data_path="$TORCH_HOME/cifar.python"
|
data_path="$TORCH_HOME/cifar.python"
|
||||||
@ -27,14 +28,14 @@ else
|
|||||||
data_path="$TORCH_HOME/cifar.python/ImageNet16"
|
data_path="$TORCH_HOME/cifar.python/ImageNet16"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
save_dir=./output/search-cell-${space}/DARTS-V1-${dataset}
|
save_dir=./output/search-cell-${space}/DARTS-V1-${dataset}-BN${BN}
|
||||||
|
|
||||||
OMP_NUM_THREADS=4 python ./exps/algos/DARTS-V1.py \
|
OMP_NUM_THREADS=4 python ./exps/algos/DARTS-V1.py \
|
||||||
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
||||||
--dataset ${dataset} --data_path ${data_path} \
|
--dataset ${dataset} --data_path ${data_path} \
|
||||||
--search_space_name ${space} \
|
--search_space_name ${space} \
|
||||||
--config_path configs/nas-benchmark/algos/DARTS.config \
|
--config_path configs/nas-benchmark/algos/DARTS.config \
|
||||||
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \
|
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \
|
||||||
--track_running_stats 1 \
|
--track_running_stats ${BN} \
|
||||||
--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \
|
--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \
|
||||||
--workers 4 --print_freq 200 --rand_seed ${seed}
|
--workers 4 --print_freq 200 --rand_seed ${seed}
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
# bash ./scripts-search/algos/DARTS-V2.sh cifar10 -1
|
# bash ./scripts-search/algos/DARTS-V2.sh cifar10 0 -1
|
||||||
echo script name: $0
|
echo script name: $0
|
||||||
echo $# arguments
|
echo $# arguments
|
||||||
if [ "$#" -ne 2 ] ;then
|
if [ "$#" -ne 3 ] ;then
|
||||||
echo "Input illegal number of parameters " $#
|
echo "Input illegal number of parameters " $#
|
||||||
echo "Need 2 parameters for dataset and seed"
|
echo "Need 3 parameters for dataset, tracking_status, and seed"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
if [ "$TORCH_HOME" = "" ]; then
|
if [ "$TORCH_HOME" = "" ]; then
|
||||||
@ -15,11 +15,12 @@ else
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
dataset=$1
|
dataset=$1
|
||||||
seed=$2
|
BN=$2
|
||||||
|
seed=$3
|
||||||
channel=16
|
channel=16
|
||||||
num_cells=5
|
num_cells=5
|
||||||
max_nodes=4
|
max_nodes=4
|
||||||
space=nas-bench-102
|
space=nas-bench-201
|
||||||
|
|
||||||
if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
|
if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
|
||||||
data_path="$TORCH_HOME/cifar.python"
|
data_path="$TORCH_HOME/cifar.python"
|
||||||
@ -27,14 +28,14 @@ else
|
|||||||
data_path="$TORCH_HOME/cifar.python/ImageNet16"
|
data_path="$TORCH_HOME/cifar.python/ImageNet16"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
save_dir=./output/search-cell-${space}/DARTS-V2-${dataset}
|
save_dir=./output/search-cell-${space}/DARTS-V2-${dataset}-BN${BN}
|
||||||
|
|
||||||
OMP_NUM_THREADS=4 python ./exps/algos/DARTS-V2.py \
|
OMP_NUM_THREADS=4 python ./exps/algos/DARTS-V2.py \
|
||||||
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
||||||
--dataset ${dataset} --data_path ${data_path} \
|
--dataset ${dataset} --data_path ${data_path} \
|
||||||
--search_space_name ${space} \
|
--search_space_name ${space} \
|
||||||
--config_path configs/nas-benchmark/algos/DARTS.config \
|
--config_path configs/nas-benchmark/algos/DARTS.config \
|
||||||
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \
|
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \
|
||||||
--track_running_stats 1 \
|
--track_running_stats ${BN} \
|
||||||
--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \
|
--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \
|
||||||
--workers 4 --print_freq 200 --rand_seed ${seed}
|
--workers 4 --print_freq 200 --rand_seed ${seed}
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
# Efficient Neural Architecture Search via Parameter Sharing, ICML 2018
|
# Efficient Neural Architecture Search via Parameter Sharing, ICML 2018
|
||||||
# bash ./scripts-search/scripts/algos/ENAS.sh cifar10 -1
|
# bash ./scripts-search/scripts/algos/ENAS.sh cifar10 0 -1
|
||||||
echo script name: $0
|
echo script name: $0
|
||||||
echo $# arguments
|
echo $# arguments
|
||||||
if [ "$#" -ne 2 ] ;then
|
if [ "$#" -ne 3 ] ;then
|
||||||
echo "Input illegal number of parameters " $#
|
echo "Input illegal number of parameters " $#
|
||||||
echo "Need 2 parameters for dataset and seed"
|
echo "Need 3 parameters for dataset, BN-tracking-status, and seed"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
if [ "$TORCH_HOME" = "" ]; then
|
if [ "$TORCH_HOME" = "" ]; then
|
||||||
@ -16,11 +16,12 @@ else
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
dataset=$1
|
dataset=$1
|
||||||
seed=$2
|
BN=$2
|
||||||
|
seed=$3
|
||||||
channel=16
|
channel=16
|
||||||
num_cells=5
|
num_cells=5
|
||||||
max_nodes=4
|
max_nodes=4
|
||||||
space=nas-bench-102
|
space=nas-bench-201
|
||||||
|
|
||||||
if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
|
if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
|
||||||
data_path="$TORCH_HOME/cifar.python"
|
data_path="$TORCH_HOME/cifar.python"
|
||||||
@ -28,14 +29,14 @@ else
|
|||||||
data_path="$TORCH_HOME/cifar.python/ImageNet16"
|
data_path="$TORCH_HOME/cifar.python/ImageNet16"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
save_dir=./output/search-cell-${space}/ENAS-${dataset}
|
save_dir=./output/search-cell-${space}/ENAS-${dataset}-BN${BN}
|
||||||
|
|
||||||
OMP_NUM_THREADS=4 python ./exps/algos/ENAS.py \
|
OMP_NUM_THREADS=4 python ./exps/algos/ENAS.py \
|
||||||
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
||||||
--dataset ${dataset} --data_path ${data_path} \
|
--dataset ${dataset} --data_path ${data_path} \
|
||||||
--search_space_name ${space} \
|
--search_space_name ${space} \
|
||||||
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \
|
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \
|
||||||
--track_running_stats 1 \
|
--track_running_stats ${BN} \
|
||||||
--config_path ./configs/nas-benchmark/algos/ENAS.config \
|
--config_path ./configs/nas-benchmark/algos/ENAS.config \
|
||||||
--controller_entropy_weight 0.0001 \
|
--controller_entropy_weight 0.0001 \
|
||||||
--controller_bl_dec 0.99 \
|
--controller_bl_dec 0.99 \
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
# bash ./scripts-search/algos/GDAS.sh cifar10 -1
|
# bash ./scripts-search/algos/GDAS.sh cifar10 0 -1
|
||||||
echo script name: $0
|
echo script name: $0
|
||||||
echo $# arguments
|
echo $# arguments
|
||||||
if [ "$#" -ne 2 ] ;then
|
if [ "$#" -ne 3 ] ;then
|
||||||
echo "Input illegal number of parameters " $#
|
echo "Input illegal number of parameters " $#
|
||||||
echo "Need 2 parameters for dataset and seed"
|
echo "Need 3 parameters for dataset, BN-tracking, and seed"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
if [ "$TORCH_HOME" = "" ]; then
|
if [ "$TORCH_HOME" = "" ]; then
|
||||||
@ -15,11 +15,12 @@ else
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
dataset=$1
|
dataset=$1
|
||||||
seed=$2
|
BN=$2
|
||||||
|
seed=$3
|
||||||
channel=16
|
channel=16
|
||||||
num_cells=5
|
num_cells=5
|
||||||
max_nodes=4
|
max_nodes=4
|
||||||
space=nas-bench-102
|
space=nas-bench-201
|
||||||
|
|
||||||
if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
|
if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
|
||||||
data_path="$TORCH_HOME/cifar.python"
|
data_path="$TORCH_HOME/cifar.python"
|
||||||
@ -27,14 +28,14 @@ else
|
|||||||
data_path="$TORCH_HOME/cifar.python/ImageNet16"
|
data_path="$TORCH_HOME/cifar.python/ImageNet16"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
save_dir=./output/search-cell-${space}/GDAS-${dataset}
|
save_dir=./output/search-cell-${space}/GDAS-${dataset}-BN${BN}
|
||||||
|
|
||||||
OMP_NUM_THREADS=4 python ./exps/algos/GDAS.py \
|
OMP_NUM_THREADS=4 python ./exps/algos/GDAS.py \
|
||||||
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
||||||
--dataset ${dataset} --data_path ${data_path} \
|
--dataset ${dataset} --data_path ${data_path} \
|
||||||
--search_space_name ${space} \
|
--search_space_name ${space} \
|
||||||
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \
|
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \
|
||||||
--config_path configs/nas-benchmark/algos/GDAS.config \
|
--config_path configs/nas-benchmark/algos/GDAS.config \
|
||||||
--tau_max 10 --tau_min 0.1 --track_running_stats 1 \
|
--tau_max 10 --tau_min 0.1 --track_running_stats ${BN} \
|
||||||
--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \
|
--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \
|
||||||
--workers 4 --print_freq 200 --rand_seed ${seed}
|
--workers 4 --print_freq 200 --rand_seed ${seed}
|
||||||
|
9
scripts-search/algos/GRID-RL.sh
Normal file
9
scripts-search/algos/GRID-RL.sh
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
echo script name: $0
|
||||||
|
|
||||||
|
lrs="0.01 0.02 0.1 0.2 0.5 1.0 1.5 2.0 2.5 3.0"
|
||||||
|
|
||||||
|
for lr in ${lrs}
|
||||||
|
do
|
||||||
|
bash ./scripts-search/algos/REINFORCE.sh ${lr} -1
|
||||||
|
done
|
@ -20,7 +20,7 @@ seed=$1
|
|||||||
channel=16
|
channel=16
|
||||||
num_cells=5
|
num_cells=5
|
||||||
max_nodes=4
|
max_nodes=4
|
||||||
space=nas-bench-102
|
space=nas-bench-201
|
||||||
|
|
||||||
save_dir=./output/search-cell-${space}/R-EA-${dataset}
|
save_dir=./output/search-cell-${space}/R-EA-${dataset}
|
||||||
|
|
||||||
@ -28,7 +28,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/R_EA.py \
|
|||||||
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
||||||
--dataset ${dataset} \
|
--dataset ${dataset} \
|
||||||
--search_space_name ${space} \
|
--search_space_name ${space} \
|
||||||
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \
|
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \
|
||||||
--time_budget 12000 \
|
--time_budget 12000 \
|
||||||
--ea_cycles 100 --ea_population 10 --ea_sample_size 3 --ea_fast_by_api 1 \
|
--ea_cycles 100 --ea_population 10 --ea_sample_size 3 --ea_fast_by_api 1 \
|
||||||
--workers 4 --print_freq 200 --rand_seed ${seed}
|
--workers 4 --print_freq 200 --rand_seed ${seed}
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
# Random Search and Reproducibility for Neural Architecture Search, UAI 2019
|
# Random Search and Reproducibility for Neural Architecture Search, UAI 2019
|
||||||
# bash ./scripts-search/algos/RANDOM-NAS.sh cifar10 -1
|
# bash ./scripts-search/algos/RANDOM-NAS.sh cifar10 0 -1
|
||||||
echo script name: $0
|
echo script name: $0
|
||||||
echo $# arguments
|
echo $# arguments
|
||||||
if [ "$#" -ne 2 ] ;then
|
if [ "$#" -ne 3 ] ;then
|
||||||
echo "Input illegal number of parameters " $#
|
echo "Input illegal number of parameters " $#
|
||||||
echo "Need 2 parameters for dataset and seed"
|
echo "Need 3 parameters for dataset, BN-tracking-status, and seed"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
if [ "$TORCH_HOME" = "" ]; then
|
if [ "$TORCH_HOME" = "" ]; then
|
||||||
@ -16,11 +16,12 @@ else
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
dataset=$1
|
dataset=$1
|
||||||
seed=$2
|
BN=$2
|
||||||
|
seed=$3
|
||||||
channel=16
|
channel=16
|
||||||
num_cells=5
|
num_cells=5
|
||||||
max_nodes=4
|
max_nodes=4
|
||||||
space=nas-bench-102
|
space=nas-bench-201
|
||||||
|
|
||||||
if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
|
if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
|
||||||
data_path="$TORCH_HOME/cifar.python"
|
data_path="$TORCH_HOME/cifar.python"
|
||||||
@ -28,14 +29,14 @@ else
|
|||||||
data_path="$TORCH_HOME/cifar.python/ImageNet16"
|
data_path="$TORCH_HOME/cifar.python/ImageNet16"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
save_dir=./output/search-cell-${space}/RANDOM-NAS-${dataset}
|
save_dir=./output/search-cell-${space}/RANDOM-NAS-${dataset}-BN${BN}
|
||||||
|
|
||||||
OMP_NUM_THREADS=4 python ./exps/algos/RANDOM-NAS.py \
|
OMP_NUM_THREADS=4 python ./exps/algos/RANDOM-NAS.py \
|
||||||
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
||||||
--dataset ${dataset} --data_path ${data_path} \
|
--dataset ${dataset} --data_path ${data_path} \
|
||||||
--search_space_name ${space} \
|
--search_space_name ${space} \
|
||||||
--track_running_stats 1 \
|
--track_running_stats ${BN} \
|
||||||
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \
|
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \
|
||||||
--config_path ./configs/nas-benchmark/algos/RANDOM.config \
|
--config_path ./configs/nas-benchmark/algos/RANDOM.config \
|
||||||
--select_num 100 \
|
--select_num 100 \
|
||||||
--workers 4 --print_freq 200 --rand_seed ${seed}
|
--workers 4 --print_freq 200 --rand_seed ${seed}
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
# bash ./scripts-search/algos/REINFORCE.sh -1
|
# bash ./scripts-search/algos/REINFORCE.sh 0.001 -1
|
||||||
echo script name: $0
|
echo script name: $0
|
||||||
echo $# arguments
|
echo $# arguments
|
||||||
if [ "$#" -ne 1 ] ;then
|
if [ "$#" -ne 2 ] ;then
|
||||||
echo "Input illegal number of parameters " $#
|
echo "Input illegal number of parameters " $#
|
||||||
echo "Need 1 parameters for seed"
|
echo "Need 2 parameters for LR and seed"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
if [ "$TORCH_HOME" = "" ]; then
|
if [ "$TORCH_HOME" = "" ]; then
|
||||||
@ -15,19 +15,20 @@ else
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
dataset=cifar10
|
dataset=cifar10
|
||||||
seed=$1
|
LR=$1
|
||||||
|
seed=$2
|
||||||
channel=16
|
channel=16
|
||||||
num_cells=5
|
num_cells=5
|
||||||
max_nodes=4
|
max_nodes=4
|
||||||
space=nas-bench-102
|
space=nas-bench-201
|
||||||
|
|
||||||
save_dir=./output/search-cell-${space}/REINFORCE-${dataset}
|
save_dir=./output/search-cell-${space}/REINFORCE-${dataset}-${LR}
|
||||||
|
|
||||||
OMP_NUM_THREADS=4 python ./exps/algos/reinforce.py \
|
OMP_NUM_THREADS=4 python ./exps/algos/reinforce.py \
|
||||||
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
||||||
--dataset ${dataset} \
|
--dataset ${dataset} \
|
||||||
--search_space_name ${space} \
|
--search_space_name ${space} \
|
||||||
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \
|
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \
|
||||||
--time_budget 12000 \
|
--time_budget 12000 \
|
||||||
--learning_rate 0.001 --EMA_momentum 0.9 \
|
--learning_rate ${LR} --EMA_momentum 0.9 \
|
||||||
--workers 4 --print_freq 200 --rand_seed ${seed}
|
--workers 4 --print_freq 200 --rand_seed ${seed}
|
||||||
|
@ -19,7 +19,7 @@ seed=$1
|
|||||||
channel=16
|
channel=16
|
||||||
num_cells=5
|
num_cells=5
|
||||||
max_nodes=4
|
max_nodes=4
|
||||||
space=nas-bench-102
|
space=nas-bench-201
|
||||||
|
|
||||||
save_dir=./output/search-cell-${space}/RAND-${dataset}
|
save_dir=./output/search-cell-${space}/RAND-${dataset}
|
||||||
|
|
||||||
@ -27,7 +27,7 @@ OMP_NUM_THREADS=4 python ./exps/algos/RANDOM.py \
|
|||||||
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
||||||
--dataset ${dataset} \
|
--dataset ${dataset} \
|
||||||
--search_space_name ${space} \
|
--search_space_name ${space} \
|
||||||
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \
|
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \
|
||||||
--time_budget 12000 \
|
--time_budget 12000 \
|
||||||
--workers 4 --print_freq 200 --rand_seed ${seed}
|
--workers 4 --print_freq 200 --rand_seed ${seed}
|
||||||
# --random_num 100 \
|
# --random_num 100 \
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
# One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019
|
# One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019
|
||||||
# bash ./scripts-search/scripts/algos/SETN.sh cifar10 -1
|
# bash ./scripts-search/scripts/algos/SETN.sh cifar10 0 -1
|
||||||
echo script name: $0
|
echo script name: $0
|
||||||
echo $# arguments
|
echo $# arguments
|
||||||
if [ "$#" -ne 2 ] ;then
|
if [ "$#" -ne 3 ] ;then
|
||||||
echo "Input illegal number of parameters " $#
|
echo "Input illegal number of parameters " $#
|
||||||
echo "Need 2 parameters for dataset and seed"
|
echo "Need 3 parameters for dataset, BN-tracking-status, and seed"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
if [ "$TORCH_HOME" = "" ]; then
|
if [ "$TORCH_HOME" = "" ]; then
|
||||||
@ -16,11 +16,12 @@ else
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
dataset=$1
|
dataset=$1
|
||||||
seed=$2
|
BN=$2
|
||||||
|
seed=$3
|
||||||
channel=16
|
channel=16
|
||||||
num_cells=5
|
num_cells=5
|
||||||
max_nodes=4
|
max_nodes=4
|
||||||
space=nas-bench-102
|
space=nas-bench-201
|
||||||
|
|
||||||
if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
|
if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
|
||||||
data_path="$TORCH_HOME/cifar.python"
|
data_path="$TORCH_HOME/cifar.python"
|
||||||
@ -28,15 +29,15 @@ else
|
|||||||
data_path="$TORCH_HOME/cifar.python/ImageNet16"
|
data_path="$TORCH_HOME/cifar.python/ImageNet16"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
save_dir=./output/search-cell-${space}/SETN-${dataset}
|
save_dir=./output/search-cell-${space}/SETN-${dataset}-BN${BN}
|
||||||
|
|
||||||
OMP_NUM_THREADS=4 python ./exps/algos/SETN.py \
|
OMP_NUM_THREADS=4 python ./exps/algos/SETN.py \
|
||||||
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
||||||
--dataset ${dataset} --data_path ${data_path} \
|
--dataset ${dataset} --data_path ${data_path} \
|
||||||
--search_space_name ${space} \
|
--search_space_name ${space} \
|
||||||
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-102-v1_0-e61699.pth \
|
--arch_nas_dataset ${TORCH_HOME}/NAS-Bench-201-v1_0-e61699.pth \
|
||||||
--config_path configs/nas-benchmark/algos/SETN.config \
|
--config_path configs/nas-benchmark/algos/SETN.config \
|
||||||
--track_running_stats 1 \
|
--track_running_stats ${BN} \
|
||||||
--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \
|
--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \
|
||||||
--select_num 100 \
|
--select_num 100 \
|
||||||
--workers 4 --print_freq 200 --rand_seed ${seed}
|
--workers 4 --print_freq 200 --rand_seed ${seed}
|
||||||
|
@ -30,6 +30,7 @@ elif [ ${dataset} == 'imagenet-1k' ]; then
|
|||||||
workers=28
|
workers=28
|
||||||
cutout_length=-1
|
cutout_length=-1
|
||||||
else
|
else
|
||||||
|
exit 1
|
||||||
echo 'Unknown dataset: '${dataset}
|
echo 'Unknown dataset: '${dataset}
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
@ -22,30 +22,44 @@ batch=256
|
|||||||
|
|
||||||
save_dir=./output/search-shape/TAS-INFER-${dataset}-${model}
|
save_dir=./output/search-shape/TAS-INFER-${dataset}-${model}
|
||||||
|
|
||||||
|
if [ ${dataset} == 'cifar10' ] || [ ${dataset} == 'cifar100' ]; then
|
||||||
|
xpath=$TORCH_HOME/cifar.python
|
||||||
|
opt_config=./configs/opts/CIFAR-E300-W5-L1-COS.config
|
||||||
|
workers=4
|
||||||
|
elif [ ${dataset} == 'imagenet-1k' ]; then
|
||||||
|
xpath=$TORCH_HOME/ILSVRC2012
|
||||||
|
#opt_config=./configs/opts/ImageNet-E120-Cos-Smooth.config
|
||||||
|
opt_config=./configs/opts/RImageNet-E120-Cos-Soft.config
|
||||||
|
workers=28
|
||||||
|
else
|
||||||
|
echo 'Unknown dataset: '${dataset}
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
python --version
|
python --version
|
||||||
|
|
||||||
# normal training
|
# normal training
|
||||||
xsave_dir=${save_dir}-NMT
|
xsave_dir=${save_dir}-NMT
|
||||||
OMP_NUM_THREADS=4 python ./exps/basic-main.py --dataset ${dataset} \
|
OMP_NUM_THREADS=4 python ./exps/basic-main.py --dataset ${dataset} \
|
||||||
--data_path $TORCH_HOME/cifar.python \
|
--data_path ${xpath} \
|
||||||
--model_config ./configs/NeurIPS-2019/${model}.config \
|
--model_config ./configs/NeurIPS-2019/${model}.config \
|
||||||
--optim_config ./configs/opts/CIFAR-E300-W5-L1-COS.config \
|
--optim_config ${opt_config} \
|
||||||
--procedure basic \
|
--procedure basic \
|
||||||
--save_dir ${xsave_dir} \
|
--save_dir ${xsave_dir} \
|
||||||
--cutout_length -1 \
|
--cutout_length -1 \
|
||||||
--batch_size ${batch} --rand_seed ${rseed} --workers 6 \
|
--batch_size ${batch} --rand_seed ${rseed} --workers ${workers} \
|
||||||
--eval_frequency 1 --print_freq 100 --print_freq_eval 200
|
--eval_frequency 1 --print_freq 100 --print_freq_eval 200
|
||||||
|
|
||||||
# KD training
|
# KD training
|
||||||
xsave_dir=${save_dir}-KDT
|
xsave_dir=${save_dir}-KDT
|
||||||
OMP_NUM_THREADS=4 python ./exps/KD-main.py --dataset ${dataset} \
|
OMP_NUM_THREADS=4 python ./exps/KD-main.py --dataset ${dataset} \
|
||||||
--data_path $TORCH_HOME/cifar.python \
|
--data_path ${xpath} \
|
||||||
--model_config ./configs/NeurIPS-2019/${model}.config \
|
--model_config ./configs/NeurIPS-2019/${model}.config \
|
||||||
--optim_config ./configs/opts/CIFAR-E300-W5-L1-COS.config \
|
--optim_config ${opt_config} \
|
||||||
--KD_checkpoint ./.latent-data/basemodels/${dataset}/${model}.pth \
|
--KD_checkpoint ./.latent-data/basemodels/${dataset}/${model}.pth \
|
||||||
--procedure Simple-KD \
|
--procedure Simple-KD \
|
||||||
--save_dir ${xsave_dir} \
|
--save_dir ${xsave_dir} \
|
||||||
--KD_alpha 0.9 --KD_temperature 4 \
|
--KD_alpha 0.9 --KD_temperature 4 \
|
||||||
--cutout_length -1 \
|
--cutout_length -1 \
|
||||||
--batch_size ${batch} --rand_seed ${rseed} --workers 6 \
|
--batch_size ${batch} --rand_seed ${rseed} --workers ${workers} \
|
||||||
--eval_frequency 1 --print_freq 100 --print_freq_eval 200
|
--eval_frequency 1 --print_freq 100 --print_freq_eval 200
|
||||||
|
Loading…
Reference in New Issue
Block a user