diff --git a/docs/ICLR-2019-DARTS.md b/docs/ICLR-2019-DARTS.md index 9b321ce..a9a6bfe 100644 --- a/docs/ICLR-2019-DARTS.md +++ b/docs/ICLR-2019-DARTS.md @@ -4,17 +4,22 @@ DARTS: Differentiable Architecture Search is accepted by ICLR 2019. In this paper, Hanxiao proposed a differentiable neural architecture search method, named as DARTS. Recently, DARTS becomes very popular due to its simplicity and performance. -**Run DARTS on the NAS-Bench-201 search space**: +## Run DARTS on the NAS-Bench-201 search space ``` CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V2.sh cifar10 1 -1 CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/GDAS.sh cifar10 1 -1 ``` -**Run the first-order DARTS on the NASNet search space**: +## Run the first-order DARTS on the NASNet/DARTS search space +This command will start to use the first-order DARTS to search architectures on the DARTS search space. ``` CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/DARTS1V-search-NASNet-space.sh cifar10 -1 ``` +After searching, if you want to train the searched architecture found by the above scripts, you need to add the config of that architecture (will be printed in log) in [genotypes.py](https://github.com/D-X-Y/AutoDL-Projects/blob/master/lib/nas_infer_model/DXYs/genotypes.py). +In future, I will add a more eligent way to train the searched architecture from the DARTS search space. + + # Citation ``` diff --git a/exps/algos/DARTS-V1.py b/exps/algos/DARTS-V1.py index f0369ce..b06f779 100644 --- a/exps/algos/DARTS-V1.py +++ b/exps/algos/DARTS-V1.py @@ -199,7 +199,8 @@ def main(xargs): logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1)) copy_checkpoint(model_base_path, model_best_path, logger) with torch.no_grad(): - logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() )) + #logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() )) + logger.log('{:}'.format(search_model.show_alphas())) if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] ))) # measure elapsed time epoch_time.update(time.time() - start_time) diff --git a/lib/models/cell_searchs/search_model_darts.py b/lib/models/cell_searchs/search_model_darts.py index fd6f4cf..e7e61a7 100644 --- a/lib/models/cell_searchs/search_model_darts.py +++ b/lib/models/cell_searchs/search_model_darts.py @@ -53,6 +53,10 @@ class TinyNetworkDarts(nn.Module): def get_alphas(self): return [self.arch_parameters] + def show_alphas(self): + with torch.no_grad(): + return 'arch-parameters :\n{:}'.format( nn.functional.softmax(self.arch_parameters, dim=-1).cpu() ) + def get_message(self): string = self.extra_repr() for i, cell in enumerate(self.cells):