update codes

This commit is contained in:
D-X-Y 2019-02-02 22:27:36 +11:00
parent 47327dc5a2
commit d6e9078568
3 changed files with 13 additions and 5 deletions

View File

@ -22,8 +22,9 @@ bash ./scripts-cnn/search-acc-v2.sh 3 acc2
Train the searched CNN on CIFAR
```
bash ./scripts-cnn/train-cifar.sh 0 GDAS_F1 cifar10
bash ./scripts-cnn/train-cifar.sh 0 GDAS_V1 cifar100
bash ./scripts-cnn/train-cifar.sh 0 GDAS_FG cifar10 cut
bash ./scripts-cnn/train-cifar.sh 0 GDAS_F1 cifar10 cut
bash ./scripts-cnn/train-cifar.sh 0 GDAS_V1 cifar100 cut
```
Train the searched CNN on ImageNet

View File

@ -236,7 +236,6 @@ def train(train_queue, valid_queue, model, criterion, base_optimizer, arch_optim
#inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True)
targets = targets.cuda(non_blocking=True)
data_time.update(time.time() - end)
# get a random minibatch from the search queue with replacement
try:
@ -246,6 +245,7 @@ def train(train_queue, valid_queue, model, criterion, base_optimizer, arch_optim
input_search, target_search = next(valid_iter)
target_search = target_search.cuda(non_blocking=True)
data_time.update(time.time() - end)
# update the architecture
arch_optimizer.zero_grad()

View File

@ -195,12 +195,18 @@ GDAS_F1 = Genotype(
)
# Combine DMS_V1 and DMS_F1
GDAS_CC = Genotype(
GDAS_GF = Genotype(
normal=[('skip_connect', 0, 0.13017432391643524), ('skip_connect', 1, 0.12947972118854523), ('skip_connect', 0, 0.13062666356563568), ('sep_conv_5x5', 2, 0.12980839610099792), ('sep_conv_3x3', 3, 0.12923765182495117), ('skip_connect', 0, 0.12901571393013), ('sep_conv_5x5', 4, 0.12938997149467468), ('sep_conv_3x3', 3, 0.1289220005273819)],
normal_concat=range(2, 6),
reduce=None,
reduce_concat=range(2, 6)
)
GDAS_FG = Genotype(
normal=[('skip_connect', 0, 0.16), ('skip_connect', 1, 0.13), ('skip_connect', 0, 0.17), ('sep_conv_3x3', 2, 0.15), ('skip_connect', 0, 0.17), ('sep_conv_3x3', 2, 0.15), ('skip_connect', 0, 0.16), ('sep_conv_3x3', 2, 0.15)],
normal_concat=range(2, 6),
reduce=[('sep_conv_5x5', 0, 0.12862831354141235), ('sep_conv_3x3', 1, 0.12783904373645782), ('sep_conv_5x5', 2, 0.12725995481014252), ('sep_conv_5x5', 1, 0.12705285847187042), ('dil_conv_5x5', 2, 0.12797553837299347), ('sep_conv_3x3', 1, 0.12737272679805756), ('sep_conv_5x5', 0, 0.12833961844444275), ('sep_conv_5x5', 1, 0.12758426368236542)],
reduce_concat=range(2, 6)
)
model_types = {'DARTS_V1': DARTS_V1,
'DARTS_V2': DARTS_V2,
@ -210,4 +216,5 @@ model_types = {'DARTS_V1': DARTS_V1,
'ENASNet' : ENASNet,
'GDAS_V1' : GDAS_V1,
'GDAS_F1' : GDAS_F1,
'GDAS_CC' : GDAS_CC}
'GDAS_GF' : GDAS_GF,
'GDAS_FG' : GDAS_FG}