Searching for A Robust Neural Architecture in Four GPU Hours is accepted at CVPR 2019.
In this paper, we proposed a Gradient-based searching algorithm using Differentiable Architecture Sampling (GDAS).
GDAS is baseed on DARTS and improves it with Gumbel-softmax sampling.
Concurrently at the submission period, several NAS papers (SNAS and FBNet) also utilized Gumbel-softmax sampling. We are different at how to forward and backward, see more details in our paper and codes.
Experiments on CIFAR-10, CIFAR-100, ImageNet, PTB, and WT2 are reported.
## Requirements and Preparation
Please install `Python>=3.6` and `PyTorch>=1.2.0`.
CIFAR and ImageNet should be downloaded and extracted into `$TORCH_HOME`.
### Usefull tools
1. Compute the number of parameters and FLOPs of a model:
```
from utils import get_model_infos
flop, param = get_model_infos(net, (1,3,32,32))
```
2. Different NAS-searched architectures are defined [here](https://github.com/D-X-Y/AutoDL-Projects/blob/master/lib/nas_infer_model/DXYs/genotypes.py).
## Usage
### Reproducing the results of our searched architecture in GDAS
Please use the following scripts to train the searched GDAS-searched CNN on CIFAR-10, CIFAR-100, and ImageNet.
If you are interested in the configs of each NAS-searched architecture, they are defined at [genotypes.py](https://github.com/D-X-Y/AutoDL-Projects/blob/master/lib/nas_infer_model/DXYs/genotypes.py).
### Searching on the NASNet search space
Please use the following scripts to use GDAS to search as in the original paper:
Note that `gdas-searched` is a string to indicate the name of the saved dir and `output/search-cell-darts/GDAS-cifar10-BN1/checkpoint/seed-945-basic.pth` is the file path that the searching algorithm generated.
The above script does not apply heavy augmentation to train the model, so the accuracy will be lower than the original paper.
If you want to change the default hyper-parameter for re-training, please have a look at `./scripts/retrain-searched-net.sh` and `configs/archs/NAS-*-none.config`.
`|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|skip_connect~1|skip_connect~2|` represents the structure of a searched architecture. My codes will automatically print it during the searching procedure.