fix small bugs in DARTS-V1 for NASNet-Space
This commit is contained in:
parent
db2760c260
commit
c2ff845d1b
@ -4,17 +4,22 @@ DARTS: Differentiable Architecture Search is accepted by ICLR 2019.
|
|||||||
In this paper, Hanxiao proposed a differentiable neural architecture search method, named as DARTS.
|
In this paper, Hanxiao proposed a differentiable neural architecture search method, named as DARTS.
|
||||||
Recently, DARTS becomes very popular due to its simplicity and performance.
|
Recently, DARTS becomes very popular due to its simplicity and performance.
|
||||||
|
|
||||||
**Run DARTS on the NAS-Bench-201 search space**:
|
## Run DARTS on the NAS-Bench-201 search space
|
||||||
```
|
```
|
||||||
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V2.sh cifar10 1 -1
|
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V2.sh cifar10 1 -1
|
||||||
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/GDAS.sh cifar10 1 -1
|
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/GDAS.sh cifar10 1 -1
|
||||||
```
|
```
|
||||||
|
|
||||||
**Run the first-order DARTS on the NASNet search space**:
|
## Run the first-order DARTS on the NASNet/DARTS search space
|
||||||
|
This command will start to use the first-order DARTS to search architectures on the DARTS search space.
|
||||||
```
|
```
|
||||||
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/DARTS1V-search-NASNet-space.sh cifar10 -1
|
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/DARTS1V-search-NASNet-space.sh cifar10 -1
|
||||||
```
|
```
|
||||||
|
|
||||||
|
After searching, if you want to train the searched architecture found by the above scripts, you need to add the config of that architecture (will be printed in log) in [genotypes.py](https://github.com/D-X-Y/AutoDL-Projects/blob/master/lib/nas_infer_model/DXYs/genotypes.py).
|
||||||
|
In future, I will add a more eligent way to train the searched architecture from the DARTS search space.
|
||||||
|
|
||||||
|
|
||||||
# Citation
|
# Citation
|
||||||
|
|
||||||
```
|
```
|
||||||
|
@ -199,7 +199,8 @@ def main(xargs):
|
|||||||
logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1))
|
logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1))
|
||||||
copy_checkpoint(model_base_path, model_best_path, logger)
|
copy_checkpoint(model_base_path, model_best_path, logger)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
|
#logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
|
||||||
|
logger.log('{:}'.format(search_model.show_alphas()))
|
||||||
if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
|
if api is not None: logger.log('{:}'.format(api.query_by_arch( genotypes[epoch] )))
|
||||||
# measure elapsed time
|
# measure elapsed time
|
||||||
epoch_time.update(time.time() - start_time)
|
epoch_time.update(time.time() - start_time)
|
||||||
|
@ -53,6 +53,10 @@ class TinyNetworkDarts(nn.Module):
|
|||||||
def get_alphas(self):
|
def get_alphas(self):
|
||||||
return [self.arch_parameters]
|
return [self.arch_parameters]
|
||||||
|
|
||||||
|
def show_alphas(self):
|
||||||
|
with torch.no_grad():
|
||||||
|
return 'arch-parameters :\n{:}'.format( nn.functional.softmax(self.arch_parameters, dim=-1).cpu() )
|
||||||
|
|
||||||
def get_message(self):
|
def get_message(self):
|
||||||
string = self.extra_repr()
|
string = self.extra_repr()
|
||||||
for i, cell in enumerate(self.cells):
|
for i, cell in enumerate(self.cells):
|
||||||
|
Loading…
Reference in New Issue
Block a user