Update q-config and black for procedures/utils

This commit is contained in:
D-X-Y 2021-03-07 03:09:47 +00:00
parent 349d9fcc9f
commit 55c9734c31
22 changed files with 1938 additions and 1390 deletions

View File

@ -147,5 +147,8 @@ If you find that this project helps your research, please consider citing the re
If you want to contribute to this repo, please see [CONTRIBUTING.md](.github/CONTRIBUTING.md). If you want to contribute to this repo, please see [CONTRIBUTING.md](.github/CONTRIBUTING.md).
Besides, please follow [CODE-OF-CONDUCT.md](.github/CODE-OF-CONDUCT.md). Besides, please follow [CODE-OF-CONDUCT.md](.github/CODE-OF-CONDUCT.md).
We use `[black](https://github.com/psf/black)` for Python code formatter.
Please use `black . -l 120`.
# License # License
The entire codebase is under the [MIT license](LICENSE.md). The entire codebase is under the [MIT license](LICENSE.md).

View File

@ -0,0 +1,82 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market all
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: DNNModelPytorch
module_path: qlib.contrib.model.pytorch_nn
kwargs:
loss: mse
input_dim: 360
output_dim: 1
lr: 0.002
lr_decay: 0.96
lr_decay_steps: 100
optimizer: adam
max_steps: 8000
batch_size: 4096
GPU: 0
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha360
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@ -0,0 +1,85 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market all
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: SFM
module_path: qlib.contrib.model.pytorch_sfm
kwargs:
d_feat: 6
hidden_size: 64
output_dim: 32
freq_dim: 25
dropout_W: 0.5
dropout_U: 0.5
n_epochs: 20
lr: 1e-3
batch_size: 1600
early_stop: 20
eval_steps: 5
loss: mse
optimizer: adam
GPU: 0
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha360
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@ -4,6 +4,8 @@
# python exps/trading/baselines.py --alg GRU # python exps/trading/baselines.py --alg GRU
# python exps/trading/baselines.py --alg LSTM # python exps/trading/baselines.py --alg LSTM
# python exps/trading/baselines.py --alg ALSTM # python exps/trading/baselines.py --alg ALSTM
# python exps/trading/baselines.py --alg MLP
# python exps/trading/baselines.py --alg SFM
# python exps/trading/baselines.py --alg XGBoost # python exps/trading/baselines.py --alg XGBoost
# python exps/trading/baselines.py --alg LightGBM # python exps/trading/baselines.py --alg LightGBM
##################################################### #####################################################
@ -17,6 +19,10 @@ lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path: if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir)) sys.path.insert(0, str(lib_dir))
from procedures.q_exps import update_gpu
from procedures.q_exps import update_market
from procedures.q_exps import run_exp
import qlib import qlib
from qlib.utils import init_instance_by_config from qlib.utils import init_instance_by_config
from qlib.workflow import R from qlib.workflow import R
@ -31,15 +37,19 @@ def retrieve_configs():
alg2names = OrderedDict() alg2names = OrderedDict()
alg2names["GRU"] = "workflow_config_gru_Alpha360.yaml" alg2names["GRU"] = "workflow_config_gru_Alpha360.yaml"
alg2names["LSTM"] = "workflow_config_lstm_Alpha360.yaml" alg2names["LSTM"] = "workflow_config_lstm_Alpha360.yaml"
alg2names["MLP"] = "workflow_config_mlp_Alpha360.yaml"
# A dual-stage attention-based recurrent neural network for time series prediction, IJCAI-2017 # A dual-stage attention-based recurrent neural network for time series prediction, IJCAI-2017
alg2names["ALSTM"] = "workflow_config_alstm_Alpha360.yaml" alg2names["ALSTM"] = "workflow_config_alstm_Alpha360.yaml"
# XGBoost: A Scalable Tree Boosting System, KDD-2016 # XGBoost: A Scalable Tree Boosting System, KDD-2016
alg2names["XGBoost"] = "workflow_config_xgboost_Alpha360.yaml" alg2names["XGBoost"] = "workflow_config_xgboost_Alpha360.yaml"
# LightGBM: A Highly Efficient Gradient Boosting Decision Tree, NeurIPS-2017 # LightGBM: A Highly Efficient Gradient Boosting Decision Tree, NeurIPS-2017
alg2names["LightGBM"] = "workflow_config_lightgbm_Alpha360.yaml" alg2names["LightGBM"] = "workflow_config_lightgbm_Alpha360.yaml"
# State Frequency Memory (SFM): Stock Price Prediction via Discovering Multi-Frequency Trading Patterns, KDD-2017
alg2names["SFM"] = "workflow_config_sfm_Alpha360.yaml"
# find the yaml paths # find the yaml paths
alg2paths = OrderedDict() alg2paths = OrderedDict()
print("Start retrieving the algorithm configurations")
for idx, (alg, name) in enumerate(alg2names.items()): for idx, (alg, name) in enumerate(alg2names.items()):
path = config_dir / name path = config_dir / name
assert path.exists(), "{:} does not exist.".format(path) assert path.exists(), "{:} does not exist.".format(path)
@ -48,56 +58,6 @@ def retrieve_configs():
return alg2paths return alg2paths
def update_gpu(config, gpu):
config = config.copy()
if "GPU" in config["task"]["model"]:
config["task"]["model"]["GPU"] = gpu
return config
def update_market(config, market):
config = config.copy()
config["market"] = market
config["data_handler_config"]["instruments"] = market
return config
def run_exp(task_config, dataset, experiment_name, recorder_name, uri):
# model initiaiton
print("")
print("[{:}] - [{:}]: {:}".format(experiment_name, recorder_name, uri))
print("dataset={:}".format(dataset))
model = init_instance_by_config(task_config["model"])
# start exp
with R.start(experiment_name=experiment_name, recorder_name=recorder_name, uri=uri):
log_file = R.get_recorder().root_uri / "{:}.log".format(experiment_name)
set_log_basic_config(log_file)
# train model
R.log_params(**flatten_dict(task_config))
model.fit(dataset)
recorder = R.get_recorder()
R.save_objects(**{"model.pkl": model})
# generate records: prediction, backtest, and analysis
for record in task_config["record"]:
record = record.copy()
if record["class"] == "SignalRecord":
srconf = {"model": model, "dataset": dataset, "recorder": recorder}
record["kwargs"].update(srconf)
sr = init_instance_by_config(record)
sr.generate()
else:
rconf = {"recorder": recorder}
record["kwargs"].update(rconf)
ar = init_instance_by_config(record)
ar.generate()
def main(xargs, exp_yaml): def main(xargs, exp_yaml):
assert Path(exp_yaml).exists(), "{:} does not exist.".format(exp_yaml) assert Path(exp_yaml).exists(), "{:} does not exist.".format(exp_yaml)

View File

@ -1,25 +1,36 @@
################################################## ##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
################################################## ##################################################
from .starts import prepare_seed, prepare_logger, get_machine_info, save_checkpoint, copy_checkpoint from .starts import prepare_seed
from .starts import prepare_logger
from .starts import get_machine_info
from .starts import save_checkpoint
from .starts import copy_checkpoint
from .optimizers import get_optim_scheduler from .optimizers import get_optim_scheduler
from .funcs_nasbench import evaluate_for_seed as bench_evaluate_for_seed from .funcs_nasbench import evaluate_for_seed as bench_evaluate_for_seed
from .funcs_nasbench import pure_evaluate as bench_pure_evaluate from .funcs_nasbench import pure_evaluate as bench_pure_evaluate
from .funcs_nasbench import get_nas_bench_loaders from .funcs_nasbench import get_nas_bench_loaders
def get_procedures(procedure):
from .basic_main import basic_train, basic_valid
from .search_main import search_train, search_valid
from .search_main_v2 import search_train_v2
from .simple_KD_main import simple_KD_train, simple_KD_valid
train_funcs = {'basic' : basic_train, \ def get_procedures(procedure):
'search': search_train,'Simple-KD': simple_KD_train, \ from .basic_main import basic_train, basic_valid
'search-v2': search_train_v2} from .search_main import search_train, search_valid
valid_funcs = {'basic' : basic_valid, \ from .search_main_v2 import search_train_v2
'search': search_valid,'Simple-KD': simple_KD_valid, \ from .simple_KD_main import simple_KD_train, simple_KD_valid
'search-v2': search_valid}
train_funcs = {
train_func = train_funcs[procedure] "basic": basic_train,
valid_func = valid_funcs[procedure] "search": search_train,
return train_func, valid_func "Simple-KD": simple_KD_train,
"search-v2": search_train_v2,
}
valid_funcs = {
"basic": basic_valid,
"search": search_valid,
"Simple-KD": simple_KD_valid,
"search-v2": search_valid,
}
train_func = train_funcs[procedure]
valid_func = valid_funcs[procedure]
return train_func, valid_func

View File

@ -3,73 +3,100 @@
################################################## ##################################################
import os, sys, time, torch import os, sys, time, torch
from log_utils import AverageMeter, time_string from log_utils import AverageMeter, time_string
from utils import obtain_accuracy from utils import obtain_accuracy
def basic_train(xloader, network, criterion, scheduler, optimizer, optim_config, extra_info, print_freq, logger): def basic_train(xloader, network, criterion, scheduler, optimizer, optim_config, extra_info, print_freq, logger):
loss, acc1, acc5 = procedure(xloader, network, criterion, scheduler, optimizer, 'train', optim_config, extra_info, print_freq, logger) loss, acc1, acc5 = procedure(
return loss, acc1, acc5 xloader, network, criterion, scheduler, optimizer, "train", optim_config, extra_info, print_freq, logger
)
return loss, acc1, acc5
def basic_valid(xloader, network, criterion, optim_config, extra_info, print_freq, logger): def basic_valid(xloader, network, criterion, optim_config, extra_info, print_freq, logger):
with torch.no_grad(): with torch.no_grad():
loss, acc1, acc5 = procedure(xloader, network, criterion, None, None, 'valid', None, extra_info, print_freq, logger) loss, acc1, acc5 = procedure(
return loss, acc1, acc5 xloader, network, criterion, None, None, "valid", None, extra_info, print_freq, logger
)
return loss, acc1, acc5
def procedure(xloader, network, criterion, scheduler, optimizer, mode, config, extra_info, print_freq, logger): def procedure(xloader, network, criterion, scheduler, optimizer, mode, config, extra_info, print_freq, logger):
data_time, batch_time, losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() data_time, batch_time, losses, top1, top5 = (
if mode == 'train': AverageMeter(),
network.train() AverageMeter(),
elif mode == 'valid': AverageMeter(),
network.eval() AverageMeter(),
else: raise ValueError("The mode is not right : {:}".format(mode)) AverageMeter(),
)
#logger.log('[{:5s}] config :: auxiliary={:}, message={:}'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1, network.module.get_message())) if mode == "train":
logger.log('[{:5s}] config :: auxiliary={:}'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1)) network.train()
end = time.time() elif mode == "valid":
for i, (inputs, targets) in enumerate(xloader): network.eval()
if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader))
# measure data loading time
data_time.update(time.time() - end)
# calculate prediction and loss
targets = targets.cuda(non_blocking=True)
if mode == 'train': optimizer.zero_grad()
features, logits = network(inputs)
if isinstance(logits, list):
assert len(logits) == 2, 'logits must has {:} items instead of {:}'.format(2, len(logits))
logits, logits_aux = logits
else: else:
logits, logits_aux = logits, None raise ValueError("The mode is not right : {:}".format(mode))
loss = criterion(logits, targets)
if config is not None and hasattr(config, 'auxiliary') and config.auxiliary > 0:
loss_aux = criterion(logits_aux, targets)
loss += config.auxiliary * loss_aux
if mode == 'train':
loss.backward()
optimizer.step()
# record # logger.log('[{:5s}] config :: auxiliary={:}, message={:}'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1, network.module.get_message()))
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) logger.log(
losses.update(loss.item(), inputs.size(0)) "[{:5s}] config :: auxiliary={:}".format(mode, config.auxiliary if hasattr(config, "auxiliary") else -1)
top1.update (prec1.item(), inputs.size(0)) )
top5.update (prec5.item(), inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time() end = time.time()
for i, (inputs, targets) in enumerate(xloader):
if mode == "train":
scheduler.update(None, 1.0 * i / len(xloader))
# measure data loading time
data_time.update(time.time() - end)
# calculate prediction and loss
targets = targets.cuda(non_blocking=True)
if i % print_freq == 0 or (i+1) == len(xloader): if mode == "train":
Sstr = ' {:5s} '.format(mode.upper()) + time_string() + ' [{:}][{:03d}/{:03d}]'.format(extra_info, i, len(xloader)) optimizer.zero_grad()
if scheduler is not None:
Sstr += ' {:}'.format(scheduler.get_min_info())
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=losses, top1=top1, top5=top5)
Istr = 'Size={:}'.format(list(inputs.size()))
logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Istr)
logger.log(' **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}'.format(mode=mode.upper(), top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, loss=losses.avg)) features, logits = network(inputs)
return losses.avg, top1.avg, top5.avg if isinstance(logits, list):
assert len(logits) == 2, "logits must has {:} items instead of {:}".format(2, len(logits))
logits, logits_aux = logits
else:
logits, logits_aux = logits, None
loss = criterion(logits, targets)
if config is not None and hasattr(config, "auxiliary") and config.auxiliary > 0:
loss_aux = criterion(logits_aux, targets)
loss += config.auxiliary * loss_aux
if mode == "train":
loss.backward()
optimizer.step()
# record
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1.update(prec1.item(), inputs.size(0))
top5.update(prec5.item(), inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % print_freq == 0 or (i + 1) == len(xloader):
Sstr = (
" {:5s} ".format(mode.upper())
+ time_string()
+ " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader))
)
if scheduler is not None:
Sstr += " {:}".format(scheduler.get_min_info())
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
batch_time=batch_time, data_time=data_time
)
Lstr = "Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format(
loss=losses, top1=top1, top5=top5
)
Istr = "Size={:}".format(list(inputs.size()))
logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Istr)
logger.log(
" **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format(
mode=mode.upper(), top1=top1, top5=top5, error1=100 - top1.avg, error5=100 - top5.avg, loss=losses.avg
)
)
return losses.avg, top1.avg, top5.avg

View File

@ -5,199 +5,348 @@ import os, time, copy, torch, pathlib
import datasets import datasets
from config_utils import load_config from config_utils import load_config
from procedures import prepare_seed, get_optim_scheduler from procedures import prepare_seed, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time from log_utils import AverageMeter, time_string, convert_secs2time
from models import get_cell_based_tiny_net from models import get_cell_based_tiny_net
__all__ = ['evaluate_for_seed', 'pure_evaluate', 'get_nas_bench_loaders'] __all__ = ["evaluate_for_seed", "pure_evaluate", "get_nas_bench_loaders"]
def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()): def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()):
data_time, batch_time, batch = AverageMeter(), AverageMeter(), None data_time, batch_time, batch = AverageMeter(), AverageMeter(), None
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
latencies, device = [], torch.cuda.current_device() latencies, device = [], torch.cuda.current_device()
network.eval() network.eval()
with torch.no_grad(): with torch.no_grad():
end = time.time() end = time.time()
for i, (inputs, targets) in enumerate(xloader): for i, (inputs, targets) in enumerate(xloader):
targets = targets.cuda(device=device, non_blocking=True) targets = targets.cuda(device=device, non_blocking=True)
inputs = inputs.cuda(device=device, non_blocking=True) inputs = inputs.cuda(device=device, non_blocking=True)
data_time.update(time.time() - end) data_time.update(time.time() - end)
# forward # forward
features, logits = network(inputs) features, logits = network(inputs)
loss = criterion(logits, targets) loss = criterion(logits, targets)
batch_time.update(time.time() - end) batch_time.update(time.time() - end)
if batch is None or batch == inputs.size(0): if batch is None or batch == inputs.size(0):
batch = inputs.size(0) batch = inputs.size(0)
latencies.append( batch_time.val - data_time.val ) latencies.append(batch_time.val - data_time.val)
# record loss and accuracy # record loss and accuracy
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
losses.update(loss.item(), inputs.size(0)) losses.update(loss.item(), inputs.size(0))
top1.update (prec1.item(), inputs.size(0)) top1.update(prec1.item(), inputs.size(0))
top5.update (prec5.item(), inputs.size(0)) top5.update(prec5.item(), inputs.size(0))
end = time.time() end = time.time()
if len(latencies) > 2: latencies = latencies[1:] if len(latencies) > 2:
return losses.avg, top1.avg, top5.avg, latencies latencies = latencies[1:]
return losses.avg, top1.avg, top5.avg, latencies
def procedure(xloader, network, criterion, scheduler, optimizer, mode: str): def procedure(xloader, network, criterion, scheduler, optimizer, mode: str):
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter() losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
if mode == 'train' : network.train() if mode == "train":
elif mode == 'valid': network.eval() network.train()
else: raise ValueError("The mode is not right : {:}".format(mode)) elif mode == "valid":
device = torch.cuda.current_device() network.eval()
data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time() else:
for i, (inputs, targets) in enumerate(xloader): raise ValueError("The mode is not right : {:}".format(mode))
if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader)) device = torch.cuda.current_device()
data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time()
for i, (inputs, targets) in enumerate(xloader):
if mode == "train":
scheduler.update(None, 1.0 * i / len(xloader))
targets = targets.cuda(device=device, non_blocking=True) targets = targets.cuda(device=device, non_blocking=True)
if mode == 'train': optimizer.zero_grad() if mode == "train":
# forward optimizer.zero_grad()
features, logits = network(inputs) # forward
loss = criterion(logits, targets) features, logits = network(inputs)
# backward loss = criterion(logits, targets)
if mode == 'train': # backward
loss.backward() if mode == "train":
optimizer.step() loss.backward()
# record loss and accuracy optimizer.step()
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) # record loss and accuracy
losses.update(loss.item(), inputs.size(0)) prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
top1.update (prec1.item(), inputs.size(0)) losses.update(loss.item(), inputs.size(0))
top5.update (prec5.item(), inputs.size(0)) top1.update(prec1.item(), inputs.size(0))
# count time top5.update(prec5.item(), inputs.size(0))
batch_time.update(time.time() - end) # count time
end = time.time() batch_time.update(time.time() - end)
return losses.avg, top1.avg, top5.avg, batch_time.sum end = time.time()
return losses.avg, top1.avg, top5.avg, batch_time.sum
def evaluate_for_seed(arch_config, opt_config, train_loader, valid_loaders, seed: int, logger): def evaluate_for_seed(arch_config, opt_config, train_loader, valid_loaders, seed: int, logger):
prepare_seed(seed) # random seed prepare_seed(seed) # random seed
net = get_cell_based_tiny_net(arch_config) net = get_cell_based_tiny_net(arch_config)
#net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num) # net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num)
flop, param = get_model_infos(net, opt_config.xshape) flop, param = get_model_infos(net, opt_config.xshape)
logger.log('Network : {:}'.format(net.get_message()), False) logger.log("Network : {:}".format(net.get_message()), False)
logger.log('{:} Seed-------------------------- {:} --------------------------'.format(time_string(), seed)) logger.log("{:} Seed-------------------------- {:} --------------------------".format(time_string(), seed))
logger.log('FLOP = {:} MB, Param = {:} MB'.format(flop, param)) logger.log("FLOP = {:} MB, Param = {:} MB".format(flop, param))
# train and valid # train and valid
optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), opt_config) optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), opt_config)
default_device = torch.cuda.current_device() default_device = torch.cuda.current_device()
network = torch.nn.DataParallel(net, device_ids=[default_device]).cuda(device=default_device) network = torch.nn.DataParallel(net, device_ids=[default_device]).cuda(device=default_device)
criterion = criterion.cuda(device=default_device) criterion = criterion.cuda(device=default_device)
# start training # start training
start_time, epoch_time, total_epoch = time.time(), AverageMeter(), opt_config.epochs + opt_config.warmup start_time, epoch_time, total_epoch = time.time(), AverageMeter(), opt_config.epochs + opt_config.warmup
train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es = {}, {}, {}, {}, {}, {} train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es = {}, {}, {}, {}, {}, {}
train_times , valid_times, lrs = {}, {}, {} train_times, valid_times, lrs = {}, {}, {}
for epoch in range(total_epoch): for epoch in range(total_epoch):
scheduler.update(epoch, 0.0) scheduler.update(epoch, 0.0)
lr = min(scheduler.get_lr()) lr = min(scheduler.get_lr())
train_loss, train_acc1, train_acc5, train_tm = procedure(train_loader, network, criterion, scheduler, optimizer, 'train') train_loss, train_acc1, train_acc5, train_tm = procedure(
train_losses[epoch] = train_loss train_loader, network, criterion, scheduler, optimizer, "train"
train_acc1es[epoch] = train_acc1 )
train_acc5es[epoch] = train_acc5 train_losses[epoch] = train_loss
train_times [epoch] = train_tm train_acc1es[epoch] = train_acc1
lrs[epoch] = lr train_acc5es[epoch] = train_acc5
with torch.no_grad(): train_times[epoch] = train_tm
for key, xloder in valid_loaders.items(): lrs[epoch] = lr
valid_loss, valid_acc1, valid_acc5, valid_tm = procedure(xloder , network, criterion, None, None, 'valid') with torch.no_grad():
valid_losses['{:}@{:}'.format(key,epoch)] = valid_loss for key, xloder in valid_loaders.items():
valid_acc1es['{:}@{:}'.format(key,epoch)] = valid_acc1 valid_loss, valid_acc1, valid_acc5, valid_tm = procedure(
valid_acc5es['{:}@{:}'.format(key,epoch)] = valid_acc5 xloder, network, criterion, None, None, "valid"
valid_times ['{:}@{:}'.format(key,epoch)] = valid_tm )
valid_losses["{:}@{:}".format(key, epoch)] = valid_loss
valid_acc1es["{:}@{:}".format(key, epoch)] = valid_acc1
valid_acc5es["{:}@{:}".format(key, epoch)] = valid_acc5
valid_times["{:}@{:}".format(key, epoch)] = valid_tm
# measure elapsed time # measure elapsed time
epoch_time.update(time.time() - start_time) epoch_time.update(time.time() - start_time)
start_time = time.time() start_time = time.time()
need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.avg * (total_epoch-epoch-1), True) ) need_time = "Time Left: {:}".format(convert_secs2time(epoch_time.avg * (total_epoch - epoch - 1), True))
logger.log('{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%], lr={:}'.format(time_string(), need_time, epoch, total_epoch, train_loss, train_acc1, train_acc5, valid_loss, valid_acc1, valid_acc5, lr)) logger.log(
info_seed = {'flop' : flop, "{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%], lr={:}".format(
'param': param, time_string(),
'arch_config' : arch_config._asdict(), need_time,
'opt_config' : opt_config._asdict(), epoch,
'total_epoch' : total_epoch , total_epoch,
'train_losses': train_losses, train_loss,
'train_acc1es': train_acc1es, train_acc1,
'train_acc5es': train_acc5es, train_acc5,
'train_times' : train_times, valid_loss,
'valid_losses': valid_losses, valid_acc1,
'valid_acc1es': valid_acc1es, valid_acc5,
'valid_acc5es': valid_acc5es, lr,
'valid_times' : valid_times, )
'learning_rates': lrs, )
'net_state_dict': net.state_dict(), info_seed = {
'net_string' : '{:}'.format(net), "flop": flop,
'finish-train': True "param": param,
} "arch_config": arch_config._asdict(),
return info_seed "opt_config": opt_config._asdict(),
"total_epoch": total_epoch,
"train_losses": train_losses,
"train_acc1es": train_acc1es,
"train_acc5es": train_acc5es,
"train_times": train_times,
"valid_losses": valid_losses,
"valid_acc1es": valid_acc1es,
"valid_acc5es": valid_acc5es,
"valid_times": valid_times,
"learning_rates": lrs,
"net_state_dict": net.state_dict(),
"net_string": "{:}".format(net),
"finish-train": True,
}
return info_seed
def get_nas_bench_loaders(workers): def get_nas_bench_loaders(workers):
torch.set_num_threads(workers) torch.set_num_threads(workers)
root_dir = (pathlib.Path(__file__).parent / '..' / '..').resolve() root_dir = (pathlib.Path(__file__).parent / ".." / "..").resolve()
torch_dir = pathlib.Path(os.environ['TORCH_HOME']) torch_dir = pathlib.Path(os.environ["TORCH_HOME"])
# cifar # cifar
cifar_config_path = root_dir / 'configs' / 'nas-benchmark' / 'CIFAR.config' cifar_config_path = root_dir / "configs" / "nas-benchmark" / "CIFAR.config"
cifar_config = load_config(cifar_config_path, None, None) cifar_config = load_config(cifar_config_path, None, None)
get_datasets = datasets.get_datasets # a function to return the dataset get_datasets = datasets.get_datasets # a function to return the dataset
break_line = '-' * 150 break_line = "-" * 150
print ('{:} Create data-loader for all datasets'.format(time_string())) print("{:} Create data-loader for all datasets".format(time_string()))
print (break_line) print(break_line)
TRAIN_CIFAR10, VALID_CIFAR10, xshape, class_num = get_datasets('cifar10', str(torch_dir/'cifar.python'), -1) TRAIN_CIFAR10, VALID_CIFAR10, xshape, class_num = get_datasets("cifar10", str(torch_dir / "cifar.python"), -1)
print ('original CIFAR-10 : {:} training images and {:} test images : {:} input shape : {:} number of classes'.format(len(TRAIN_CIFAR10), len(VALID_CIFAR10), xshape, class_num)) print(
cifar10_splits = load_config(root_dir / 'configs' / 'nas-benchmark' / 'cifar-split.txt', None, None) "original CIFAR-10 : {:} training images and {:} test images : {:} input shape : {:} number of classes".format(
assert cifar10_splits.train[:10] == [0, 5, 7, 11, 13, 15, 16, 17, 20, 24] and cifar10_splits.valid[:10] == [1, 2, 3, 4, 6, 8, 9, 10, 12, 14] len(TRAIN_CIFAR10), len(VALID_CIFAR10), xshape, class_num
temp_dataset = copy.deepcopy(TRAIN_CIFAR10) )
temp_dataset.transform = VALID_CIFAR10.transform )
# data loader cifar10_splits = load_config(root_dir / "configs" / "nas-benchmark" / "cifar-split.txt", None, None)
trainval_cifar10_loader = torch.utils.data.DataLoader(TRAIN_CIFAR10, batch_size=cifar_config.batch_size, shuffle=True , num_workers=workers, pin_memory=True) assert cifar10_splits.train[:10] == [0, 5, 7, 11, 13, 15, 16, 17, 20, 24] and cifar10_splits.valid[:10] == [
train_cifar10_loader = torch.utils.data.DataLoader(TRAIN_CIFAR10, batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.train), num_workers=workers, pin_memory=True) 1,
valid_cifar10_loader = torch.utils.data.DataLoader(temp_dataset , batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.valid), num_workers=workers, pin_memory=True) 2,
test__cifar10_loader = torch.utils.data.DataLoader(VALID_CIFAR10, batch_size=cifar_config.batch_size, shuffle=False, num_workers=workers, pin_memory=True) 3,
print ('CIFAR-10 : trval-loader has {:3d} batch with {:} per batch'.format(len(trainval_cifar10_loader), cifar_config.batch_size)) 4,
print ('CIFAR-10 : train-loader has {:3d} batch with {:} per batch'.format(len(train_cifar10_loader), cifar_config.batch_size)) 6,
print ('CIFAR-10 : valid-loader has {:3d} batch with {:} per batch'.format(len(valid_cifar10_loader), cifar_config.batch_size)) 8,
print ('CIFAR-10 : test--loader has {:3d} batch with {:} per batch'.format(len(test__cifar10_loader), cifar_config.batch_size)) 9,
print (break_line) 10,
# CIFAR-100 12,
TRAIN_CIFAR100, VALID_CIFAR100, xshape, class_num = get_datasets('cifar100', str(torch_dir/'cifar.python'), -1) 14,
print ('original CIFAR-100: {:} training images and {:} test images : {:} input shape : {:} number of classes'.format(len(TRAIN_CIFAR100), len(VALID_CIFAR100), xshape, class_num)) ]
cifar100_splits = load_config(root_dir / 'configs' / 'nas-benchmark' / 'cifar100-test-split.txt', None, None) temp_dataset = copy.deepcopy(TRAIN_CIFAR10)
assert cifar100_splits.xvalid[:10] == [1, 3, 4, 5, 8, 10, 13, 14, 15, 16] and cifar100_splits.xtest[:10] == [0, 2, 6, 7, 9, 11, 12, 17, 20, 24] temp_dataset.transform = VALID_CIFAR10.transform
train_cifar100_loader = torch.utils.data.DataLoader(TRAIN_CIFAR100, batch_size=cifar_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True) # data loader
valid_cifar100_loader = torch.utils.data.DataLoader(VALID_CIFAR100, batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid), num_workers=workers, pin_memory=True) trainval_cifar10_loader = torch.utils.data.DataLoader(
test__cifar100_loader = torch.utils.data.DataLoader(VALID_CIFAR100, batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest) , num_workers=workers, pin_memory=True) TRAIN_CIFAR10, batch_size=cifar_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True
print ('CIFAR-100 : train-loader has {:3d} batch'.format(len(train_cifar100_loader))) )
print ('CIFAR-100 : valid-loader has {:3d} batch'.format(len(valid_cifar100_loader))) train_cifar10_loader = torch.utils.data.DataLoader(
print ('CIFAR-100 : test--loader has {:3d} batch'.format(len(test__cifar100_loader))) TRAIN_CIFAR10,
print (break_line) batch_size=cifar_config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.train),
num_workers=workers,
pin_memory=True,
)
valid_cifar10_loader = torch.utils.data.DataLoader(
temp_dataset,
batch_size=cifar_config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.valid),
num_workers=workers,
pin_memory=True,
)
test__cifar10_loader = torch.utils.data.DataLoader(
VALID_CIFAR10, batch_size=cifar_config.batch_size, shuffle=False, num_workers=workers, pin_memory=True
)
print(
"CIFAR-10 : trval-loader has {:3d} batch with {:} per batch".format(
len(trainval_cifar10_loader), cifar_config.batch_size
)
)
print(
"CIFAR-10 : train-loader has {:3d} batch with {:} per batch".format(
len(train_cifar10_loader), cifar_config.batch_size
)
)
print(
"CIFAR-10 : valid-loader has {:3d} batch with {:} per batch".format(
len(valid_cifar10_loader), cifar_config.batch_size
)
)
print(
"CIFAR-10 : test--loader has {:3d} batch with {:} per batch".format(
len(test__cifar10_loader), cifar_config.batch_size
)
)
print(break_line)
# CIFAR-100
TRAIN_CIFAR100, VALID_CIFAR100, xshape, class_num = get_datasets("cifar100", str(torch_dir / "cifar.python"), -1)
print(
"original CIFAR-100: {:} training images and {:} test images : {:} input shape : {:} number of classes".format(
len(TRAIN_CIFAR100), len(VALID_CIFAR100), xshape, class_num
)
)
cifar100_splits = load_config(root_dir / "configs" / "nas-benchmark" / "cifar100-test-split.txt", None, None)
assert cifar100_splits.xvalid[:10] == [1, 3, 4, 5, 8, 10, 13, 14, 15, 16] and cifar100_splits.xtest[:10] == [
0,
2,
6,
7,
9,
11,
12,
17,
20,
24,
]
train_cifar100_loader = torch.utils.data.DataLoader(
TRAIN_CIFAR100, batch_size=cifar_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True
)
valid_cifar100_loader = torch.utils.data.DataLoader(
VALID_CIFAR100,
batch_size=cifar_config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid),
num_workers=workers,
pin_memory=True,
)
test__cifar100_loader = torch.utils.data.DataLoader(
VALID_CIFAR100,
batch_size=cifar_config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest),
num_workers=workers,
pin_memory=True,
)
print("CIFAR-100 : train-loader has {:3d} batch".format(len(train_cifar100_loader)))
print("CIFAR-100 : valid-loader has {:3d} batch".format(len(valid_cifar100_loader)))
print("CIFAR-100 : test--loader has {:3d} batch".format(len(test__cifar100_loader)))
print(break_line)
imagenet16_config_path = 'configs/nas-benchmark/ImageNet-16.config' imagenet16_config_path = "configs/nas-benchmark/ImageNet-16.config"
imagenet16_config = load_config(imagenet16_config_path, None, None) imagenet16_config = load_config(imagenet16_config_path, None, None)
TRAIN_ImageNet16_120, VALID_ImageNet16_120, xshape, class_num = get_datasets('ImageNet16-120', str(torch_dir/'cifar.python'/'ImageNet16'), -1) TRAIN_ImageNet16_120, VALID_ImageNet16_120, xshape, class_num = get_datasets(
print ('original TRAIN_ImageNet16_120: {:} training images and {:} test images : {:} input shape : {:} number of classes'.format(len(TRAIN_ImageNet16_120), len(VALID_ImageNet16_120), xshape, class_num)) "ImageNet16-120", str(torch_dir / "cifar.python" / "ImageNet16"), -1
imagenet_splits = load_config(root_dir / 'configs' / 'nas-benchmark' / 'imagenet-16-120-test-split.txt', None, None) )
assert imagenet_splits.xvalid[:10] == [1, 2, 3, 6, 7, 8, 9, 12, 16, 18] and imagenet_splits.xtest[:10] == [0, 4, 5, 10, 11, 13, 14, 15, 17, 20] print(
train_imagenet_loader = torch.utils.data.DataLoader(TRAIN_ImageNet16_120, batch_size=imagenet16_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True) "original TRAIN_ImageNet16_120: {:} training images and {:} test images : {:} input shape : {:} number of classes".format(
valid_imagenet_loader = torch.utils.data.DataLoader(VALID_ImageNet16_120, batch_size=imagenet16_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xvalid), num_workers=workers, pin_memory=True) len(TRAIN_ImageNet16_120), len(VALID_ImageNet16_120), xshape, class_num
test__imagenet_loader = torch.utils.data.DataLoader(VALID_ImageNet16_120, batch_size=imagenet16_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xtest) , num_workers=workers, pin_memory=True) )
print ('ImageNet-16-120 : train-loader has {:3d} batch with {:} per batch'.format(len(train_imagenet_loader), imagenet16_config.batch_size)) )
print ('ImageNet-16-120 : valid-loader has {:3d} batch with {:} per batch'.format(len(valid_imagenet_loader), imagenet16_config.batch_size)) imagenet_splits = load_config(root_dir / "configs" / "nas-benchmark" / "imagenet-16-120-test-split.txt", None, None)
print ('ImageNet-16-120 : test--loader has {:3d} batch with {:} per batch'.format(len(test__imagenet_loader), imagenet16_config.batch_size)) assert imagenet_splits.xvalid[:10] == [1, 2, 3, 6, 7, 8, 9, 12, 16, 18] and imagenet_splits.xtest[:10] == [
0,
4,
5,
10,
11,
13,
14,
15,
17,
20,
]
train_imagenet_loader = torch.utils.data.DataLoader(
TRAIN_ImageNet16_120,
batch_size=imagenet16_config.batch_size,
shuffle=True,
num_workers=workers,
pin_memory=True,
)
valid_imagenet_loader = torch.utils.data.DataLoader(
VALID_ImageNet16_120,
batch_size=imagenet16_config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xvalid),
num_workers=workers,
pin_memory=True,
)
test__imagenet_loader = torch.utils.data.DataLoader(
VALID_ImageNet16_120,
batch_size=imagenet16_config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xtest),
num_workers=workers,
pin_memory=True,
)
print(
"ImageNet-16-120 : train-loader has {:3d} batch with {:} per batch".format(
len(train_imagenet_loader), imagenet16_config.batch_size
)
)
print(
"ImageNet-16-120 : valid-loader has {:3d} batch with {:} per batch".format(
len(valid_imagenet_loader), imagenet16_config.batch_size
)
)
print(
"ImageNet-16-120 : test--loader has {:3d} batch with {:} per batch".format(
len(test__imagenet_loader), imagenet16_config.batch_size
)
)
# 'cifar10', 'cifar100', 'ImageNet16-120' # 'cifar10', 'cifar100', 'ImageNet16-120'
loaders = {'cifar10@trainval': trainval_cifar10_loader, loaders = {
'cifar10@train' : train_cifar10_loader, "cifar10@trainval": trainval_cifar10_loader,
'cifar10@valid' : valid_cifar10_loader, "cifar10@train": train_cifar10_loader,
'cifar10@test' : test__cifar10_loader, "cifar10@valid": valid_cifar10_loader,
'cifar100@train' : train_cifar100_loader, "cifar10@test": test__cifar10_loader,
'cifar100@valid' : valid_cifar100_loader, "cifar100@train": train_cifar100_loader,
'cifar100@test' : test__cifar100_loader, "cifar100@valid": valid_cifar100_loader,
'ImageNet16-120@train': train_imagenet_loader, "cifar100@test": test__cifar100_loader,
'ImageNet16-120@valid': valid_imagenet_loader, "ImageNet16-120@train": train_imagenet_loader,
'ImageNet16-120@test' : test__imagenet_loader} "ImageNet16-120@valid": valid_imagenet_loader,
return loaders "ImageNet16-120@test": test__imagenet_loader,
}
return loaders

View File

@ -8,197 +8,201 @@ from torch.optim import Optimizer
class _LRScheduler(object): class _LRScheduler(object):
def __init__(self, optimizer, warmup_epochs, epochs):
if not isinstance(optimizer, Optimizer):
raise TypeError("{:} is not an Optimizer".format(type(optimizer).__name__))
self.optimizer = optimizer
for group in optimizer.param_groups:
group.setdefault("initial_lr", group["lr"])
self.base_lrs = list(map(lambda group: group["initial_lr"], optimizer.param_groups))
self.max_epochs = epochs
self.warmup_epochs = warmup_epochs
self.current_epoch = 0
self.current_iter = 0
def __init__(self, optimizer, warmup_epochs, epochs): def extra_repr(self):
if not isinstance(optimizer, Optimizer): return ""
raise TypeError('{:} is not an Optimizer'.format(type(optimizer).__name__))
self.optimizer = optimizer
for group in optimizer.param_groups:
group.setdefault('initial_lr', group['lr'])
self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
self.max_epochs = epochs
self.warmup_epochs = warmup_epochs
self.current_epoch = 0
self.current_iter = 0
def extra_repr(self): def __repr__(self):
return '' return "{name}(warmup={warmup_epochs}, max-epoch={max_epochs}, current::epoch={current_epoch}, iter={current_iter:.2f}".format(
name=self.__class__.__name__, **self.__dict__
) + ", {:})".format(
self.extra_repr()
)
def __repr__(self): def state_dict(self):
return ('{name}(warmup={warmup_epochs}, max-epoch={max_epochs}, current::epoch={current_epoch}, iter={current_iter:.2f}'.format(name=self.__class__.__name__, **self.__dict__) return {key: value for key, value in self.__dict__.items() if key != "optimizer"}
+ ', {:})'.format(self.extra_repr()))
def state_dict(self): def load_state_dict(self, state_dict):
return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} self.__dict__.update(state_dict)
def load_state_dict(self, state_dict): def get_lr(self):
self.__dict__.update(state_dict) raise NotImplementedError
def get_lr(self): def get_min_info(self):
raise NotImplementedError lrs = self.get_lr()
return "#LR=[{:.6f}~{:.6f}] epoch={:03d}, iter={:4.2f}#".format(
min(lrs), max(lrs), self.current_epoch, self.current_iter
)
def get_min_info(self): def get_min_lr(self):
lrs = self.get_lr() return min(self.get_lr())
return '#LR=[{:.6f}~{:.6f}] epoch={:03d}, iter={:4.2f}#'.format(min(lrs), max(lrs), self.current_epoch, self.current_iter)
def get_min_lr(self):
return min( self.get_lr() )
def update(self, cur_epoch, cur_iter):
if cur_epoch is not None:
assert isinstance(cur_epoch, int) and cur_epoch>=0, 'invalid cur-epoch : {:}'.format(cur_epoch)
self.current_epoch = cur_epoch
if cur_iter is not None:
assert isinstance(cur_iter, float) and cur_iter>=0, 'invalid cur-iter : {:}'.format(cur_iter)
self.current_iter = cur_iter
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group['lr'] = lr
def update(self, cur_epoch, cur_iter):
if cur_epoch is not None:
assert isinstance(cur_epoch, int) and cur_epoch >= 0, "invalid cur-epoch : {:}".format(cur_epoch)
self.current_epoch = cur_epoch
if cur_iter is not None:
assert isinstance(cur_iter, float) and cur_iter >= 0, "invalid cur-iter : {:}".format(cur_iter)
self.current_iter = cur_iter
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group["lr"] = lr
class CosineAnnealingLR(_LRScheduler): class CosineAnnealingLR(_LRScheduler):
def __init__(self, optimizer, warmup_epochs, epochs, T_max, eta_min):
self.T_max = T_max
self.eta_min = eta_min
super(CosineAnnealingLR, self).__init__(optimizer, warmup_epochs, epochs)
def __init__(self, optimizer, warmup_epochs, epochs, T_max, eta_min): def extra_repr(self):
self.T_max = T_max return "type={:}, T-max={:}, eta-min={:}".format("cosine", self.T_max, self.eta_min)
self.eta_min = eta_min
super(CosineAnnealingLR, self).__init__(optimizer, warmup_epochs, epochs)
def extra_repr(self):
return 'type={:}, T-max={:}, eta-min={:}'.format('cosine', self.T_max, self.eta_min)
def get_lr(self):
lrs = []
for base_lr in self.base_lrs:
if self.current_epoch >= self.warmup_epochs and self.current_epoch < self.max_epochs:
last_epoch = self.current_epoch - self.warmup_epochs
#if last_epoch < self.T_max:
#if last_epoch < self.max_epochs:
lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * last_epoch / self.T_max)) / 2
#else:
# lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * (self.T_max-1.0) / self.T_max)) / 2
elif self.current_epoch >= self.max_epochs:
lr = self.eta_min
else:
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
lrs.append( lr )
return lrs
def get_lr(self):
lrs = []
for base_lr in self.base_lrs:
if self.current_epoch >= self.warmup_epochs and self.current_epoch < self.max_epochs:
last_epoch = self.current_epoch - self.warmup_epochs
# if last_epoch < self.T_max:
# if last_epoch < self.max_epochs:
lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * last_epoch / self.T_max)) / 2
# else:
# lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * (self.T_max-1.0) / self.T_max)) / 2
elif self.current_epoch >= self.max_epochs:
lr = self.eta_min
else:
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
lrs.append(lr)
return lrs
class MultiStepLR(_LRScheduler): class MultiStepLR(_LRScheduler):
def __init__(self, optimizer, warmup_epochs, epochs, milestones, gammas):
assert len(milestones) == len(gammas), "invalid {:} vs {:}".format(len(milestones), len(gammas))
self.milestones = milestones
self.gammas = gammas
super(MultiStepLR, self).__init__(optimizer, warmup_epochs, epochs)
def __init__(self, optimizer, warmup_epochs, epochs, milestones, gammas): def extra_repr(self):
assert len(milestones) == len(gammas), 'invalid {:} vs {:}'.format(len(milestones), len(gammas)) return "type={:}, milestones={:}, gammas={:}, base-lrs={:}".format(
self.milestones = milestones "multistep", self.milestones, self.gammas, self.base_lrs
self.gammas = gammas )
super(MultiStepLR, self).__init__(optimizer, warmup_epochs, epochs)
def extra_repr(self): def get_lr(self):
return 'type={:}, milestones={:}, gammas={:}, base-lrs={:}'.format('multistep', self.milestones, self.gammas, self.base_lrs) lrs = []
for base_lr in self.base_lrs:
def get_lr(self): if self.current_epoch >= self.warmup_epochs:
lrs = [] last_epoch = self.current_epoch - self.warmup_epochs
for base_lr in self.base_lrs: idx = bisect_right(self.milestones, last_epoch)
if self.current_epoch >= self.warmup_epochs: lr = base_lr
last_epoch = self.current_epoch - self.warmup_epochs for x in self.gammas[:idx]:
idx = bisect_right(self.milestones, last_epoch) lr *= x
lr = base_lr else:
for x in self.gammas[:idx]: lr *= x lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
else: lrs.append(lr)
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr return lrs
lrs.append( lr )
return lrs
class ExponentialLR(_LRScheduler): class ExponentialLR(_LRScheduler):
def __init__(self, optimizer, warmup_epochs, epochs, gamma):
self.gamma = gamma
super(ExponentialLR, self).__init__(optimizer, warmup_epochs, epochs)
def __init__(self, optimizer, warmup_epochs, epochs, gamma): def extra_repr(self):
self.gamma = gamma return "type={:}, gamma={:}, base-lrs={:}".format("exponential", self.gamma, self.base_lrs)
super(ExponentialLR, self).__init__(optimizer, warmup_epochs, epochs)
def extra_repr(self): def get_lr(self):
return 'type={:}, gamma={:}, base-lrs={:}'.format('exponential', self.gamma, self.base_lrs) lrs = []
for base_lr in self.base_lrs:
def get_lr(self): if self.current_epoch >= self.warmup_epochs:
lrs = [] last_epoch = self.current_epoch - self.warmup_epochs
for base_lr in self.base_lrs: assert last_epoch >= 0, "invalid last_epoch : {:}".format(last_epoch)
if self.current_epoch >= self.warmup_epochs: lr = base_lr * (self.gamma ** last_epoch)
last_epoch = self.current_epoch - self.warmup_epochs else:
assert last_epoch >= 0, 'invalid last_epoch : {:}'.format(last_epoch) lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
lr = base_lr * (self.gamma ** last_epoch) lrs.append(lr)
else: return lrs
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
lrs.append( lr )
return lrs
class LinearLR(_LRScheduler): class LinearLR(_LRScheduler):
def __init__(self, optimizer, warmup_epochs, epochs, max_LR, min_LR):
self.max_LR = max_LR
self.min_LR = min_LR
super(LinearLR, self).__init__(optimizer, warmup_epochs, epochs)
def __init__(self, optimizer, warmup_epochs, epochs, max_LR, min_LR): def extra_repr(self):
self.max_LR = max_LR return "type={:}, max_LR={:}, min_LR={:}, base-lrs={:}".format(
self.min_LR = min_LR "LinearLR", self.max_LR, self.min_LR, self.base_lrs
super(LinearLR, self).__init__(optimizer, warmup_epochs, epochs) )
def extra_repr(self):
return 'type={:}, max_LR={:}, min_LR={:}, base-lrs={:}'.format('LinearLR', self.max_LR, self.min_LR, self.base_lrs)
def get_lr(self):
lrs = []
for base_lr in self.base_lrs:
if self.current_epoch >= self.warmup_epochs:
last_epoch = self.current_epoch - self.warmup_epochs
assert last_epoch >= 0, 'invalid last_epoch : {:}'.format(last_epoch)
ratio = (self.max_LR - self.min_LR) * last_epoch / self.max_epochs / self.max_LR
lr = base_lr * (1-ratio)
else:
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
lrs.append( lr )
return lrs
def get_lr(self):
lrs = []
for base_lr in self.base_lrs:
if self.current_epoch >= self.warmup_epochs:
last_epoch = self.current_epoch - self.warmup_epochs
assert last_epoch >= 0, "invalid last_epoch : {:}".format(last_epoch)
ratio = (self.max_LR - self.min_LR) * last_epoch / self.max_epochs / self.max_LR
lr = base_lr * (1 - ratio)
else:
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
lrs.append(lr)
return lrs
class CrossEntropyLabelSmooth(nn.Module): class CrossEntropyLabelSmooth(nn.Module):
def __init__(self, num_classes, epsilon):
super(CrossEntropyLabelSmooth, self).__init__()
self.num_classes = num_classes
self.epsilon = epsilon
self.logsoftmax = nn.LogSoftmax(dim=1)
def __init__(self, num_classes, epsilon): def forward(self, inputs, targets):
super(CrossEntropyLabelSmooth, self).__init__() log_probs = self.logsoftmax(inputs)
self.num_classes = num_classes targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
self.epsilon = epsilon targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
self.logsoftmax = nn.LogSoftmax(dim=1) loss = (-targets * log_probs).mean(0).sum()
return loss
def forward(self, inputs, targets):
log_probs = self.logsoftmax(inputs)
targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
loss = (-targets * log_probs).mean(0).sum()
return loss
def get_optim_scheduler(parameters, config): def get_optim_scheduler(parameters, config):
assert hasattr(config, 'optim') and hasattr(config, 'scheduler') and hasattr(config, 'criterion'), 'config must have optim / scheduler / criterion keys instead of {:}'.format(config) assert (
if config.optim == 'SGD': hasattr(config, "optim") and hasattr(config, "scheduler") and hasattr(config, "criterion")
optim = torch.optim.SGD(parameters, config.LR, momentum=config.momentum, weight_decay=config.decay, nesterov=config.nesterov) ), "config must have optim / scheduler / criterion keys instead of {:}".format(config)
elif config.optim == 'RMSprop': if config.optim == "SGD":
optim = torch.optim.RMSprop(parameters, config.LR, momentum=config.momentum, weight_decay=config.decay) optim = torch.optim.SGD(
else: parameters, config.LR, momentum=config.momentum, weight_decay=config.decay, nesterov=config.nesterov
raise ValueError('invalid optim : {:}'.format(config.optim)) )
elif config.optim == "RMSprop":
optim = torch.optim.RMSprop(parameters, config.LR, momentum=config.momentum, weight_decay=config.decay)
else:
raise ValueError("invalid optim : {:}".format(config.optim))
if config.scheduler == 'cos': if config.scheduler == "cos":
T_max = getattr(config, 'T_max', config.epochs) T_max = getattr(config, "T_max", config.epochs)
scheduler = CosineAnnealingLR(optim, config.warmup, config.epochs, T_max, config.eta_min) scheduler = CosineAnnealingLR(optim, config.warmup, config.epochs, T_max, config.eta_min)
elif config.scheduler == 'multistep': elif config.scheduler == "multistep":
scheduler = MultiStepLR(optim, config.warmup, config.epochs, config.milestones, config.gammas) scheduler = MultiStepLR(optim, config.warmup, config.epochs, config.milestones, config.gammas)
elif config.scheduler == 'exponential': elif config.scheduler == "exponential":
scheduler = ExponentialLR(optim, config.warmup, config.epochs, config.gamma) scheduler = ExponentialLR(optim, config.warmup, config.epochs, config.gamma)
elif config.scheduler == 'linear': elif config.scheduler == "linear":
scheduler = LinearLR(optim, config.warmup, config.epochs, config.LR, config.LR_min) scheduler = LinearLR(optim, config.warmup, config.epochs, config.LR, config.LR_min)
else: else:
raise ValueError('invalid scheduler : {:}'.format(config.scheduler)) raise ValueError("invalid scheduler : {:}".format(config.scheduler))
if config.criterion == 'Softmax': if config.criterion == "Softmax":
criterion = torch.nn.CrossEntropyLoss() criterion = torch.nn.CrossEntropyLoss()
elif config.criterion == 'SmoothSoftmax': elif config.criterion == "SmoothSoftmax":
criterion = CrossEntropyLabelSmooth(config.class_num, config.label_smooth) criterion = CrossEntropyLabelSmooth(config.class_num, config.label_smooth)
else: else:
raise ValueError('invalid criterion : {:}'.format(config.criterion)) raise ValueError("invalid criterion : {:}".format(config.criterion))
return optim, scheduler, criterion return optim, scheduler, criterion

View File

@ -7,11 +7,12 @@ from qlib.utils import init_instance_by_config
from qlib.workflow import R from qlib.workflow import R
from qlib.utils import flatten_dict from qlib.utils import flatten_dict
from qlib.log import set_log_basic_config from qlib.log import set_log_basic_config
from qlib.log import get_module_logger
def update_gpu(config, gpu): def update_gpu(config, gpu):
config = config.copy() config = config.copy()
if "task" in config and "GPU" in config["task"]["model"]: if "task" in config and "moodel" in config["task"] and "GPU" in config["task"]["model"]:
config["task"]["model"]["GPU"] = gpu config["task"]["model"]["GPU"] = gpu
elif "model" in config and "GPU" in config["model"]: elif "model" in config and "GPU" in config["model"]:
config["model"]["GPU"] = gpu config["model"]["GPU"] = gpu
@ -29,11 +30,6 @@ def update_market(config, market):
def run_exp(task_config, dataset, experiment_name, recorder_name, uri): def run_exp(task_config, dataset, experiment_name, recorder_name, uri):
# model initiaiton
print("")
print("[{:}] - [{:}]: {:}".format(experiment_name, recorder_name, uri))
print("dataset={:}".format(dataset))
model = init_instance_by_config(task_config["model"]) model = init_instance_by_config(task_config["model"])
# start exp # start exp
@ -41,6 +37,10 @@ def run_exp(task_config, dataset, experiment_name, recorder_name, uri):
log_file = R.get_recorder().root_uri / "{:}.log".format(experiment_name) log_file = R.get_recorder().root_uri / "{:}.log".format(experiment_name)
set_log_basic_config(log_file) set_log_basic_config(log_file)
logger = get_module_logger("q.run_exp")
logger.info("task_config={:}".format(task_config))
logger.info("[{:}] - [{:}]: {:}".format(experiment_name, recorder_name, uri))
logger.info("dataset={:}".format(dataset))
# train model # train model
R.log_params(**flatten_dict(task_config)) R.log_params(**flatten_dict(task_config))

View File

@ -3,124 +3,170 @@
################################################## ##################################################
import os, sys, time, torch import os, sys, time, torch
from log_utils import AverageMeter, time_string from log_utils import AverageMeter, time_string
from utils import obtain_accuracy from utils import obtain_accuracy
from models import change_key from models import change_key
def get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant): def get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant):
expected_flop = torch.mean( expected_flop ) expected_flop = torch.mean(expected_flop)
if flop_cur < flop_need - flop_tolerant: # Too Small FLOP if flop_cur < flop_need - flop_tolerant: # Too Small FLOP
loss = - torch.log( expected_flop ) loss = -torch.log(expected_flop)
#elif flop_cur > flop_need + flop_tolerant: # Too Large FLOP # elif flop_cur > flop_need + flop_tolerant: # Too Large FLOP
elif flop_cur > flop_need: # Too Large FLOP elif flop_cur > flop_need: # Too Large FLOP
loss = torch.log( expected_flop ) loss = torch.log(expected_flop)
else: # Required FLOP else: # Required FLOP
loss = None loss = None
if loss is None: return 0, 0 if loss is None:
else : return loss, loss.item() return 0, 0
else:
return loss, loss.item()
def search_train(search_loader, network, criterion, scheduler, base_optimizer, arch_optimizer, optim_config, extra_info, print_freq, logger): def search_train(
data_time, batch_time = AverageMeter(), AverageMeter() search_loader,
base_losses, arch_losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() network,
arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter() criterion,
epoch_str, flop_need, flop_weight, flop_tolerant = extra_info['epoch-str'], extra_info['FLOP-exp'], extra_info['FLOP-weight'], extra_info['FLOP-tolerant'] scheduler,
base_optimizer,
arch_optimizer,
optim_config,
extra_info,
print_freq,
logger,
):
data_time, batch_time = AverageMeter(), AverageMeter()
base_losses, arch_losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter()
epoch_str, flop_need, flop_weight, flop_tolerant = (
extra_info["epoch-str"],
extra_info["FLOP-exp"],
extra_info["FLOP-weight"],
extra_info["FLOP-tolerant"],
)
network.train() network.train()
logger.log('[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}'.format(epoch_str, flop_need, flop_weight)) logger.log("[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}".format(epoch_str, flop_need, flop_weight))
end = time.time()
network.apply( change_key('search_mode', 'search') )
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(search_loader):
scheduler.update(None, 1.0 * step / len(search_loader))
# calculate prediction and loss
base_targets = base_targets.cuda(non_blocking=True)
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
# update the weights
base_optimizer.zero_grad()
logits, expected_flop = network(base_inputs)
#network.apply( change_key('search_mode', 'basic') )
#features, logits = network(base_inputs)
base_loss = criterion(logits, base_targets)
base_loss.backward()
base_optimizer.step()
# record
prec1, prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
base_losses.update(base_loss.item(), base_inputs.size(0))
top1.update (prec1.item(), base_inputs.size(0))
top5.update (prec5.item(), base_inputs.size(0))
# update the architecture
arch_optimizer.zero_grad()
logits, expected_flop = network(arch_inputs)
flop_cur = network.module.get_flop('genotype', None, None)
flop_loss, flop_loss_scale = get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant)
acls_loss = criterion(logits, arch_targets)
arch_loss = acls_loss + flop_loss * flop_weight
arch_loss.backward()
arch_optimizer.step()
# record
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_flop_losses.update(flop_loss_scale, arch_inputs.size(0))
arch_cls_losses.update (acls_loss.item(), arch_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time() end = time.time()
if step % print_freq == 0 or (step+1) == len(search_loader): network.apply(change_key("search_mode", "search"))
Sstr = '**TRAIN** ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(search_loader)) for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(search_loader):
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) scheduler.update(None, 1.0 * step / len(search_loader))
Lstr = 'Base-Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=base_losses, top1=top1, top5=top5) # calculate prediction and loss
Vstr = 'Acls-loss {aloss.val:.3f} ({aloss.avg:.3f}) FLOP-Loss {floss.val:.3f} ({floss.avg:.3f}) Arch-Loss {loss.val:.3f} ({loss.avg:.3f})'.format(aloss=arch_cls_losses, floss=arch_flop_losses, loss=arch_losses) base_targets = base_targets.cuda(non_blocking=True)
logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr) arch_targets = arch_targets.cuda(non_blocking=True)
#Istr = 'Bsz={:} Asz={:}'.format(list(base_inputs.size()), list(arch_inputs.size())) # measure data loading time
#logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' ' + Istr) data_time.update(time.time() - end)
#print(network.module.get_arch_info())
#print(network.module.width_attentions[0])
#print(network.module.width_attentions[1])
logger.log(' **TRAIN** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Base-Loss:{baseloss:.3f}, Arch-Loss={archloss:.3f}'.format(top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, baseloss=base_losses.avg, archloss=arch_losses.avg)) # update the weights
return base_losses.avg, arch_losses.avg, top1.avg, top5.avg base_optimizer.zero_grad()
logits, expected_flop = network(base_inputs)
# network.apply( change_key('search_mode', 'basic') )
# features, logits = network(base_inputs)
base_loss = criterion(logits, base_targets)
base_loss.backward()
base_optimizer.step()
# record
prec1, prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
base_losses.update(base_loss.item(), base_inputs.size(0))
top1.update(prec1.item(), base_inputs.size(0))
top5.update(prec5.item(), base_inputs.size(0))
# update the architecture
arch_optimizer.zero_grad()
logits, expected_flop = network(arch_inputs)
flop_cur = network.module.get_flop("genotype", None, None)
flop_loss, flop_loss_scale = get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant)
acls_loss = criterion(logits, arch_targets)
arch_loss = acls_loss + flop_loss * flop_weight
arch_loss.backward()
arch_optimizer.step()
# record
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_flop_losses.update(flop_loss_scale, arch_inputs.size(0))
arch_cls_losses.update(acls_loss.item(), arch_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if step % print_freq == 0 or (step + 1) == len(search_loader):
Sstr = "**TRAIN** " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(search_loader))
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
batch_time=batch_time, data_time=data_time
)
Lstr = "Base-Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format(
loss=base_losses, top1=top1, top5=top5
)
Vstr = "Acls-loss {aloss.val:.3f} ({aloss.avg:.3f}) FLOP-Loss {floss.val:.3f} ({floss.avg:.3f}) Arch-Loss {loss.val:.3f} ({loss.avg:.3f})".format(
aloss=arch_cls_losses, floss=arch_flop_losses, loss=arch_losses
)
logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Vstr)
# Istr = 'Bsz={:} Asz={:}'.format(list(base_inputs.size()), list(arch_inputs.size()))
# logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' ' + Istr)
# print(network.module.get_arch_info())
# print(network.module.width_attentions[0])
# print(network.module.width_attentions[1])
logger.log(
" **TRAIN** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Base-Loss:{baseloss:.3f}, Arch-Loss={archloss:.3f}".format(
top1=top1,
top5=top5,
error1=100 - top1.avg,
error5=100 - top5.avg,
baseloss=base_losses.avg,
archloss=arch_losses.avg,
)
)
return base_losses.avg, arch_losses.avg, top1.avg, top5.avg
def search_valid(xloader, network, criterion, extra_info, print_freq, logger): def search_valid(xloader, network, criterion, extra_info, print_freq, logger):
data_time, batch_time, losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() data_time, batch_time, losses, top1, top5 = (
AverageMeter(),
AverageMeter(),
AverageMeter(),
AverageMeter(),
AverageMeter(),
)
network.eval() network.eval()
network.apply( change_key('search_mode', 'search') ) network.apply(change_key("search_mode", "search"))
end = time.time() end = time.time()
#logger.log('Starting evaluating {:}'.format(epoch_info)) # logger.log('Starting evaluating {:}'.format(epoch_info))
with torch.no_grad(): with torch.no_grad():
for i, (inputs, targets) in enumerate(xloader): for i, (inputs, targets) in enumerate(xloader):
# measure data loading time # measure data loading time
data_time.update(time.time() - end) data_time.update(time.time() - end)
# calculate prediction and loss # calculate prediction and loss
targets = targets.cuda(non_blocking=True) targets = targets.cuda(non_blocking=True)
logits, expected_flop = network(inputs) logits, expected_flop = network(inputs)
loss = criterion(logits, targets) loss = criterion(logits, targets)
# record # record
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
losses.update(loss.item(), inputs.size(0)) losses.update(loss.item(), inputs.size(0))
top1.update (prec1.item(), inputs.size(0)) top1.update(prec1.item(), inputs.size(0))
top5.update (prec5.item(), inputs.size(0)) top5.update(prec5.item(), inputs.size(0))
# measure elapsed time # measure elapsed time
batch_time.update(time.time() - end) batch_time.update(time.time() - end)
end = time.time() end = time.time()
if i % print_freq == 0 or (i+1) == len(xloader): if i % print_freq == 0 or (i + 1) == len(xloader):
Sstr = '**VALID** ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(extra_info, i, len(xloader)) Sstr = "**VALID** " + time_string() + " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader))
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=losses, top1=top1, top5=top5) batch_time=batch_time, data_time=data_time
Istr = 'Size={:}'.format(list(inputs.size())) )
logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Istr) Lstr = "Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format(
loss=losses, top1=top1, top5=top5
)
Istr = "Size={:}".format(list(inputs.size()))
logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Istr)
logger.log(' **VALID** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}'.format(top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, loss=losses.avg)) logger.log(
" **VALID** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format(
return losses.avg, top1.avg, top5.avg top1=top1, top5=top5, error1=100 - top1.avg, error5=100 - top5.avg, loss=losses.avg
)
)
return losses.avg, top1.avg, top5.avg

View File

@ -3,85 +3,118 @@
################################################## ##################################################
import os, sys, time, torch import os, sys, time, torch
from log_utils import AverageMeter, time_string from log_utils import AverageMeter, time_string
from utils import obtain_accuracy from utils import obtain_accuracy
from models import change_key from models import change_key
def get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant): def get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant):
expected_flop = torch.mean( expected_flop ) expected_flop = torch.mean(expected_flop)
if flop_cur < flop_need - flop_tolerant: # Too Small FLOP if flop_cur < flop_need - flop_tolerant: # Too Small FLOP
loss = - torch.log( expected_flop ) loss = -torch.log(expected_flop)
#elif flop_cur > flop_need + flop_tolerant: # Too Large FLOP # elif flop_cur > flop_need + flop_tolerant: # Too Large FLOP
elif flop_cur > flop_need: # Too Large FLOP elif flop_cur > flop_need: # Too Large FLOP
loss = torch.log( expected_flop ) loss = torch.log(expected_flop)
else: # Required FLOP else: # Required FLOP
loss = None loss = None
if loss is None: return 0, 0 if loss is None:
else : return loss, loss.item() return 0, 0
else:
return loss, loss.item()
def search_train_v2(search_loader, network, criterion, scheduler, base_optimizer, arch_optimizer, optim_config, extra_info, print_freq, logger): def search_train_v2(
data_time, batch_time = AverageMeter(), AverageMeter() search_loader,
base_losses, arch_losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() network,
arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter() criterion,
epoch_str, flop_need, flop_weight, flop_tolerant = extra_info['epoch-str'], extra_info['FLOP-exp'], extra_info['FLOP-weight'], extra_info['FLOP-tolerant'] scheduler,
base_optimizer,
arch_optimizer,
optim_config,
extra_info,
print_freq,
logger,
):
data_time, batch_time = AverageMeter(), AverageMeter()
base_losses, arch_losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter()
epoch_str, flop_need, flop_weight, flop_tolerant = (
extra_info["epoch-str"],
extra_info["FLOP-exp"],
extra_info["FLOP-weight"],
extra_info["FLOP-tolerant"],
)
network.train() network.train()
logger.log('[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}'.format(epoch_str, flop_need, flop_weight)) logger.log("[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}".format(epoch_str, flop_need, flop_weight))
end = time.time()
network.apply( change_key('search_mode', 'search') )
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(search_loader):
scheduler.update(None, 1.0 * step / len(search_loader))
# calculate prediction and loss
base_targets = base_targets.cuda(non_blocking=True)
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
# update the weights
base_optimizer.zero_grad()
logits, expected_flop = network(base_inputs)
base_loss = criterion(logits, base_targets)
base_loss.backward()
base_optimizer.step()
# record
prec1, prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
base_losses.update(base_loss.item(), base_inputs.size(0))
top1.update (prec1.item(), base_inputs.size(0))
top5.update (prec5.item(), base_inputs.size(0))
# update the architecture
arch_optimizer.zero_grad()
logits, expected_flop = network(arch_inputs)
flop_cur = network.module.get_flop('genotype', None, None)
flop_loss, flop_loss_scale = get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant)
acls_loss = criterion(logits, arch_targets)
arch_loss = acls_loss + flop_loss * flop_weight
arch_loss.backward()
arch_optimizer.step()
# record
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_flop_losses.update(flop_loss_scale, arch_inputs.size(0))
arch_cls_losses.update (acls_loss.item(), arch_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time() end = time.time()
if step % print_freq == 0 or (step+1) == len(search_loader): network.apply(change_key("search_mode", "search"))
Sstr = '**TRAIN** ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(search_loader)) for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(search_loader):
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time) scheduler.update(None, 1.0 * step / len(search_loader))
Lstr = 'Base-Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=base_losses, top1=top1, top5=top5) # calculate prediction and loss
Vstr = 'Acls-loss {aloss.val:.3f} ({aloss.avg:.3f}) FLOP-Loss {floss.val:.3f} ({floss.avg:.3f}) Arch-Loss {loss.val:.3f} ({loss.avg:.3f})'.format(aloss=arch_cls_losses, floss=arch_flop_losses, loss=arch_losses) base_targets = base_targets.cuda(non_blocking=True)
logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr) arch_targets = arch_targets.cuda(non_blocking=True)
#num_bytes = torch.cuda.max_memory_allocated( next(network.parameters()).device ) * 1.0 # measure data loading time
#logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' GPU={:.2f}MB'.format(num_bytes/1e6)) data_time.update(time.time() - end)
#Istr = 'Bsz={:} Asz={:}'.format(list(base_inputs.size()), list(arch_inputs.size()))
#logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' ' + Istr)
#print(network.module.get_arch_info())
#print(network.module.width_attentions[0])
#print(network.module.width_attentions[1])
logger.log(' **TRAIN** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Base-Loss:{baseloss:.3f}, Arch-Loss={archloss:.3f}'.format(top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, baseloss=base_losses.avg, archloss=arch_losses.avg)) # update the weights
return base_losses.avg, arch_losses.avg, top1.avg, top5.avg base_optimizer.zero_grad()
logits, expected_flop = network(base_inputs)
base_loss = criterion(logits, base_targets)
base_loss.backward()
base_optimizer.step()
# record
prec1, prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
base_losses.update(base_loss.item(), base_inputs.size(0))
top1.update(prec1.item(), base_inputs.size(0))
top5.update(prec5.item(), base_inputs.size(0))
# update the architecture
arch_optimizer.zero_grad()
logits, expected_flop = network(arch_inputs)
flop_cur = network.module.get_flop("genotype", None, None)
flop_loss, flop_loss_scale = get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant)
acls_loss = criterion(logits, arch_targets)
arch_loss = acls_loss + flop_loss * flop_weight
arch_loss.backward()
arch_optimizer.step()
# record
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_flop_losses.update(flop_loss_scale, arch_inputs.size(0))
arch_cls_losses.update(acls_loss.item(), arch_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if step % print_freq == 0 or (step + 1) == len(search_loader):
Sstr = "**TRAIN** " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(search_loader))
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
batch_time=batch_time, data_time=data_time
)
Lstr = "Base-Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format(
loss=base_losses, top1=top1, top5=top5
)
Vstr = "Acls-loss {aloss.val:.3f} ({aloss.avg:.3f}) FLOP-Loss {floss.val:.3f} ({floss.avg:.3f}) Arch-Loss {loss.val:.3f} ({loss.avg:.3f})".format(
aloss=arch_cls_losses, floss=arch_flop_losses, loss=arch_losses
)
logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Vstr)
# num_bytes = torch.cuda.max_memory_allocated( next(network.parameters()).device ) * 1.0
# logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' GPU={:.2f}MB'.format(num_bytes/1e6))
# Istr = 'Bsz={:} Asz={:}'.format(list(base_inputs.size()), list(arch_inputs.size()))
# logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' ' + Istr)
# print(network.module.get_arch_info())
# print(network.module.width_attentions[0])
# print(network.module.width_attentions[1])
logger.log(
" **TRAIN** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Base-Loss:{baseloss:.3f}, Arch-Loss={archloss:.3f}".format(
top1=top1,
top5=top5,
error1=100 - top1.avg,
error5=100 - top5.avg,
baseloss=base_losses.avg,
archloss=arch_losses.avg,
)
)
return base_losses.avg, arch_losses.avg, top1.avg, top5.avg

View File

@ -3,92 +3,143 @@
##################################################### #####################################################
import os, sys, time, torch import os, sys, time, torch
import torch.nn.functional as F import torch.nn.functional as F
# our modules # our modules
from log_utils import AverageMeter, time_string from log_utils import AverageMeter, time_string
from utils import obtain_accuracy from utils import obtain_accuracy
def simple_KD_train(xloader, teacher, network, criterion, scheduler, optimizer, optim_config, extra_info, print_freq, logger): def simple_KD_train(
loss, acc1, acc5 = procedure(xloader, teacher, network, criterion, scheduler, optimizer, 'train', optim_config, extra_info, print_freq, logger) xloader, teacher, network, criterion, scheduler, optimizer, optim_config, extra_info, print_freq, logger
return loss, acc1, acc5 ):
loss, acc1, acc5 = procedure(
xloader,
teacher,
network,
criterion,
scheduler,
optimizer,
"train",
optim_config,
extra_info,
print_freq,
logger,
)
return loss, acc1, acc5
def simple_KD_valid(xloader, teacher, network, criterion, optim_config, extra_info, print_freq, logger): def simple_KD_valid(xloader, teacher, network, criterion, optim_config, extra_info, print_freq, logger):
with torch.no_grad(): with torch.no_grad():
loss, acc1, acc5 = procedure(xloader, teacher, network, criterion, None, None, 'valid', optim_config, extra_info, print_freq, logger) loss, acc1, acc5 = procedure(
return loss, acc1, acc5 xloader, teacher, network, criterion, None, None, "valid", optim_config, extra_info, print_freq, logger
)
return loss, acc1, acc5
def loss_KD_fn(criterion, student_logits, teacher_logits, studentFeatures, teacherFeatures, targets, alpha, temperature): def loss_KD_fn(
basic_loss = criterion(student_logits, targets) * (1. - alpha) criterion, student_logits, teacher_logits, studentFeatures, teacherFeatures, targets, alpha, temperature
log_student= F.log_softmax(student_logits / temperature, dim=1) ):
sof_teacher= F.softmax (teacher_logits / temperature, dim=1) basic_loss = criterion(student_logits, targets) * (1.0 - alpha)
KD_loss = F.kl_div(log_student, sof_teacher, reduction='batchmean') * (alpha * temperature * temperature) log_student = F.log_softmax(student_logits / temperature, dim=1)
return basic_loss + KD_loss sof_teacher = F.softmax(teacher_logits / temperature, dim=1)
KD_loss = F.kl_div(log_student, sof_teacher, reduction="batchmean") * (alpha * temperature * temperature)
return basic_loss + KD_loss
def procedure(xloader, teacher, network, criterion, scheduler, optimizer, mode, config, extra_info, print_freq, logger): def procedure(xloader, teacher, network, criterion, scheduler, optimizer, mode, config, extra_info, print_freq, logger):
data_time, batch_time, losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() data_time, batch_time, losses, top1, top5 = (
Ttop1, Ttop5 = AverageMeter(), AverageMeter() AverageMeter(),
if mode == 'train': AverageMeter(),
network.train() AverageMeter(),
elif mode == 'valid': AverageMeter(),
network.eval() AverageMeter(),
else: raise ValueError("The mode is not right : {:}".format(mode)) )
teacher.eval() Ttop1, Ttop5 = AverageMeter(), AverageMeter()
if mode == "train":
logger.log('[{:5s}] config :: auxiliary={:}, KD :: [alpha={:.2f}, temperature={:.2f}]'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1, config.KD_alpha, config.KD_temperature)) network.train()
end = time.time() elif mode == "valid":
for i, (inputs, targets) in enumerate(xloader): network.eval()
if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader))
# measure data loading time
data_time.update(time.time() - end)
# calculate prediction and loss
targets = targets.cuda(non_blocking=True)
if mode == 'train': optimizer.zero_grad()
student_f, logits = network(inputs)
if isinstance(logits, list):
assert len(logits) == 2, 'logits must has {:} items instead of {:}'.format(2, len(logits))
logits, logits_aux = logits
else: else:
logits, logits_aux = logits, None raise ValueError("The mode is not right : {:}".format(mode))
with torch.no_grad(): teacher.eval()
teacher_f, teacher_logits = teacher(inputs)
loss = loss_KD_fn(criterion, logits, teacher_logits, student_f, teacher_f, targets, config.KD_alpha, config.KD_temperature) logger.log(
if config is not None and hasattr(config, 'auxiliary') and config.auxiliary > 0: "[{:5s}] config :: auxiliary={:}, KD :: [alpha={:.2f}, temperature={:.2f}]".format(
loss_aux = criterion(logits_aux, targets) mode, config.auxiliary if hasattr(config, "auxiliary") else -1, config.KD_alpha, config.KD_temperature
loss += config.auxiliary * loss_aux )
)
if mode == 'train':
loss.backward()
optimizer.step()
# record
sprec1, sprec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1.update (sprec1.item(), inputs.size(0))
top5.update (sprec5.item(), inputs.size(0))
# teacher
tprec1, tprec5 = obtain_accuracy(teacher_logits.data, targets.data, topk=(1, 5))
Ttop1.update (tprec1.item(), inputs.size(0))
Ttop5.update (tprec5.item(), inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time() end = time.time()
for i, (inputs, targets) in enumerate(xloader):
if mode == "train":
scheduler.update(None, 1.0 * i / len(xloader))
# measure data loading time
data_time.update(time.time() - end)
# calculate prediction and loss
targets = targets.cuda(non_blocking=True)
if i % print_freq == 0 or (i+1) == len(xloader): if mode == "train":
Sstr = ' {:5s} '.format(mode.upper()) + time_string() + ' [{:}][{:03d}/{:03d}]'.format(extra_info, i, len(xloader)) optimizer.zero_grad()
if scheduler is not None:
Sstr += ' {:}'.format(scheduler.get_min_info())
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=losses, top1=top1, top5=top5)
Lstr+= ' Teacher : acc@1={:.2f}, acc@5={:.2f}'.format(Ttop1.avg, Ttop5.avg)
Istr = 'Size={:}'.format(list(inputs.size()))
logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Istr)
logger.log(' **{:5s}** accuracy drop :: @1={:.2f}, @5={:.2f}'.format(mode.upper(), Ttop1.avg - top1.avg, Ttop5.avg - top5.avg)) student_f, logits = network(inputs)
logger.log(' **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}'.format(mode=mode.upper(), top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, loss=losses.avg)) if isinstance(logits, list):
return losses.avg, top1.avg, top5.avg assert len(logits) == 2, "logits must has {:} items instead of {:}".format(2, len(logits))
logits, logits_aux = logits
else:
logits, logits_aux = logits, None
with torch.no_grad():
teacher_f, teacher_logits = teacher(inputs)
loss = loss_KD_fn(
criterion, logits, teacher_logits, student_f, teacher_f, targets, config.KD_alpha, config.KD_temperature
)
if config is not None and hasattr(config, "auxiliary") and config.auxiliary > 0:
loss_aux = criterion(logits_aux, targets)
loss += config.auxiliary * loss_aux
if mode == "train":
loss.backward()
optimizer.step()
# record
sprec1, sprec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1.update(sprec1.item(), inputs.size(0))
top5.update(sprec5.item(), inputs.size(0))
# teacher
tprec1, tprec5 = obtain_accuracy(teacher_logits.data, targets.data, topk=(1, 5))
Ttop1.update(tprec1.item(), inputs.size(0))
Ttop5.update(tprec5.item(), inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % print_freq == 0 or (i + 1) == len(xloader):
Sstr = (
" {:5s} ".format(mode.upper())
+ time_string()
+ " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader))
)
if scheduler is not None:
Sstr += " {:}".format(scheduler.get_min_info())
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
batch_time=batch_time, data_time=data_time
)
Lstr = "Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format(
loss=losses, top1=top1, top5=top5
)
Lstr += " Teacher : acc@1={:.2f}, acc@5={:.2f}".format(Ttop1.avg, Ttop5.avg)
Istr = "Size={:}".format(list(inputs.size()))
logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Istr)
logger.log(
" **{:5s}** accuracy drop :: @1={:.2f}, @5={:.2f}".format(
mode.upper(), Ttop1.avg - top1.avg, Ttop5.avg - top5.avg
)
)
logger.log(
" **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format(
mode=mode.upper(), top1=top1, top5=top5, error1=100 - top1.avg, error5=100 - top5.avg, loss=losses.avg
)
)
return losses.avg, top1.avg, top5.avg

View File

@ -3,62 +3,71 @@
################################################## ##################################################
import os, sys, torch, random, PIL, copy, numpy as np import os, sys, torch, random, PIL, copy, numpy as np
from os import path as osp from os import path as osp
from shutil import copyfile from shutil import copyfile
def prepare_seed(rand_seed): def prepare_seed(rand_seed):
random.seed(rand_seed) random.seed(rand_seed)
np.random.seed(rand_seed) np.random.seed(rand_seed)
torch.manual_seed(rand_seed) torch.manual_seed(rand_seed)
torch.cuda.manual_seed(rand_seed) torch.cuda.manual_seed(rand_seed)
torch.cuda.manual_seed_all(rand_seed) torch.cuda.manual_seed_all(rand_seed)
def prepare_logger(xargs): def prepare_logger(xargs):
args = copy.deepcopy( xargs ) args = copy.deepcopy(xargs)
from log_utils import Logger from log_utils import Logger
logger = Logger(args.save_dir, args.rand_seed)
logger.log('Main Function with logger : {:}'.format(logger)) logger = Logger(args.save_dir, args.rand_seed)
logger.log('Arguments : -------------------------------') logger.log("Main Function with logger : {:}".format(logger))
for name, value in args._get_kwargs(): logger.log("Arguments : -------------------------------")
logger.log('{:16} : {:}'.format(name, value)) for name, value in args._get_kwargs():
logger.log("Python Version : {:}".format(sys.version.replace('\n', ' '))) logger.log("{:16} : {:}".format(name, value))
logger.log("Pillow Version : {:}".format(PIL.__version__)) logger.log("Python Version : {:}".format(sys.version.replace("\n", " ")))
logger.log("PyTorch Version : {:}".format(torch.__version__)) logger.log("Pillow Version : {:}".format(PIL.__version__))
logger.log("cuDNN Version : {:}".format(torch.backends.cudnn.version())) logger.log("PyTorch Version : {:}".format(torch.__version__))
logger.log("CUDA available : {:}".format(torch.cuda.is_available())) logger.log("cuDNN Version : {:}".format(torch.backends.cudnn.version()))
logger.log("CUDA GPU numbers : {:}".format(torch.cuda.device_count())) logger.log("CUDA available : {:}".format(torch.cuda.is_available()))
logger.log("CUDA_VISIBLE_DEVICES : {:}".format(os.environ['CUDA_VISIBLE_DEVICES'] if 'CUDA_VISIBLE_DEVICES' in os.environ else 'None')) logger.log("CUDA GPU numbers : {:}".format(torch.cuda.device_count()))
return logger logger.log(
"CUDA_VISIBLE_DEVICES : {:}".format(
os.environ["CUDA_VISIBLE_DEVICES"] if "CUDA_VISIBLE_DEVICES" in os.environ else "None"
)
)
return logger
def get_machine_info(): def get_machine_info():
info = "Python Version : {:}".format(sys.version.replace('\n', ' ')) info = "Python Version : {:}".format(sys.version.replace("\n", " "))
info+= "\nPillow Version : {:}".format(PIL.__version__) info += "\nPillow Version : {:}".format(PIL.__version__)
info+= "\nPyTorch Version : {:}".format(torch.__version__) info += "\nPyTorch Version : {:}".format(torch.__version__)
info+= "\ncuDNN Version : {:}".format(torch.backends.cudnn.version()) info += "\ncuDNN Version : {:}".format(torch.backends.cudnn.version())
info+= "\nCUDA available : {:}".format(torch.cuda.is_available()) info += "\nCUDA available : {:}".format(torch.cuda.is_available())
info+= "\nCUDA GPU numbers : {:}".format(torch.cuda.device_count()) info += "\nCUDA GPU numbers : {:}".format(torch.cuda.device_count())
if 'CUDA_VISIBLE_DEVICES' in os.environ: if "CUDA_VISIBLE_DEVICES" in os.environ:
info+= "\nCUDA_VISIBLE_DEVICES={:}".format(os.environ['CUDA_VISIBLE_DEVICES']) info += "\nCUDA_VISIBLE_DEVICES={:}".format(os.environ["CUDA_VISIBLE_DEVICES"])
else: else:
info+= "\nDoes not set CUDA_VISIBLE_DEVICES" info += "\nDoes not set CUDA_VISIBLE_DEVICES"
return info return info
def save_checkpoint(state, filename, logger): def save_checkpoint(state, filename, logger):
if osp.isfile(filename): if osp.isfile(filename):
if hasattr(logger, 'log'): logger.log('Find {:} exist, delete is at first before saving'.format(filename)) if hasattr(logger, "log"):
os.remove(filename) logger.log("Find {:} exist, delete is at first before saving".format(filename))
torch.save(state, filename) os.remove(filename)
assert osp.isfile(filename), 'save filename : {:} failed, which is not found.'.format(filename) torch.save(state, filename)
if hasattr(logger, 'log'): logger.log('save checkpoint into {:}'.format(filename)) assert osp.isfile(filename), "save filename : {:} failed, which is not found.".format(filename)
return filename if hasattr(logger, "log"):
logger.log("save checkpoint into {:}".format(filename))
return filename
def copy_checkpoint(src, dst, logger): def copy_checkpoint(src, dst, logger):
if osp.isfile(dst): if osp.isfile(dst):
if hasattr(logger, 'log'): logger.log('Find {:} exist, delete is at first before saving'.format(dst)) if hasattr(logger, "log"):
os.remove(dst) logger.log("Find {:} exist, delete is at first before saving".format(dst))
copyfile(src, dst) os.remove(dst)
if hasattr(logger, 'log'): logger.log('copy the file from {:} into {:}'.format(src, dst)) copyfile(src, dst)
if hasattr(logger, "log"):
logger.log("copy the file from {:} into {:}".format(src, dst))

View File

@ -1,7 +1,7 @@
from .evaluation_utils import obtain_accuracy from .evaluation_utils import obtain_accuracy
from .gpu_manager import GPUManager from .gpu_manager import GPUManager
from .flop_benchmark import get_model_infos, count_parameters, count_parameters_in_MB from .flop_benchmark import get_model_infos, count_parameters, count_parameters_in_MB
from .affine_utils import normalize_points, denormalize_points from .affine_utils import normalize_points, denormalize_points
from .affine_utils import identity2affine, solve2theta, affine2image from .affine_utils import identity2affine, solve2theta, affine2image
from .hash_utils import get_md5_file from .hash_utils import get_md5_file
from .str_utils import split_str2indexes from .str_utils import split_str2indexes

View File

@ -1,125 +1,149 @@
# functions for affine transformation # functions for affine transformation
import math, torch import math
import torch
import numpy as np import numpy as np
import torch.nn.functional as F import torch.nn.functional as F
def identity2affine(full=False): def identity2affine(full=False):
if not full: if not full:
parameters = torch.zeros((2,3)) parameters = torch.zeros((2, 3))
parameters[0, 0] = parameters[1, 1] = 1 parameters[0, 0] = parameters[1, 1] = 1
else: else:
parameters = torch.zeros((3,3)) parameters = torch.zeros((3, 3))
parameters[0, 0] = parameters[1, 1] = parameters[2, 2] = 1 parameters[0, 0] = parameters[1, 1] = parameters[2, 2] = 1
return parameters return parameters
def normalize_L(x, L): def normalize_L(x, L):
return -1. + 2. * x / (L-1) return -1.0 + 2.0 * x / (L - 1)
def denormalize_L(x, L): def denormalize_L(x, L):
return (x + 1.0) / 2.0 * (L-1) return (x + 1.0) / 2.0 * (L - 1)
def crop2affine(crop_box, W, H): def crop2affine(crop_box, W, H):
assert len(crop_box) == 4, 'Invalid crop-box : {:}'.format(crop_box) assert len(crop_box) == 4, "Invalid crop-box : {:}".format(crop_box)
parameters = torch.zeros(3,3) parameters = torch.zeros(3, 3)
x1, y1 = normalize_L(crop_box[0], W), normalize_L(crop_box[1], H) x1, y1 = normalize_L(crop_box[0], W), normalize_L(crop_box[1], H)
x2, y2 = normalize_L(crop_box[2], W), normalize_L(crop_box[3], H) x2, y2 = normalize_L(crop_box[2], W), normalize_L(crop_box[3], H)
parameters[0,0] = (x2-x1)/2 parameters[0, 0] = (x2 - x1) / 2
parameters[0,2] = (x2+x1)/2 parameters[0, 2] = (x2 + x1) / 2
parameters[1, 1] = (y2 - y1) / 2
parameters[1, 2] = (y2 + y1) / 2
parameters[2, 2] = 1
return parameters
parameters[1,1] = (y2-y1)/2
parameters[1,2] = (y2+y1)/2
parameters[2,2] = 1
return parameters
def scale2affine(scalex, scaley): def scale2affine(scalex, scaley):
parameters = torch.zeros(3,3) parameters = torch.zeros(3, 3)
parameters[0,0] = scalex parameters[0, 0] = scalex
parameters[1,1] = scaley parameters[1, 1] = scaley
parameters[2,2] = 1 parameters[2, 2] = 1
return parameters return parameters
def offset2affine(offx, offy): def offset2affine(offx, offy):
parameters = torch.zeros(3,3) parameters = torch.zeros(3, 3)
parameters[0,0] = parameters[1,1] = parameters[2,2] = 1 parameters[0, 0] = parameters[1, 1] = parameters[2, 2] = 1
parameters[0,2] = offx parameters[0, 2] = offx
parameters[1,2] = offy parameters[1, 2] = offy
return parameters return parameters
def horizontalmirror2affine(): def horizontalmirror2affine():
parameters = torch.zeros(3,3) parameters = torch.zeros(3, 3)
parameters[0,0] = -1 parameters[0, 0] = -1
parameters[1,1] = parameters[2,2] = 1 parameters[1, 1] = parameters[2, 2] = 1
return parameters return parameters
# clockwise rotate image = counterclockwise rotate the rectangle # clockwise rotate image = counterclockwise rotate the rectangle
# degree is between [0, 360] # degree is between [0, 360]
def rotate2affine(degree): def rotate2affine(degree):
assert degree >= 0 and degree <= 360, 'Invalid degree : {:}'.format(degree) assert degree >= 0 and degree <= 360, "Invalid degree : {:}".format(degree)
degree = degree / 180 * math.pi degree = degree / 180 * math.pi
parameters = torch.zeros(3,3) parameters = torch.zeros(3, 3)
parameters[0,0] = math.cos(-degree) parameters[0, 0] = math.cos(-degree)
parameters[0,1] = -math.sin(-degree) parameters[0, 1] = -math.sin(-degree)
parameters[1,0] = math.sin(-degree) parameters[1, 0] = math.sin(-degree)
parameters[1,1] = math.cos(-degree) parameters[1, 1] = math.cos(-degree)
parameters[2,2] = 1 parameters[2, 2] = 1
return parameters return parameters
# shape is a tuple [H, W] # shape is a tuple [H, W]
def normalize_points(shape, points): def normalize_points(shape, points):
assert (isinstance(shape, tuple) or isinstance(shape, list)) and len(shape) == 2, 'invalid shape : {:}'.format(shape) assert (isinstance(shape, tuple) or isinstance(shape, list)) and len(shape) == 2, "invalid shape : {:}".format(
assert isinstance(points, torch.Tensor) and (points.shape[0] == 2), 'points are wrong : {:}'.format(points.shape) shape
(H, W), points = shape, points.clone() )
points[0, :] = normalize_L(points[0,:], W) assert isinstance(points, torch.Tensor) and (points.shape[0] == 2), "points are wrong : {:}".format(points.shape)
points[1, :] = normalize_L(points[1,:], H) (H, W), points = shape, points.clone()
return points points[0, :] = normalize_L(points[0, :], W)
points[1, :] = normalize_L(points[1, :], H)
return points
# shape is a tuple [H, W] # shape is a tuple [H, W]
def normalize_points_batch(shape, points): def normalize_points_batch(shape, points):
assert (isinstance(shape, tuple) or isinstance(shape, list)) and len(shape) == 2, 'invalid shape : {:}'.format(shape) assert (isinstance(shape, tuple) or isinstance(shape, list)) and len(shape) == 2, "invalid shape : {:}".format(
assert isinstance(points, torch.Tensor) and (points.size(-1) == 2), 'points are wrong : {:}'.format(points.shape) shape
(H, W), points = shape, points.clone() )
x = normalize_L(points[...,0], W) assert isinstance(points, torch.Tensor) and (points.size(-1) == 2), "points are wrong : {:}".format(points.shape)
y = normalize_L(points[...,1], H) (H, W), points = shape, points.clone()
return torch.stack((x,y), dim=-1) x = normalize_L(points[..., 0], W)
y = normalize_L(points[..., 1], H)
return torch.stack((x, y), dim=-1)
# shape is a tuple [H, W] # shape is a tuple [H, W]
def denormalize_points(shape, points): def denormalize_points(shape, points):
assert (isinstance(shape, tuple) or isinstance(shape, list)) and len(shape) == 2, 'invalid shape : {:}'.format(shape) assert (isinstance(shape, tuple) or isinstance(shape, list)) and len(shape) == 2, "invalid shape : {:}".format(
assert isinstance(points, torch.Tensor) and (points.shape[0] == 2), 'points are wrong : {:}'.format(points.shape) shape
(H, W), points = shape, points.clone() )
points[0, :] = denormalize_L(points[0,:], W) assert isinstance(points, torch.Tensor) and (points.shape[0] == 2), "points are wrong : {:}".format(points.shape)
points[1, :] = denormalize_L(points[1,:], H) (H, W), points = shape, points.clone()
return points points[0, :] = denormalize_L(points[0, :], W)
points[1, :] = denormalize_L(points[1, :], H)
return points
# shape is a tuple [H, W] # shape is a tuple [H, W]
def denormalize_points_batch(shape, points): def denormalize_points_batch(shape, points):
assert (isinstance(shape, tuple) or isinstance(shape, list)) and len(shape) == 2, 'invalid shape : {:}'.format(shape) assert (isinstance(shape, tuple) or isinstance(shape, list)) and len(shape) == 2, "invalid shape : {:}".format(
assert isinstance(points, torch.Tensor) and (points.shape[-1] == 2), 'points are wrong : {:}'.format(points.shape) shape
(H, W), points = shape, points.clone() )
x = denormalize_L(points[...,0], W) assert isinstance(points, torch.Tensor) and (points.shape[-1] == 2), "points are wrong : {:}".format(points.shape)
y = denormalize_L(points[...,1], H) (H, W), points = shape, points.clone()
return torch.stack((x,y), dim=-1) x = denormalize_L(points[..., 0], W)
y = denormalize_L(points[..., 1], H)
return torch.stack((x, y), dim=-1)
# make target * theta = source # make target * theta = source
def solve2theta(source, target): def solve2theta(source, target):
source, target = source.clone(), target.clone() source, target = source.clone(), target.clone()
oks = source[2, :] == 1 oks = source[2, :] == 1
assert torch.sum(oks).item() >= 3, 'valid points : {:} is short'.format(oks) assert torch.sum(oks).item() >= 3, "valid points : {:} is short".format(oks)
if target.size(0) == 2: target = torch.cat((target, oks.unsqueeze(0).float()), dim=0) if target.size(0) == 2:
source, target = source[:, oks], target[:, oks] target = torch.cat((target, oks.unsqueeze(0).float()), dim=0)
source, target = source.transpose(1,0), target.transpose(1,0) source, target = source[:, oks], target[:, oks]
assert source.size(1) == target.size(1) == 3 source, target = source.transpose(1, 0), target.transpose(1, 0)
#X, residual, rank, s = np.linalg.lstsq(target.numpy(), source.numpy()) assert source.size(1) == target.size(1) == 3
#theta = torch.Tensor(X.T[:2, :]) # X, residual, rank, s = np.linalg.lstsq(target.numpy(), source.numpy())
X_, qr = torch.gels(source, target) # theta = torch.Tensor(X.T[:2, :])
theta = X_[:3, :2].transpose(1, 0) X_, qr = torch.gels(source, target)
return theta theta = X_[:3, :2].transpose(1, 0)
return theta
# shape = [H,W] # shape = [H,W]
def affine2image(image, theta, shape): def affine2image(image, theta, shape):
C, H, W = image.size() C, H, W = image.size()
theta = theta[:2, :].unsqueeze(0) theta = theta[:2, :].unsqueeze(0)
grid_size = torch.Size([1, C, shape[0], shape[1]]) grid_size = torch.Size([1, C, shape[0], shape[1]])
grid = F.affine_grid(theta, grid_size) grid = F.affine_grid(theta, grid_size)
affI = F.grid_sample(image.unsqueeze(0), grid, mode='bilinear', padding_mode='border') affI = F.grid_sample(image.unsqueeze(0), grid, mode="bilinear", padding_mode="border")
return affI.squeeze(0) return affI.squeeze(0)

View File

@ -1,16 +1,17 @@
import torch import torch
def obtain_accuracy(output, target, topk=(1,)): def obtain_accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k""" """Computes the precision@k for the specified values of k"""
maxk = max(topk) maxk = max(topk)
batch_size = target.size(0) batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True) _, pred = output.topk(maxk, 1, True, True)
pred = pred.t() pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred)) correct = pred.eq(target.view(1, -1).expand_as(pred))
res = [] res = []
for k in topk: for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size)) res.append(correct_k.mul_(100.0 / batch_size))
return res return res

View File

@ -4,191 +4,199 @@ import numpy as np
def count_parameters_in_MB(model): def count_parameters_in_MB(model):
return count_parameters(model, "mb") return count_parameters(model, "mb")
def count_parameters(model_or_parameters, unit="mb"): def count_parameters(model_or_parameters, unit="mb"):
if isinstance(model_or_parameters, nn.Module): if isinstance(model_or_parameters, nn.Module):
counts = np.sum(np.prod(v.size()) for v in model_or_parameters.parameters()) counts = np.sum(np.prod(v.size()) for v in model_or_parameters.parameters())
else: else:
counts = np.sum(np.prod(v.size()) for v in model_or_parameters) counts = np.sum(np.prod(v.size()) for v in model_or_parameters)
if unit.lower() == "mb": if unit.lower() == "mb":
counts /= 1e6 counts /= 1e6
elif unit.lower() == "kb": elif unit.lower() == "kb":
counts /= 1e3 counts /= 1e3
elif unit.lower() == "gb": elif unit.lower() == "gb":
counts /= 1e9 counts /= 1e9
elif unit is not None: elif unit is not None:
raise ValueError("Unknow unit: {:}".format(unit)) raise ValueError("Unknow unit: {:}".format(unit))
return counts return counts
def get_model_infos(model, shape): def get_model_infos(model, shape):
#model = copy.deepcopy( model ) # model = copy.deepcopy( model )
model = add_flops_counting_methods(model) model = add_flops_counting_methods(model)
#model = model.cuda() # model = model.cuda()
model.eval() model.eval()
#cache_inputs = torch.zeros(*shape).cuda() # cache_inputs = torch.zeros(*shape).cuda()
#cache_inputs = torch.zeros(*shape) # cache_inputs = torch.zeros(*shape)
cache_inputs = torch.rand(*shape) cache_inputs = torch.rand(*shape)
if next(model.parameters()).is_cuda: cache_inputs = cache_inputs.cuda() if next(model.parameters()).is_cuda:
#print_log('In the calculating function : cache input size : {:}'.format(cache_inputs.size()), log) cache_inputs = cache_inputs.cuda()
with torch.no_grad(): # print_log('In the calculating function : cache input size : {:}'.format(cache_inputs.size()), log)
_____ = model(cache_inputs) with torch.no_grad():
FLOPs = compute_average_flops_cost( model ) / 1e6 _____ = model(cache_inputs)
Param = count_parameters_in_MB(model) FLOPs = compute_average_flops_cost(model) / 1e6
Param = count_parameters_in_MB(model)
if hasattr(model, 'auxiliary_param'): if hasattr(model, "auxiliary_param"):
aux_params = count_parameters_in_MB(model.auxiliary_param()) aux_params = count_parameters_in_MB(model.auxiliary_param())
print ('The auxiliary params of this model is : {:}'.format(aux_params)) print("The auxiliary params of this model is : {:}".format(aux_params))
print ('We remove the auxiliary params from the total params ({:}) when counting'.format(Param)) print("We remove the auxiliary params from the total params ({:}) when counting".format(Param))
Param = Param - aux_params Param = Param - aux_params
#print_log('FLOPs : {:} MB'.format(FLOPs), log) # print_log('FLOPs : {:} MB'.format(FLOPs), log)
torch.cuda.empty_cache() torch.cuda.empty_cache()
model.apply( remove_hook_function ) model.apply(remove_hook_function)
return FLOPs, Param return FLOPs, Param
# ---- Public functions # ---- Public functions
def add_flops_counting_methods( model ): def add_flops_counting_methods(model):
model.__batch_counter__ = 0 model.__batch_counter__ = 0
add_batch_counter_hook_function( model ) add_batch_counter_hook_function(model)
model.apply( add_flops_counter_variable_or_reset ) model.apply(add_flops_counter_variable_or_reset)
model.apply( add_flops_counter_hook_function ) model.apply(add_flops_counter_hook_function)
return model return model
def compute_average_flops_cost(model): def compute_average_flops_cost(model):
""" """
A method that will be available after add_flops_counting_methods() is called on a desired net object. A method that will be available after add_flops_counting_methods() is called on a desired net object.
Returns current mean flops consumption per image. Returns current mean flops consumption per image.
""" """
batches_count = model.__batch_counter__ batches_count = model.__batch_counter__
flops_sum = 0 flops_sum = 0
#or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d) \ # or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d) \
for module in model.modules(): for module in model.modules():
if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear) \ if (
or isinstance(module, torch.nn.Conv1d) \ isinstance(module, torch.nn.Conv2d)
or hasattr(module, 'calculate_flop_self'): or isinstance(module, torch.nn.Linear)
flops_sum += module.__flops__ or isinstance(module, torch.nn.Conv1d)
return flops_sum / batches_count or hasattr(module, "calculate_flop_self")
):
flops_sum += module.__flops__
return flops_sum / batches_count
# ---- Internal functions # ---- Internal functions
def pool_flops_counter_hook(pool_module, inputs, output): def pool_flops_counter_hook(pool_module, inputs, output):
batch_size = inputs[0].size(0) batch_size = inputs[0].size(0)
kernel_size = pool_module.kernel_size kernel_size = pool_module.kernel_size
out_C, output_height, output_width = output.shape[1:] out_C, output_height, output_width = output.shape[1:]
assert out_C == inputs[0].size(1), '{:} vs. {:}'.format(out_C, inputs[0].size()) assert out_C == inputs[0].size(1), "{:} vs. {:}".format(out_C, inputs[0].size())
overall_flops = batch_size * out_C * output_height * output_width * kernel_size * kernel_size overall_flops = batch_size * out_C * output_height * output_width * kernel_size * kernel_size
pool_module.__flops__ += overall_flops pool_module.__flops__ += overall_flops
def self_calculate_flops_counter_hook(self_module, inputs, output): def self_calculate_flops_counter_hook(self_module, inputs, output):
overall_flops = self_module.calculate_flop_self(inputs[0].shape, output.shape) overall_flops = self_module.calculate_flop_self(inputs[0].shape, output.shape)
self_module.__flops__ += overall_flops self_module.__flops__ += overall_flops
def fc_flops_counter_hook(fc_module, inputs, output): def fc_flops_counter_hook(fc_module, inputs, output):
batch_size = inputs[0].size(0) batch_size = inputs[0].size(0)
xin, xout = fc_module.in_features, fc_module.out_features xin, xout = fc_module.in_features, fc_module.out_features
assert xin == inputs[0].size(1) and xout == output.size(1), 'IO=({:}, {:})'.format(xin, xout) assert xin == inputs[0].size(1) and xout == output.size(1), "IO=({:}, {:})".format(xin, xout)
overall_flops = batch_size * xin * xout overall_flops = batch_size * xin * xout
if fc_module.bias is not None: if fc_module.bias is not None:
overall_flops += batch_size * xout overall_flops += batch_size * xout
fc_module.__flops__ += overall_flops fc_module.__flops__ += overall_flops
def conv1d_flops_counter_hook(conv_module, inputs, outputs): def conv1d_flops_counter_hook(conv_module, inputs, outputs):
batch_size = inputs[0].size(0) batch_size = inputs[0].size(0)
outL = outputs.shape[-1] outL = outputs.shape[-1]
[kernel] = conv_module.kernel_size [kernel] = conv_module.kernel_size
in_channels = conv_module.in_channels in_channels = conv_module.in_channels
out_channels = conv_module.out_channels out_channels = conv_module.out_channels
groups = conv_module.groups groups = conv_module.groups
conv_per_position_flops = kernel * in_channels * out_channels / groups conv_per_position_flops = kernel * in_channels * out_channels / groups
active_elements_count = batch_size * outL
overall_flops = conv_per_position_flops * active_elements_count
if conv_module.bias is not None: active_elements_count = batch_size * outL
overall_flops += out_channels * active_elements_count overall_flops = conv_per_position_flops * active_elements_count
conv_module.__flops__ += overall_flops
if conv_module.bias is not None:
overall_flops += out_channels * active_elements_count
conv_module.__flops__ += overall_flops
def conv2d_flops_counter_hook(conv_module, inputs, output): def conv2d_flops_counter_hook(conv_module, inputs, output):
batch_size = inputs[0].size(0) batch_size = inputs[0].size(0)
output_height, output_width = output.shape[2:] output_height, output_width = output.shape[2:]
kernel_height, kernel_width = conv_module.kernel_size kernel_height, kernel_width = conv_module.kernel_size
in_channels = conv_module.in_channels in_channels = conv_module.in_channels
out_channels = conv_module.out_channels out_channels = conv_module.out_channels
groups = conv_module.groups groups = conv_module.groups
conv_per_position_flops = kernel_height * kernel_width * in_channels * out_channels / groups conv_per_position_flops = kernel_height * kernel_width * in_channels * out_channels / groups
active_elements_count = batch_size * output_height * output_width active_elements_count = batch_size * output_height * output_width
overall_flops = conv_per_position_flops * active_elements_count overall_flops = conv_per_position_flops * active_elements_count
if conv_module.bias is not None: if conv_module.bias is not None:
overall_flops += out_channels * active_elements_count overall_flops += out_channels * active_elements_count
conv_module.__flops__ += overall_flops conv_module.__flops__ += overall_flops
def batch_counter_hook(module, inputs, output): def batch_counter_hook(module, inputs, output):
# Can have multiple inputs, getting the first one # Can have multiple inputs, getting the first one
inputs = inputs[0] inputs = inputs[0]
batch_size = inputs.shape[0] batch_size = inputs.shape[0]
module.__batch_counter__ += batch_size module.__batch_counter__ += batch_size
def add_batch_counter_hook_function(module): def add_batch_counter_hook_function(module):
if not hasattr(module, '__batch_counter_handle__'): if not hasattr(module, "__batch_counter_handle__"):
handle = module.register_forward_hook(batch_counter_hook) handle = module.register_forward_hook(batch_counter_hook)
module.__batch_counter_handle__ = handle module.__batch_counter_handle__ = handle
def add_flops_counter_variable_or_reset(module): def add_flops_counter_variable_or_reset(module):
if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear) \ if (
or isinstance(module, torch.nn.Conv1d) \ isinstance(module, torch.nn.Conv2d)
or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d) \ or isinstance(module, torch.nn.Linear)
or hasattr(module, 'calculate_flop_self'): or isinstance(module, torch.nn.Conv1d)
module.__flops__ = 0 or isinstance(module, torch.nn.AvgPool2d)
or isinstance(module, torch.nn.MaxPool2d)
or hasattr(module, "calculate_flop_self")
):
module.__flops__ = 0
def add_flops_counter_hook_function(module): def add_flops_counter_hook_function(module):
if isinstance(module, torch.nn.Conv2d): if isinstance(module, torch.nn.Conv2d):
if not hasattr(module, '__flops_handle__'): if not hasattr(module, "__flops_handle__"):
handle = module.register_forward_hook(conv2d_flops_counter_hook) handle = module.register_forward_hook(conv2d_flops_counter_hook)
module.__flops_handle__ = handle module.__flops_handle__ = handle
elif isinstance(module, torch.nn.Conv1d): elif isinstance(module, torch.nn.Conv1d):
if not hasattr(module, '__flops_handle__'): if not hasattr(module, "__flops_handle__"):
handle = module.register_forward_hook(conv1d_flops_counter_hook) handle = module.register_forward_hook(conv1d_flops_counter_hook)
module.__flops_handle__ = handle module.__flops_handle__ = handle
elif isinstance(module, torch.nn.Linear): elif isinstance(module, torch.nn.Linear):
if not hasattr(module, '__flops_handle__'): if not hasattr(module, "__flops_handle__"):
handle = module.register_forward_hook(fc_flops_counter_hook) handle = module.register_forward_hook(fc_flops_counter_hook)
module.__flops_handle__ = handle module.__flops_handle__ = handle
elif isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d): elif isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d):
if not hasattr(module, '__flops_handle__'): if not hasattr(module, "__flops_handle__"):
handle = module.register_forward_hook(pool_flops_counter_hook) handle = module.register_forward_hook(pool_flops_counter_hook)
module.__flops_handle__ = handle module.__flops_handle__ = handle
elif hasattr(module, 'calculate_flop_self'): # self-defined module elif hasattr(module, "calculate_flop_self"): # self-defined module
if not hasattr(module, '__flops_handle__'): if not hasattr(module, "__flops_handle__"):
handle = module.register_forward_hook(self_calculate_flops_counter_hook) handle = module.register_forward_hook(self_calculate_flops_counter_hook)
module.__flops_handle__ = handle module.__flops_handle__ = handle
def remove_hook_function(module): def remove_hook_function(module):
hookers = ['__batch_counter_handle__', '__flops_handle__'] hookers = ["__batch_counter_handle__", "__flops_handle__"]
for hooker in hookers: for hooker in hookers:
if hasattr(module, hooker): if hasattr(module, hooker):
handle = getattr(module, hooker) handle = getattr(module, hooker)
handle.remove() handle.remove()
keys = ['__flops__', '__batch_counter__', '__flops__'] + hookers keys = ["__flops__", "__batch_counter__", "__flops__"] + hookers
for ckey in keys: for ckey in keys:
if hasattr(module, ckey): delattr(module, ckey) if hasattr(module, ckey):
delattr(module, ckey)

View File

@ -1,65 +1,69 @@
import os import os
class GPUManager():
queries = ('index', 'gpu_name', 'memory.free', 'memory.used', 'memory.total', 'power.draw', 'power.limit')
def __init__(self): class GPUManager:
all_gpus = self.query_gpu(False) queries = ("index", "gpu_name", "memory.free", "memory.used", "memory.total", "power.draw", "power.limit")
def get_info(self, ctype): def __init__(self):
cmd = 'nvidia-smi --query-gpu={} --format=csv,noheader'.format(ctype) all_gpus = self.query_gpu(False)
lines = os.popen(cmd).readlines()
lines = [line.strip('\n') for line in lines]
return lines
def query_gpu(self, show=True): def get_info(self, ctype):
num_gpus = len( self.get_info('index') ) cmd = "nvidia-smi --query-gpu={} --format=csv,noheader".format(ctype)
all_gpus = [ {} for i in range(num_gpus) ] lines = os.popen(cmd).readlines()
for query in self.queries: lines = [line.strip("\n") for line in lines]
infos = self.get_info(query) return lines
for idx, info in enumerate(infos):
all_gpus[idx][query] = info
if 'CUDA_VISIBLE_DEVICES' in os.environ: def query_gpu(self, show=True):
CUDA_VISIBLE_DEVICES = os.environ['CUDA_VISIBLE_DEVICES'].split(',') num_gpus = len(self.get_info("index"))
selected_gpus = [] all_gpus = [{} for i in range(num_gpus)]
for idx, CUDA_VISIBLE_DEVICE in enumerate(CUDA_VISIBLE_DEVICES):
find = False
for gpu in all_gpus:
if gpu['index'] == CUDA_VISIBLE_DEVICE:
assert not find, 'Duplicate cuda device index : {}'.format(CUDA_VISIBLE_DEVICE)
find = True
selected_gpus.append( gpu.copy() )
selected_gpus[-1]['index'] = '{}'.format(idx)
assert find, 'Does not find the device : {}'.format(CUDA_VISIBLE_DEVICE)
all_gpus = selected_gpus
if show:
allstrings = ''
for gpu in all_gpus:
string = '| '
for query in self.queries: for query in self.queries:
if query.find('memory') == 0: xinfo = '{:>9}'.format(gpu[query]) infos = self.get_info(query)
else: xinfo = gpu[query] for idx, info in enumerate(infos):
string = string + query + ' : ' + xinfo + ' | ' all_gpus[idx][query] = info
allstrings = allstrings + string + '\n'
return allstrings if "CUDA_VISIBLE_DEVICES" in os.environ:
else: CUDA_VISIBLE_DEVICES = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
return all_gpus selected_gpus = []
for idx, CUDA_VISIBLE_DEVICE in enumerate(CUDA_VISIBLE_DEVICES):
find = False
for gpu in all_gpus:
if gpu["index"] == CUDA_VISIBLE_DEVICE:
assert not find, "Duplicate cuda device index : {}".format(CUDA_VISIBLE_DEVICE)
find = True
selected_gpus.append(gpu.copy())
selected_gpus[-1]["index"] = "{}".format(idx)
assert find, "Does not find the device : {}".format(CUDA_VISIBLE_DEVICE)
all_gpus = selected_gpus
if show:
allstrings = ""
for gpu in all_gpus:
string = "| "
for query in self.queries:
if query.find("memory") == 0:
xinfo = "{:>9}".format(gpu[query])
else:
xinfo = gpu[query]
string = string + query + " : " + xinfo + " | "
allstrings = allstrings + string + "\n"
return allstrings
else:
return all_gpus
def select_by_memory(self, numbers=1):
all_gpus = self.query_gpu(False)
assert numbers <= len(all_gpus), "Require {} gpus more than you have".format(numbers)
alls = []
for idx, gpu in enumerate(all_gpus):
free_memory = gpu["memory.free"]
free_memory = free_memory.split(" ")[0]
free_memory = int(free_memory)
index = gpu["index"]
alls.append((free_memory, index))
alls.sort(reverse=True)
alls = [int(alls[i][1]) for i in range(numbers)]
return sorted(alls)
def select_by_memory(self, numbers=1):
all_gpus = self.query_gpu(False)
assert numbers <= len(all_gpus), 'Require {} gpus more than you have'.format(numbers)
alls = []
for idx, gpu in enumerate(all_gpus):
free_memory = gpu['memory.free']
free_memory = free_memory.split(' ')[0]
free_memory = int(free_memory)
index = gpu['index']
alls.append((free_memory, index))
alls.sort(reverse = True)
alls = [ int(alls[i][1]) for i in range(numbers) ]
return sorted(alls)
""" """
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -1,16 +1,17 @@
import os, hashlib import os
import hashlib
def get_md5_file(file_path, post_truncated=5): def get_md5_file(file_path, post_truncated=5):
md5_hash = hashlib.md5() md5_hash = hashlib.md5()
if os.path.exists(file_path): if os.path.exists(file_path):
xfile = open(file_path, "rb") xfile = open(file_path, "rb")
content = xfile.read() content = xfile.read()
md5_hash.update(content) md5_hash.update(content)
digest = md5_hash.hexdigest() digest = md5_hash.hexdigest()
else: else:
raise ValueError('[get_md5_file] {:} does not exist'.format(file_path)) raise ValueError("[get_md5_file] {:} does not exist".format(file_path))
if post_truncated is None: if post_truncated is None:
return digest return digest
else: else:
return digest[-post_truncated:] return digest[-post_truncated:]

View File

@ -10,48 +10,58 @@ from log_utils import time_string
def evaluate_one_shot(model, xloader, api, cal_mode, seed=111): def evaluate_one_shot(model, xloader, api, cal_mode, seed=111):
print ('This is an old version of codes to use NAS-Bench-API, and should be modified to align with the new version. Please contact me for more details if you use this function.') print(
weights = deepcopy(model.state_dict()) "This is an old version of codes to use NAS-Bench-API, and should be modified to align with the new version. Please contact me for more details if you use this function."
model.train(cal_mode) )
with torch.no_grad(): weights = deepcopy(model.state_dict())
logits = nn.functional.log_softmax(model.arch_parameters, dim=-1) model.train(cal_mode)
archs = CellStructure.gen_all(model.op_names, model.max_nodes, False) with torch.no_grad():
probs, accuracies, gt_accs_10_valid, gt_accs_10_test = [], [], [], [] logits = nn.functional.log_softmax(model.arch_parameters, dim=-1)
loader_iter = iter(xloader) archs = CellStructure.gen_all(model.op_names, model.max_nodes, False)
random.seed(seed) probs, accuracies, gt_accs_10_valid, gt_accs_10_test = [], [], [], []
random.shuffle(archs)
for idx, arch in enumerate(archs):
arch_index = api.query_index_by_arch( arch )
metrics = api.get_more_info(arch_index, 'cifar10-valid', None, False, False)
gt_accs_10_valid.append( metrics['valid-accuracy'] )
metrics = api.get_more_info(arch_index, 'cifar10', None, False, False)
gt_accs_10_test.append( metrics['test-accuracy'] )
select_logits = []
for i, node_info in enumerate(arch.nodes):
for op, xin in node_info:
node_str = '{:}<-{:}'.format(i+1, xin)
op_index = model.op_names.index(op)
select_logits.append( logits[model.edge2index[node_str], op_index] )
cur_prob = sum(select_logits).item()
probs.append( cur_prob )
cor_prob_valid = np.corrcoef(probs, gt_accs_10_valid)[0,1]
cor_prob_test = np.corrcoef(probs, gt_accs_10_test )[0,1]
print ('{:} correlation for probabilities : {:.6f} on CIFAR-10 validation and {:.6f} on CIFAR-10 test'.format(time_string(), cor_prob_valid, cor_prob_test))
for idx, arch in enumerate(archs):
model.set_cal_mode('dynamic', arch)
try:
inputs, targets = next(loader_iter)
except:
loader_iter = iter(xloader) loader_iter = iter(xloader)
inputs, targets = next(loader_iter) random.seed(seed)
_, logits = model(inputs.cuda()) random.shuffle(archs)
_, preds = torch.max(logits, dim=-1) for idx, arch in enumerate(archs):
correct = (preds == targets.cuda() ).float() arch_index = api.query_index_by_arch(arch)
accuracies.append( correct.mean().item() ) metrics = api.get_more_info(arch_index, "cifar10-valid", None, False, False)
if idx != 0 and (idx % 500 == 0 or idx + 1 == len(archs)): gt_accs_10_valid.append(metrics["valid-accuracy"])
cor_accs_valid = np.corrcoef(accuracies, gt_accs_10_valid[:idx+1])[0,1] metrics = api.get_more_info(arch_index, "cifar10", None, False, False)
cor_accs_test = np.corrcoef(accuracies, gt_accs_10_test [:idx+1])[0,1] gt_accs_10_test.append(metrics["test-accuracy"])
print ('{:} {:05d}/{:05d} mode={:5s}, correlation : accs={:.5f} for CIFAR-10 valid, {:.5f} for CIFAR-10 test.'.format(time_string(), idx, len(archs), 'Train' if cal_mode else 'Eval', cor_accs_valid, cor_accs_test)) select_logits = []
model.load_state_dict(weights) for i, node_info in enumerate(arch.nodes):
return archs, probs, accuracies for op, xin in node_info:
node_str = "{:}<-{:}".format(i + 1, xin)
op_index = model.op_names.index(op)
select_logits.append(logits[model.edge2index[node_str], op_index])
cur_prob = sum(select_logits).item()
probs.append(cur_prob)
cor_prob_valid = np.corrcoef(probs, gt_accs_10_valid)[0, 1]
cor_prob_test = np.corrcoef(probs, gt_accs_10_test)[0, 1]
print(
"{:} correlation for probabilities : {:.6f} on CIFAR-10 validation and {:.6f} on CIFAR-10 test".format(
time_string(), cor_prob_valid, cor_prob_test
)
)
for idx, arch in enumerate(archs):
model.set_cal_mode("dynamic", arch)
try:
inputs, targets = next(loader_iter)
except:
loader_iter = iter(xloader)
inputs, targets = next(loader_iter)
_, logits = model(inputs.cuda())
_, preds = torch.max(logits, dim=-1)
correct = (preds == targets.cuda()).float()
accuracies.append(correct.mean().item())
if idx != 0 and (idx % 500 == 0 or idx + 1 == len(archs)):
cor_accs_valid = np.corrcoef(accuracies, gt_accs_10_valid[: idx + 1])[0, 1]
cor_accs_test = np.corrcoef(accuracies, gt_accs_10_test[: idx + 1])[0, 1]
print(
"{:} {:05d}/{:05d} mode={:5s}, correlation : accs={:.5f} for CIFAR-10 valid, {:.5f} for CIFAR-10 test.".format(
time_string(), idx, len(archs), "Train" if cal_mode else "Eval", cor_accs_valid, cor_accs_test
)
)
model.load_state_dict(weights)
return archs, probs, accuracies

View File

@ -1,18 +1,17 @@
def split_str2indexes(string: str, max_check: int, length_limit=5): def split_str2indexes(string: str, max_check: int, length_limit=5):
if not isinstance(string, str): if not isinstance(string, str):
raise ValueError('Invalid scheme for {:}'.format(string)) raise ValueError("Invalid scheme for {:}".format(string))
srangestr = "".join(string.split()) srangestr = "".join(string.split())
indexes = set() indexes = set()
for srange in srangestr.split(','): for srange in srangestr.split(","):
srange = srange.split('-') srange = srange.split("-")
if len(srange) != 2: if len(srange) != 2:
raise ValueError('invalid srange : {:}'.format(srange)) raise ValueError("invalid srange : {:}".format(srange))
if length_limit is not None: if length_limit is not None:
assert len(srange[0]) == len(srange[1]) == length_limit, 'invalid srange : {:}'.format(srange) assert len(srange[0]) == len(srange[1]) == length_limit, "invalid srange : {:}".format(srange)
srange = (int(srange[0]), int(srange[1])) srange = (int(srange[0]), int(srange[1]))
if not (0 <= srange[0] <= srange[1] < max_check): if not (0 <= srange[0] <= srange[1] < max_check):
raise ValueError('{:} vs {:} vs {:}'.format(srange[0], srange[1], max_check)) raise ValueError("{:} vs {:} vs {:}".format(srange[0], srange[1], max_check))
for i in range(srange[0], srange[1]+1): for i in range(srange[0], srange[1] + 1):
indexes.add(i) indexes.add(i)
return indexes return indexes

View File

@ -11,309 +11,350 @@ from sklearn.decomposition import TruncatedSVD
def available_module_types(): def available_module_types():
return (nn.Conv2d, nn.Linear) return (nn.Conv2d, nn.Linear)
def get_conv2D_Wmats(tensor: np.ndarray) -> List[np.ndarray]: def get_conv2D_Wmats(tensor: np.ndarray) -> List[np.ndarray]:
""" """
Extract W slices from a 4 index conv2D tensor of shape: (N,M,i,j) or (M,N,i,j). Extract W slices from a 4 index conv2D tensor of shape: (N,M,i,j) or (M,N,i,j).
Return ij (N x M) matrices Return ij (N x M) matrices
""" """
mats = [] mats = []
N, M, imax, jmax = tensor.shape N, M, imax, jmax = tensor.shape
assert N + M >= imax + jmax, 'invalid tensor shape detected: {}x{} (NxM), {}x{} (i,j)'.format(N, M, imax, jmax) assert N + M >= imax + jmax, "invalid tensor shape detected: {}x{} (NxM), {}x{} (i,j)".format(N, M, imax, jmax)
for i in range(imax): for i in range(imax):
for j in range(jmax): for j in range(jmax):
w = tensor[:, :, i, j] w = tensor[:, :, i, j]
if N < M: w = w.T if N < M:
mats.append(w) w = w.T
return mats mats.append(w)
return mats
def glorot_norm_check(W, N, M, rf_size, lower=0.5, upper=1.5): def glorot_norm_check(W, N, M, rf_size, lower=0.5, upper=1.5):
"""Check if this layer needs Glorot Normalization Fix""" """Check if this layer needs Glorot Normalization Fix"""
kappa = np.sqrt(2 / ((N + M) * rf_size)) kappa = np.sqrt(2 / ((N + M) * rf_size))
norm = np.linalg.norm(W) norm = np.linalg.norm(W)
check1 = norm / np.sqrt(N * M) check1 = norm / np.sqrt(N * M)
check2 = norm / (kappa * np.sqrt(N * M)) check2 = norm / (kappa * np.sqrt(N * M))
if (rf_size > 1) and (check2 > lower) and (check2 < upper):
return check2, True
elif (check1 > lower) & (check1 < upper):
return check1, True
else:
if rf_size > 1:
return check2, False
else:
return check1, False
if (rf_size > 1) and (check2 > lower) and (check2 < upper):
return check2, True
elif (check1 > lower) & (check1 < upper):
return check1, True
else:
if rf_size > 1: return check2, False
else: return check1, False
def glorot_norm_fix(w, n, m, rf_size): def glorot_norm_fix(w, n, m, rf_size):
"""Apply Glorot Normalization Fix.""" """Apply Glorot Normalization Fix."""
kappa = np.sqrt(2 / ((n + m) * rf_size)) kappa = np.sqrt(2 / ((n + m) * rf_size))
w = w / kappa w = w / kappa
return w return w
def analyze_weights(weights, min_size, max_size, alphas, lognorms, spectralnorms, softranks, normalize, glorot_fix): def analyze_weights(weights, min_size, max_size, alphas, lognorms, spectralnorms, softranks, normalize, glorot_fix):
results = OrderedDict() results = OrderedDict()
count = len(weights) count = len(weights)
if count == 0: return results if count == 0:
return results
for i, weight in enumerate(weights): for i, weight in enumerate(weights):
M, N = np.min(weight.shape), np.max(weight.shape) M, N = np.min(weight.shape), np.max(weight.shape)
Q = N / M Q = N / M
results[i] = cur_res = OrderedDict(N=N, M=M, Q=Q) results[i] = cur_res = OrderedDict(N=N, M=M, Q=Q)
check, checkTF = glorot_norm_check(weight, N, M, count) check, checkTF = glorot_norm_check(weight, N, M, count)
cur_res['check'] = check cur_res["check"] = check
cur_res['checkTF'] = checkTF cur_res["checkTF"] = checkTF
# assume receptive field size is count # assume receptive field size is count
if glorot_fix: if glorot_fix:
weight = glorot_norm_fix(weight, N, M, count) weight = glorot_norm_fix(weight, N, M, count)
else: else:
# probably never needed since we always fix for glorot # probably never needed since we always fix for glorot
weight = weight * np.sqrt(count / 2.0) weight = weight * np.sqrt(count / 2.0)
if spectralnorms: # spectralnorm is the max eigenvalues if spectralnorms: # spectralnorm is the max eigenvalues
svd = TruncatedSVD(n_components=1, n_iter=7, random_state=10) svd = TruncatedSVD(n_components=1, n_iter=7, random_state=10)
svd.fit(weight) svd.fit(weight)
sv = svd.singular_values_ sv = svd.singular_values_
sv_max = np.max(sv) sv_max = np.max(sv)
if normalize: if normalize:
evals = sv * sv / N evals = sv * sv / N
else: else:
evals = sv * sv evals = sv * sv
lambda0 = evals[0] lambda0 = evals[0]
cur_res["spectralnorm"] = lambda0 cur_res["spectralnorm"] = lambda0
cur_res["logspectralnorm"] = np.log10(lambda0) cur_res["logspectralnorm"] = np.log10(lambda0)
else: else:
lambda0 = None lambda0 = None
if M < min_size: if M < min_size:
summary = "Weight matrix {}/{} ({},{}): Skipping: too small (<{})".format(i + 1, count, M, N, min_size) summary = "Weight matrix {}/{} ({},{}): Skipping: too small (<{})".format(i + 1, count, M, N, min_size)
cur_res["summary"] = summary cur_res["summary"] = summary
continue continue
elif max_size > 0 and M > max_size: elif max_size > 0 and M > max_size:
summary = "Weight matrix {}/{} ({},{}): Skipping: too big (testing) (>{})".format(i + 1, count, M, N, max_size) summary = "Weight matrix {}/{} ({},{}): Skipping: too big (testing) (>{})".format(
cur_res["summary"] = summary i + 1, count, M, N, max_size
continue )
else: cur_res["summary"] = summary
summary = [] continue
if alphas: else:
import powerlaw summary = []
svd = TruncatedSVD(n_components=M - 1, n_iter=7, random_state=10) if alphas:
svd.fit(weight.astype(float)) import powerlaw
sv = svd.singular_values_
if normalize: evals = sv * sv / N
else: evals = sv * sv
lambda_max = np.max(evals) svd = TruncatedSVD(n_components=M - 1, n_iter=7, random_state=10)
fit = powerlaw.Fit(evals, xmax=lambda_max, verbose=False) svd.fit(weight.astype(float))
alpha = fit.alpha sv = svd.singular_values_
cur_res["alpha"] = alpha if normalize:
D = fit.D evals = sv * sv / N
cur_res["D"] = D else:
cur_res["lambda_min"] = np.min(evals) evals = sv * sv
cur_res["lambda_max"] = lambda_max
alpha_weighted = alpha * np.log10(lambda_max)
cur_res["alpha_weighted"] = alpha_weighted
tolerance = lambda_max * M * np.finfo(np.max(sv)).eps
cur_res["rank_loss"] = np.count_nonzero(sv > tolerance, axis=-1)
logpnorm = np.log10(np.sum([ev ** alpha for ev in evals])) lambda_max = np.max(evals)
cur_res["logpnorm"] = logpnorm fit = powerlaw.Fit(evals, xmax=lambda_max, verbose=False)
alpha = fit.alpha
cur_res["alpha"] = alpha
D = fit.D
cur_res["D"] = D
cur_res["lambda_min"] = np.min(evals)
cur_res["lambda_max"] = lambda_max
alpha_weighted = alpha * np.log10(lambda_max)
cur_res["alpha_weighted"] = alpha_weighted
tolerance = lambda_max * M * np.finfo(np.max(sv)).eps
cur_res["rank_loss"] = np.count_nonzero(sv > tolerance, axis=-1)
summary.append( logpnorm = np.log10(np.sum([ev ** alpha for ev in evals]))
"Weight matrix {}/{} ({},{}): Alpha: {}, Alpha Weighted: {}, D: {}, pNorm {}".format(i + 1, count, M, N, alpha, cur_res["logpnorm"] = logpnorm
alpha_weighted, D,
logpnorm))
if lognorms: summary.append(
norm = np.linalg.norm(weight) # Frobenius Norm "Weight matrix {}/{} ({},{}): Alpha: {}, Alpha Weighted: {}, D: {}, pNorm {}".format(
cur_res["norm"] = norm i + 1, count, M, N, alpha, alpha_weighted, D, logpnorm
lognorm = np.log10(norm) )
cur_res["lognorm"] = lognorm )
X = np.dot(weight.T, weight) if lognorms:
if normalize: X = X / N norm = np.linalg.norm(weight) # Frobenius Norm
normX = np.linalg.norm(X) # Frobenius Norm cur_res["norm"] = norm
cur_res["normX"] = normX lognorm = np.log10(norm)
lognormX = np.log10(normX) cur_res["lognorm"] = lognorm
cur_res["lognormX"] = lognormX
summary.append( X = np.dot(weight.T, weight)
"Weight matrix {}/{} ({},{}): LogNorm: {} ; LogNormX: {}".format(i + 1, count, M, N, lognorm, lognormX)) if normalize:
X = X / N
normX = np.linalg.norm(X) # Frobenius Norm
cur_res["normX"] = normX
lognormX = np.log10(normX)
cur_res["lognormX"] = lognormX
if softranks: summary.append(
softrank = norm ** 2 / sv_max ** 2 "Weight matrix {}/{} ({},{}): LogNorm: {} ; LogNormX: {}".format(i + 1, count, M, N, lognorm, lognormX)
softranklog = np.log10(softrank) )
softranklogratio = lognorm / np.log10(sv_max)
cur_res["softrank"] = softrank if softranks:
cur_res["softranklog"] = softranklog softrank = norm ** 2 / sv_max ** 2
cur_res["softranklogratio"] = softranklogratio softranklog = np.log10(softrank)
summary += "{}. Softrank: {}. Softrank log: {}. Softrank log ratio: {}".format(summary, softrank, softranklog, softranklogratio = lognorm / np.log10(sv_max)
softranklogratio) cur_res["softrank"] = softrank
cur_res["summary"] = "\n".join(summary) cur_res["softranklog"] = softranklog
return results cur_res["softranklogratio"] = softranklogratio
summary += "{}. Softrank: {}. Softrank log: {}. Softrank log ratio: {}".format(
summary, softrank, softranklog, softranklogratio
)
cur_res["summary"] = "\n".join(summary)
return results
def compute_details(results): def compute_details(results):
""" """
Return a pandas data frame. Return a pandas data frame.
""" """
final_summary = OrderedDict() final_summary = OrderedDict()
metrics = { metrics = {
# key in "results" : pretty print name # key in "results" : pretty print name
"check": "Check", "check": "Check",
"checkTF": "CheckTF", "checkTF": "CheckTF",
"norm": "Norm", "norm": "Norm",
"lognorm": "LogNorm", "lognorm": "LogNorm",
"normX": "Norm X", "normX": "Norm X",
"lognormX": "LogNorm X", "lognormX": "LogNorm X",
"alpha": "Alpha", "alpha": "Alpha",
"alpha_weighted": "Alpha Weighted", "alpha_weighted": "Alpha Weighted",
"spectralnorm": "Spectral Norm", "spectralnorm": "Spectral Norm",
"logspectralnorm": "Log Spectral Norm", "logspectralnorm": "Log Spectral Norm",
"softrank": "Softrank", "softrank": "Softrank",
"softranklog": "Softrank Log", "softranklog": "Softrank Log",
"softranklogratio": "Softrank Log Ratio", "softranklogratio": "Softrank Log Ratio",
"sigma_mp": "Marchenko-Pastur (MP) fit sigma", "sigma_mp": "Marchenko-Pastur (MP) fit sigma",
"numofSpikes": "Number of spikes per MP fit", "numofSpikes": "Number of spikes per MP fit",
"ratio_numofSpikes": "aka, percent_mass, Number of spikes / total number of evals", "ratio_numofSpikes": "aka, percent_mass, Number of spikes / total number of evals",
"softrank_mp": "Softrank for MP fit", "softrank_mp": "Softrank for MP fit",
"logpnorm": "alpha pNorm" "logpnorm": "alpha pNorm",
} }
metrics_stats = [] metrics_stats = []
for metric in metrics:
metrics_stats.append("{}_min".format(metric))
metrics_stats.append("{}_max".format(metric))
metrics_stats.append("{}_avg".format(metric))
metrics_stats.append("{}_compound_min".format(metric))
metrics_stats.append("{}_compound_max".format(metric))
metrics_stats.append("{}_compound_avg".format(metric))
columns = ["layer_id", "layer_type", "N", "M", "layer_count", "slice",
"slice_count", "level", "comment"] + [*metrics] + metrics_stats
metrics_values = {}
metrics_values_compound = {}
for metric in metrics:
metrics_values[metric] = []
metrics_values_compound[metric] = []
layer_count = 0
for layer_id, result in results.items():
layer_count += 1
layer_type = np.NAN
if "layer_type" in result:
layer_type = str(result["layer_type"]).replace("LAYER_TYPE.", "")
compounds = {} # temp var
for metric in metrics: for metric in metrics:
compounds[metric] = [] metrics_stats.append("{}_min".format(metric))
metrics_stats.append("{}_max".format(metric))
metrics_stats.append("{}_avg".format(metric))
slice_count, Ntotal, Mtotal = 0, 0, 0 metrics_stats.append("{}_compound_min".format(metric))
for slice_id, summary in result.items(): metrics_stats.append("{}_compound_max".format(metric))
if not str(slice_id).isdigit(): metrics_stats.append("{}_compound_avg".format(metric))
continue
slice_count += 1
N = np.NAN columns = (
if "N" in summary: ["layer_id", "layer_type", "N", "M", "layer_count", "slice", "slice_count", "level", "comment"]
N = summary["N"] + [*metrics]
Ntotal += N + metrics_stats
)
M = np.NAN metrics_values = {}
if "M" in summary: metrics_values_compound = {}
M = summary["M"]
Mtotal += M
data = {"layer_id": layer_id, "layer_type": layer_type, "N": N, "M": M, "slice": slice_id, "level": "SLICE", for metric in metrics:
"comment": "Slice level"} metrics_values[metric] = []
for metric in metrics: metrics_values_compound[metric] = []
if metric in summary:
value = summary[metric]
if value is not None:
metrics_values[metric].append(value)
compounds[metric].append(value)
data[metric] = value
data = {"layer_id": layer_id, "layer_type": layer_type, "N": Ntotal, "M": Mtotal, "slice_count": slice_count, layer_count = 0
"level": "LAYER", "comment": "Layer level"} for layer_id, result in results.items():
# Compute the compound value over the slices layer_count += 1
for metric, value in compounds.items():
count = len(value)
if count == 0:
continue
compound = np.mean(value) layer_type = np.NAN
metrics_values_compound[metric].append(compound) if "layer_type" in result:
data[metric] = compound layer_type = str(result["layer_type"]).replace("LAYER_TYPE.", "")
data = {"layer_count": layer_count, "level": "NETWORK", "comment": "Network Level"} compounds = {} # temp var
for metric, metric_name in metrics.items(): for metric in metrics:
if metric not in metrics_values or len(metrics_values[metric]) == 0: compounds[metric] = []
continue
values = metrics_values[metric] slice_count, Ntotal, Mtotal = 0, 0, 0
minimum = min(values) for slice_id, summary in result.items():
maximum = max(values) if not str(slice_id).isdigit():
avg = np.mean(values) continue
final_summary[metric] = avg slice_count += 1
# print("{}: min: {}, max: {}, avg: {}".format(metric_name, minimum, maximum, avg))
data["{}_min".format(metric)] = minimum
data["{}_max".format(metric)] = maximum
data["{}_avg".format(metric)] = avg
values = metrics_values_compound[metric] N = np.NAN
minimum = min(values) if "N" in summary:
maximum = max(values) N = summary["N"]
avg = np.mean(values) Ntotal += N
final_summary["{}_compound".format(metric)] = avg
# print("{} compound: min: {}, max: {}, avg: {}".format(metric_name, minimum, maximum, avg))
data["{}_compound_min".format(metric)] = minimum
data["{}_compound_max".format(metric)] = maximum
data["{}_compound_avg".format(metric)] = avg
return final_summary M = np.NAN
if "M" in summary:
M = summary["M"]
Mtotal += M
data = {
"layer_id": layer_id,
"layer_type": layer_type,
"N": N,
"M": M,
"slice": slice_id,
"level": "SLICE",
"comment": "Slice level",
}
for metric in metrics:
if metric in summary:
value = summary[metric]
if value is not None:
metrics_values[metric].append(value)
compounds[metric].append(value)
data[metric] = value
data = {
"layer_id": layer_id,
"layer_type": layer_type,
"N": Ntotal,
"M": Mtotal,
"slice_count": slice_count,
"level": "LAYER",
"comment": "Layer level",
}
# Compute the compound value over the slices
for metric, value in compounds.items():
count = len(value)
if count == 0:
continue
compound = np.mean(value)
metrics_values_compound[metric].append(compound)
data[metric] = compound
data = {"layer_count": layer_count, "level": "NETWORK", "comment": "Network Level"}
for metric, metric_name in metrics.items():
if metric not in metrics_values or len(metrics_values[metric]) == 0:
continue
values = metrics_values[metric]
minimum = min(values)
maximum = max(values)
avg = np.mean(values)
final_summary[metric] = avg
# print("{}: min: {}, max: {}, avg: {}".format(metric_name, minimum, maximum, avg))
data["{}_min".format(metric)] = minimum
data["{}_max".format(metric)] = maximum
data["{}_avg".format(metric)] = avg
values = metrics_values_compound[metric]
minimum = min(values)
maximum = max(values)
avg = np.mean(values)
final_summary["{}_compound".format(metric)] = avg
# print("{} compound: min: {}, max: {}, avg: {}".format(metric_name, minimum, maximum, avg))
data["{}_compound_min".format(metric)] = minimum
data["{}_compound_max".format(metric)] = maximum
data["{}_compound_avg".format(metric)] = avg
return final_summary
def analyze(model: nn.Module, min_size=50, max_size=0, def analyze(
alphas: bool = False, lognorms: bool = True, spectralnorms: bool = False, model: nn.Module,
softranks: bool = False, normalize: bool = False, glorot_fix: bool = False): min_size=50,
""" max_size=0,
Analyze the weight matrices of a model. alphas: bool = False,
:param model: A PyTorch model lognorms: bool = True,
:param min_size: The minimum weight matrix size to analyze. spectralnorms: bool = False,
:param max_size: The maximum weight matrix size to analyze (0 = no limit). softranks: bool = False,
:param alphas: Compute the power laws (alpha) of the weight matrices. normalize: bool = False,
Time consuming so disabled by default (use lognorm if you want speed) glorot_fix: bool = False,
:param lognorms: Compute the log norms of the weight matrices. ):
:param spectralnorms: Compute the spectral norm (max eigenvalue) of the weight matrices. """
:param softranks: Compute the soft norm (i.e. StableRank) of the weight matrices. Analyze the weight matrices of a model.
:param normalize: Normalize or not. :param model: A PyTorch model
:param glorot_fix: :param min_size: The minimum weight matrix size to analyze.
:return: (a dict of all layers' results, a dict of the summarized info) :param max_size: The maximum weight matrix size to analyze (0 = no limit).
""" :param alphas: Compute the power laws (alpha) of the weight matrices.
names, modules = [], [] Time consuming so disabled by default (use lognorm if you want speed)
for name, module in model.named_modules(): :param lognorms: Compute the log norms of the weight matrices.
if isinstance(module, available_module_types()): :param spectralnorms: Compute the spectral norm (max eigenvalue) of the weight matrices.
names.append(name) :param softranks: Compute the soft norm (i.e. StableRank) of the weight matrices.
modules.append(module) :param normalize: Normalize or not.
# print('There are {:} layers to be analyzed in this model.'.format(len(modules))) :param glorot_fix:
all_results = OrderedDict() :return: (a dict of all layers' results, a dict of the summarized info)
for index, module in enumerate(modules): """
if isinstance(module, nn.Linear): names, modules = [], []
weights = [module.weight.cpu().detach().numpy()] for name, module in model.named_modules():
else: if isinstance(module, available_module_types()):
weights = get_conv2D_Wmats(module.weight.cpu().detach().numpy()) names.append(name)
results = analyze_weights(weights, min_size, max_size, alphas, lognorms, spectralnorms, softranks, normalize, glorot_fix) modules.append(module)
results['id'] = index # print('There are {:} layers to be analyzed in this model.'.format(len(modules)))
results['type'] = type(module) all_results = OrderedDict()
all_results[index] = results for index, module in enumerate(modules):
summary = compute_details(all_results) if isinstance(module, nn.Linear):
return all_results, summary weights = [module.weight.cpu().detach().numpy()]
else:
weights = get_conv2D_Wmats(module.weight.cpu().detach().numpy())
results = analyze_weights(
weights, min_size, max_size, alphas, lognorms, spectralnorms, softranks, normalize, glorot_fix
)
results["id"] = index
results["type"] = type(module)
all_results[index] = results
summary = compute_details(all_results)
return all_results, summary