update GDAS and SETN
This commit is contained in:
parent
7f13385f28
commit
a84f483882
16
README.md
16
README.md
@ -65,7 +65,10 @@ CUDA_VISIBLE_DEVICES=0 bash ./scripts/nas-infer-train.sh cifar100 SETN 96 -1
|
|||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./scripts/nas-infer-train.sh imagenet-1k SETN 256 -1
|
CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./scripts/nas-infer-train.sh imagenet-1k SETN 256 -1
|
||||||
```
|
```
|
||||||
|
|
||||||
Searching codes come soon!
|
The searching codes of SETN on a small search space:
|
||||||
|
```
|
||||||
|
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/SETN.sh cifar10 -1
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
## [Searching for A Robust Neural Architecture in Four GPU Hours](https://arxiv.org/abs/1910.04465)
|
## [Searching for A Robust Neural Architecture in Four GPU Hours](https://arxiv.org/abs/1910.04465)
|
||||||
@ -88,7 +91,16 @@ CUDA_VISIBLE_DEVICES=0 bash ./scripts/nas-infer-train.sh cifar100 GDAS_V1 96 -1
|
|||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./scripts/nas-infer-train.sh imagenet-1k GDAS_V1 256 -1
|
CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./scripts/nas-infer-train.sh imagenet-1k GDAS_V1 256 -1
|
||||||
```
|
```
|
||||||
|
|
||||||
Searching codes come soon! A small example forward code segment for searching can be found in [this issue](https://github.com/D-X-Y/NAS-Projects/issues/12).
|
The GDAS searching codes on a small search space:
|
||||||
|
```
|
||||||
|
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/GDAS.sh cifar10 -1
|
||||||
|
```
|
||||||
|
|
||||||
|
The baseline searching codes are DARTS:
|
||||||
|
```
|
||||||
|
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V1.sh cifar10 -1
|
||||||
|
CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/algos/DARTS-V2.sh cifar10 -1
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
## [Auto-ReID: Searching for a Part-Aware ConvNet for Person Re-Identification](https://arxiv.org/abs/1903.09776)
|
## [Auto-ReID: Searching for a Part-Aware ConvNet for Person Re-Identification](https://arxiv.org/abs/1903.09776)
|
||||||
|
13
configs/nas-benchmark/CIFAR.config
Normal file
13
configs/nas-benchmark/CIFAR.config
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
{
|
||||||
|
"scheduler": ["str", "cos"],
|
||||||
|
"eta_min" : ["float", "0.0"],
|
||||||
|
"epochs" : ["int", "200"],
|
||||||
|
"warmup" : ["int", "0"],
|
||||||
|
"optim" : ["str", "SGD"],
|
||||||
|
"LR" : ["float", "0.1"],
|
||||||
|
"decay" : ["float", "0.0005"],
|
||||||
|
"momentum" : ["float", "0.9"],
|
||||||
|
"nesterov" : ["bool", "1"],
|
||||||
|
"criterion": ["str", "Softmax"],
|
||||||
|
"batch_size": ["int", "256"]
|
||||||
|
}
|
13
configs/nas-benchmark/ImageNet-16.config
Normal file
13
configs/nas-benchmark/ImageNet-16.config
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
{
|
||||||
|
"scheduler": ["str", "cos"],
|
||||||
|
"eta_min" : ["float", "0.0"],
|
||||||
|
"epochs" : ["int", "200"],
|
||||||
|
"warmup" : ["int", "0"],
|
||||||
|
"optim" : ["str", "SGD"],
|
||||||
|
"LR" : ["float", "0.1"],
|
||||||
|
"decay" : ["float", "0.0005"],
|
||||||
|
"momentum" : ["float", "0.9"],
|
||||||
|
"nesterov" : ["bool", "1"],
|
||||||
|
"criterion": ["str", "Softmax"],
|
||||||
|
"batch_size": ["int", "256"]
|
||||||
|
}
|
4
configs/nas-benchmark/ImageNet16-120-split.txt
Normal file
4
configs/nas-benchmark/ImageNet16-120-split.txt
Normal file
File diff suppressed because one or more lines are too long
13
configs/nas-benchmark/algos/DARTS.config
Normal file
13
configs/nas-benchmark/algos/DARTS.config
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
{
|
||||||
|
"scheduler": ["str", "cos"],
|
||||||
|
"eta_min" : ["float", "0.001"],
|
||||||
|
"epochs" : ["int", "50"],
|
||||||
|
"warmup" : ["int", "0"],
|
||||||
|
"optim" : ["str", "SGD"],
|
||||||
|
"LR" : ["float", "0.025"],
|
||||||
|
"decay" : ["float", "0.0005"],
|
||||||
|
"momentum" : ["float", "0.9"],
|
||||||
|
"nesterov" : ["bool", "1"],
|
||||||
|
"criterion": ["str", "Softmax"],
|
||||||
|
"batch_size": ["int", "64"]
|
||||||
|
}
|
13
configs/nas-benchmark/algos/GDAS-noacc.config
Normal file
13
configs/nas-benchmark/algos/GDAS-noacc.config
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
{
|
||||||
|
"scheduler": ["str", "cos"],
|
||||||
|
"eta_min" : ["float", "0.001"],
|
||||||
|
"epochs" : ["int", "50"],
|
||||||
|
"warmup" : ["int", "0"],
|
||||||
|
"optim" : ["str", "SGD"],
|
||||||
|
"LR" : ["float", "0.025"],
|
||||||
|
"decay" : ["float", "0.0005"],
|
||||||
|
"momentum" : ["float", "0.9"],
|
||||||
|
"nesterov" : ["bool", "1"],
|
||||||
|
"criterion": ["str", "Softmax"],
|
||||||
|
"batch_size": ["int", "64"]
|
||||||
|
}
|
13
configs/nas-benchmark/algos/GDAS.config
Normal file
13
configs/nas-benchmark/algos/GDAS.config
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
{
|
||||||
|
"scheduler": ["str", "cos"],
|
||||||
|
"eta_min" : ["float", "0.001"],
|
||||||
|
"epochs" : ["int", "240"],
|
||||||
|
"warmup" : ["int", "0"],
|
||||||
|
"optim" : ["str", "SGD"],
|
||||||
|
"LR" : ["float", "0.025"],
|
||||||
|
"decay" : ["float", "0.0005"],
|
||||||
|
"momentum" : ["float", "0.9"],
|
||||||
|
"nesterov" : ["bool", "1"],
|
||||||
|
"criterion": ["str", "Softmax"],
|
||||||
|
"batch_size": ["int", "64"]
|
||||||
|
}
|
13
configs/nas-benchmark/algos/R-EA.config
Normal file
13
configs/nas-benchmark/algos/R-EA.config
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
{
|
||||||
|
"scheduler": ["str", "cos"],
|
||||||
|
"eta_min" : ["float", "0.001"],
|
||||||
|
"epochs" : ["int", "25"],
|
||||||
|
"warmup" : ["int", "0"],
|
||||||
|
"optim" : ["str", "SGD"],
|
||||||
|
"LR" : ["float", "0.025"],
|
||||||
|
"decay" : ["float", "0.0005"],
|
||||||
|
"momentum" : ["float", "0.9"],
|
||||||
|
"nesterov" : ["bool", "1"],
|
||||||
|
"criterion": ["str", "Softmax"],
|
||||||
|
"batch_size": ["int", "64"]
|
||||||
|
}
|
13
configs/nas-benchmark/algos/RANDOM.config
Normal file
13
configs/nas-benchmark/algos/RANDOM.config
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
{
|
||||||
|
"scheduler": ["str", "cos"],
|
||||||
|
"eta_min" : ["float", "0.001"],
|
||||||
|
"epochs" : ["int", "150"],
|
||||||
|
"warmup" : ["int", "0"],
|
||||||
|
"optim" : ["str", "SGD"],
|
||||||
|
"LR" : ["float", "0.025"],
|
||||||
|
"decay" : ["float", "0.0005"],
|
||||||
|
"momentum" : ["float", "0.9"],
|
||||||
|
"nesterov" : ["bool", "1"],
|
||||||
|
"criterion": ["str", "Softmax"],
|
||||||
|
"batch_size": ["int", "64"]
|
||||||
|
}
|
13
configs/nas-benchmark/algos/SETN.config
Normal file
13
configs/nas-benchmark/algos/SETN.config
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
{
|
||||||
|
"scheduler": ["str", "cos"],
|
||||||
|
"eta_min" : ["float", "0.001"],
|
||||||
|
"epochs" : ["int", "400"],
|
||||||
|
"warmup" : ["int", "0"],
|
||||||
|
"optim" : ["str", "SGD"],
|
||||||
|
"LR" : ["float", "0.025"],
|
||||||
|
"decay" : ["float", "0.0005"],
|
||||||
|
"momentum" : ["float", "0.9"],
|
||||||
|
"nesterov" : ["bool", "1"],
|
||||||
|
"criterion": ["str", "Softmax"],
|
||||||
|
"batch_size": ["int", "64"]
|
||||||
|
}
|
4
configs/nas-benchmark/cifar-split.txt
Normal file
4
configs/nas-benchmark/cifar-split.txt
Normal file
File diff suppressed because one or more lines are too long
4
configs/nas-benchmark/cifar100-test-split.txt
Normal file
4
configs/nas-benchmark/cifar100-test-split.txt
Normal file
File diff suppressed because one or more lines are too long
4
configs/nas-benchmark/imagenet-16-120-test-split.txt
Normal file
4
configs/nas-benchmark/imagenet-16-120-test-split.txt
Normal file
File diff suppressed because one or more lines are too long
252
exps/algos/DARTS-V1.py
Normal file
252
exps/algos/DARTS-V1.py
Normal file
@ -0,0 +1,252 @@
|
|||||||
|
##################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||||
|
##################################################
|
||||||
|
import os, sys, time, glob, random, argparse
|
||||||
|
import numpy as np
|
||||||
|
from copy import deepcopy
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from pathlib import Path
|
||||||
|
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||||
|
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||||
|
from config_utils import load_config, dict2config, configure2str
|
||||||
|
from datasets import get_datasets, SearchDataset
|
||||||
|
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
|
||||||
|
from utils import get_model_infos, obtain_accuracy
|
||||||
|
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||||
|
from models import get_cell_based_tiny_net, get_search_spaces
|
||||||
|
|
||||||
|
|
||||||
|
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger):
|
||||||
|
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||||
|
base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||||
|
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||||
|
network.train()
|
||||||
|
end = time.time()
|
||||||
|
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader):
|
||||||
|
scheduler.update(None, 1.0 * step / len(xloader))
|
||||||
|
base_targets = base_targets.cuda(non_blocking=True)
|
||||||
|
arch_targets = arch_targets.cuda(non_blocking=True)
|
||||||
|
# measure data loading time
|
||||||
|
data_time.update(time.time() - end)
|
||||||
|
|
||||||
|
# update the weights
|
||||||
|
w_optimizer.zero_grad()
|
||||||
|
_, logits = network(base_inputs)
|
||||||
|
base_loss = criterion(logits, base_targets)
|
||||||
|
base_loss.backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_(network.parameters(), 5)
|
||||||
|
w_optimizer.step()
|
||||||
|
# record
|
||||||
|
base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
|
||||||
|
base_losses.update(base_loss.item(), base_inputs.size(0))
|
||||||
|
base_top1.update (base_prec1.item(), base_inputs.size(0))
|
||||||
|
base_top5.update (base_prec5.item(), base_inputs.size(0))
|
||||||
|
|
||||||
|
# update the architecture-weight
|
||||||
|
a_optimizer.zero_grad()
|
||||||
|
_, logits = network(arch_inputs)
|
||||||
|
arch_loss = criterion(logits, arch_targets)
|
||||||
|
arch_loss.backward()
|
||||||
|
a_optimizer.step()
|
||||||
|
# record
|
||||||
|
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
|
||||||
|
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
|
||||||
|
arch_top1.update (arch_prec1.item(), arch_inputs.size(0))
|
||||||
|
arch_top5.update (arch_prec5.item(), arch_inputs.size(0))
|
||||||
|
|
||||||
|
# measure elapsed time
|
||||||
|
batch_time.update(time.time() - end)
|
||||||
|
end = time.time()
|
||||||
|
|
||||||
|
if step % print_freq == 0 or step + 1 == len(xloader):
|
||||||
|
Sstr = '*SEARCH* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader))
|
||||||
|
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
|
||||||
|
Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=base_losses, top1=base_top1, top5=base_top5)
|
||||||
|
Astr = 'Arch [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=arch_losses, top1=arch_top1, top5=arch_top5)
|
||||||
|
logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Astr)
|
||||||
|
return base_losses.avg, base_top1.avg, base_top5.avg
|
||||||
|
|
||||||
|
|
||||||
|
def valid_func(xloader, network, criterion):
|
||||||
|
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||||
|
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||||
|
network.eval()
|
||||||
|
end = time.time()
|
||||||
|
with torch.no_grad():
|
||||||
|
for step, (arch_inputs, arch_targets) in enumerate(xloader):
|
||||||
|
arch_targets = arch_targets.cuda(non_blocking=True)
|
||||||
|
# measure data loading time
|
||||||
|
data_time.update(time.time() - end)
|
||||||
|
# prediction
|
||||||
|
_, logits = network(arch_inputs)
|
||||||
|
arch_loss = criterion(logits, arch_targets)
|
||||||
|
# record
|
||||||
|
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
|
||||||
|
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
|
||||||
|
arch_top1.update (arch_prec1.item(), arch_inputs.size(0))
|
||||||
|
arch_top5.update (arch_prec5.item(), arch_inputs.size(0))
|
||||||
|
# measure elapsed time
|
||||||
|
batch_time.update(time.time() - end)
|
||||||
|
end = time.time()
|
||||||
|
return arch_losses.avg, arch_top1.avg, arch_top5.avg
|
||||||
|
|
||||||
|
|
||||||
|
def main(xargs):
|
||||||
|
assert torch.cuda.is_available(), 'CUDA is not available.'
|
||||||
|
torch.backends.cudnn.enabled = True
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
torch.set_num_threads( xargs.workers )
|
||||||
|
prepare_seed(xargs.rand_seed)
|
||||||
|
logger = prepare_logger(args)
|
||||||
|
|
||||||
|
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||||
|
if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100':
|
||||||
|
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
||||||
|
cifar_split = load_config(split_Fpath, None, None)
|
||||||
|
train_split, valid_split = cifar_split.train, cifar_split.valid
|
||||||
|
logger.log('Load split file from {:}'.format(split_Fpath))
|
||||||
|
elif xargs.dataset.startswith('ImageNet16'):
|
||||||
|
split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset)
|
||||||
|
imagenet16_split = load_config(split_Fpath, None, None)
|
||||||
|
train_split, valid_split = imagenet16_split.train, imagenet16_split.valid
|
||||||
|
logger.log('Load split file from {:}'.format(split_Fpath))
|
||||||
|
else:
|
||||||
|
raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
|
||||||
|
config_path = 'configs/nas-benchmark/algos/DARTS.config'
|
||||||
|
config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
||||||
|
# To split data
|
||||||
|
train_data_v2 = deepcopy(train_data)
|
||||||
|
train_data_v2.transform = valid_data.transform
|
||||||
|
valid_data = train_data_v2
|
||||||
|
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
|
||||||
|
# data loader
|
||||||
|
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
|
||||||
|
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
|
||||||
|
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
|
||||||
|
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
||||||
|
|
||||||
|
search_space = get_search_spaces('cell', xargs.search_space_name)
|
||||||
|
model_config = dict2config({'name': 'DARTS-V1', 'C': xargs.channel, 'N': xargs.num_cells,
|
||||||
|
'max_nodes': xargs.max_nodes, 'num_classes': class_num,
|
||||||
|
'space' : search_space}, None)
|
||||||
|
search_model = get_cell_based_tiny_net(model_config)
|
||||||
|
|
||||||
|
w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config)
|
||||||
|
a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay)
|
||||||
|
logger.log('w-optimizer : {:}'.format(w_optimizer))
|
||||||
|
logger.log('a-optimizer : {:}'.format(a_optimizer))
|
||||||
|
logger.log('w-scheduler : {:}'.format(w_scheduler))
|
||||||
|
logger.log('criterion : {:}'.format(criterion))
|
||||||
|
flop, param = get_model_infos(search_model, xshape)
|
||||||
|
#logger.log('{:}'.format(search_model))
|
||||||
|
logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
|
||||||
|
|
||||||
|
last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
|
||||||
|
network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()
|
||||||
|
|
||||||
|
if last_info.exists(): # automatically resume from previous checkpoint
|
||||||
|
logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info))
|
||||||
|
last_info = torch.load(last_info)
|
||||||
|
start_epoch = last_info['epoch']
|
||||||
|
checkpoint = torch.load(last_info['last_checkpoint'])
|
||||||
|
genotypes = checkpoint['genotypes']
|
||||||
|
valid_accuracies = checkpoint['valid_accuracies']
|
||||||
|
search_model.load_state_dict( checkpoint['search_model'] )
|
||||||
|
w_scheduler.load_state_dict ( checkpoint['w_scheduler'] )
|
||||||
|
w_optimizer.load_state_dict ( checkpoint['w_optimizer'] )
|
||||||
|
a_optimizer.load_state_dict ( checkpoint['a_optimizer'] )
|
||||||
|
logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch))
|
||||||
|
else:
|
||||||
|
logger.log("=> do not find the last-info file : {:}".format(last_info))
|
||||||
|
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {}
|
||||||
|
|
||||||
|
# start training
|
||||||
|
start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup
|
||||||
|
for epoch in range(start_epoch, total_epoch):
|
||||||
|
w_scheduler.update(epoch, 0.0)
|
||||||
|
need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) )
|
||||||
|
epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
|
||||||
|
logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr())))
|
||||||
|
|
||||||
|
search_w_loss, search_w_top1, search_w_top5 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger)
|
||||||
|
logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5))
|
||||||
|
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
|
||||||
|
logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
|
||||||
|
# check the best accuracy
|
||||||
|
valid_accuracies[epoch] = valid_a_top1
|
||||||
|
if valid_a_top1 > valid_accuracies['best']:
|
||||||
|
valid_accuracies['best'] = valid_a_top1
|
||||||
|
genotypes['best'] = search_model.genotype()
|
||||||
|
find_best = True
|
||||||
|
else: find_best = False
|
||||||
|
|
||||||
|
genotypes[epoch] = search_model.genotype()
|
||||||
|
logger.log('<<<--->>> The {:}-th epoch : {:}'.format(epoch_str, genotypes[epoch]))
|
||||||
|
# save checkpoint
|
||||||
|
save_path = save_checkpoint({'epoch' : epoch + 1,
|
||||||
|
'args' : deepcopy(xargs),
|
||||||
|
'search_model': search_model.state_dict(),
|
||||||
|
'w_optimizer' : w_optimizer.state_dict(),
|
||||||
|
'a_optimizer' : a_optimizer.state_dict(),
|
||||||
|
'w_scheduler' : w_scheduler.state_dict(),
|
||||||
|
'genotypes' : genotypes,
|
||||||
|
'valid_accuracies' : valid_accuracies},
|
||||||
|
model_base_path, logger)
|
||||||
|
last_info = save_checkpoint({
|
||||||
|
'epoch': epoch + 1,
|
||||||
|
'args' : deepcopy(args),
|
||||||
|
'last_checkpoint': save_path,
|
||||||
|
}, logger.path('info'), logger)
|
||||||
|
if find_best:
|
||||||
|
logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1))
|
||||||
|
copy_checkpoint(model_base_path, model_best_path, logger)
|
||||||
|
with torch.no_grad():
|
||||||
|
logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
|
||||||
|
# measure elapsed time
|
||||||
|
epoch_time.update(time.time() - start_time)
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
logger.log('\n' + '-'*100)
|
||||||
|
# check the performance from the architecture dataset
|
||||||
|
#if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset):
|
||||||
|
# logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset))
|
||||||
|
#else:
|
||||||
|
# nas_bench = NASBenchmarkAPI(xargs.arch_nas_dataset)
|
||||||
|
# geno = genotypes[total_epoch-1]
|
||||||
|
# logger.log('The last model is {:}'.format(geno))
|
||||||
|
# info = nas_bench.query_by_arch( geno )
|
||||||
|
# if info is None: logger.log('Did not find this architecture : {:}.'.format(geno))
|
||||||
|
# else : logger.log('{:}'.format(info))
|
||||||
|
# logger.log('-'*100)
|
||||||
|
# geno = genotypes['best']
|
||||||
|
# logger.log('The best model is {:}'.format(geno))
|
||||||
|
# info = nas_bench.query_by_arch( geno )
|
||||||
|
# if info is None: logger.log('Did not find this architecture : {:}.'.format(geno))
|
||||||
|
# else : logger.log('{:}'.format(info))
|
||||||
|
logger.close()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser("DARTS first order")
|
||||||
|
parser.add_argument('--data_path', type=str, help='Path to dataset')
|
||||||
|
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
|
||||||
|
# channels and number-of-cells
|
||||||
|
parser.add_argument('--search_space_name', type=str, help='The search space name.')
|
||||||
|
parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.')
|
||||||
|
parser.add_argument('--channel', type=int, help='The number of channels.')
|
||||||
|
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
|
||||||
|
# architecture leraning rate
|
||||||
|
parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
|
||||||
|
parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding')
|
||||||
|
# log
|
||||||
|
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
|
||||||
|
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')
|
||||||
|
parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (nas-benchmark).')
|
||||||
|
parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)')
|
||||||
|
parser.add_argument('--rand_seed', type=int, help='manual seed')
|
||||||
|
args = parser.parse_args()
|
||||||
|
if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
|
||||||
|
main(args)
|
319
exps/algos/DARTS-V2.py
Normal file
319
exps/algos/DARTS-V2.py
Normal file
@ -0,0 +1,319 @@
|
|||||||
|
##################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||||
|
##################################################
|
||||||
|
import os, sys, time, glob, random, argparse
|
||||||
|
import numpy as np
|
||||||
|
from copy import deepcopy
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from pathlib import Path
|
||||||
|
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||||
|
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||||
|
from config_utils import load_config, dict2config, configure2str
|
||||||
|
from datasets import get_datasets, SearchDataset
|
||||||
|
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
|
||||||
|
from utils import get_model_infos, obtain_accuracy
|
||||||
|
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||||
|
from models import get_cell_based_tiny_net, get_search_spaces
|
||||||
|
|
||||||
|
|
||||||
|
def _concat(xs):
|
||||||
|
return torch.cat([x.view(-1) for x in xs])
|
||||||
|
|
||||||
|
|
||||||
|
def _hessian_vector_product(vector, network, criterion, base_inputs, base_targets, r=1e-2):
|
||||||
|
R = r / _concat(vector).norm()
|
||||||
|
for p, v in zip(network.module.get_weights(), vector):
|
||||||
|
p.data.add_(R, v)
|
||||||
|
_, logits = network(base_inputs)
|
||||||
|
loss = criterion(logits, base_targets)
|
||||||
|
grads_p = torch.autograd.grad(loss, network.module.get_alphas())
|
||||||
|
|
||||||
|
for p, v in zip(network.module.get_weights(), vector):
|
||||||
|
p.data.sub_(2*R, v)
|
||||||
|
_, logits = network(base_inputs)
|
||||||
|
loss = criterion(logits, base_targets)
|
||||||
|
grads_n = torch.autograd.grad(loss, network.module.get_alphas())
|
||||||
|
|
||||||
|
for p, v in zip(network.module.get_weights(), vector):
|
||||||
|
p.data.add_(R, v)
|
||||||
|
return [(x-y).div_(2*R) for x, y in zip(grads_p, grads_n)]
|
||||||
|
|
||||||
|
|
||||||
|
def backward_step_unrolled(network, criterion, base_inputs, base_targets, w_optimizer, arch_inputs, arch_targets):
|
||||||
|
# _compute_unrolled_model
|
||||||
|
_, logits = network(base_inputs)
|
||||||
|
loss = criterion(logits, base_targets)
|
||||||
|
LR, WD, momentum = w_optimizer.param_groups[0]['lr'], w_optimizer.param_groups[0]['weight_decay'], w_optimizer.param_groups[0]['momentum']
|
||||||
|
with torch.no_grad():
|
||||||
|
theta = _concat(network.module.get_weights())
|
||||||
|
try:
|
||||||
|
moment = _concat(w_optimizer.state[v]['momentum_buffer'] for v in network.module.get_weights())
|
||||||
|
moment = moment.mul_(momentum)
|
||||||
|
except:
|
||||||
|
moment = torch.zeros_like(theta)
|
||||||
|
dtheta = _concat(torch.autograd.grad(loss, network.module.get_weights())) + WD*theta
|
||||||
|
params = theta.sub(LR, moment+dtheta)
|
||||||
|
unrolled_model = deepcopy(network)
|
||||||
|
model_dict = unrolled_model.state_dict()
|
||||||
|
new_params, offset = {}, 0
|
||||||
|
for k, v in network.named_parameters():
|
||||||
|
if 'arch_parameters' in k: continue
|
||||||
|
v_length = np.prod(v.size())
|
||||||
|
new_params[k] = params[offset: offset+v_length].view(v.size())
|
||||||
|
offset += v_length
|
||||||
|
model_dict.update(new_params)
|
||||||
|
unrolled_model.load_state_dict(model_dict)
|
||||||
|
|
||||||
|
unrolled_model.zero_grad()
|
||||||
|
_, unrolled_logits = unrolled_model(arch_inputs)
|
||||||
|
unrolled_loss = criterion(unrolled_logits, arch_targets)
|
||||||
|
unrolled_loss.backward()
|
||||||
|
|
||||||
|
dalpha = unrolled_model.module.arch_parameters.grad
|
||||||
|
vector = [v.grad.data for v in unrolled_model.module.get_weights()]
|
||||||
|
[implicit_grads] = _hessian_vector_product(vector, network, criterion, base_inputs, base_targets)
|
||||||
|
|
||||||
|
dalpha.data.sub_(LR, implicit_grads.data)
|
||||||
|
|
||||||
|
if network.module.arch_parameters.grad is None:
|
||||||
|
network.module.arch_parameters.grad = deepcopy( dalpha )
|
||||||
|
else:
|
||||||
|
network.module.arch_parameters.grad.data.copy_( dalpha.data )
|
||||||
|
return unrolled_loss.detach(), unrolled_logits.detach()
|
||||||
|
|
||||||
|
|
||||||
|
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger):
|
||||||
|
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||||
|
base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||||
|
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||||
|
network.train()
|
||||||
|
end = time.time()
|
||||||
|
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader):
|
||||||
|
scheduler.update(None, 1.0 * step / len(xloader))
|
||||||
|
base_targets = base_targets.cuda(non_blocking=True)
|
||||||
|
arch_targets = arch_targets.cuda(non_blocking=True)
|
||||||
|
# measure data loading time
|
||||||
|
data_time.update(time.time() - end)
|
||||||
|
|
||||||
|
# update the architecture-weight
|
||||||
|
a_optimizer.zero_grad()
|
||||||
|
arch_loss, arch_logits = backward_step_unrolled(network, criterion, base_inputs, base_targets, w_optimizer, arch_inputs, arch_targets)
|
||||||
|
a_optimizer.step()
|
||||||
|
# record
|
||||||
|
arch_prec1, arch_prec5 = obtain_accuracy(arch_logits.data, arch_targets.data, topk=(1, 5))
|
||||||
|
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
|
||||||
|
arch_top1.update (arch_prec1.item(), arch_inputs.size(0))
|
||||||
|
arch_top5.update (arch_prec5.item(), arch_inputs.size(0))
|
||||||
|
|
||||||
|
# update the weights
|
||||||
|
w_optimizer.zero_grad()
|
||||||
|
_, logits = network(base_inputs)
|
||||||
|
base_loss = criterion(logits, base_targets)
|
||||||
|
base_loss.backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_(network.parameters(), 5)
|
||||||
|
w_optimizer.step()
|
||||||
|
# record
|
||||||
|
base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
|
||||||
|
base_losses.update(base_loss.item(), base_inputs.size(0))
|
||||||
|
base_top1.update (base_prec1.item(), base_inputs.size(0))
|
||||||
|
base_top5.update (base_prec5.item(), base_inputs.size(0))
|
||||||
|
|
||||||
|
# measure elapsed time
|
||||||
|
batch_time.update(time.time() - end)
|
||||||
|
end = time.time()
|
||||||
|
|
||||||
|
if step % print_freq == 0 or step + 1 == len(xloader):
|
||||||
|
Sstr = '*SEARCH* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader))
|
||||||
|
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
|
||||||
|
Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=base_losses, top1=base_top1, top5=base_top5)
|
||||||
|
Astr = 'Arch [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=arch_losses, top1=arch_top1, top5=arch_top5)
|
||||||
|
logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Astr)
|
||||||
|
return base_losses.avg, base_top1.avg, base_top5.avg
|
||||||
|
|
||||||
|
|
||||||
|
def valid_func(xloader, network, criterion):
|
||||||
|
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||||
|
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||||
|
network.eval()
|
||||||
|
end = time.time()
|
||||||
|
with torch.no_grad():
|
||||||
|
for step, (arch_inputs, arch_targets) in enumerate(xloader):
|
||||||
|
arch_targets = arch_targets.cuda(non_blocking=True)
|
||||||
|
# measure data loading time
|
||||||
|
data_time.update(time.time() - end)
|
||||||
|
# prediction
|
||||||
|
_, logits = network(arch_inputs)
|
||||||
|
arch_loss = criterion(logits, arch_targets)
|
||||||
|
# record
|
||||||
|
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
|
||||||
|
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
|
||||||
|
arch_top1.update (arch_prec1.item(), arch_inputs.size(0))
|
||||||
|
arch_top5.update (arch_prec5.item(), arch_inputs.size(0))
|
||||||
|
# measure elapsed time
|
||||||
|
batch_time.update(time.time() - end)
|
||||||
|
end = time.time()
|
||||||
|
return arch_losses.avg, arch_top1.avg, arch_top5.avg
|
||||||
|
|
||||||
|
|
||||||
|
def main(xargs):
|
||||||
|
assert torch.cuda.is_available(), 'CUDA is not available.'
|
||||||
|
torch.backends.cudnn.enabled = True
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
torch.set_num_threads( xargs.workers )
|
||||||
|
prepare_seed(xargs.rand_seed)
|
||||||
|
logger = prepare_logger(args)
|
||||||
|
|
||||||
|
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||||
|
if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100':
|
||||||
|
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
||||||
|
cifar_split = load_config(split_Fpath, None, None)
|
||||||
|
train_split, valid_split = cifar_split.train, cifar_split.valid
|
||||||
|
logger.log('Load split file from {:}'.format(split_Fpath))
|
||||||
|
elif xargs.dataset.startswith('ImageNet16'):
|
||||||
|
split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset)
|
||||||
|
imagenet16_split = load_config(split_Fpath, None, None)
|
||||||
|
train_split, valid_split = imagenet16_split.train, imagenet16_split.valid
|
||||||
|
logger.log('Load split file from {:}'.format(split_Fpath))
|
||||||
|
else:
|
||||||
|
raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
|
||||||
|
config_path = 'configs/nas-benchmark/algos/DARTS.config'
|
||||||
|
config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
||||||
|
# To split data
|
||||||
|
train_data_v2 = deepcopy(train_data)
|
||||||
|
train_data_v2.transform = valid_data.transform
|
||||||
|
valid_data = train_data_v2
|
||||||
|
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
|
||||||
|
# data loader
|
||||||
|
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
|
||||||
|
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
|
||||||
|
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
|
||||||
|
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
||||||
|
|
||||||
|
search_space = get_search_spaces('cell', xargs.search_space_name)
|
||||||
|
model_config = dict2config({'name': 'DARTS-V2', 'C': xargs.channel, 'N': xargs.num_cells,
|
||||||
|
'max_nodes': xargs.max_nodes, 'num_classes': class_num,
|
||||||
|
'space' : search_space}, None)
|
||||||
|
search_model = get_cell_based_tiny_net(model_config)
|
||||||
|
|
||||||
|
w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config)
|
||||||
|
a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay)
|
||||||
|
logger.log('w-optimizer : {:}'.format(w_optimizer))
|
||||||
|
logger.log('a-optimizer : {:}'.format(a_optimizer))
|
||||||
|
logger.log('w-scheduler : {:}'.format(w_scheduler))
|
||||||
|
logger.log('criterion : {:}'.format(criterion))
|
||||||
|
flop, param = get_model_infos(search_model, xshape)
|
||||||
|
#logger.log('{:}'.format(search_model))
|
||||||
|
logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
|
||||||
|
|
||||||
|
last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
|
||||||
|
network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()
|
||||||
|
|
||||||
|
if last_info.exists(): # automatically resume from previous checkpoint
|
||||||
|
logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info))
|
||||||
|
last_info = torch.load(last_info)
|
||||||
|
start_epoch = last_info['epoch']
|
||||||
|
checkpoint = torch.load(last_info['last_checkpoint'])
|
||||||
|
genotypes = checkpoint['genotypes']
|
||||||
|
valid_accuracies = checkpoint['valid_accuracies']
|
||||||
|
search_model.load_state_dict( checkpoint['search_model'] )
|
||||||
|
w_scheduler.load_state_dict ( checkpoint['w_scheduler'] )
|
||||||
|
w_optimizer.load_state_dict ( checkpoint['w_optimizer'] )
|
||||||
|
a_optimizer.load_state_dict ( checkpoint['a_optimizer'] )
|
||||||
|
logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch))
|
||||||
|
else:
|
||||||
|
logger.log("=> do not find the last-info file : {:}".format(last_info))
|
||||||
|
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {}
|
||||||
|
|
||||||
|
# start training
|
||||||
|
start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup
|
||||||
|
for epoch in range(start_epoch, total_epoch):
|
||||||
|
w_scheduler.update(epoch, 0.0)
|
||||||
|
need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) )
|
||||||
|
epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
|
||||||
|
min_LR = min(w_scheduler.get_lr())
|
||||||
|
logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min_LR))
|
||||||
|
|
||||||
|
search_w_loss, search_w_top1, search_w_top5 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger)
|
||||||
|
logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5))
|
||||||
|
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
|
||||||
|
logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
|
||||||
|
# check the best accuracy
|
||||||
|
valid_accuracies[epoch] = valid_a_top1
|
||||||
|
if valid_a_top1 > valid_accuracies['best']:
|
||||||
|
valid_accuracies['best'] = valid_a_top1
|
||||||
|
genotypes['best'] = search_model.genotype()
|
||||||
|
find_best = True
|
||||||
|
else: find_best = False
|
||||||
|
|
||||||
|
genotypes[epoch] = search_model.genotype()
|
||||||
|
logger.log('<<<--->>> The {:}-th epoch : {:}'.format(epoch_str, genotypes[epoch]))
|
||||||
|
# save checkpoint
|
||||||
|
save_path = save_checkpoint({'epoch' : epoch + 1,
|
||||||
|
'args' : deepcopy(xargs),
|
||||||
|
'search_model': search_model.state_dict(),
|
||||||
|
'w_optimizer' : w_optimizer.state_dict(),
|
||||||
|
'a_optimizer' : a_optimizer.state_dict(),
|
||||||
|
'w_scheduler' : w_scheduler.state_dict(),
|
||||||
|
'genotypes' : genotypes,
|
||||||
|
'valid_accuracies' : valid_accuracies},
|
||||||
|
model_base_path, logger)
|
||||||
|
last_info = save_checkpoint({
|
||||||
|
'epoch': epoch + 1,
|
||||||
|
'args' : deepcopy(args),
|
||||||
|
'last_checkpoint': save_path,
|
||||||
|
}, logger.path('info'), logger)
|
||||||
|
if find_best:
|
||||||
|
logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1))
|
||||||
|
copy_checkpoint(model_base_path, model_best_path, logger)
|
||||||
|
with torch.no_grad():
|
||||||
|
logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
|
||||||
|
# measure elapsed time
|
||||||
|
epoch_time.update(time.time() - start_time)
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
logger.log('\n' + '-'*100)
|
||||||
|
# check the performance from the architecture dataset
|
||||||
|
"""
|
||||||
|
if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset):
|
||||||
|
logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset))
|
||||||
|
else:
|
||||||
|
nas_bench = NASBenchmarkAPI(xargs.arch_nas_dataset)
|
||||||
|
geno = genotypes[total_epoch-1]
|
||||||
|
logger.log('The last model is {:}'.format(geno))
|
||||||
|
info = nas_bench.query_by_arch( geno )
|
||||||
|
if info is None: logger.log('Did not find this architecture : {:}.'.format(geno))
|
||||||
|
else : logger.log('{:}'.format(info))
|
||||||
|
logger.log('-'*100)
|
||||||
|
geno = genotypes['best']
|
||||||
|
logger.log('The best model is {:}'.format(geno))
|
||||||
|
info = nas_bench.query_by_arch( geno )
|
||||||
|
if info is None: logger.log('Did not find this architecture : {:}.'.format(geno))
|
||||||
|
else : logger.log('{:}'.format(info))
|
||||||
|
"""
|
||||||
|
logger.close()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser("DARTS Second Order")
|
||||||
|
parser.add_argument('--data_path', type=str, help='Path to dataset')
|
||||||
|
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
|
||||||
|
# channels and number-of-cells
|
||||||
|
parser.add_argument('--search_space_name', type=str, help='The search space name.')
|
||||||
|
parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.')
|
||||||
|
parser.add_argument('--channel', type=int, help='The number of channels.')
|
||||||
|
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
|
||||||
|
# architecture leraning rate
|
||||||
|
parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
|
||||||
|
parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding')
|
||||||
|
# log
|
||||||
|
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
|
||||||
|
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')
|
||||||
|
parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (tiny-nas-benchmark).')
|
||||||
|
parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)')
|
||||||
|
parser.add_argument('--rand_seed', type=int, help='manual seed')
|
||||||
|
args = parser.parse_args()
|
||||||
|
if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
|
||||||
|
main(args)
|
224
exps/algos/GDAS.py
Normal file
224
exps/algos/GDAS.py
Normal file
@ -0,0 +1,224 @@
|
|||||||
|
##################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||||
|
##################################################
|
||||||
|
import os, sys, time, glob, random, argparse
|
||||||
|
import numpy as np
|
||||||
|
from copy import deepcopy
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from pathlib import Path
|
||||||
|
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||||
|
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||||
|
from config_utils import load_config, dict2config, configure2str
|
||||||
|
from datasets import get_datasets, SearchDataset
|
||||||
|
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
|
||||||
|
from utils import get_model_infos, obtain_accuracy
|
||||||
|
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||||
|
from models import get_cell_based_tiny_net, get_search_spaces
|
||||||
|
|
||||||
|
|
||||||
|
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger):
|
||||||
|
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||||
|
base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||||
|
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||||
|
network.train()
|
||||||
|
end = time.time()
|
||||||
|
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader):
|
||||||
|
scheduler.update(None, 1.0 * step / len(xloader))
|
||||||
|
base_targets = base_targets.cuda(non_blocking=True)
|
||||||
|
arch_targets = arch_targets.cuda(non_blocking=True)
|
||||||
|
# measure data loading time
|
||||||
|
data_time.update(time.time() - end)
|
||||||
|
|
||||||
|
# update the weights
|
||||||
|
w_optimizer.zero_grad()
|
||||||
|
_, logits = network(base_inputs)
|
||||||
|
base_loss = criterion(logits, base_targets)
|
||||||
|
base_loss.backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_(network.parameters(), 5)
|
||||||
|
w_optimizer.step()
|
||||||
|
# record
|
||||||
|
base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
|
||||||
|
base_losses.update(base_loss.item(), base_inputs.size(0))
|
||||||
|
base_top1.update (base_prec1.item(), base_inputs.size(0))
|
||||||
|
base_top5.update (base_prec5.item(), base_inputs.size(0))
|
||||||
|
|
||||||
|
# update the architecture-weight
|
||||||
|
a_optimizer.zero_grad()
|
||||||
|
_, logits = network(arch_inputs)
|
||||||
|
arch_loss = criterion(logits, arch_targets)
|
||||||
|
arch_loss.backward()
|
||||||
|
a_optimizer.step()
|
||||||
|
# record
|
||||||
|
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
|
||||||
|
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
|
||||||
|
arch_top1.update (arch_prec1.item(), arch_inputs.size(0))
|
||||||
|
arch_top5.update (arch_prec5.item(), arch_inputs.size(0))
|
||||||
|
|
||||||
|
# measure elapsed time
|
||||||
|
batch_time.update(time.time() - end)
|
||||||
|
end = time.time()
|
||||||
|
|
||||||
|
if step % print_freq == 0 or step + 1 == len(xloader):
|
||||||
|
Sstr = '*SEARCH* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader))
|
||||||
|
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
|
||||||
|
Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=base_losses, top1=base_top1, top5=base_top5)
|
||||||
|
Astr = 'Arch [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=arch_losses, top1=arch_top1, top5=arch_top5)
|
||||||
|
logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Astr)
|
||||||
|
return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg
|
||||||
|
|
||||||
|
|
||||||
|
def main(xargs):
|
||||||
|
assert torch.cuda.is_available(), 'CUDA is not available.'
|
||||||
|
torch.backends.cudnn.enabled = True
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
torch.set_num_threads( xargs.workers )
|
||||||
|
prepare_seed(xargs.rand_seed)
|
||||||
|
logger = prepare_logger(args)
|
||||||
|
|
||||||
|
train_data, _, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||||
|
if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100':
|
||||||
|
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
||||||
|
cifar_split = load_config(split_Fpath, None, None)
|
||||||
|
train_split, valid_split = cifar_split.train, cifar_split.valid
|
||||||
|
logger.log('Load split file from {:}'.format(split_Fpath))
|
||||||
|
elif xargs.dataset.startswith('ImageNet16'):
|
||||||
|
split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset)
|
||||||
|
imagenet16_split = load_config(split_Fpath, None, None)
|
||||||
|
train_split, valid_split = imagenet16_split.train, imagenet16_split.valid
|
||||||
|
logger.log('Load split file from {:}'.format(split_Fpath))
|
||||||
|
else:
|
||||||
|
raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
|
||||||
|
config_path = 'configs/nas-benchmark/algos/GDAS.config'
|
||||||
|
config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
||||||
|
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
|
||||||
|
# data loader
|
||||||
|
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
|
||||||
|
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), config.batch_size))
|
||||||
|
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
||||||
|
|
||||||
|
search_space = get_search_spaces('cell', xargs.search_space_name)
|
||||||
|
model_config = dict2config({'name': 'GDAS', 'C': xargs.channel, 'N': xargs.num_cells,
|
||||||
|
'max_nodes': xargs.max_nodes, 'num_classes': class_num,
|
||||||
|
'space' : search_space}, None)
|
||||||
|
search_model = get_cell_based_tiny_net(model_config)
|
||||||
|
|
||||||
|
w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config)
|
||||||
|
a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay)
|
||||||
|
logger.log('w-optimizer : {:}'.format(w_optimizer))
|
||||||
|
logger.log('a-optimizer : {:}'.format(a_optimizer))
|
||||||
|
logger.log('w-scheduler : {:}'.format(w_scheduler))
|
||||||
|
logger.log('criterion : {:}'.format(criterion))
|
||||||
|
flop, param = get_model_infos(search_model, xshape)
|
||||||
|
#logger.log('{:}'.format(search_model))
|
||||||
|
logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
|
||||||
|
logger.log('search_space : {:}'.format(search_space))
|
||||||
|
|
||||||
|
last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
|
||||||
|
network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()
|
||||||
|
|
||||||
|
if last_info.exists(): # automatically resume from previous checkpoint
|
||||||
|
logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info))
|
||||||
|
last_info = torch.load(last_info)
|
||||||
|
start_epoch = last_info['epoch']
|
||||||
|
checkpoint = torch.load(last_info['last_checkpoint'])
|
||||||
|
genotypes = checkpoint['genotypes']
|
||||||
|
valid_accuracies = checkpoint['valid_accuracies']
|
||||||
|
search_model.load_state_dict( checkpoint['search_model'] )
|
||||||
|
w_scheduler.load_state_dict ( checkpoint['w_scheduler'] )
|
||||||
|
w_optimizer.load_state_dict ( checkpoint['w_optimizer'] )
|
||||||
|
a_optimizer.load_state_dict ( checkpoint['a_optimizer'] )
|
||||||
|
logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch))
|
||||||
|
else:
|
||||||
|
logger.log("=> do not find the last-info file : {:}".format(last_info))
|
||||||
|
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {}
|
||||||
|
|
||||||
|
# start training
|
||||||
|
start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup
|
||||||
|
for epoch in range(start_epoch, total_epoch):
|
||||||
|
w_scheduler.update(epoch, 0.0)
|
||||||
|
need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) )
|
||||||
|
epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
|
||||||
|
search_model.set_tau( xargs.tau_max - (xargs.tau_max-xargs.tau_min) * epoch / (total_epoch-1) )
|
||||||
|
logger.log('\n[Search the {:}-th epoch] {:}, tau={:}, LR={:}'.format(epoch_str, need_time, search_model.get_tau(), min(w_scheduler.get_lr())))
|
||||||
|
|
||||||
|
search_w_loss, search_w_top1, search_w_top5, valid_a_loss , valid_a_top1 , valid_a_top5 \
|
||||||
|
= search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger)
|
||||||
|
logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5))
|
||||||
|
logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss , valid_a_top1 , valid_a_top5 ))
|
||||||
|
# check the best accuracy
|
||||||
|
valid_accuracies[epoch] = valid_a_top1
|
||||||
|
if valid_a_top1 > valid_accuracies['best']:
|
||||||
|
valid_accuracies['best'] = valid_a_top1
|
||||||
|
genotypes['best'] = search_model.genotype()
|
||||||
|
find_best = True
|
||||||
|
else: find_best = False
|
||||||
|
|
||||||
|
genotypes[epoch] = search_model.genotype()
|
||||||
|
logger.log('<<<--->>> The {:}-th epoch : {:}'.format(epoch_str, genotypes[epoch]))
|
||||||
|
# save checkpoint
|
||||||
|
save_path = save_checkpoint({'epoch' : epoch + 1,
|
||||||
|
'args' : deepcopy(xargs),
|
||||||
|
'search_model': search_model.state_dict(),
|
||||||
|
'w_optimizer' : w_optimizer.state_dict(),
|
||||||
|
'a_optimizer' : a_optimizer.state_dict(),
|
||||||
|
'w_scheduler' : w_scheduler.state_dict(),
|
||||||
|
'genotypes' : genotypes,
|
||||||
|
'valid_accuracies' : valid_accuracies},
|
||||||
|
model_base_path, logger)
|
||||||
|
last_info = save_checkpoint({
|
||||||
|
'epoch': epoch + 1,
|
||||||
|
'args' : deepcopy(args),
|
||||||
|
'last_checkpoint': save_path,
|
||||||
|
}, logger.path('info'), logger)
|
||||||
|
if find_best:
|
||||||
|
logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1))
|
||||||
|
copy_checkpoint(model_base_path, model_best_path, logger)
|
||||||
|
with torch.no_grad():
|
||||||
|
logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
|
||||||
|
# measure elapsed time
|
||||||
|
epoch_time.update(time.time() - start_time)
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
logger.log('\n' + '-'*100)
|
||||||
|
# check the performance from the architecture dataset
|
||||||
|
"""
|
||||||
|
if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset):
|
||||||
|
logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset))
|
||||||
|
else:
|
||||||
|
nas_bench = NASBenchmarkAPI(xargs.arch_nas_dataset)
|
||||||
|
geno = genotypes[total_epoch-1]
|
||||||
|
logger.log('The last model is {:}'.format(geno))
|
||||||
|
info = nas_bench.query_by_arch( geno )
|
||||||
|
if info is None: logger.log('Did not find this architecture : {:}.'.format(geno))
|
||||||
|
else : logger.log('{:}'.format(info))
|
||||||
|
logger.log('-'*100)
|
||||||
|
"""
|
||||||
|
logger.close()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser("GDAS")
|
||||||
|
parser.add_argument('--data_path', type=str, help='Path to dataset')
|
||||||
|
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
|
||||||
|
# channels and number-of-cells
|
||||||
|
parser.add_argument('--search_space_name', type=str, help='The search space name.')
|
||||||
|
parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.')
|
||||||
|
parser.add_argument('--channel', type=int, help='The number of channels.')
|
||||||
|
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
|
||||||
|
# architecture leraning rate
|
||||||
|
parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
|
||||||
|
parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding')
|
||||||
|
parser.add_argument('--tau_min', type=float, help='The minimum tau for Gumbel')
|
||||||
|
parser.add_argument('--tau_max', type=float, help='The maximum tau for Gumbel')
|
||||||
|
# log
|
||||||
|
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
|
||||||
|
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')
|
||||||
|
parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (tiny-nas-benchmark).')
|
||||||
|
parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)')
|
||||||
|
parser.add_argument('--rand_seed', type=int, help='manual seed')
|
||||||
|
args = parser.parse_args()
|
||||||
|
if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
|
||||||
|
main(args)
|
281
exps/algos/SETN.py
Normal file
281
exps/algos/SETN.py
Normal file
@ -0,0 +1,281 @@
|
|||||||
|
##################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||||
|
##################################################
|
||||||
|
# One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019
|
||||||
|
##################################################
|
||||||
|
import os, sys, time, glob, random, argparse
|
||||||
|
import numpy as np
|
||||||
|
from copy import deepcopy
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from pathlib import Path
|
||||||
|
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||||
|
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||||
|
from config_utils import load_config, dict2config, configure2str
|
||||||
|
from datasets import get_datasets, SearchDataset
|
||||||
|
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
|
||||||
|
from utils import get_model_infos, obtain_accuracy
|
||||||
|
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||||
|
from models import get_cell_based_tiny_net, get_search_spaces
|
||||||
|
|
||||||
|
|
||||||
|
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger):
|
||||||
|
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||||
|
base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||||
|
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||||
|
network.train()
|
||||||
|
end = time.time()
|
||||||
|
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader):
|
||||||
|
scheduler.update(None, 1.0 * step / len(xloader))
|
||||||
|
base_targets = base_targets.cuda(non_blocking=True)
|
||||||
|
arch_targets = arch_targets.cuda(non_blocking=True)
|
||||||
|
# measure data loading time
|
||||||
|
data_time.update(time.time() - end)
|
||||||
|
|
||||||
|
# update the weights
|
||||||
|
network.module.set_cal_mode( 'urs' )
|
||||||
|
w_optimizer.zero_grad()
|
||||||
|
_, logits = network(base_inputs)
|
||||||
|
base_loss = criterion(logits, base_targets)
|
||||||
|
base_loss.backward()
|
||||||
|
w_optimizer.step()
|
||||||
|
# record
|
||||||
|
base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
|
||||||
|
base_losses.update(base_loss.item(), base_inputs.size(0))
|
||||||
|
base_top1.update (base_prec1.item(), base_inputs.size(0))
|
||||||
|
base_top5.update (base_prec5.item(), base_inputs.size(0))
|
||||||
|
|
||||||
|
# update the architecture-weight
|
||||||
|
network.module.set_cal_mode( 'joint' )
|
||||||
|
a_optimizer.zero_grad()
|
||||||
|
_, logits = network(arch_inputs)
|
||||||
|
arch_loss = criterion(logits, arch_targets)
|
||||||
|
arch_loss.backward()
|
||||||
|
a_optimizer.step()
|
||||||
|
# record
|
||||||
|
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
|
||||||
|
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
|
||||||
|
arch_top1.update (arch_prec1.item(), arch_inputs.size(0))
|
||||||
|
arch_top5.update (arch_prec5.item(), arch_inputs.size(0))
|
||||||
|
|
||||||
|
# measure elapsed time
|
||||||
|
batch_time.update(time.time() - end)
|
||||||
|
end = time.time()
|
||||||
|
|
||||||
|
if step % print_freq == 0 or step + 1 == len(xloader):
|
||||||
|
Sstr = '*SEARCH* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader))
|
||||||
|
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
|
||||||
|
Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=base_losses, top1=base_top1, top5=base_top5)
|
||||||
|
Astr = 'Arch [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=arch_losses, top1=arch_top1, top5=arch_top5)
|
||||||
|
logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Astr)
|
||||||
|
return base_losses.avg, base_top1.avg, base_top5.avg
|
||||||
|
|
||||||
|
|
||||||
|
def valid_func(xloader, network, criterion):
|
||||||
|
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||||
|
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||||
|
network.train()
|
||||||
|
end = time.time()
|
||||||
|
with torch.no_grad():
|
||||||
|
for step, (arch_inputs, arch_targets) in enumerate(xloader):
|
||||||
|
arch_targets = arch_targets.cuda(non_blocking=True)
|
||||||
|
# measure data loading time
|
||||||
|
data_time.update(time.time() - end)
|
||||||
|
# prediction
|
||||||
|
_, logits = network(arch_inputs)
|
||||||
|
arch_loss = criterion(logits, arch_targets)
|
||||||
|
# record
|
||||||
|
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
|
||||||
|
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
|
||||||
|
arch_top1.update (arch_prec1.item(), arch_inputs.size(0))
|
||||||
|
arch_top5.update (arch_prec5.item(), arch_inputs.size(0))
|
||||||
|
# measure elapsed time
|
||||||
|
batch_time.update(time.time() - end)
|
||||||
|
end = time.time()
|
||||||
|
return arch_losses.avg, arch_top1.avg, arch_top5.avg
|
||||||
|
|
||||||
|
|
||||||
|
def main(xargs):
|
||||||
|
assert torch.cuda.is_available(), 'CUDA is not available.'
|
||||||
|
torch.backends.cudnn.enabled = True
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
torch.set_num_threads( xargs.workers )
|
||||||
|
prepare_seed(xargs.rand_seed)
|
||||||
|
logger = prepare_logger(args)
|
||||||
|
|
||||||
|
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||||
|
if xargs.dataset == 'cifar10' or xargs.dataset == 'cifar100':
|
||||||
|
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
||||||
|
cifar_split = load_config(split_Fpath, None, None)
|
||||||
|
train_split, valid_split = cifar_split.train, cifar_split.valid
|
||||||
|
logger.log('Load split file from {:}'.format(split_Fpath))
|
||||||
|
elif xargs.dataset.startswith('ImageNet16'):
|
||||||
|
split_Fpath = 'configs/nas-benchmark/{:}-split.txt'.format(xargs.dataset)
|
||||||
|
imagenet16_split = load_config(split_Fpath, None, None)
|
||||||
|
train_split, valid_split = imagenet16_split.train, imagenet16_split.valid
|
||||||
|
logger.log('Load split file from {:}'.format(split_Fpath))
|
||||||
|
else:
|
||||||
|
raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
|
||||||
|
config_path = 'configs/nas-benchmark/algos/SETN.config'
|
||||||
|
config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
||||||
|
# To split data
|
||||||
|
train_data_v2 = deepcopy(train_data)
|
||||||
|
train_data_v2.transform = valid_data.transform
|
||||||
|
valid_data = train_data_v2
|
||||||
|
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
|
||||||
|
# data loader
|
||||||
|
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
|
||||||
|
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
|
||||||
|
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
|
||||||
|
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
||||||
|
|
||||||
|
search_space = get_search_spaces('cell', xargs.search_space_name)
|
||||||
|
model_config = dict2config({'name': 'SETN', 'C': xargs.channel, 'N': xargs.num_cells,
|
||||||
|
'max_nodes': xargs.max_nodes, 'num_classes': class_num,
|
||||||
|
'space' : search_space}, None)
|
||||||
|
search_model = get_cell_based_tiny_net(model_config)
|
||||||
|
|
||||||
|
w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config)
|
||||||
|
a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay)
|
||||||
|
logger.log('w-optimizer : {:}'.format(w_optimizer))
|
||||||
|
logger.log('a-optimizer : {:}'.format(a_optimizer))
|
||||||
|
logger.log('w-scheduler : {:}'.format(w_scheduler))
|
||||||
|
logger.log('criterion : {:}'.format(criterion))
|
||||||
|
flop, param = get_model_infos(search_model, xshape)
|
||||||
|
#logger.log('{:}'.format(search_model))
|
||||||
|
logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
|
||||||
|
|
||||||
|
last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
|
||||||
|
network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()
|
||||||
|
|
||||||
|
if last_info.exists(): # automatically resume from previous checkpoint
|
||||||
|
logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info))
|
||||||
|
last_info = torch.load(last_info)
|
||||||
|
start_epoch = last_info['epoch']
|
||||||
|
checkpoint = torch.load(last_info['last_checkpoint'])
|
||||||
|
genotypes = checkpoint['genotypes']
|
||||||
|
valid_accuracies = checkpoint['valid_accuracies']
|
||||||
|
search_model.load_state_dict( checkpoint['search_model'] )
|
||||||
|
w_scheduler.load_state_dict ( checkpoint['w_scheduler'] )
|
||||||
|
w_optimizer.load_state_dict ( checkpoint['w_optimizer'] )
|
||||||
|
a_optimizer.load_state_dict ( checkpoint['a_optimizer'] )
|
||||||
|
logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch))
|
||||||
|
else:
|
||||||
|
logger.log("=> do not find the last-info file : {:}".format(last_info))
|
||||||
|
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {}
|
||||||
|
|
||||||
|
# start training
|
||||||
|
start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup
|
||||||
|
for epoch in range(start_epoch, total_epoch):
|
||||||
|
w_scheduler.update(epoch, 0.0)
|
||||||
|
need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) )
|
||||||
|
epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
|
||||||
|
logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr())))
|
||||||
|
|
||||||
|
search_w_loss, search_w_top1, search_w_top5 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger)
|
||||||
|
logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5))
|
||||||
|
search_model.set_cal_mode('urs')
|
||||||
|
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
|
||||||
|
logger.log('[{:}] URS---evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
|
||||||
|
search_model.set_cal_mode('joint')
|
||||||
|
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
|
||||||
|
logger.log('[{:}] JOINT-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
|
||||||
|
search_model.set_cal_mode('select')
|
||||||
|
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
|
||||||
|
logger.log('[{:}] Selec-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
|
||||||
|
# check the best accuracy
|
||||||
|
valid_accuracies[epoch] = valid_a_top1
|
||||||
|
if valid_a_top1 > valid_accuracies['best']:
|
||||||
|
valid_accuracies['best'] = valid_a_top1
|
||||||
|
genotypes['best'] = search_model.genotype()
|
||||||
|
find_best = True
|
||||||
|
else: find_best = False
|
||||||
|
|
||||||
|
genotypes[epoch] = search_model.genotype()
|
||||||
|
logger.log('<<<--->>> The {:}-th epoch : {:}'.format(epoch_str, genotypes[epoch]))
|
||||||
|
# save checkpoint
|
||||||
|
save_path = save_checkpoint({'epoch' : epoch + 1,
|
||||||
|
'args' : deepcopy(xargs),
|
||||||
|
'search_model': search_model.state_dict(),
|
||||||
|
'w_optimizer' : w_optimizer.state_dict(),
|
||||||
|
'a_optimizer' : a_optimizer.state_dict(),
|
||||||
|
'w_scheduler' : w_scheduler.state_dict(),
|
||||||
|
'genotypes' : genotypes,
|
||||||
|
'valid_accuracies' : valid_accuracies},
|
||||||
|
model_base_path, logger)
|
||||||
|
last_info = save_checkpoint({
|
||||||
|
'epoch': epoch + 1,
|
||||||
|
'args' : deepcopy(args),
|
||||||
|
'last_checkpoint': save_path,
|
||||||
|
}, logger.path('info'), logger)
|
||||||
|
if find_best:
|
||||||
|
logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1))
|
||||||
|
copy_checkpoint(model_base_path, model_best_path, logger)
|
||||||
|
with torch.no_grad():
|
||||||
|
logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
|
||||||
|
# measure elapsed time
|
||||||
|
epoch_time.update(time.time() - start_time)
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# sampling
|
||||||
|
with torch.no_grad():
|
||||||
|
logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
|
||||||
|
selected_archs = set()
|
||||||
|
while len(selected_archs) < xargs.select_num:
|
||||||
|
architecture = search_model.dync_genotype()
|
||||||
|
selected_archs.add( architecture )
|
||||||
|
logger.log('select {:} architectures based on the learned arch-parameters'.format( len(selected_archs) ))
|
||||||
|
|
||||||
|
best_arch, best_acc = None, -1
|
||||||
|
state_dict = deepcopy( network.state_dict() )
|
||||||
|
for index, arch in enumerate(selected_archs):
|
||||||
|
with torch.no_grad():
|
||||||
|
search_model.set_cal_mode('dynamic', arch)
|
||||||
|
network.load_state_dict( deepcopy(state_dict) )
|
||||||
|
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
|
||||||
|
logger.log('{:} [{:03d}/{:03d}] : {:125s}, loss={:.3f}, accuracy={:.3f}%'.format(time_string(), index, len(selected_archs), str(arch), valid_a_loss , valid_a_top1))
|
||||||
|
if best_arch is None or best_acc < valid_a_top1:
|
||||||
|
best_arch, best_acc = arch, valid_a_top1
|
||||||
|
logger.log('Find the best one : {:} with accuracy={:.2f}%'.format(best_arch, best_acc))
|
||||||
|
|
||||||
|
logger.log('\n' + '-'*100)
|
||||||
|
# check the performance from the architecture dataset
|
||||||
|
"""
|
||||||
|
if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset):
|
||||||
|
logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset))
|
||||||
|
else:
|
||||||
|
nas_bench = TinyNASBenchmarkAPI(xargs.arch_nas_dataset)
|
||||||
|
geno = best_arch
|
||||||
|
logger.log('The last model is {:}'.format(geno))
|
||||||
|
info = nas_bench.query_by_arch( geno )
|
||||||
|
if info is None: logger.log('Did not find this architecture : {:}.'.format(geno))
|
||||||
|
else : logger.log('{:}'.format(info))
|
||||||
|
logger.log('-'*100)
|
||||||
|
"""
|
||||||
|
logger.close()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser("SETN")
|
||||||
|
parser.add_argument('--data_path', type=str, help='Path to dataset')
|
||||||
|
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
|
||||||
|
# channels and number-of-cells
|
||||||
|
parser.add_argument('--search_space_name', type=str, help='The search space name.')
|
||||||
|
parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.')
|
||||||
|
parser.add_argument('--channel', type=int, help='The number of channels.')
|
||||||
|
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
|
||||||
|
parser.add_argument('--select_num', type=int, help='The number of selected architectures to evaluate.')
|
||||||
|
# architecture leraning rate
|
||||||
|
parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
|
||||||
|
parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding')
|
||||||
|
# log
|
||||||
|
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
|
||||||
|
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')
|
||||||
|
parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (tiny-nas-benchmark).')
|
||||||
|
parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)')
|
||||||
|
parser.add_argument('--rand_seed', type=int, help='manual seed')
|
||||||
|
args = parser.parse_args()
|
||||||
|
if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
|
||||||
|
main(args)
|
@ -12,7 +12,7 @@ class SearchDataset(data.Dataset):
|
|||||||
self.length = len(self.train_split)
|
self.length = len(self.train_split)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return ('{name}(name={datasetname}, length={length})'.format(name=self.__class__.__name__, **self.__dict__))
|
return ('{name}(name={datasetname}, train={tr_L}, valid={val_L})'.format(name=self.__class__.__name__, tr_L=len(self.train_split), val_L=len(self.valid_split)))
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.length
|
return self.length
|
||||||
|
@ -3,11 +3,36 @@
|
|||||||
##################################################
|
##################################################
|
||||||
import torch
|
import torch
|
||||||
from os import path as osp
|
from os import path as osp
|
||||||
# our modules
|
# useful modules
|
||||||
from config_utils import dict2config
|
from config_utils import dict2config
|
||||||
from .SharedUtils import change_key
|
from .SharedUtils import change_key
|
||||||
from .clone_weights import init_from_model
|
from .clone_weights import init_from_model
|
||||||
|
|
||||||
|
# Cell-based NAS Models
|
||||||
|
def get_cell_based_tiny_net(config):
|
||||||
|
if config.name == 'DARTS-V1':
|
||||||
|
from .cell_searchs import TinyNetworkDartsV1
|
||||||
|
return TinyNetworkDartsV1(config.C, config.N, config.max_nodes, config.num_classes, config.space)
|
||||||
|
elif config.name == 'DARTS-V2':
|
||||||
|
from .cell_searchs import TinyNetworkDartsV2
|
||||||
|
return TinyNetworkDartsV2(config.C, config.N, config.max_nodes, config.num_classes, config.space)
|
||||||
|
elif config.name == 'GDAS':
|
||||||
|
from .cell_searchs import TinyNetworkGDAS
|
||||||
|
return TinyNetworkGDAS(config.C, config.N, config.max_nodes, config.num_classes, config.space)
|
||||||
|
elif config.name == 'SETN':
|
||||||
|
from .cell_searchs import TinyNetworkSETN
|
||||||
|
return TinyNetworkSETN(config.C, config.N, config.max_nodes, config.num_classes, config.space)
|
||||||
|
else:
|
||||||
|
raise ValueError('invalid network name : {:}'.format(config.name))
|
||||||
|
|
||||||
|
# obtain the search space, i.e., a dict mapping the operation name into a python-function for this op
|
||||||
|
def get_search_spaces(xtype, name):
|
||||||
|
if xtype == 'cell':
|
||||||
|
from .cell_operations import SearchSpaceNames
|
||||||
|
return SearchSpaceNames[name]
|
||||||
|
else:
|
||||||
|
raise ValueError('invalid search-space type is {:}'.format(xtype))
|
||||||
|
|
||||||
|
|
||||||
def get_cifar_models(config):
|
def get_cifar_models(config):
|
||||||
from .CifarResNet import CifarResNet
|
from .CifarResNet import CifarResNet
|
||||||
@ -22,9 +47,9 @@ def get_cifar_models(config):
|
|||||||
else:
|
else:
|
||||||
raise ValueError('invalid module type : {:}'.format(config.arch))
|
raise ValueError('invalid module type : {:}'.format(config.arch))
|
||||||
elif super_type.startswith('infer'):
|
elif super_type.startswith('infer'):
|
||||||
from .infers import InferWidthCifarResNet
|
from .shape_infers import InferWidthCifarResNet
|
||||||
from .infers import InferDepthCifarResNet
|
from .shape_infers import InferDepthCifarResNet
|
||||||
from .infers import InferCifarResNet
|
from .shape_infers import InferCifarResNet
|
||||||
assert len(super_type.split('-')) == 2, 'invalid super_type : {:}'.format(super_type)
|
assert len(super_type.split('-')) == 2, 'invalid super_type : {:}'.format(super_type)
|
||||||
infer_mode = super_type.split('-')[1]
|
infer_mode = super_type.split('-')[1]
|
||||||
if infer_mode == 'width':
|
if infer_mode == 'width':
|
||||||
@ -46,8 +71,8 @@ def get_imagenet_models(config):
|
|||||||
assert len(super_type.split('-')) == 2, 'invalid super_type : {:}'.format(super_type)
|
assert len(super_type.split('-')) == 2, 'invalid super_type : {:}'.format(super_type)
|
||||||
infer_mode = super_type.split('-')[1]
|
infer_mode = super_type.split('-')[1]
|
||||||
if infer_mode == 'shape':
|
if infer_mode == 'shape':
|
||||||
from .infers import InferImagenetResNet
|
from .shape_infers import InferImagenetResNet
|
||||||
from .infers import InferMobileNetV2
|
from .shape_infers import InferMobileNetV2
|
||||||
if config.arch == 'resnet':
|
if config.arch == 'resnet':
|
||||||
return InferImagenetResNet(config.block_name, config.layers, config.xblocks, config.xchannels, config.deep_stem, config.class_num, config.zero_init_residual)
|
return InferImagenetResNet(config.block_name, config.layers, config.xblocks, config.xchannels, config.deep_stem, config.class_num, config.zero_init_residual)
|
||||||
elif config.arch == "MobileNetV2":
|
elif config.arch == "MobileNetV2":
|
||||||
@ -72,9 +97,9 @@ def obtain_model(config):
|
|||||||
def obtain_search_model(config):
|
def obtain_search_model(config):
|
||||||
if config.dataset == 'cifar':
|
if config.dataset == 'cifar':
|
||||||
if config.arch == 'resnet':
|
if config.arch == 'resnet':
|
||||||
from .searchs import SearchWidthCifarResNet
|
from .shape_searchs import SearchWidthCifarResNet
|
||||||
from .searchs import SearchDepthCifarResNet
|
from .shape_searchs import SearchDepthCifarResNet
|
||||||
from .searchs import SearchShapeCifarResNet
|
from .shape_searchs import SearchShapeCifarResNet
|
||||||
if config.search_mode == 'width':
|
if config.search_mode == 'width':
|
||||||
return SearchWidthCifarResNet(config.module, config.depth, config.class_num)
|
return SearchWidthCifarResNet(config.module, config.depth, config.class_num)
|
||||||
elif config.search_mode == 'depth':
|
elif config.search_mode == 'depth':
|
||||||
@ -85,7 +110,7 @@ def obtain_search_model(config):
|
|||||||
else:
|
else:
|
||||||
raise ValueError('invalid arch : {:} for dataset [{:}]'.format(config.arch, config.dataset))
|
raise ValueError('invalid arch : {:} for dataset [{:}]'.format(config.arch, config.dataset))
|
||||||
elif config.dataset == 'imagenet':
|
elif config.dataset == 'imagenet':
|
||||||
from .searchs import SearchShapeImagenetResNet
|
from .shape_searchs import SearchShapeImagenetResNet
|
||||||
assert config.search_mode == 'shape', 'invalid search-mode : {:}'.format( config.search_mode )
|
assert config.search_mode == 'shape', 'invalid search-mode : {:}'.format( config.search_mode )
|
||||||
if config.arch == 'resnet':
|
if config.arch == 'resnet':
|
||||||
return SearchShapeImagenetResNet(config.block_name, config.layers, config.deep_stem, config.class_num)
|
return SearchShapeImagenetResNet(config.block_name, config.layers, config.deep_stem, config.class_num)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
__all__ = ['OPS', 'ReLUConvBN', 'SearchSpaceNames']
|
__all__ = ['OPS', 'ReLUConvBN', 'ResNetBasicblock', 'SearchSpaceNames']
|
||||||
|
|
||||||
OPS = {
|
OPS = {
|
||||||
'none' : lambda C_in, C_out, stride: Zero(C_in, C_out, stride),
|
'none' : lambda C_in, C_out, stride: Zero(C_in, C_out, stride),
|
||||||
@ -14,8 +14,60 @@ OPS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3']
|
CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3']
|
||||||
|
AA_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
|
||||||
|
|
||||||
SearchSpaceNames = {'connect-nas' : CONNECT_NAS_BENCHMARK}
|
SearchSpaceNames = {'connect-nas' : CONNECT_NAS_BENCHMARK,
|
||||||
|
'aa-nas' : AA_NAS_BENCHMARK}
|
||||||
|
|
||||||
|
|
||||||
|
class ReLUConvBN(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation):
|
||||||
|
super(ReLUConvBN, self).__init__()
|
||||||
|
self.op = nn.Sequential(
|
||||||
|
nn.ReLU(inplace=False),
|
||||||
|
nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False),
|
||||||
|
nn.BatchNorm2d(C_out)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.op(x)
|
||||||
|
|
||||||
|
|
||||||
|
class ResNetBasicblock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, inplanes, planes, stride):
|
||||||
|
super(ResNetBasicblock, self).__init__()
|
||||||
|
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
|
||||||
|
self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, 1, 1)
|
||||||
|
self.conv_b = ReLUConvBN( planes, planes, 3, 1, 1, 1)
|
||||||
|
if stride == 2:
|
||||||
|
self.downsample = nn.Sequential(
|
||||||
|
nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
|
||||||
|
nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False))
|
||||||
|
elif inplanes != planes:
|
||||||
|
self.downsample = ReLUConvBN(inplanes, planes, 1, 1, 0, 1)
|
||||||
|
else:
|
||||||
|
self.downsample = None
|
||||||
|
self.in_dim = inplanes
|
||||||
|
self.out_dim = planes
|
||||||
|
self.stride = stride
|
||||||
|
self.num_conv = 2
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
string = '{name}(inC={in_dim}, outC={out_dim}, stride={stride})'.format(name=self.__class__.__name__, **self.__dict__)
|
||||||
|
return string
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
|
||||||
|
basicblock = self.conv_a(inputs)
|
||||||
|
basicblock = self.conv_b(basicblock)
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
residual = self.downsample(inputs)
|
||||||
|
else:
|
||||||
|
residual = inputs
|
||||||
|
return residual + basicblock
|
||||||
|
|
||||||
|
|
||||||
class POOLING(nn.Module):
|
class POOLING(nn.Module):
|
||||||
@ -36,20 +88,6 @@ class POOLING(nn.Module):
|
|||||||
return self.op(x)
|
return self.op(x)
|
||||||
|
|
||||||
|
|
||||||
class ReLUConvBN(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation):
|
|
||||||
super(ReLUConvBN, self).__init__()
|
|
||||||
self.op = nn.Sequential(
|
|
||||||
nn.ReLU(inplace=False),
|
|
||||||
nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False),
|
|
||||||
nn.BatchNorm2d(C_out)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.op(x)
|
|
||||||
|
|
||||||
|
|
||||||
class Identity(nn.Module):
|
class Identity(nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
@ -0,0 +1,4 @@
|
|||||||
|
from .search_model_darts_v1 import TinyNetworkDartsV1
|
||||||
|
from .search_model_darts_v2 import TinyNetworkDartsV2
|
||||||
|
from .search_model_gdas import TinyNetworkGDAS
|
||||||
|
from .search_model_setn import TinyNetworkSETN
|
@ -2,7 +2,7 @@ import math, torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from .operations import OPS, ReLUConvBN
|
from ..cell_operations import OPS
|
||||||
|
|
||||||
|
|
||||||
class SearchCell(nn.Module):
|
class SearchCell(nn.Module):
|
||||||
@ -113,84 +113,3 @@ class SearchCell(nn.Module):
|
|||||||
inter_nodes.append( self.edges[node_str][op_index]( nodes[j] ) )
|
inter_nodes.append( self.edges[node_str][op_index]( nodes[j] ) )
|
||||||
nodes.append( sum(inter_nodes) )
|
nodes.append( sum(inter_nodes) )
|
||||||
return nodes[-1]
|
return nodes[-1]
|
||||||
|
|
||||||
|
|
||||||
class InferCell(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, genotype, C_in, C_out, stride):
|
|
||||||
super(InferCell, self).__init__()
|
|
||||||
|
|
||||||
self.layers = nn.ModuleList()
|
|
||||||
self.node_IN = []
|
|
||||||
self.node_IX = []
|
|
||||||
self.genotype = deepcopy(genotype)
|
|
||||||
for i in range(1, len(genotype)):
|
|
||||||
node_info = genotype[i-1]
|
|
||||||
cur_index = []
|
|
||||||
cur_innod = []
|
|
||||||
for (op_name, op_in) in node_info:
|
|
||||||
if op_in == 0:
|
|
||||||
layer = OPS[op_name](C_in , C_out, stride)
|
|
||||||
else:
|
|
||||||
layer = OPS[op_name](C_out, C_out, 1)
|
|
||||||
cur_index.append( len(self.layers) )
|
|
||||||
cur_innod.append( op_in )
|
|
||||||
self.layers.append( layer )
|
|
||||||
self.node_IX.append( cur_index )
|
|
||||||
self.node_IN.append( cur_innod )
|
|
||||||
self.nodes = len(genotype)
|
|
||||||
self.in_dim = C_in
|
|
||||||
self.out_dim = C_out
|
|
||||||
|
|
||||||
def extra_repr(self):
|
|
||||||
string = 'info :: nodes={nodes}, inC={in_dim}, outC={out_dim}'.format(**self.__dict__)
|
|
||||||
laystr = []
|
|
||||||
for i, (node_layers, node_innods) in enumerate(zip(self.node_IX,self.node_IN)):
|
|
||||||
y = ['I{:}-L{:}'.format(_ii, _il) for _il, _ii in zip(node_layers, node_innods)]
|
|
||||||
x = '{:}<-({:})'.format(i+1, ','.join(y))
|
|
||||||
laystr.append( x )
|
|
||||||
return string + ', [{:}]'.format( ' | '.join(laystr) ) + ', {:}'.format(self.genotype.tostr())
|
|
||||||
|
|
||||||
def forward(self, inputs):
|
|
||||||
nodes = [inputs]
|
|
||||||
for i, (node_layers, node_innods) in enumerate(zip(self.node_IX,self.node_IN)):
|
|
||||||
node_feature = sum( self.layers[_il](nodes[_ii]) for _il, _ii in zip(node_layers, node_innods) )
|
|
||||||
nodes.append( node_feature )
|
|
||||||
return nodes[-1]
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ResNetBasicblock(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, inplanes, planes, stride):
|
|
||||||
super(ResNetBasicblock, self).__init__()
|
|
||||||
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
|
|
||||||
self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, 1, 1)
|
|
||||||
self.conv_b = ReLUConvBN( planes, planes, 3, 1, 1, 1)
|
|
||||||
if stride == 2:
|
|
||||||
self.downsample = nn.Sequential(
|
|
||||||
nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
|
|
||||||
nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False))
|
|
||||||
elif inplanes != planes:
|
|
||||||
self.downsample = ReLUConvBN(inplanes, planes, 1, 1, 0, 1)
|
|
||||||
else:
|
|
||||||
self.downsample = None
|
|
||||||
self.in_dim = inplanes
|
|
||||||
self.out_dim = planes
|
|
||||||
self.stride = stride
|
|
||||||
self.num_conv = 2
|
|
||||||
|
|
||||||
def extra_repr(self):
|
|
||||||
string = '{name}(inC={in_dim}, outC={out_dim}, stride={stride})'.format(name=self.__class__.__name__, **self.__dict__)
|
|
||||||
return string
|
|
||||||
|
|
||||||
def forward(self, inputs):
|
|
||||||
|
|
||||||
basicblock = self.conv_a(inputs)
|
|
||||||
basicblock = self.conv_b(basicblock)
|
|
||||||
|
|
||||||
if self.downsample is not None:
|
|
||||||
residual = self.downsample(inputs)
|
|
||||||
else:
|
|
||||||
residual = inputs
|
|
||||||
return residual + basicblock
|
|
||||||
|
158
lib/models/cell_searchs/genotypes.py
Normal file
158
lib/models/cell_searchs/genotypes.py
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def get_combination(space, num):
|
||||||
|
combs = []
|
||||||
|
for i in range(num):
|
||||||
|
if i == 0:
|
||||||
|
for func in space:
|
||||||
|
combs.append( [(func, i)] )
|
||||||
|
else:
|
||||||
|
new_combs = []
|
||||||
|
for string in combs:
|
||||||
|
for func in space:
|
||||||
|
xstring = string + [(func, i)]
|
||||||
|
new_combs.append( xstring )
|
||||||
|
combs = new_combs
|
||||||
|
return combs
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Structure:
|
||||||
|
|
||||||
|
def __init__(self, genotype):
|
||||||
|
assert isinstance(genotype, list) or isinstance(genotype, tuple), 'invalid class of genotype : {:}'.format(type(genotype))
|
||||||
|
self.node_num = len(genotype) + 1
|
||||||
|
self.nodes = []
|
||||||
|
self.node_N = []
|
||||||
|
for idx, node_info in enumerate(genotype):
|
||||||
|
assert isinstance(node_info, list) or isinstance(node_info, tuple), 'invalid class of node_info : {:}'.format(type(node_info))
|
||||||
|
assert len(node_info) >= 1, 'invalid length : {:}'.format(len(node_info))
|
||||||
|
for node_in in node_info:
|
||||||
|
assert isinstance(node_in, list) or isinstance(node_in, tuple), 'invalid class of in-node : {:}'.format(type(node_in))
|
||||||
|
assert len(node_in) == 2 and node_in[1] <= idx, 'invalid in-node : {:}'.format(node_in)
|
||||||
|
self.node_N.append( len(node_info) )
|
||||||
|
self.nodes.append( tuple(deepcopy(node_info)) )
|
||||||
|
|
||||||
|
def tolist(self, remove_str):
|
||||||
|
# convert this class to the list, if remove_str is 'none', then remove the 'none' operation.
|
||||||
|
# note that we re-order the input node in this function
|
||||||
|
# return the-genotype-list and success [if unsuccess, it is not a connectivity]
|
||||||
|
genotypes = []
|
||||||
|
for node_info in self.nodes:
|
||||||
|
node_info = list( node_info )
|
||||||
|
node_info = sorted(node_info, key=lambda x: (x[1], x[0]))
|
||||||
|
node_info = tuple(filter(lambda x: x[0] != remove_str, node_info))
|
||||||
|
if len(node_info) == 0: return None, False
|
||||||
|
genotypes.append( node_info )
|
||||||
|
return genotypes, True
|
||||||
|
|
||||||
|
def node(self, index):
|
||||||
|
assert index > 0 and index <= len(self), 'invalid index={:} < {:}'.format(index, len(self))
|
||||||
|
return self.nodes[index]
|
||||||
|
|
||||||
|
def tostr(self):
|
||||||
|
strings = []
|
||||||
|
for node_info in self.nodes:
|
||||||
|
string = '|'.join([x[0]+'~{:}'.format(x[1]) for x in node_info])
|
||||||
|
string = '|{:}|'.format(string)
|
||||||
|
strings.append( string )
|
||||||
|
return '+'.join(strings)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return ('{name}({node_num} nodes with {node_info})'.format(name=self.__class__.__name__, node_info=self.tostr(), **self.__dict__))
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.nodes) + 1
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
return self.nodes[index]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def str2structure(xstr):
|
||||||
|
assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr))
|
||||||
|
nodestrs = xstr.split('+')
|
||||||
|
genotypes = []
|
||||||
|
for i, node_str in enumerate(nodestrs):
|
||||||
|
inputs = list(filter(lambda x: x != '', node_str.split('|')))
|
||||||
|
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
|
||||||
|
inputs = ( xi.split('~') for xi in inputs )
|
||||||
|
input_infos = tuple( (op, int(IDX)) for (op, IDX) in inputs)
|
||||||
|
genotypes.append( input_infos )
|
||||||
|
return Structure( genotypes )
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def str2fullstructure(xstr, default_name='none'):
|
||||||
|
assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr))
|
||||||
|
nodestrs = xstr.split('+')
|
||||||
|
genotypes = []
|
||||||
|
for i, node_str in enumerate(nodestrs):
|
||||||
|
inputs = list(filter(lambda x: x != '', node_str.split('|')))
|
||||||
|
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
|
||||||
|
inputs = ( xi.split('~') for xi in inputs )
|
||||||
|
input_infos = list( (op, int(IDX)) for (op, IDX) in inputs)
|
||||||
|
all_in_nodes= list(x[1] for x in input_infos)
|
||||||
|
for j in range(i):
|
||||||
|
if j not in all_in_nodes: input_infos.append((default_name, j))
|
||||||
|
node_info = sorted(input_infos, key=lambda x: (x[1], x[0]))
|
||||||
|
genotypes.append( tuple(node_info) )
|
||||||
|
return Structure( genotypes )
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def gen_all(search_space, num, return_ori):
|
||||||
|
assert isinstance(search_space, list) or isinstance(search_space, tuple), 'invalid class of search-space : {:}'.format(type(search_space))
|
||||||
|
assert num >= 2, 'There should be at least two nodes in a neural cell instead of {:}'.format(num)
|
||||||
|
all_archs = get_combination(search_space, 1)
|
||||||
|
for i, arch in enumerate(all_archs):
|
||||||
|
all_archs[i] = [ tuple(arch) ]
|
||||||
|
|
||||||
|
for inode in range(2, num):
|
||||||
|
cur_nodes = get_combination(search_space, inode)
|
||||||
|
new_all_archs = []
|
||||||
|
for previous_arch in all_archs:
|
||||||
|
for cur_node in cur_nodes:
|
||||||
|
new_all_archs.append( previous_arch + [tuple(cur_node)] )
|
||||||
|
all_archs = new_all_archs
|
||||||
|
if return_ori:
|
||||||
|
return all_archs
|
||||||
|
else:
|
||||||
|
return [Structure(x) for x in all_archs]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
ResNet_CODE = Structure(
|
||||||
|
[(('nor_conv_3x3', 0), ), # node-1
|
||||||
|
(('nor_conv_3x3', 1), ), # node-2
|
||||||
|
(('skip_connect', 0), ('skip_connect', 2))] # node-3
|
||||||
|
)
|
||||||
|
|
||||||
|
AllConv3x3_CODE = Structure(
|
||||||
|
[(('nor_conv_3x3', 0), ), # node-1
|
||||||
|
(('nor_conv_3x3', 0), ('nor_conv_3x3', 1)), # node-2
|
||||||
|
(('nor_conv_3x3', 0), ('nor_conv_3x3', 1), ('nor_conv_3x3', 2))] # node-3
|
||||||
|
)
|
||||||
|
|
||||||
|
AllFull_CODE = Structure(
|
||||||
|
[(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0)), # node-1
|
||||||
|
(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0), ('skip_connect', 1), ('nor_conv_1x1', 1), ('nor_conv_3x3', 1), ('avg_pool_3x3', 1)), # node-2
|
||||||
|
(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0), ('skip_connect', 1), ('nor_conv_1x1', 1), ('nor_conv_3x3', 1), ('avg_pool_3x3', 1), ('skip_connect', 2), ('nor_conv_1x1', 2), ('nor_conv_3x3', 2), ('avg_pool_3x3', 2))] # node-3
|
||||||
|
)
|
||||||
|
|
||||||
|
AllConv1x1_CODE = Structure(
|
||||||
|
[(('nor_conv_1x1', 0), ), # node-1
|
||||||
|
(('nor_conv_1x1', 0), ('nor_conv_1x1', 1)), # node-2
|
||||||
|
(('nor_conv_1x1', 0), ('nor_conv_1x1', 1), ('nor_conv_1x1', 2))] # node-3
|
||||||
|
)
|
||||||
|
|
||||||
|
AllIdentity_CODE = Structure(
|
||||||
|
[(('skip_connect', 0), ), # node-1
|
||||||
|
(('skip_connect', 0), ('skip_connect', 1)), # node-2
|
||||||
|
(('skip_connect', 0), ('skip_connect', 1), ('skip_connect', 2))] # node-3
|
||||||
|
)
|
||||||
|
|
||||||
|
architectures = {'resnet' : ResNet_CODE,
|
||||||
|
'all_c3x3': AllConv3x3_CODE,
|
||||||
|
'all_c1x1': AllConv1x1_CODE,
|
||||||
|
'all_idnt': AllIdentity_CODE,
|
||||||
|
'all_full': AllFull_CODE}
|
134
lib/models/cell_searchs/search_cells.py
Normal file
134
lib/models/cell_searchs/search_cells.py
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
import math, random, torch
|
||||||
|
import warnings
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from copy import deepcopy
|
||||||
|
from ..cell_operations import OPS
|
||||||
|
|
||||||
|
|
||||||
|
class SearchCell(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, C_in, C_out, stride, max_nodes, op_names):
|
||||||
|
super(SearchCell, self).__init__()
|
||||||
|
|
||||||
|
self.op_names = deepcopy(op_names)
|
||||||
|
self.edges = nn.ModuleDict()
|
||||||
|
self.max_nodes = max_nodes
|
||||||
|
self.in_dim = C_in
|
||||||
|
self.out_dim = C_out
|
||||||
|
for i in range(1, max_nodes):
|
||||||
|
for j in range(i):
|
||||||
|
node_str = '{:}<-{:}'.format(i, j)
|
||||||
|
if j == 0:
|
||||||
|
xlists = [OPS[op_name](C_in , C_out, stride) for op_name in op_names]
|
||||||
|
else:
|
||||||
|
xlists = [OPS[op_name](C_in , C_out, 1) for op_name in op_names]
|
||||||
|
self.edges[ node_str ] = nn.ModuleList( xlists )
|
||||||
|
self.edge_keys = sorted(list(self.edges.keys()))
|
||||||
|
self.edge2index = {key:i for i, key in enumerate(self.edge_keys)}
|
||||||
|
self.num_edges = len(self.edges)
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
string = 'info :: {max_nodes} nodes, inC={in_dim}, outC={out_dim}'.format(**self.__dict__)
|
||||||
|
return string
|
||||||
|
|
||||||
|
def forward(self, inputs, weightss):
|
||||||
|
nodes = [inputs]
|
||||||
|
for i in range(1, self.max_nodes):
|
||||||
|
inter_nodes = []
|
||||||
|
for j in range(i):
|
||||||
|
node_str = '{:}<-{:}'.format(i, j)
|
||||||
|
weights = weightss[ self.edge2index[node_str] ]
|
||||||
|
inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) )
|
||||||
|
nodes.append( sum(inter_nodes) )
|
||||||
|
return nodes[-1]
|
||||||
|
|
||||||
|
# GDAS
|
||||||
|
def forward_gdas(self, inputs, alphas, _tau):
|
||||||
|
avoid_zero = 0
|
||||||
|
while True:
|
||||||
|
gumbels = -torch.empty_like(alphas).exponential_().log()
|
||||||
|
logits = (alphas.log_softmax(dim=1) + gumbels) / _tau
|
||||||
|
probs = nn.functional.softmax(logits, dim=1)
|
||||||
|
index = probs.max(-1, keepdim=True)[1]
|
||||||
|
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
|
||||||
|
hardwts = one_h - probs.detach() + probs
|
||||||
|
if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()):
|
||||||
|
continue # avoid the numerical error
|
||||||
|
nodes = [inputs]
|
||||||
|
for i in range(1, self.max_nodes):
|
||||||
|
inter_nodes = []
|
||||||
|
for j in range(i):
|
||||||
|
node_str = '{:}<-{:}'.format(i, j)
|
||||||
|
weights = hardwts[ self.edge2index[node_str] ]
|
||||||
|
argmaxs = index[ self.edge2index[node_str] ].item()
|
||||||
|
weigsum = sum( weights[_ie] * edge(nodes[j]) if _ie == argmaxs else weights[_ie] for _ie, edge in enumerate(self.edges[node_str]) )
|
||||||
|
inter_nodes.append( weigsum )
|
||||||
|
nodes.append( sum(inter_nodes) )
|
||||||
|
avoid_zero += 1
|
||||||
|
if nodes[-1].sum().item() == 0:
|
||||||
|
if avoid_zero < 10: continue
|
||||||
|
else:
|
||||||
|
warnings.warn('get zero outputs with avoid_zero={:}'.format(avoid_zero))
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
return nodes[-1]
|
||||||
|
|
||||||
|
# joint
|
||||||
|
def forward_joint(self, inputs, weightss):
|
||||||
|
nodes = [inputs]
|
||||||
|
for i in range(1, self.max_nodes):
|
||||||
|
inter_nodes = []
|
||||||
|
for j in range(i):
|
||||||
|
node_str = '{:}<-{:}'.format(i, j)
|
||||||
|
weights = weightss[ self.edge2index[node_str] ]
|
||||||
|
aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) / weights.numel()
|
||||||
|
inter_nodes.append( aggregation )
|
||||||
|
nodes.append( sum(inter_nodes) )
|
||||||
|
return nodes[-1]
|
||||||
|
|
||||||
|
# uniform random sampling per iteration
|
||||||
|
def forward_urs(self, inputs):
|
||||||
|
nodes = [inputs]
|
||||||
|
for i in range(1, self.max_nodes):
|
||||||
|
while True: # to avoid select zero for all ops
|
||||||
|
sops, has_non_zero = [], False
|
||||||
|
for j in range(i):
|
||||||
|
node_str = '{:}<-{:}'.format(i, j)
|
||||||
|
candidates = self.edges[node_str]
|
||||||
|
select_op = random.choice(candidates)
|
||||||
|
sops.append( select_op )
|
||||||
|
if not hasattr(select_op, 'is_zero') or select_op.is_zero == False: has_non_zero=True
|
||||||
|
if has_non_zero: break
|
||||||
|
inter_nodes = []
|
||||||
|
for j, select_op in enumerate(sops):
|
||||||
|
inter_nodes.append( select_op(nodes[j]) )
|
||||||
|
nodes.append( sum(inter_nodes) )
|
||||||
|
return nodes[-1]
|
||||||
|
|
||||||
|
# select the argmax
|
||||||
|
def forward_select(self, inputs, weightss):
|
||||||
|
nodes = [inputs]
|
||||||
|
for i in range(1, self.max_nodes):
|
||||||
|
inter_nodes = []
|
||||||
|
for j in range(i):
|
||||||
|
node_str = '{:}<-{:}'.format(i, j)
|
||||||
|
weights = weightss[ self.edge2index[node_str] ]
|
||||||
|
inter_nodes.append( self.edges[node_str][ weights.argmax().item() ]( nodes[j] ) )
|
||||||
|
#inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) )
|
||||||
|
nodes.append( sum(inter_nodes) )
|
||||||
|
return nodes[-1]
|
||||||
|
|
||||||
|
# forward with a specific structure
|
||||||
|
def forward_dynamic(self, inputs, structure):
|
||||||
|
nodes = [inputs]
|
||||||
|
for i in range(1, self.max_nodes):
|
||||||
|
cur_op_node = structure.nodes[i-1]
|
||||||
|
inter_nodes = []
|
||||||
|
for op_name, j in cur_op_node:
|
||||||
|
node_str = '{:}<-{:}'.format(i, j)
|
||||||
|
op_index = self.op_names.index( op_name )
|
||||||
|
inter_nodes.append( self.edges[node_str][op_index]( nodes[j] ) )
|
||||||
|
nodes.append( sum(inter_nodes) )
|
||||||
|
return nodes[-1]
|
93
lib/models/cell_searchs/search_model_darts_v1.py
Normal file
93
lib/models/cell_searchs/search_model_darts_v1.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
##################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||||
|
########################################################
|
||||||
|
# DARTS: Differentiable Architecture Search, ICLR 2019 #
|
||||||
|
########################################################
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from copy import deepcopy
|
||||||
|
from ..cell_operations import ResNetBasicblock
|
||||||
|
from .search_cells import SearchCell
|
||||||
|
from .genotypes import Structure
|
||||||
|
|
||||||
|
|
||||||
|
class TinyNetworkDartsV1(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, C, N, max_nodes, num_classes, search_space):
|
||||||
|
super(TinyNetworkDartsV1, self).__init__()
|
||||||
|
self._C = C
|
||||||
|
self._layerN = N
|
||||||
|
self.max_nodes = max_nodes
|
||||||
|
self.stem = nn.Sequential(
|
||||||
|
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
|
||||||
|
nn.BatchNorm2d(C))
|
||||||
|
|
||||||
|
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
|
||||||
|
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||||
|
|
||||||
|
C_prev, num_edge, edge2index = C, None, None
|
||||||
|
self.cells = nn.ModuleList()
|
||||||
|
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
|
||||||
|
if reduction:
|
||||||
|
cell = ResNetBasicblock(C_prev, C_curr, 2)
|
||||||
|
else:
|
||||||
|
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space)
|
||||||
|
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
|
||||||
|
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
|
||||||
|
self.cells.append( cell )
|
||||||
|
C_prev = cell.out_dim
|
||||||
|
self.op_names = deepcopy( search_space )
|
||||||
|
self._Layer = len(self.cells)
|
||||||
|
self.edge2index = edge2index
|
||||||
|
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||||
|
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||||
|
self.classifier = nn.Linear(C_prev, num_classes)
|
||||||
|
self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
|
||||||
|
|
||||||
|
def get_weights(self):
|
||||||
|
xlist = list( self.stem.parameters() ) + list( self.cells.parameters() )
|
||||||
|
xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() )
|
||||||
|
xlist+= list( self.classifier.parameters() )
|
||||||
|
return xlist
|
||||||
|
|
||||||
|
def get_alphas(self):
|
||||||
|
return [self.arch_parameters]
|
||||||
|
|
||||||
|
def get_message(self):
|
||||||
|
string = self.extra_repr()
|
||||||
|
for i, cell in enumerate(self.cells):
|
||||||
|
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
|
||||||
|
return string
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||||
|
|
||||||
|
def genotype(self):
|
||||||
|
genotypes = []
|
||||||
|
for i in range(1, self.max_nodes):
|
||||||
|
xlist = []
|
||||||
|
for j in range(i):
|
||||||
|
node_str = '{:}<-{:}'.format(i, j)
|
||||||
|
with torch.no_grad():
|
||||||
|
weights = self.arch_parameters[ self.edge2index[node_str] ]
|
||||||
|
op_name = self.op_names[ weights.argmax().item() ]
|
||||||
|
xlist.append((op_name, j))
|
||||||
|
genotypes.append( tuple(xlist) )
|
||||||
|
return Structure( genotypes )
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
|
||||||
|
|
||||||
|
feature = self.stem(inputs)
|
||||||
|
for i, cell in enumerate(self.cells):
|
||||||
|
if isinstance(cell, SearchCell):
|
||||||
|
feature = cell(feature, alphas)
|
||||||
|
else:
|
||||||
|
feature = cell(feature)
|
||||||
|
|
||||||
|
out = self.lastact(feature)
|
||||||
|
out = self.global_pooling( out )
|
||||||
|
out = out.view(out.size(0), -1)
|
||||||
|
logits = self.classifier(out)
|
||||||
|
|
||||||
|
return out, logits
|
93
lib/models/cell_searchs/search_model_darts_v2.py
Normal file
93
lib/models/cell_searchs/search_model_darts_v2.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
##################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||||
|
########################################################
|
||||||
|
# DARTS: Differentiable Architecture Search, ICLR 2019 #
|
||||||
|
########################################################
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from copy import deepcopy
|
||||||
|
from ..cell_operations import ResNetBasicblock
|
||||||
|
from .search_cells import SearchCell
|
||||||
|
from .genotypes import Structure
|
||||||
|
|
||||||
|
|
||||||
|
class TinyNetworkDartsV2(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, C, N, max_nodes, num_classes, search_space):
|
||||||
|
super(TinyNetworkDartsV2, self).__init__()
|
||||||
|
self._C = C
|
||||||
|
self._layerN = N
|
||||||
|
self.max_nodes = max_nodes
|
||||||
|
self.stem = nn.Sequential(
|
||||||
|
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
|
||||||
|
nn.BatchNorm2d(C))
|
||||||
|
|
||||||
|
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
|
||||||
|
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||||
|
|
||||||
|
C_prev, num_edge, edge2index = C, None, None
|
||||||
|
self.cells = nn.ModuleList()
|
||||||
|
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
|
||||||
|
if reduction:
|
||||||
|
cell = ResNetBasicblock(C_prev, C_curr, 2)
|
||||||
|
else:
|
||||||
|
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space)
|
||||||
|
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
|
||||||
|
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
|
||||||
|
self.cells.append( cell )
|
||||||
|
C_prev = cell.out_dim
|
||||||
|
self.op_names = deepcopy( search_space )
|
||||||
|
self._Layer = len(self.cells)
|
||||||
|
self.edge2index = edge2index
|
||||||
|
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||||
|
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||||
|
self.classifier = nn.Linear(C_prev, num_classes)
|
||||||
|
self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
|
||||||
|
|
||||||
|
def get_weights(self):
|
||||||
|
xlist = list( self.stem.parameters() ) + list( self.cells.parameters() )
|
||||||
|
xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() )
|
||||||
|
xlist+= list( self.classifier.parameters() )
|
||||||
|
return xlist
|
||||||
|
|
||||||
|
def get_alphas(self):
|
||||||
|
return [self.arch_parameters]
|
||||||
|
|
||||||
|
def get_message(self):
|
||||||
|
string = self.extra_repr()
|
||||||
|
for i, cell in enumerate(self.cells):
|
||||||
|
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
|
||||||
|
return string
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||||
|
|
||||||
|
def genotype(self):
|
||||||
|
genotypes = []
|
||||||
|
for i in range(1, self.max_nodes):
|
||||||
|
xlist = []
|
||||||
|
for j in range(i):
|
||||||
|
node_str = '{:}<-{:}'.format(i, j)
|
||||||
|
with torch.no_grad():
|
||||||
|
weights = self.arch_parameters[ self.edge2index[node_str] ]
|
||||||
|
op_name = self.op_names[ weights.argmax().item() ]
|
||||||
|
xlist.append((op_name, j))
|
||||||
|
genotypes.append( tuple(xlist) )
|
||||||
|
return Structure( genotypes )
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
|
||||||
|
|
||||||
|
feature = self.stem(inputs)
|
||||||
|
for i, cell in enumerate(self.cells):
|
||||||
|
if isinstance(cell, SearchCell):
|
||||||
|
feature = cell(feature, alphas)
|
||||||
|
else:
|
||||||
|
feature = cell(feature)
|
||||||
|
|
||||||
|
out = self.lastact(feature)
|
||||||
|
out = self.global_pooling( out )
|
||||||
|
out = out.view(out.size(0), -1)
|
||||||
|
logits = self.classifier(out)
|
||||||
|
|
||||||
|
return out, logits
|
@ -6,7 +6,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from .infer_cells import ResNetBasicblock
|
from ..cell_operations import ResNetBasicblock
|
||||||
from .search_cells import SearchCell
|
from .search_cells import SearchCell
|
||||||
from .genotypes import Structure
|
from .genotypes import Structure
|
||||||
|
|
||||||
@ -44,7 +44,6 @@ class TinyNetworkGDAS(nn.Module):
|
|||||||
self.classifier = nn.Linear(C_prev, num_classes)
|
self.classifier = nn.Linear(C_prev, num_classes)
|
||||||
self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
|
self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
|
||||||
self.tau = 10
|
self.tau = 10
|
||||||
self.nan_count = 0
|
|
||||||
|
|
||||||
def get_weights(self):
|
def get_weights(self):
|
||||||
xlist = list( self.stem.parameters() ) + list( self.cells.parameters() )
|
xlist = list( self.stem.parameters() ) + list( self.cells.parameters() )
|
||||||
@ -52,9 +51,8 @@ class TinyNetworkGDAS(nn.Module):
|
|||||||
xlist+= list( self.classifier.parameters() )
|
xlist+= list( self.classifier.parameters() )
|
||||||
return xlist
|
return xlist
|
||||||
|
|
||||||
def set_tau(self, tau, _nan_count=0):
|
def set_tau(self, tau):
|
||||||
self.tau = tau
|
self.tau = tau
|
||||||
self.nan_count = _nan_count
|
|
||||||
|
|
||||||
def get_tau(self):
|
def get_tau(self):
|
||||||
return self.tau
|
return self.tau
|
||||||
@ -85,27 +83,10 @@ class TinyNetworkGDAS(nn.Module):
|
|||||||
return Structure( genotypes )
|
return Structure( genotypes )
|
||||||
|
|
||||||
def forward(self, inputs):
|
def forward(self, inputs):
|
||||||
def gumbel_softmax(_logits, _tau):
|
|
||||||
while True: # a trick to avoid the gumbels bug
|
|
||||||
gumbels = -torch.empty_like(_logits).exponential_().log()
|
|
||||||
new_logits = (_logits.log_softmax(dim=1) + gumbels) / _tau
|
|
||||||
probs = nn.functional.softmax(new_logits, dim=1)
|
|
||||||
index = probs.max(-1, keepdim=True)[1]
|
|
||||||
if index[0].item() == self.op_names.index('none') and index[3].item() == self.op_names.index('none') and index[5].item() == self.op_names.index('none'): continue
|
|
||||||
if index[1].item() == self.op_names.index('none') and index[2].item() == self.op_names.index('none') and index[3].item() == self.op_names.index('none') and index[4].item() == self.op_names.index('none'): continue
|
|
||||||
if index[3].item() == self.op_names.index('none') and index[4].item() == self.op_names.index('none') and index[5].item() == self.op_names.index('none'): continue
|
|
||||||
if index[3].item() == self.op_names.index('none') and index[0].item() == self.op_names.index('none') and index[1].item() == self.op_names.index('none'): continue
|
|
||||||
one_h = torch.zeros_like(_logits).scatter_(-1, index, 1.0)
|
|
||||||
xres = one_h - probs.detach() + probs
|
|
||||||
if (not torch.isinf(gumbels).any()) and (not torch.isinf(probs).any()) and (not torch.isnan(probs).any()): break
|
|
||||||
self.nan_count += 1
|
|
||||||
return xres, index
|
|
||||||
|
|
||||||
feature = self.stem(inputs)
|
feature = self.stem(inputs)
|
||||||
for i, cell in enumerate(self.cells):
|
for i, cell in enumerate(self.cells):
|
||||||
if isinstance(cell, SearchCell):
|
if isinstance(cell, SearchCell):
|
||||||
alphas, IDX = gumbel_softmax(self.arch_parameters, self.tau)
|
feature = cell.forward_gdas(feature, self.arch_parameters, self.tau)
|
||||||
feature = cell.forward_gdas(feature, alphas, IDX.cpu())
|
|
||||||
else:
|
else:
|
||||||
feature = cell(feature)
|
feature = cell(feature)
|
||||||
|
|
||||||
|
130
lib/models/cell_searchs/search_model_setn.py
Normal file
130
lib/models/cell_searchs/search_model_setn.py
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
##################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||||
|
######################################################################################
|
||||||
|
# One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 #
|
||||||
|
######################################################################################
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from copy import deepcopy
|
||||||
|
from ..cell_operations import ResNetBasicblock
|
||||||
|
from .search_cells import SearchCell
|
||||||
|
from .genotypes import Structure
|
||||||
|
|
||||||
|
|
||||||
|
class TinyNetworkSETN(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, C, N, max_nodes, num_classes, search_space):
|
||||||
|
super(TinyNetworkSETN, self).__init__()
|
||||||
|
self._C = C
|
||||||
|
self._layerN = N
|
||||||
|
self.max_nodes = max_nodes
|
||||||
|
self.stem = nn.Sequential(
|
||||||
|
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
|
||||||
|
nn.BatchNorm2d(C))
|
||||||
|
|
||||||
|
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
|
||||||
|
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||||
|
|
||||||
|
C_prev, num_edge, edge2index = C, None, None
|
||||||
|
self.cells = nn.ModuleList()
|
||||||
|
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
|
||||||
|
if reduction:
|
||||||
|
cell = ResNetBasicblock(C_prev, C_curr, 2)
|
||||||
|
else:
|
||||||
|
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space)
|
||||||
|
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
|
||||||
|
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
|
||||||
|
self.cells.append( cell )
|
||||||
|
C_prev = cell.out_dim
|
||||||
|
self.op_names = deepcopy( search_space )
|
||||||
|
self._Layer = len(self.cells)
|
||||||
|
self.edge2index = edge2index
|
||||||
|
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||||
|
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||||
|
self.classifier = nn.Linear(C_prev, num_classes)
|
||||||
|
self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
|
||||||
|
self.mode = 'urs'
|
||||||
|
self.dynamic_cell = None
|
||||||
|
|
||||||
|
def set_cal_mode(self, mode, dynamic_cell=None):
|
||||||
|
assert mode in ['urs', 'joint', 'select', 'dynamic']
|
||||||
|
self.mode = mode
|
||||||
|
if mode == 'dynamic': self.dynamic_cell = deepcopy( dynamic_cell )
|
||||||
|
else : self.dynamic_cell = None
|
||||||
|
|
||||||
|
def get_cal_mode(self):
|
||||||
|
return self.mode
|
||||||
|
|
||||||
|
def get_weights(self):
|
||||||
|
xlist = list( self.stem.parameters() ) + list( self.cells.parameters() )
|
||||||
|
xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() )
|
||||||
|
xlist+= list( self.classifier.parameters() )
|
||||||
|
return xlist
|
||||||
|
|
||||||
|
def get_alphas(self):
|
||||||
|
return [self.arch_parameters]
|
||||||
|
|
||||||
|
def get_message(self):
|
||||||
|
string = self.extra_repr()
|
||||||
|
for i, cell in enumerate(self.cells):
|
||||||
|
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
|
||||||
|
return string
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||||
|
|
||||||
|
def genotype(self):
|
||||||
|
genotypes = []
|
||||||
|
for i in range(1, self.max_nodes):
|
||||||
|
xlist = []
|
||||||
|
for j in range(i):
|
||||||
|
node_str = '{:}<-{:}'.format(i, j)
|
||||||
|
with torch.no_grad():
|
||||||
|
weights = self.arch_parameters[ self.edge2index[node_str] ]
|
||||||
|
op_name = self.op_names[ weights.argmax().item() ]
|
||||||
|
xlist.append((op_name, j))
|
||||||
|
genotypes.append( tuple(xlist) )
|
||||||
|
return Structure( genotypes )
|
||||||
|
|
||||||
|
|
||||||
|
def dync_genotype(self):
|
||||||
|
genotypes = []
|
||||||
|
with torch.no_grad():
|
||||||
|
alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1)
|
||||||
|
for i in range(1, self.max_nodes):
|
||||||
|
xlist = []
|
||||||
|
for j in range(i):
|
||||||
|
node_str = '{:}<-{:}'.format(i, j)
|
||||||
|
weights = alphas_cpu[ self.edge2index[node_str] ]
|
||||||
|
op_index = torch.multinomial(weights, 1).item()
|
||||||
|
op_name = self.op_names[ op_index ]
|
||||||
|
xlist.append((op_name, j))
|
||||||
|
genotypes.append( tuple(xlist) )
|
||||||
|
return Structure( genotypes )
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
|
||||||
|
with torch.no_grad():
|
||||||
|
alphas_cpu = alphas.detach().cpu()
|
||||||
|
|
||||||
|
feature = self.stem(inputs)
|
||||||
|
for i, cell in enumerate(self.cells):
|
||||||
|
if isinstance(cell, SearchCell):
|
||||||
|
if self.mode == 'urs':
|
||||||
|
feature = cell.forward_urs(feature)
|
||||||
|
elif self.mode == 'select':
|
||||||
|
feature = cell.forward_select(feature, alphas_cpu)
|
||||||
|
elif self.mode == 'joint':
|
||||||
|
feature = cell.forward_joint(feature, alphas)
|
||||||
|
elif self.mode == 'dynamic':
|
||||||
|
feature = cell.forward_dynamic(feature, self.dynamic_cell)
|
||||||
|
else: raise ValueError('invalid mode={:}'.format(self.mode))
|
||||||
|
else: feature = cell(feature)
|
||||||
|
|
||||||
|
out = self.lastact(feature)
|
||||||
|
out = self.global_pooling( out )
|
||||||
|
out = out.view(out.size(0), -1)
|
||||||
|
logits = self.classifier(out)
|
||||||
|
|
||||||
|
return out, logits
|
36
scripts-search/algos/DARTS-V1.sh
Normal file
36
scripts-search/algos/DARTS-V1.sh
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# bash ./scripts-search/algos/DARTS-V1.sh cifar10 -1
|
||||||
|
echo script name: $0
|
||||||
|
echo $# arguments
|
||||||
|
if [ "$#" -ne 2 ] ;then
|
||||||
|
echo "Input illegal number of parameters " $#
|
||||||
|
echo "Need 2 parameters for dataset and seed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
if [ "$TORCH_HOME" = "" ]; then
|
||||||
|
echo "Must set TORCH_HOME envoriment variable for data dir saving"
|
||||||
|
exit 1
|
||||||
|
else
|
||||||
|
echo "TORCH_HOME : $TORCH_HOME"
|
||||||
|
fi
|
||||||
|
|
||||||
|
dataset=$1
|
||||||
|
seed=$2
|
||||||
|
channel=16
|
||||||
|
num_cells=5
|
||||||
|
max_nodes=4
|
||||||
|
|
||||||
|
if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
|
||||||
|
data_path="$TORCH_HOME/cifar.python"
|
||||||
|
else
|
||||||
|
data_path="$TORCH_HOME/cifar.python/ImageNet16"
|
||||||
|
fi
|
||||||
|
|
||||||
|
save_dir=./output/cell-search-tiny/DARTS-V1-${dataset}
|
||||||
|
|
||||||
|
OMP_NUM_THREADS=4 python ./exps/algos/DARTS-V1.py \
|
||||||
|
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
||||||
|
--dataset ${dataset} --data_path ${data_path} \
|
||||||
|
--search_space_name aa-nas \
|
||||||
|
--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \
|
||||||
|
--workers 4 --print_freq 200 --rand_seed ${seed}
|
36
scripts-search/algos/DARTS-V2.sh
Normal file
36
scripts-search/algos/DARTS-V2.sh
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# bash ./scripts-search/algos/DARTS-V2.sh cifar10 -1
|
||||||
|
echo script name: $0
|
||||||
|
echo $# arguments
|
||||||
|
if [ "$#" -ne 2 ] ;then
|
||||||
|
echo "Input illegal number of parameters " $#
|
||||||
|
echo "Need 2 parameters for dataset and seed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
if [ "$TORCH_HOME" = "" ]; then
|
||||||
|
echo "Must set TORCH_HOME envoriment variable for data dir saving"
|
||||||
|
exit 1
|
||||||
|
else
|
||||||
|
echo "TORCH_HOME : $TORCH_HOME"
|
||||||
|
fi
|
||||||
|
|
||||||
|
dataset=$1
|
||||||
|
seed=$2
|
||||||
|
channel=16
|
||||||
|
num_cells=5
|
||||||
|
max_nodes=4
|
||||||
|
|
||||||
|
if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
|
||||||
|
data_path="$TORCH_HOME/cifar.python"
|
||||||
|
else
|
||||||
|
data_path="$TORCH_HOME/cifar.python/ImageNet16"
|
||||||
|
fi
|
||||||
|
|
||||||
|
save_dir=./output/cell-search-tiny/DARTS-V2-${dataset}
|
||||||
|
|
||||||
|
OMP_NUM_THREADS=4 python ./exps/algos/DARTS-V2.py \
|
||||||
|
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
||||||
|
--dataset ${dataset} --data_path ${data_path} \
|
||||||
|
--search_space_name aa-nas \
|
||||||
|
--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \
|
||||||
|
--workers 4 --print_freq 200 --rand_seed ${seed}
|
37
scripts-search/algos/GDAS.sh
Normal file
37
scripts-search/algos/GDAS.sh
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# bash ./scripts-search/algos/GDAS.sh cifar10 -1
|
||||||
|
echo script name: $0
|
||||||
|
echo $# arguments
|
||||||
|
if [ "$#" -ne 2 ] ;then
|
||||||
|
echo "Input illegal number of parameters " $#
|
||||||
|
echo "Need 2 parameters for dataset and seed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
if [ "$TORCH_HOME" = "" ]; then
|
||||||
|
echo "Must set TORCH_HOME envoriment variable for data dir saving"
|
||||||
|
exit 1
|
||||||
|
else
|
||||||
|
echo "TORCH_HOME : $TORCH_HOME"
|
||||||
|
fi
|
||||||
|
|
||||||
|
dataset=$1
|
||||||
|
seed=$2
|
||||||
|
channel=16
|
||||||
|
num_cells=5
|
||||||
|
max_nodes=4
|
||||||
|
|
||||||
|
if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
|
||||||
|
data_path="$TORCH_HOME/cifar.python"
|
||||||
|
else
|
||||||
|
data_path="$TORCH_HOME/cifar.python/ImageNet16"
|
||||||
|
fi
|
||||||
|
|
||||||
|
save_dir=./output/cell-search-tiny/GDAS-${dataset}
|
||||||
|
|
||||||
|
OMP_NUM_THREADS=4 python ./exps/algos/GDAS.py \
|
||||||
|
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
||||||
|
--dataset ${dataset} --data_path ${data_path} \
|
||||||
|
--search_space_name aa-nas \
|
||||||
|
--tau_max 10 --tau_min 0.1 \
|
||||||
|
--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \
|
||||||
|
--workers 4 --print_freq 200 --rand_seed ${seed}
|
38
scripts-search/algos/SETN.sh
Normal file
38
scripts-search/algos/SETN.sh
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019
|
||||||
|
# bash ./scripts-search/scripts/algos/SETN.sh cifar10 -1
|
||||||
|
echo script name: $0
|
||||||
|
echo $# arguments
|
||||||
|
if [ "$#" -ne 2 ] ;then
|
||||||
|
echo "Input illegal number of parameters " $#
|
||||||
|
echo "Need 2 parameters for dataset and seed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
if [ "$TORCH_HOME" = "" ]; then
|
||||||
|
echo "Must set TORCH_HOME envoriment variable for data dir saving"
|
||||||
|
exit 1
|
||||||
|
else
|
||||||
|
echo "TORCH_HOME : $TORCH_HOME"
|
||||||
|
fi
|
||||||
|
|
||||||
|
dataset=$1
|
||||||
|
seed=$2
|
||||||
|
channel=16
|
||||||
|
num_cells=5
|
||||||
|
max_nodes=4
|
||||||
|
|
||||||
|
if [ "$dataset" == "cifar10" ] || [ "$dataset" == "cifar100" ]; then
|
||||||
|
data_path="$TORCH_HOME/cifar.python"
|
||||||
|
else
|
||||||
|
data_path="$TORCH_HOME/cifar.python/ImageNet16"
|
||||||
|
fi
|
||||||
|
|
||||||
|
save_dir=./output/cell-search-tiny/SETN-${dataset}
|
||||||
|
|
||||||
|
OMP_NUM_THREADS=4 python ./exps/algos/SETN.py \
|
||||||
|
--save_dir ${save_dir} --max_nodes ${max_nodes} --channel ${channel} --num_cells ${num_cells} \
|
||||||
|
--dataset ${dataset} --data_path ${data_path} \
|
||||||
|
--search_space_name aa-nas \
|
||||||
|
--arch_learning_rate 0.0003 --arch_weight_decay 0.001 \
|
||||||
|
--select_num 100 \
|
||||||
|
--workers 4 --print_freq 200 --rand_seed ${seed}
|
Loading…
Reference in New Issue
Block a user