Move to xautodl
This commit is contained in:
@@ -1,20 +0,0 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
# general config related functions
|
||||
from .config_utils import load_config, dict2config, configure2str
|
||||
|
||||
# the args setting for different experiments
|
||||
from .basic_args import obtain_basic_args
|
||||
from .attention_args import obtain_attention_args
|
||||
from .random_baseline import obtain_RandomSearch_args
|
||||
from .cls_kd_args import obtain_cls_kd_args
|
||||
from .cls_init_args import obtain_cls_init_args
|
||||
from .search_single_args import obtain_search_single_args
|
||||
from .search_args import obtain_search_args
|
||||
|
||||
# for network pruning
|
||||
from .pruning_args import obtain_pruning_args
|
||||
|
||||
# utils for args
|
||||
from .args_utils import arg_str2bool
|
||||
@@ -1,12 +0,0 @@
|
||||
import argparse
|
||||
|
||||
|
||||
def arg_str2bool(v):
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
elif v.lower() in ("yes", "true", "t", "y", "1"):
|
||||
return True
|
||||
elif v.lower() in ("no", "false", "f", "n", "0"):
|
||||
return False
|
||||
else:
|
||||
raise argparse.ArgumentTypeError("Boolean value expected.")
|
||||
@@ -1,32 +0,0 @@
|
||||
import random, argparse
|
||||
from .share_args import add_shared_args
|
||||
|
||||
|
||||
def obtain_attention_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Train a classification model on typical image classification datasets.",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument("--resume", type=str, help="Resume path.")
|
||||
parser.add_argument("--init_model", type=str, help="The initialization model path.")
|
||||
parser.add_argument(
|
||||
"--model_config", type=str, help="The path to the model configuration"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--optim_config", type=str, help="The path to the optimizer configuration"
|
||||
)
|
||||
parser.add_argument("--procedure", type=str, help="The procedure basic prefix.")
|
||||
parser.add_argument("--att_channel", type=int, help=".")
|
||||
parser.add_argument("--att_spatial", type=str, help=".")
|
||||
parser.add_argument("--att_active", type=str, help=".")
|
||||
add_shared_args(parser)
|
||||
# Optimization options
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=2, help="Batch size for training."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.rand_seed is None or args.rand_seed < 0:
|
||||
args.rand_seed = random.randint(1, 100000)
|
||||
assert args.save_dir is not None, "save-path argument can not be None"
|
||||
return args
|
||||
@@ -1,44 +0,0 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
|
||||
##################################################
|
||||
import random, argparse
|
||||
from .share_args import add_shared_args
|
||||
|
||||
|
||||
def obtain_basic_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Train a classification model on typical image classification datasets.",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument("--resume", type=str, help="Resume path.")
|
||||
parser.add_argument("--init_model", type=str, help="The initialization model path.")
|
||||
parser.add_argument(
|
||||
"--model_config", type=str, help="The path to the model configuration"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--optim_config", type=str, help="The path to the optimizer configuration"
|
||||
)
|
||||
parser.add_argument("--procedure", type=str, help="The procedure basic prefix.")
|
||||
parser.add_argument(
|
||||
"--model_source",
|
||||
type=str,
|
||||
default="normal",
|
||||
help="The source of model defination.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--extra_model_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The extra model ckp file (help to indicate the searched architecture).",
|
||||
)
|
||||
add_shared_args(parser)
|
||||
# Optimization options
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=2, help="Batch size for training."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.rand_seed is None or args.rand_seed < 0:
|
||||
args.rand_seed = random.randint(1, 100000)
|
||||
assert args.save_dir is not None, "save-path argument can not be None"
|
||||
return args
|
||||
@@ -1,32 +0,0 @@
|
||||
import random, argparse
|
||||
from .share_args import add_shared_args
|
||||
|
||||
|
||||
def obtain_cls_init_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Train a classification model on typical image classification datasets.",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument("--resume", type=str, help="Resume path.")
|
||||
parser.add_argument("--init_model", type=str, help="The initialization model path.")
|
||||
parser.add_argument(
|
||||
"--model_config", type=str, help="The path to the model configuration"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--optim_config", type=str, help="The path to the optimizer configuration"
|
||||
)
|
||||
parser.add_argument("--procedure", type=str, help="The procedure basic prefix.")
|
||||
parser.add_argument(
|
||||
"--init_checkpoint", type=str, help="The checkpoint path to the initial model."
|
||||
)
|
||||
add_shared_args(parser)
|
||||
# Optimization options
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=2, help="Batch size for training."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.rand_seed is None or args.rand_seed < 0:
|
||||
args.rand_seed = random.randint(1, 100000)
|
||||
assert args.save_dir is not None, "save-path argument can not be None"
|
||||
return args
|
||||
@@ -1,43 +0,0 @@
|
||||
import random, argparse
|
||||
from .share_args import add_shared_args
|
||||
|
||||
|
||||
def obtain_cls_kd_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Train a classification model on typical image classification datasets.",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument("--resume", type=str, help="Resume path.")
|
||||
parser.add_argument("--init_model", type=str, help="The initialization model path.")
|
||||
parser.add_argument(
|
||||
"--model_config", type=str, help="The path to the model configuration"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--optim_config", type=str, help="The path to the optimizer configuration"
|
||||
)
|
||||
parser.add_argument("--procedure", type=str, help="The procedure basic prefix.")
|
||||
parser.add_argument(
|
||||
"--KD_checkpoint",
|
||||
type=str,
|
||||
help="The teacher checkpoint in knowledge distillation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--KD_alpha", type=float, help="The alpha parameter in knowledge distillation."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--KD_temperature",
|
||||
type=float,
|
||||
help="The temperature parameter in knowledge distillation.",
|
||||
)
|
||||
# parser.add_argument('--KD_feature', type=float, help='Knowledge distillation at the feature level.')
|
||||
add_shared_args(parser)
|
||||
# Optimization options
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=2, help="Batch size for training."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.rand_seed is None or args.rand_seed < 0:
|
||||
args.rand_seed = random.randint(1, 100000)
|
||||
assert args.save_dir is not None, "save-path argument can not be None"
|
||||
return args
|
||||
@@ -1,135 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
#
|
||||
import os, json
|
||||
from os import path as osp
|
||||
from pathlib import Path
|
||||
from collections import namedtuple
|
||||
|
||||
support_types = ("str", "int", "bool", "float", "none")
|
||||
|
||||
|
||||
def convert_param(original_lists):
|
||||
assert isinstance(original_lists, list), "The type is not right : {:}".format(
|
||||
original_lists
|
||||
)
|
||||
ctype, value = original_lists[0], original_lists[1]
|
||||
assert ctype in support_types, "Ctype={:}, support={:}".format(ctype, support_types)
|
||||
is_list = isinstance(value, list)
|
||||
if not is_list:
|
||||
value = [value]
|
||||
outs = []
|
||||
for x in value:
|
||||
if ctype == "int":
|
||||
x = int(x)
|
||||
elif ctype == "str":
|
||||
x = str(x)
|
||||
elif ctype == "bool":
|
||||
x = bool(int(x))
|
||||
elif ctype == "float":
|
||||
x = float(x)
|
||||
elif ctype == "none":
|
||||
if x.lower() != "none":
|
||||
raise ValueError(
|
||||
"For the none type, the value must be none instead of {:}".format(x)
|
||||
)
|
||||
x = None
|
||||
else:
|
||||
raise TypeError("Does not know this type : {:}".format(ctype))
|
||||
outs.append(x)
|
||||
if not is_list:
|
||||
outs = outs[0]
|
||||
return outs
|
||||
|
||||
|
||||
def load_config(path, extra, logger):
|
||||
path = str(path)
|
||||
if hasattr(logger, "log"):
|
||||
logger.log(path)
|
||||
assert os.path.exists(path), "Can not find {:}".format(path)
|
||||
# Reading data back
|
||||
with open(path, "r") as f:
|
||||
data = json.load(f)
|
||||
content = {k: convert_param(v) for k, v in data.items()}
|
||||
assert extra is None or isinstance(
|
||||
extra, dict
|
||||
), "invalid type of extra : {:}".format(extra)
|
||||
if isinstance(extra, dict):
|
||||
content = {**content, **extra}
|
||||
Arguments = namedtuple("Configure", " ".join(content.keys()))
|
||||
content = Arguments(**content)
|
||||
if hasattr(logger, "log"):
|
||||
logger.log("{:}".format(content))
|
||||
return content
|
||||
|
||||
|
||||
def configure2str(config, xpath=None):
|
||||
if not isinstance(config, dict):
|
||||
config = config._asdict()
|
||||
|
||||
def cstring(x):
|
||||
return '"{:}"'.format(x)
|
||||
|
||||
def gtype(x):
|
||||
if isinstance(x, list):
|
||||
x = x[0]
|
||||
if isinstance(x, str):
|
||||
return "str"
|
||||
elif isinstance(x, bool):
|
||||
return "bool"
|
||||
elif isinstance(x, int):
|
||||
return "int"
|
||||
elif isinstance(x, float):
|
||||
return "float"
|
||||
elif x is None:
|
||||
return "none"
|
||||
else:
|
||||
raise ValueError("invalid : {:}".format(x))
|
||||
|
||||
def cvalue(x, xtype):
|
||||
if isinstance(x, list):
|
||||
is_list = True
|
||||
else:
|
||||
is_list, x = False, [x]
|
||||
temps = []
|
||||
for temp in x:
|
||||
if xtype == "bool":
|
||||
temp = cstring(int(temp))
|
||||
elif xtype == "none":
|
||||
temp = cstring("None")
|
||||
else:
|
||||
temp = cstring(temp)
|
||||
temps.append(temp)
|
||||
if is_list:
|
||||
return "[{:}]".format(", ".join(temps))
|
||||
else:
|
||||
return temps[0]
|
||||
|
||||
xstrings = []
|
||||
for key, value in config.items():
|
||||
xtype = gtype(value)
|
||||
string = " {:20s} : [{:8s}, {:}]".format(
|
||||
cstring(key), cstring(xtype), cvalue(value, xtype)
|
||||
)
|
||||
xstrings.append(string)
|
||||
Fstring = "{\n" + ",\n".join(xstrings) + "\n}"
|
||||
if xpath is not None:
|
||||
parent = Path(xpath).resolve().parent
|
||||
parent.mkdir(parents=True, exist_ok=True)
|
||||
if osp.isfile(xpath):
|
||||
os.remove(xpath)
|
||||
with open(xpath, "w") as text_file:
|
||||
text_file.write("{:}".format(Fstring))
|
||||
return Fstring
|
||||
|
||||
|
||||
def dict2config(xdict, logger):
|
||||
assert isinstance(xdict, dict), "invalid type : {:}".format(type(xdict))
|
||||
Arguments = namedtuple("Configure", " ".join(xdict.keys()))
|
||||
content = Arguments(**xdict)
|
||||
if hasattr(logger, "log"):
|
||||
logger.log("{:}".format(content))
|
||||
return content
|
||||
@@ -1,48 +0,0 @@
|
||||
import os, sys, time, random, argparse
|
||||
from .share_args import add_shared_args
|
||||
|
||||
|
||||
def obtain_pruning_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Train a classification model on typical image classification datasets.",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument("--resume", type=str, help="Resume path.")
|
||||
parser.add_argument("--init_model", type=str, help="The initialization model path.")
|
||||
parser.add_argument(
|
||||
"--model_config", type=str, help="The path to the model configuration"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--optim_config", type=str, help="The path to the optimizer configuration"
|
||||
)
|
||||
parser.add_argument("--procedure", type=str, help="The procedure basic prefix.")
|
||||
parser.add_argument(
|
||||
"--keep_ratio",
|
||||
type=float,
|
||||
help="The left channel ratio compared to the original network.",
|
||||
)
|
||||
parser.add_argument("--model_version", type=str, help="The network version.")
|
||||
parser.add_argument(
|
||||
"--KD_alpha", type=float, help="The alpha parameter in knowledge distillation."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--KD_temperature",
|
||||
type=float,
|
||||
help="The temperature parameter in knowledge distillation.",
|
||||
)
|
||||
parser.add_argument("--Regular_W_feat", type=float, help="The .")
|
||||
parser.add_argument("--Regular_W_conv", type=float, help="The .")
|
||||
add_shared_args(parser)
|
||||
# Optimization options
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=2, help="Batch size for training."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.rand_seed is None or args.rand_seed < 0:
|
||||
args.rand_seed = random.randint(1, 100000)
|
||||
assert args.save_dir is not None, "save-path argument can not be None"
|
||||
assert (
|
||||
args.keep_ratio > 0 and args.keep_ratio <= 1
|
||||
), "invalid keep ratio : {:}".format(args.keep_ratio)
|
||||
return args
|
||||
@@ -1,44 +0,0 @@
|
||||
import os, sys, time, random, argparse
|
||||
from .share_args import add_shared_args
|
||||
|
||||
|
||||
def obtain_RandomSearch_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Train a classification model on typical image classification datasets.",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument("--resume", type=str, help="Resume path.")
|
||||
parser.add_argument("--init_model", type=str, help="The initialization model path.")
|
||||
parser.add_argument(
|
||||
"--expect_flop", type=float, help="The expected flop keep ratio."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch_nums",
|
||||
type=int,
|
||||
help="The maximum number of running random arch generating..",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_config", type=str, help="The path to the model configuration"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--optim_config", type=str, help="The path to the optimizer configuration"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random_mode",
|
||||
type=str,
|
||||
choices=["random", "fix"],
|
||||
help="The path to the optimizer configuration",
|
||||
)
|
||||
parser.add_argument("--procedure", type=str, help="The procedure basic prefix.")
|
||||
add_shared_args(parser)
|
||||
# Optimization options
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=2, help="Batch size for training."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.rand_seed is None or args.rand_seed < 0:
|
||||
args.rand_seed = random.randint(1, 100000)
|
||||
assert args.save_dir is not None, "save-path argument can not be None"
|
||||
# assert args.flop_ratio_min < args.flop_ratio_max, 'flop-ratio {:} vs {:}'.format(args.flop_ratio_min, args.flop_ratio_max)
|
||||
return args
|
||||
@@ -1,53 +0,0 @@
|
||||
import os, sys, time, random, argparse
|
||||
from .share_args import add_shared_args
|
||||
|
||||
|
||||
def obtain_search_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Train a classification model on typical image classification datasets.",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument("--resume", type=str, help="Resume path.")
|
||||
parser.add_argument(
|
||||
"--model_config", type=str, help="The path to the model configuration"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--optim_config", type=str, help="The path to the optimizer configuration"
|
||||
)
|
||||
parser.add_argument("--split_path", type=str, help="The split file path.")
|
||||
# parser.add_argument('--arch_para_pure', type=int, help='The architecture-parameter pure or not.')
|
||||
parser.add_argument(
|
||||
"--gumbel_tau_max", type=float, help="The maximum tau for Gumbel."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gumbel_tau_min", type=float, help="The minimum tau for Gumbel."
|
||||
)
|
||||
parser.add_argument("--procedure", type=str, help="The procedure basic prefix.")
|
||||
parser.add_argument("--FLOP_ratio", type=float, help="The expected FLOP ratio.")
|
||||
parser.add_argument("--FLOP_weight", type=float, help="The loss weight for FLOP.")
|
||||
parser.add_argument(
|
||||
"--FLOP_tolerant", type=float, help="The tolerant range for FLOP."
|
||||
)
|
||||
# ablation studies
|
||||
parser.add_argument(
|
||||
"--ablation_num_select",
|
||||
type=int,
|
||||
help="The number of randomly selected channels.",
|
||||
)
|
||||
add_shared_args(parser)
|
||||
# Optimization options
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=2, help="Batch size for training."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.rand_seed is None or args.rand_seed < 0:
|
||||
args.rand_seed = random.randint(1, 100000)
|
||||
assert args.save_dir is not None, "save-path argument can not be None"
|
||||
assert args.gumbel_tau_max is not None and args.gumbel_tau_min is not None
|
||||
assert (
|
||||
args.FLOP_tolerant is not None and args.FLOP_tolerant > 0
|
||||
), "invalid FLOP_tolerant : {:}".format(FLOP_tolerant)
|
||||
# assert args.arch_para_pure is not None, 'arch_para_pure is not None: {:}'.format(args.arch_para_pure)
|
||||
# args.arch_para_pure = bool(args.arch_para_pure)
|
||||
return args
|
||||
@@ -1,48 +0,0 @@
|
||||
import os, sys, time, random, argparse
|
||||
from .share_args import add_shared_args
|
||||
|
||||
|
||||
def obtain_search_single_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Train a classification model on typical image classification datasets.",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument("--resume", type=str, help="Resume path.")
|
||||
parser.add_argument(
|
||||
"--model_config", type=str, help="The path to the model configuration"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--optim_config", type=str, help="The path to the optimizer configuration"
|
||||
)
|
||||
parser.add_argument("--split_path", type=str, help="The split file path.")
|
||||
parser.add_argument("--search_shape", type=str, help="The shape to be searched.")
|
||||
# parser.add_argument('--arch_para_pure', type=int, help='The architecture-parameter pure or not.')
|
||||
parser.add_argument(
|
||||
"--gumbel_tau_max", type=float, help="The maximum tau for Gumbel."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gumbel_tau_min", type=float, help="The minimum tau for Gumbel."
|
||||
)
|
||||
parser.add_argument("--procedure", type=str, help="The procedure basic prefix.")
|
||||
parser.add_argument("--FLOP_ratio", type=float, help="The expected FLOP ratio.")
|
||||
parser.add_argument("--FLOP_weight", type=float, help="The loss weight for FLOP.")
|
||||
parser.add_argument(
|
||||
"--FLOP_tolerant", type=float, help="The tolerant range for FLOP."
|
||||
)
|
||||
add_shared_args(parser)
|
||||
# Optimization options
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=2, help="Batch size for training."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.rand_seed is None or args.rand_seed < 0:
|
||||
args.rand_seed = random.randint(1, 100000)
|
||||
assert args.save_dir is not None, "save-path argument can not be None"
|
||||
assert args.gumbel_tau_max is not None and args.gumbel_tau_min is not None
|
||||
assert (
|
||||
args.FLOP_tolerant is not None and args.FLOP_tolerant > 0
|
||||
), "invalid FLOP_tolerant : {:}".format(FLOP_tolerant)
|
||||
# assert args.arch_para_pure is not None, 'arch_para_pure is not None: {:}'.format(args.arch_para_pure)
|
||||
# args.arch_para_pure = bool(args.arch_para_pure)
|
||||
return args
|
||||
@@ -1,39 +0,0 @@
|
||||
import os, sys, time, random, argparse
|
||||
|
||||
|
||||
def add_shared_args(parser):
|
||||
# Data Generation
|
||||
parser.add_argument("--dataset", type=str, help="The dataset name.")
|
||||
parser.add_argument("--data_path", type=str, help="The dataset name.")
|
||||
parser.add_argument(
|
||||
"--cutout_length", type=int, help="The cutout length, negative means not use."
|
||||
)
|
||||
# Printing
|
||||
parser.add_argument(
|
||||
"--print_freq", type=int, default=100, help="print frequency (default: 200)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--print_freq_eval",
|
||||
type=int,
|
||||
default=100,
|
||||
help="print frequency (default: 200)",
|
||||
)
|
||||
# Checkpoints
|
||||
parser.add_argument(
|
||||
"--eval_frequency",
|
||||
type=int,
|
||||
default=1,
|
||||
help="evaluation frequency (default: 200)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_dir", type=str, help="Folder to save checkpoints and log."
|
||||
)
|
||||
# Acceleration
|
||||
parser.add_argument(
|
||||
"--workers",
|
||||
type=int,
|
||||
default=8,
|
||||
help="number of data loading workers (default: 8)",
|
||||
)
|
||||
# Random Seed
|
||||
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
|
||||
@@ -1,148 +0,0 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import os, sys, hashlib, torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch.utils.data as data
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
import cPickle as pickle
|
||||
else:
|
||||
import pickle
|
||||
|
||||
|
||||
def calculate_md5(fpath, chunk_size=1024 * 1024):
|
||||
md5 = hashlib.md5()
|
||||
with open(fpath, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(chunk_size), b""):
|
||||
md5.update(chunk)
|
||||
return md5.hexdigest()
|
||||
|
||||
|
||||
def check_md5(fpath, md5, **kwargs):
|
||||
return md5 == calculate_md5(fpath, **kwargs)
|
||||
|
||||
|
||||
def check_integrity(fpath, md5=None):
|
||||
if not os.path.isfile(fpath):
|
||||
return False
|
||||
if md5 is None:
|
||||
return True
|
||||
else:
|
||||
return check_md5(fpath, md5)
|
||||
|
||||
|
||||
class ImageNet16(data.Dataset):
|
||||
# http://image-net.org/download-images
|
||||
# A Downsampled Variant of ImageNet as an Alternative to the CIFAR datasets
|
||||
# https://arxiv.org/pdf/1707.08819.pdf
|
||||
|
||||
train_list = [
|
||||
["train_data_batch_1", "27846dcaa50de8e21a7d1a35f30f0e91"],
|
||||
["train_data_batch_2", "c7254a054e0e795c69120a5727050e3f"],
|
||||
["train_data_batch_3", "4333d3df2e5ffb114b05d2ffc19b1e87"],
|
||||
["train_data_batch_4", "1620cdf193304f4a92677b695d70d10f"],
|
||||
["train_data_batch_5", "348b3c2fdbb3940c4e9e834affd3b18d"],
|
||||
["train_data_batch_6", "6e765307c242a1b3d7d5ef9139b48945"],
|
||||
["train_data_batch_7", "564926d8cbf8fc4818ba23d2faac7564"],
|
||||
["train_data_batch_8", "f4755871f718ccb653440b9dd0ebac66"],
|
||||
["train_data_batch_9", "bb6dd660c38c58552125b1a92f86b5d4"],
|
||||
["train_data_batch_10", "8f03f34ac4b42271a294f91bf480f29b"],
|
||||
]
|
||||
valid_list = [
|
||||
["val_data", "3410e3017fdaefba8d5073aaa65e4bd6"],
|
||||
]
|
||||
|
||||
def __init__(self, root, train, transform, use_num_of_class_only=None):
|
||||
self.root = root
|
||||
self.transform = transform
|
||||
self.train = train # training set or valid set
|
||||
if not self._check_integrity():
|
||||
raise RuntimeError("Dataset not found or corrupted.")
|
||||
|
||||
if self.train:
|
||||
downloaded_list = self.train_list
|
||||
else:
|
||||
downloaded_list = self.valid_list
|
||||
self.data = []
|
||||
self.targets = []
|
||||
|
||||
# now load the picked numpy arrays
|
||||
for i, (file_name, checksum) in enumerate(downloaded_list):
|
||||
file_path = os.path.join(self.root, file_name)
|
||||
# print ('Load {:}/{:02d}-th : {:}'.format(i, len(downloaded_list), file_path))
|
||||
with open(file_path, "rb") as f:
|
||||
if sys.version_info[0] == 2:
|
||||
entry = pickle.load(f)
|
||||
else:
|
||||
entry = pickle.load(f, encoding="latin1")
|
||||
self.data.append(entry["data"])
|
||||
self.targets.extend(entry["labels"])
|
||||
self.data = np.vstack(self.data).reshape(-1, 3, 16, 16)
|
||||
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
|
||||
if use_num_of_class_only is not None:
|
||||
assert (
|
||||
isinstance(use_num_of_class_only, int)
|
||||
and use_num_of_class_only > 0
|
||||
and use_num_of_class_only < 1000
|
||||
), "invalid use_num_of_class_only : {:}".format(use_num_of_class_only)
|
||||
new_data, new_targets = [], []
|
||||
for I, L in zip(self.data, self.targets):
|
||||
if 1 <= L <= use_num_of_class_only:
|
||||
new_data.append(I)
|
||||
new_targets.append(L)
|
||||
self.data = new_data
|
||||
self.targets = new_targets
|
||||
# self.mean.append(entry['mean'])
|
||||
# self.mean = np.vstack(self.mean).reshape(-1, 3, 16, 16)
|
||||
# self.mean = np.mean(np.mean(np.mean(self.mean, axis=0), axis=1), axis=1)
|
||||
# print ('Mean : {:}'.format(self.mean))
|
||||
# temp = self.data - np.reshape(self.mean, (1, 1, 1, 3))
|
||||
# std_data = np.std(temp, axis=0)
|
||||
# std_data = np.mean(np.mean(std_data, axis=0), axis=0)
|
||||
# print ('Std : {:}'.format(std_data))
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}({num} images, {classes} classes)".format(
|
||||
name=self.__class__.__name__,
|
||||
num=len(self.data),
|
||||
classes=len(set(self.targets)),
|
||||
)
|
||||
|
||||
def __getitem__(self, index):
|
||||
img, target = self.data[index], self.targets[index] - 1
|
||||
|
||||
img = Image.fromarray(img)
|
||||
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
return img, target
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def _check_integrity(self):
|
||||
root = self.root
|
||||
for fentry in self.train_list + self.valid_list:
|
||||
filename, md5 = fentry[0], fentry[1]
|
||||
fpath = os.path.join(root, filename)
|
||||
if not check_integrity(fpath, md5):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
"""
|
||||
if __name__ == '__main__':
|
||||
train = ImageNet16('~/.torch/cifar.python/ImageNet16', True , None)
|
||||
valid = ImageNet16('~/.torch/cifar.python/ImageNet16', False, None)
|
||||
|
||||
print ( len(train) )
|
||||
print ( len(valid) )
|
||||
image, label = train[111]
|
||||
trainX = ImageNet16('~/.torch/cifar.python/ImageNet16', True , None, 200)
|
||||
validX = ImageNet16('~/.torch/cifar.python/ImageNet16', False , None, 200)
|
||||
print ( len(trainX) )
|
||||
print ( len(validX) )
|
||||
"""
|
||||
@@ -1,301 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
#
|
||||
from os import path as osp
|
||||
from copy import deepcopy as copy
|
||||
from tqdm import tqdm
|
||||
import warnings, time, random, numpy as np
|
||||
|
||||
from pts_utils import generate_label_map
|
||||
from xvision import denormalize_points
|
||||
from xvision import identity2affine, solve2theta, affine2image
|
||||
from .dataset_utils import pil_loader
|
||||
from .landmark_utils import PointMeta2V
|
||||
from .augmentation_utils import CutOut
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
|
||||
|
||||
class LandmarkDataset(data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
transform,
|
||||
sigma,
|
||||
downsample,
|
||||
heatmap_type,
|
||||
shape,
|
||||
use_gray,
|
||||
mean_file,
|
||||
data_indicator,
|
||||
cache_images=None,
|
||||
):
|
||||
|
||||
self.transform = transform
|
||||
self.sigma = sigma
|
||||
self.downsample = downsample
|
||||
self.heatmap_type = heatmap_type
|
||||
self.dataset_name = data_indicator
|
||||
self.shape = shape # [H,W]
|
||||
self.use_gray = use_gray
|
||||
assert transform is not None, "transform : {:}".format(transform)
|
||||
self.mean_file = mean_file
|
||||
if mean_file is None:
|
||||
self.mean_data = None
|
||||
warnings.warn("LandmarkDataset initialized with mean_data = None")
|
||||
else:
|
||||
assert osp.isfile(mean_file), "{:} is not a file.".format(mean_file)
|
||||
self.mean_data = torch.load(mean_file)
|
||||
self.reset()
|
||||
self.cutout = None
|
||||
self.cache_images = cache_images
|
||||
print("The general dataset initialization done : {:}".format(self))
|
||||
warnings.simplefilter("once")
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(point-num={NUM_PTS}, shape={shape}, sigma={sigma}, heatmap_type={heatmap_type}, length={length}, cutout={cutout}, dataset={dataset_name}, mean={mean_file})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
def set_cutout(self, length):
|
||||
if length is not None and length >= 1:
|
||||
self.cutout = CutOut(int(length))
|
||||
else:
|
||||
self.cutout = None
|
||||
|
||||
def reset(self, num_pts=-1, boxid="default", only_pts=False):
|
||||
self.NUM_PTS = num_pts
|
||||
if only_pts:
|
||||
return
|
||||
self.length = 0
|
||||
self.datas = []
|
||||
self.labels = []
|
||||
self.NormDistances = []
|
||||
self.BOXID = boxid
|
||||
if self.mean_data is None:
|
||||
self.mean_face = None
|
||||
else:
|
||||
self.mean_face = torch.Tensor(self.mean_data[boxid].copy().T)
|
||||
assert (self.mean_face >= -1).all() and (
|
||||
self.mean_face <= 1
|
||||
).all(), "mean-{:}-face : {:}".format(boxid, self.mean_face)
|
||||
# assert self.dataset_name is not None, 'The dataset name is None'
|
||||
|
||||
def __len__(self):
|
||||
assert len(self.datas) == self.length, "The length is not correct : {}".format(
|
||||
self.length
|
||||
)
|
||||
return self.length
|
||||
|
||||
def append(self, data, label, distance):
|
||||
assert osp.isfile(data), "The image path is not a file : {:}".format(data)
|
||||
self.datas.append(data)
|
||||
self.labels.append(label)
|
||||
self.NormDistances.append(distance)
|
||||
self.length = self.length + 1
|
||||
|
||||
def load_list(self, file_lists, num_pts, boxindicator, normalizeL, reset):
|
||||
if reset:
|
||||
self.reset(num_pts, boxindicator)
|
||||
else:
|
||||
assert (
|
||||
self.NUM_PTS == num_pts and self.BOXID == boxindicator
|
||||
), "The number of point is inconsistance : {:} vs {:}".format(
|
||||
self.NUM_PTS, num_pts
|
||||
)
|
||||
if isinstance(file_lists, str):
|
||||
file_lists = [file_lists]
|
||||
samples = []
|
||||
for idx, file_path in enumerate(file_lists):
|
||||
print(
|
||||
":::: load list {:}/{:} : {:}".format(idx, len(file_lists), file_path)
|
||||
)
|
||||
xdata = torch.load(file_path)
|
||||
if isinstance(xdata, list):
|
||||
data = xdata # image or video dataset list
|
||||
elif isinstance(xdata, dict):
|
||||
data = xdata["datas"] # multi-view dataset list
|
||||
else:
|
||||
raise ValueError("Invalid Type Error : {:}".format(type(xdata)))
|
||||
samples = samples + data
|
||||
# samples is a dict, where the key is the image-path and the value is the annotation
|
||||
# each annotation is a dict, contains 'points' (3,num_pts), and various box
|
||||
print("GeneralDataset-V2 : {:} samples".format(len(samples)))
|
||||
|
||||
# for index, annotation in enumerate(samples):
|
||||
for index in tqdm(range(len(samples))):
|
||||
annotation = samples[index]
|
||||
image_path = annotation["current_frame"]
|
||||
points, box = (
|
||||
annotation["points"],
|
||||
annotation["box-{:}".format(boxindicator)],
|
||||
)
|
||||
label = PointMeta2V(
|
||||
self.NUM_PTS, points, box, image_path, self.dataset_name
|
||||
)
|
||||
if normalizeL is None:
|
||||
normDistance = None
|
||||
else:
|
||||
normDistance = annotation["normalizeL-{:}".format(normalizeL)]
|
||||
self.append(image_path, label, normDistance)
|
||||
|
||||
assert (
|
||||
len(self.datas) == self.length
|
||||
), "The length and the data is not right {} vs {}".format(
|
||||
self.length, len(self.datas)
|
||||
)
|
||||
assert (
|
||||
len(self.labels) == self.length
|
||||
), "The length and the labels is not right {} vs {}".format(
|
||||
self.length, len(self.labels)
|
||||
)
|
||||
assert (
|
||||
len(self.NormDistances) == self.length
|
||||
), "The length and the NormDistances is not right {} vs {}".format(
|
||||
self.length, len(self.NormDistance)
|
||||
)
|
||||
print(
|
||||
"Load data done for LandmarkDataset, which has {:} images.".format(
|
||||
self.length
|
||||
)
|
||||
)
|
||||
|
||||
def __getitem__(self, index):
|
||||
assert index >= 0 and index < self.length, "Invalid index : {:}".format(index)
|
||||
if self.cache_images is not None and self.datas[index] in self.cache_images:
|
||||
image = self.cache_images[self.datas[index]].clone()
|
||||
else:
|
||||
image = pil_loader(self.datas[index], self.use_gray)
|
||||
target = self.labels[index].copy()
|
||||
return self._process_(image, target, index)
|
||||
|
||||
def _process_(self, image, target, index):
|
||||
|
||||
# transform the image and points
|
||||
image, target, theta = self.transform(image, target)
|
||||
(C, H, W), (height, width) = image.size(), self.shape
|
||||
|
||||
# obtain the visiable indicator vector
|
||||
if target.is_none():
|
||||
nopoints = True
|
||||
else:
|
||||
nopoints = False
|
||||
if index == -1:
|
||||
__path = None
|
||||
else:
|
||||
__path = self.datas[index]
|
||||
if isinstance(theta, list) or isinstance(theta, tuple):
|
||||
affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = (
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
)
|
||||
for _theta in theta:
|
||||
(
|
||||
_affineImage,
|
||||
_heatmaps,
|
||||
_mask,
|
||||
_norm_trans_points,
|
||||
_theta,
|
||||
_transpose_theta,
|
||||
) = self.__process_affine(
|
||||
image, target, _theta, nopoints, "P[{:}]@{:}".format(index, __path)
|
||||
)
|
||||
affineImage.append(_affineImage)
|
||||
heatmaps.append(_heatmaps)
|
||||
mask.append(_mask)
|
||||
norm_trans_points.append(_norm_trans_points)
|
||||
THETA.append(_theta)
|
||||
transpose_theta.append(_transpose_theta)
|
||||
affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = (
|
||||
torch.stack(affineImage),
|
||||
torch.stack(heatmaps),
|
||||
torch.stack(mask),
|
||||
torch.stack(norm_trans_points),
|
||||
torch.stack(THETA),
|
||||
torch.stack(transpose_theta),
|
||||
)
|
||||
else:
|
||||
(
|
||||
affineImage,
|
||||
heatmaps,
|
||||
mask,
|
||||
norm_trans_points,
|
||||
THETA,
|
||||
transpose_theta,
|
||||
) = self.__process_affine(
|
||||
image, target, theta, nopoints, "S[{:}]@{:}".format(index, __path)
|
||||
)
|
||||
|
||||
torch_index = torch.IntTensor([index])
|
||||
torch_nopoints = torch.ByteTensor([nopoints])
|
||||
torch_shape = torch.IntTensor([H, W])
|
||||
|
||||
return (
|
||||
affineImage,
|
||||
heatmaps,
|
||||
mask,
|
||||
norm_trans_points,
|
||||
THETA,
|
||||
transpose_theta,
|
||||
torch_index,
|
||||
torch_nopoints,
|
||||
torch_shape,
|
||||
)
|
||||
|
||||
def __process_affine(self, image, target, theta, nopoints, aux_info=None):
|
||||
image, target, theta = image.clone(), target.copy(), theta.clone()
|
||||
(C, H, W), (height, width) = image.size(), self.shape
|
||||
if nopoints: # do not have label
|
||||
norm_trans_points = torch.zeros((3, self.NUM_PTS))
|
||||
heatmaps = torch.zeros(
|
||||
(self.NUM_PTS + 1, height // self.downsample, width // self.downsample)
|
||||
)
|
||||
mask = torch.ones((self.NUM_PTS + 1, 1, 1), dtype=torch.uint8)
|
||||
transpose_theta = identity2affine(False)
|
||||
else:
|
||||
norm_trans_points = apply_affine2point(target.get_points(), theta, (H, W))
|
||||
norm_trans_points = apply_boundary(norm_trans_points)
|
||||
real_trans_points = norm_trans_points.clone()
|
||||
real_trans_points[:2, :] = denormalize_points(
|
||||
self.shape, real_trans_points[:2, :]
|
||||
)
|
||||
heatmaps, mask = generate_label_map(
|
||||
real_trans_points.numpy(),
|
||||
height // self.downsample,
|
||||
width // self.downsample,
|
||||
self.sigma,
|
||||
self.downsample,
|
||||
nopoints,
|
||||
self.heatmap_type,
|
||||
) # H*W*C
|
||||
heatmaps = torch.from_numpy(heatmaps.transpose((2, 0, 1))).type(
|
||||
torch.FloatTensor
|
||||
)
|
||||
mask = torch.from_numpy(mask.transpose((2, 0, 1))).type(torch.ByteTensor)
|
||||
if self.mean_face is None:
|
||||
# warnings.warn('In LandmarkDataset use identity2affine for transpose_theta because self.mean_face is None.')
|
||||
transpose_theta = identity2affine(False)
|
||||
else:
|
||||
if torch.sum(norm_trans_points[2, :] == 1) < 3:
|
||||
warnings.warn(
|
||||
"In LandmarkDataset after transformation, no visiable point, using identity instead. Aux: {:}".format(
|
||||
aux_info
|
||||
)
|
||||
)
|
||||
transpose_theta = identity2affine(False)
|
||||
else:
|
||||
transpose_theta = solve2theta(
|
||||
norm_trans_points, self.mean_face.clone()
|
||||
)
|
||||
|
||||
affineImage = affine2image(image, theta, self.shape)
|
||||
if self.cutout is not None:
|
||||
affineImage = self.cutout(affineImage)
|
||||
|
||||
return affineImage, heatmaps, mask, norm_trans_points, theta, transpose_theta
|
||||
@@ -1,54 +0,0 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import torch, copy, random
|
||||
import torch.utils.data as data
|
||||
|
||||
|
||||
class SearchDataset(data.Dataset):
|
||||
def __init__(self, name, data, train_split, valid_split, check=True):
|
||||
self.datasetname = name
|
||||
if isinstance(data, (list, tuple)): # new type of SearchDataset
|
||||
assert len(data) == 2, "invalid length: {:}".format(len(data))
|
||||
self.train_data = data[0]
|
||||
self.valid_data = data[1]
|
||||
self.train_split = train_split.copy()
|
||||
self.valid_split = valid_split.copy()
|
||||
self.mode_str = "V2" # new mode
|
||||
else:
|
||||
self.mode_str = "V1" # old mode
|
||||
self.data = data
|
||||
self.train_split = train_split.copy()
|
||||
self.valid_split = valid_split.copy()
|
||||
if check:
|
||||
intersection = set(train_split).intersection(set(valid_split))
|
||||
assert (
|
||||
len(intersection) == 0
|
||||
), "the splitted train and validation sets should have no intersection"
|
||||
self.length = len(self.train_split)
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(name={datasetname}, train={tr_L}, valid={val_L}, version={ver})".format(
|
||||
name=self.__class__.__name__,
|
||||
datasetname=self.datasetname,
|
||||
tr_L=len(self.train_split),
|
||||
val_L=len(self.valid_split),
|
||||
ver=self.mode_str,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
def __getitem__(self, index):
|
||||
assert index >= 0 and index < self.length, "invalid index = {:}".format(index)
|
||||
train_index = self.train_split[index]
|
||||
valid_index = random.choice(self.valid_split)
|
||||
if self.mode_str == "V1":
|
||||
train_image, train_label = self.data[train_index]
|
||||
valid_image, valid_label = self.data[valid_index]
|
||||
elif self.mode_str == "V2":
|
||||
train_image, train_label = self.train_data[train_index]
|
||||
valid_image, valid_label = self.valid_data[valid_index]
|
||||
else:
|
||||
raise ValueError("invalid mode : {:}".format(self.mode_str))
|
||||
return train_image, train_label, valid_image, valid_label
|
||||
@@ -1,5 +0,0 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
from .get_dataset_with_transform import get_datasets, get_nas_search_loaders
|
||||
from .SearchDatasetWrap import SearchDataset
|
||||
@@ -1,362 +0,0 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import os, sys, torch
|
||||
import os.path as osp
|
||||
import numpy as np
|
||||
import torchvision.datasets as dset
|
||||
import torchvision.transforms as transforms
|
||||
from copy import deepcopy
|
||||
from PIL import Image
|
||||
|
||||
from .DownsampledImageNet import ImageNet16
|
||||
from .SearchDatasetWrap import SearchDataset
|
||||
from config_utils import load_config
|
||||
|
||||
|
||||
Dataset2Class = {
|
||||
"cifar10": 10,
|
||||
"cifar100": 100,
|
||||
"imagenet-1k-s": 1000,
|
||||
"imagenet-1k": 1000,
|
||||
"ImageNet16": 1000,
|
||||
"ImageNet16-150": 150,
|
||||
"ImageNet16-120": 120,
|
||||
"ImageNet16-200": 200,
|
||||
}
|
||||
|
||||
|
||||
class CUTOUT(object):
|
||||
def __init__(self, length):
|
||||
self.length = length
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(length={length})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
def __call__(self, img):
|
||||
h, w = img.size(1), img.size(2)
|
||||
mask = np.ones((h, w), np.float32)
|
||||
y = np.random.randint(h)
|
||||
x = np.random.randint(w)
|
||||
|
||||
y1 = np.clip(y - self.length // 2, 0, h)
|
||||
y2 = np.clip(y + self.length // 2, 0, h)
|
||||
x1 = np.clip(x - self.length // 2, 0, w)
|
||||
x2 = np.clip(x + self.length // 2, 0, w)
|
||||
|
||||
mask[y1:y2, x1:x2] = 0.0
|
||||
mask = torch.from_numpy(mask)
|
||||
mask = mask.expand_as(img)
|
||||
img *= mask
|
||||
return img
|
||||
|
||||
|
||||
imagenet_pca = {
|
||||
"eigval": np.asarray([0.2175, 0.0188, 0.0045]),
|
||||
"eigvec": np.asarray(
|
||||
[
|
||||
[-0.5675, 0.7192, 0.4009],
|
||||
[-0.5808, -0.0045, -0.8140],
|
||||
[-0.5836, -0.6948, 0.4203],
|
||||
]
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class Lighting(object):
|
||||
def __init__(
|
||||
self, alphastd, eigval=imagenet_pca["eigval"], eigvec=imagenet_pca["eigvec"]
|
||||
):
|
||||
self.alphastd = alphastd
|
||||
assert eigval.shape == (3,)
|
||||
assert eigvec.shape == (3, 3)
|
||||
self.eigval = eigval
|
||||
self.eigvec = eigvec
|
||||
|
||||
def __call__(self, img):
|
||||
if self.alphastd == 0.0:
|
||||
return img
|
||||
rnd = np.random.randn(3) * self.alphastd
|
||||
rnd = rnd.astype("float32")
|
||||
v = rnd
|
||||
old_dtype = np.asarray(img).dtype
|
||||
v = v * self.eigval
|
||||
v = v.reshape((3, 1))
|
||||
inc = np.dot(self.eigvec, v).reshape((3,))
|
||||
img = np.add(img, inc)
|
||||
if old_dtype == np.uint8:
|
||||
img = np.clip(img, 0, 255)
|
||||
img = Image.fromarray(img.astype(old_dtype), "RGB")
|
||||
return img
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + "()"
|
||||
|
||||
|
||||
def get_datasets(name, root, cutout):
|
||||
|
||||
if name == "cifar10":
|
||||
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
|
||||
std = [x / 255 for x in [63.0, 62.1, 66.7]]
|
||||
elif name == "cifar100":
|
||||
mean = [x / 255 for x in [129.3, 124.1, 112.4]]
|
||||
std = [x / 255 for x in [68.2, 65.4, 70.4]]
|
||||
elif name.startswith("imagenet-1k"):
|
||||
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
|
||||
elif name.startswith("ImageNet16"):
|
||||
mean = [x / 255 for x in [122.68, 116.66, 104.01]]
|
||||
std = [x / 255 for x in [63.22, 61.26, 65.09]]
|
||||
else:
|
||||
raise TypeError("Unknow dataset : {:}".format(name))
|
||||
|
||||
# Data Argumentation
|
||||
if name == "cifar10" or name == "cifar100":
|
||||
lists = [
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean, std),
|
||||
]
|
||||
if cutout > 0:
|
||||
lists += [CUTOUT(cutout)]
|
||||
train_transform = transforms.Compose(lists)
|
||||
test_transform = transforms.Compose(
|
||||
[transforms.ToTensor(), transforms.Normalize(mean, std)]
|
||||
)
|
||||
xshape = (1, 3, 32, 32)
|
||||
elif name.startswith("ImageNet16"):
|
||||
lists = [
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.RandomCrop(16, padding=2),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean, std),
|
||||
]
|
||||
if cutout > 0:
|
||||
lists += [CUTOUT(cutout)]
|
||||
train_transform = transforms.Compose(lists)
|
||||
test_transform = transforms.Compose(
|
||||
[transforms.ToTensor(), transforms.Normalize(mean, std)]
|
||||
)
|
||||
xshape = (1, 3, 16, 16)
|
||||
elif name == "tiered":
|
||||
lists = [
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.RandomCrop(80, padding=4),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean, std),
|
||||
]
|
||||
if cutout > 0:
|
||||
lists += [CUTOUT(cutout)]
|
||||
train_transform = transforms.Compose(lists)
|
||||
test_transform = transforms.Compose(
|
||||
[
|
||||
transforms.CenterCrop(80),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean, std),
|
||||
]
|
||||
)
|
||||
xshape = (1, 3, 32, 32)
|
||||
elif name.startswith("imagenet-1k"):
|
||||
normalize = transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||
)
|
||||
if name == "imagenet-1k":
|
||||
xlists = [transforms.RandomResizedCrop(224)]
|
||||
xlists.append(
|
||||
transforms.ColorJitter(
|
||||
brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2
|
||||
)
|
||||
)
|
||||
xlists.append(Lighting(0.1))
|
||||
elif name == "imagenet-1k-s":
|
||||
xlists = [transforms.RandomResizedCrop(224, scale=(0.2, 1.0))]
|
||||
else:
|
||||
raise ValueError("invalid name : {:}".format(name))
|
||||
xlists.append(transforms.RandomHorizontalFlip(p=0.5))
|
||||
xlists.append(transforms.ToTensor())
|
||||
xlists.append(normalize)
|
||||
train_transform = transforms.Compose(xlists)
|
||||
test_transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
]
|
||||
)
|
||||
xshape = (1, 3, 224, 224)
|
||||
else:
|
||||
raise TypeError("Unknow dataset : {:}".format(name))
|
||||
|
||||
if name == "cifar10":
|
||||
train_data = dset.CIFAR10(
|
||||
root, train=True, transform=train_transform, download=True
|
||||
)
|
||||
test_data = dset.CIFAR10(
|
||||
root, train=False, transform=test_transform, download=True
|
||||
)
|
||||
assert len(train_data) == 50000 and len(test_data) == 10000
|
||||
elif name == "cifar100":
|
||||
train_data = dset.CIFAR100(
|
||||
root, train=True, transform=train_transform, download=True
|
||||
)
|
||||
test_data = dset.CIFAR100(
|
||||
root, train=False, transform=test_transform, download=True
|
||||
)
|
||||
assert len(train_data) == 50000 and len(test_data) == 10000
|
||||
elif name.startswith("imagenet-1k"):
|
||||
train_data = dset.ImageFolder(osp.join(root, "train"), train_transform)
|
||||
test_data = dset.ImageFolder(osp.join(root, "val"), test_transform)
|
||||
assert (
|
||||
len(train_data) == 1281167 and len(test_data) == 50000
|
||||
), "invalid number of images : {:} & {:} vs {:} & {:}".format(
|
||||
len(train_data), len(test_data), 1281167, 50000
|
||||
)
|
||||
elif name == "ImageNet16":
|
||||
train_data = ImageNet16(root, True, train_transform)
|
||||
test_data = ImageNet16(root, False, test_transform)
|
||||
assert len(train_data) == 1281167 and len(test_data) == 50000
|
||||
elif name == "ImageNet16-120":
|
||||
train_data = ImageNet16(root, True, train_transform, 120)
|
||||
test_data = ImageNet16(root, False, test_transform, 120)
|
||||
assert len(train_data) == 151700 and len(test_data) == 6000
|
||||
elif name == "ImageNet16-150":
|
||||
train_data = ImageNet16(root, True, train_transform, 150)
|
||||
test_data = ImageNet16(root, False, test_transform, 150)
|
||||
assert len(train_data) == 190272 and len(test_data) == 7500
|
||||
elif name == "ImageNet16-200":
|
||||
train_data = ImageNet16(root, True, train_transform, 200)
|
||||
test_data = ImageNet16(root, False, test_transform, 200)
|
||||
assert len(train_data) == 254775 and len(test_data) == 10000
|
||||
else:
|
||||
raise TypeError("Unknow dataset : {:}".format(name))
|
||||
|
||||
class_num = Dataset2Class[name]
|
||||
return train_data, test_data, xshape, class_num
|
||||
|
||||
|
||||
def get_nas_search_loaders(
|
||||
train_data, valid_data, dataset, config_root, batch_size, workers
|
||||
):
|
||||
if isinstance(batch_size, (list, tuple)):
|
||||
batch, test_batch = batch_size
|
||||
else:
|
||||
batch, test_batch = batch_size, batch_size
|
||||
if dataset == "cifar10":
|
||||
# split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
||||
cifar_split = load_config("{:}/cifar-split.txt".format(config_root), None, None)
|
||||
train_split, valid_split = (
|
||||
cifar_split.train,
|
||||
cifar_split.valid,
|
||||
) # search over the proposed training and validation set
|
||||
# logger.log('Load split file from {:}'.format(split_Fpath)) # they are two disjoint groups in the original CIFAR-10 training set
|
||||
# To split data
|
||||
xvalid_data = deepcopy(train_data)
|
||||
if hasattr(xvalid_data, "transforms"): # to avoid a print issue
|
||||
xvalid_data.transforms = valid_data.transform
|
||||
xvalid_data.transform = deepcopy(valid_data.transform)
|
||||
search_data = SearchDataset(dataset, train_data, train_split, valid_split)
|
||||
# data loader
|
||||
search_loader = torch.utils.data.DataLoader(
|
||||
search_data,
|
||||
batch_size=batch,
|
||||
shuffle=True,
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_data,
|
||||
batch_size=batch,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split),
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
valid_loader = torch.utils.data.DataLoader(
|
||||
xvalid_data,
|
||||
batch_size=test_batch,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split),
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
elif dataset == "cifar100":
|
||||
cifar100_test_split = load_config(
|
||||
"{:}/cifar100-test-split.txt".format(config_root), None, None
|
||||
)
|
||||
search_train_data = train_data
|
||||
search_valid_data = deepcopy(valid_data)
|
||||
search_valid_data.transform = train_data.transform
|
||||
search_data = SearchDataset(
|
||||
dataset,
|
||||
[search_train_data, search_valid_data],
|
||||
list(range(len(search_train_data))),
|
||||
cifar100_test_split.xvalid,
|
||||
)
|
||||
search_loader = torch.utils.data.DataLoader(
|
||||
search_data,
|
||||
batch_size=batch,
|
||||
shuffle=True,
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_data,
|
||||
batch_size=batch,
|
||||
shuffle=True,
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
valid_loader = torch.utils.data.DataLoader(
|
||||
valid_data,
|
||||
batch_size=test_batch,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(
|
||||
cifar100_test_split.xvalid
|
||||
),
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
elif dataset == "ImageNet16-120":
|
||||
imagenet_test_split = load_config(
|
||||
"{:}/imagenet-16-120-test-split.txt".format(config_root), None, None
|
||||
)
|
||||
search_train_data = train_data
|
||||
search_valid_data = deepcopy(valid_data)
|
||||
search_valid_data.transform = train_data.transform
|
||||
search_data = SearchDataset(
|
||||
dataset,
|
||||
[search_train_data, search_valid_data],
|
||||
list(range(len(search_train_data))),
|
||||
imagenet_test_split.xvalid,
|
||||
)
|
||||
search_loader = torch.utils.data.DataLoader(
|
||||
search_data,
|
||||
batch_size=batch,
|
||||
shuffle=True,
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_data,
|
||||
batch_size=batch,
|
||||
shuffle=True,
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
valid_loader = torch.utils.data.DataLoader(
|
||||
valid_data,
|
||||
batch_size=test_batch,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(
|
||||
imagenet_test_split.xvalid
|
||||
),
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
else:
|
||||
raise ValueError("invalid dataset : {:}".format(dataset))
|
||||
return search_loader, train_loader, valid_loader
|
||||
|
||||
|
||||
# if __name__ == '__main__':
|
||||
# train_data, test_data, xshape, class_num = dataset = get_datasets('cifar10', '/data02/dongxuanyi/.torch/cifar.python/', -1)
|
||||
# import pdb; pdb.set_trace()
|
||||
@@ -1 +0,0 @@
|
||||
from .point_meta import PointMeta2V, apply_affine2point, apply_boundary
|
||||
@@ -1,219 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
#
|
||||
import copy, math, torch, numpy as np
|
||||
from xvision import normalize_points
|
||||
from xvision import denormalize_points
|
||||
|
||||
|
||||
class PointMeta:
|
||||
# points : 3 x num_pts (x, y, oculusion)
|
||||
# image_size: original [width, height]
|
||||
def __init__(self, num_point, points, box, image_path, dataset_name):
|
||||
|
||||
self.num_point = num_point
|
||||
if box is not None:
|
||||
assert (isinstance(box, tuple) or isinstance(box, list)) and len(box) == 4
|
||||
self.box = torch.Tensor(box)
|
||||
else:
|
||||
self.box = None
|
||||
if points is None:
|
||||
self.points = points
|
||||
else:
|
||||
assert (
|
||||
len(points.shape) == 2
|
||||
and points.shape[0] == 3
|
||||
and points.shape[1] == self.num_point
|
||||
), "The shape of point is not right : {}".format(points)
|
||||
self.points = torch.Tensor(points.copy())
|
||||
self.image_path = image_path
|
||||
self.datasets = dataset_name
|
||||
|
||||
def __repr__(self):
|
||||
if self.box is None:
|
||||
boxstr = "None"
|
||||
else:
|
||||
boxstr = "box=[{:.1f}, {:.1f}, {:.1f}, {:.1f}]".format(*self.box.tolist())
|
||||
return (
|
||||
"{name}(points={num_point}, ".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
+ boxstr
|
||||
+ ")"
|
||||
)
|
||||
|
||||
def get_box(self, return_diagonal=False):
|
||||
if self.box is None:
|
||||
return None
|
||||
if not return_diagonal:
|
||||
return self.box.clone()
|
||||
else:
|
||||
W = (self.box[2] - self.box[0]).item()
|
||||
H = (self.box[3] - self.box[1]).item()
|
||||
return math.sqrt(H * H + W * W)
|
||||
|
||||
def get_points(self, ignore_indicator=False):
|
||||
if ignore_indicator:
|
||||
last = 2
|
||||
else:
|
||||
last = 3
|
||||
if self.points is not None:
|
||||
return self.points.clone()[:last, :]
|
||||
else:
|
||||
return torch.zeros((last, self.num_point))
|
||||
|
||||
def is_none(self):
|
||||
# assert self.box is not None, 'The box should not be None'
|
||||
return self.points is None
|
||||
# if self.box is None: return True
|
||||
# else : return self.points is None
|
||||
|
||||
def copy(self):
|
||||
return copy.deepcopy(self)
|
||||
|
||||
def visiable_pts_num(self):
|
||||
with torch.no_grad():
|
||||
ans = self.points[2, :] > 0
|
||||
ans = torch.sum(ans)
|
||||
ans = ans.item()
|
||||
return ans
|
||||
|
||||
def special_fun(self, indicator):
|
||||
if (
|
||||
indicator == "68to49"
|
||||
): # For 300W or 300VW, convert the default 68 points to 49 points.
|
||||
assert self.num_point == 68, "num-point must be 68 vs. {:}".format(
|
||||
self.num_point
|
||||
)
|
||||
self.num_point = 49
|
||||
out = torch.ones((68), dtype=torch.uint8)
|
||||
out[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 60, 64]] = 0
|
||||
if self.points is not None:
|
||||
self.points = self.points.clone()[:, out]
|
||||
else:
|
||||
raise ValueError("Invalid indicator : {:}".format(indicator))
|
||||
|
||||
def apply_horizontal_flip(self):
|
||||
# self.points[0, :] = width - self.points[0, :] - 1
|
||||
# Mugsy spefic or Synthetic
|
||||
if self.datasets.startswith("HandsyROT"):
|
||||
ori = np.array(list(range(0, 42)))
|
||||
pos = np.array(list(range(21, 42)) + list(range(0, 21)))
|
||||
self.points[:, pos] = self.points[:, ori]
|
||||
elif self.datasets.startswith("face68"):
|
||||
ori = np.array(list(range(0, 68)))
|
||||
pos = (
|
||||
np.array(
|
||||
[
|
||||
17,
|
||||
16,
|
||||
15,
|
||||
14,
|
||||
13,
|
||||
12,
|
||||
11,
|
||||
10,
|
||||
9,
|
||||
8,
|
||||
7,
|
||||
6,
|
||||
5,
|
||||
4,
|
||||
3,
|
||||
2,
|
||||
1,
|
||||
27,
|
||||
26,
|
||||
25,
|
||||
24,
|
||||
23,
|
||||
22,
|
||||
21,
|
||||
20,
|
||||
19,
|
||||
18,
|
||||
28,
|
||||
29,
|
||||
30,
|
||||
31,
|
||||
36,
|
||||
35,
|
||||
34,
|
||||
33,
|
||||
32,
|
||||
46,
|
||||
45,
|
||||
44,
|
||||
43,
|
||||
48,
|
||||
47,
|
||||
40,
|
||||
39,
|
||||
38,
|
||||
37,
|
||||
42,
|
||||
41,
|
||||
55,
|
||||
54,
|
||||
53,
|
||||
52,
|
||||
51,
|
||||
50,
|
||||
49,
|
||||
60,
|
||||
59,
|
||||
58,
|
||||
57,
|
||||
56,
|
||||
65,
|
||||
64,
|
||||
63,
|
||||
62,
|
||||
61,
|
||||
68,
|
||||
67,
|
||||
66,
|
||||
]
|
||||
)
|
||||
- 1
|
||||
)
|
||||
self.points[:, ori] = self.points[:, pos]
|
||||
else:
|
||||
raise ValueError("Does not support {:}".format(self.datasets))
|
||||
|
||||
|
||||
# shape = (H,W)
|
||||
def apply_affine2point(points, theta, shape):
|
||||
assert points.size(0) == 3, "invalid points shape : {:}".format(points.size())
|
||||
with torch.no_grad():
|
||||
ok_points = points[2, :] == 1
|
||||
assert torch.sum(ok_points).item() > 0, "there is no visiable point"
|
||||
points[:2, :] = normalize_points(shape, points[:2, :])
|
||||
|
||||
norm_trans_points = ok_points.unsqueeze(0).repeat(3, 1).float()
|
||||
|
||||
trans_points, ___ = torch.gesv(points[:, ok_points], theta)
|
||||
|
||||
norm_trans_points[:, ok_points] = trans_points
|
||||
|
||||
return norm_trans_points
|
||||
|
||||
|
||||
def apply_boundary(norm_trans_points):
|
||||
with torch.no_grad():
|
||||
norm_trans_points = norm_trans_points.clone()
|
||||
oks = torch.stack(
|
||||
(
|
||||
norm_trans_points[0] > -1,
|
||||
norm_trans_points[0] < 1,
|
||||
norm_trans_points[1] > -1,
|
||||
norm_trans_points[1] < 1,
|
||||
norm_trans_points[2] > 0,
|
||||
)
|
||||
)
|
||||
oks = torch.sum(oks, dim=0) == 5
|
||||
norm_trans_points[2, :] = oks
|
||||
return norm_trans_points
|
||||
@@ -1,100 +0,0 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
|
||||
#####################################################
|
||||
import math
|
||||
import abc
|
||||
import copy
|
||||
import numpy as np
|
||||
from typing import Optional
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
|
||||
from .math_base_funcs import FitFunc
|
||||
from .math_base_funcs import QuadraticFunc
|
||||
from .math_base_funcs import QuarticFunc
|
||||
|
||||
|
||||
class ConstantFunc(FitFunc):
|
||||
"""The constant function: f(x) = c."""
|
||||
|
||||
def __init__(self, constant=None):
|
||||
param = dict()
|
||||
param[0] = constant
|
||||
super(ConstantFunc, self).__init__(0, None, param)
|
||||
|
||||
def __call__(self, x):
|
||||
self.check_valid()
|
||||
return self._params[0]
|
||||
|
||||
def fit(self, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def _getitem(self, x, weights):
|
||||
raise NotImplementedError
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}({a})".format(name=self.__class__.__name__, a=self._params[0])
|
||||
|
||||
|
||||
class ComposedSinFunc(FitFunc):
|
||||
"""The composed sin function that outputs:
|
||||
f(x) = amplitude-scale-of(x) * sin( period-phase-shift-of(x) )
|
||||
- the amplitude scale is a quadratic function of x
|
||||
- the period-phase-shift is another quadratic function of x
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(ComposedSinFunc, self).__init__(0, None)
|
||||
self.fit(**kwargs)
|
||||
|
||||
def __call__(self, x):
|
||||
self.check_valid()
|
||||
scale = self._params["amplitude_scale"](x)
|
||||
period_phase = self._params["period_phase_shift"](x)
|
||||
return scale * math.sin(period_phase)
|
||||
|
||||
def fit(self, **kwargs):
|
||||
num_sin_phase = kwargs.get("num_sin_phase", 7)
|
||||
sin_speed_use_power = kwargs.get("sin_speed_use_power", True)
|
||||
min_amplitude = kwargs.get("min_amplitude", 1)
|
||||
max_amplitude = kwargs.get("max_amplitude", 4)
|
||||
phase_shift = kwargs.get("phase_shift", 0.0)
|
||||
# create parameters
|
||||
if kwargs.get("amplitude_scale", None) is None:
|
||||
amplitude_scale = QuadraticFunc(
|
||||
[(0, min_amplitude), (0.5, max_amplitude), (1, min_amplitude)]
|
||||
)
|
||||
else:
|
||||
amplitude_scale = kwargs.get("amplitude_scale")
|
||||
if kwargs.get("period_phase_shift", None) is None:
|
||||
fitting_data = []
|
||||
if sin_speed_use_power:
|
||||
temp_max_scalar = 2 ** (num_sin_phase - 1)
|
||||
else:
|
||||
temp_max_scalar = num_sin_phase - 1
|
||||
for i in range(num_sin_phase):
|
||||
if sin_speed_use_power:
|
||||
value = (2 ** i) / temp_max_scalar
|
||||
next_value = (2 ** (i + 1)) / temp_max_scalar
|
||||
else:
|
||||
value = i / temp_max_scalar
|
||||
next_value = (i + 1) / temp_max_scalar
|
||||
for _phase in (0, 0.25, 0.5, 0.75):
|
||||
inter_value = value + (next_value - value) * _phase
|
||||
fitting_data.append((inter_value, math.pi * (2 * i + _phase)))
|
||||
period_phase_shift = QuarticFunc(fitting_data)
|
||||
else:
|
||||
period_phase_shift = kwargs.get("period_phase_shift")
|
||||
self.set(
|
||||
dict(amplitude_scale=amplitude_scale, period_phase_shift=period_phase_shift)
|
||||
)
|
||||
|
||||
def _getitem(self, x, weights):
|
||||
raise NotImplementedError
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}({amplitude_scale} * sin({period_phase_shift}))".format(
|
||||
name=self.__class__.__name__,
|
||||
amplitude_scale=self._params["amplitude_scale"],
|
||||
period_phase_shift=self._params["period_phase_shift"],
|
||||
)
|
||||
@@ -1,210 +0,0 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
|
||||
#####################################################
|
||||
import math
|
||||
import abc
|
||||
import copy
|
||||
import numpy as np
|
||||
from typing import Optional
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
|
||||
|
||||
class FitFunc(abc.ABC):
|
||||
"""The fit function that outputs f(x) = a * x^2 + b * x + c."""
|
||||
|
||||
def __init__(self, freedom: int, list_of_points=None, params=None):
|
||||
self._params = dict()
|
||||
for i in range(freedom):
|
||||
self._params[i] = None
|
||||
self._freedom = freedom
|
||||
if list_of_points is not None and params is not None:
|
||||
raise ValueError("list_of_points and params can not be set simultaneously")
|
||||
if list_of_points is not None:
|
||||
self.fit(list_of_points=list_of_points)
|
||||
if params is not None:
|
||||
self.set(params)
|
||||
|
||||
def set(self, params):
|
||||
self._params = copy.deepcopy(params)
|
||||
|
||||
def check_valid(self):
|
||||
for key, value in self._params.items():
|
||||
if value is None:
|
||||
raise ValueError("The {:} is None".format(key))
|
||||
|
||||
@abc.abstractmethod
|
||||
def __call__(self, x):
|
||||
raise NotImplementedError
|
||||
|
||||
def noise_call(self, x, std=0.1):
|
||||
clean_y = self.__call__(x)
|
||||
if isinstance(clean_y, np.ndarray):
|
||||
noise_y = clean_y + np.random.normal(scale=std, size=clean_y.shape)
|
||||
else:
|
||||
raise ValueError("Unkonwn type: {:}".format(type(clean_y)))
|
||||
return noise_y
|
||||
|
||||
@abc.abstractmethod
|
||||
def _getitem(self, x):
|
||||
raise NotImplementedError
|
||||
|
||||
def fit(self, **kwargs):
|
||||
list_of_points = kwargs["list_of_points"]
|
||||
max_iter, lr_max, verbose = (
|
||||
kwargs.get("max_iter", 900),
|
||||
kwargs.get("lr_max", 1.0),
|
||||
kwargs.get("verbose", False),
|
||||
)
|
||||
with torch.no_grad():
|
||||
data = torch.Tensor(list_of_points).type(torch.float32)
|
||||
assert data.ndim == 2 and data.size(1) == 2, "Invalid shape : {:}".format(
|
||||
data.shape
|
||||
)
|
||||
x, y = data[:, 0], data[:, 1]
|
||||
weights = torch.nn.Parameter(torch.Tensor(self._freedom))
|
||||
torch.nn.init.normal_(weights, mean=0.0, std=1.0)
|
||||
optimizer = torch.optim.Adam([weights], lr=lr_max, amsgrad=True)
|
||||
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
||||
optimizer,
|
||||
milestones=[
|
||||
int(max_iter * 0.25),
|
||||
int(max_iter * 0.5),
|
||||
int(max_iter * 0.75),
|
||||
],
|
||||
gamma=0.1,
|
||||
)
|
||||
if verbose:
|
||||
print("The optimizer: {:}".format(optimizer))
|
||||
|
||||
best_loss = None
|
||||
for _iter in range(max_iter):
|
||||
y_hat = self._getitem(x, weights)
|
||||
loss = torch.mean(torch.abs(y - y_hat))
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
if verbose:
|
||||
print(
|
||||
"In the fit, loss at the {:02d}/{:02d}-th iter is {:}".format(
|
||||
_iter, max_iter, loss.item()
|
||||
)
|
||||
)
|
||||
# Update the params
|
||||
if best_loss is None or best_loss > loss.item():
|
||||
best_loss = loss.item()
|
||||
for i in range(self._freedom):
|
||||
self._params[i] = weights[i].item()
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(freedom={freedom})".format(
|
||||
name=self.__class__.__name__, freedom=freedom
|
||||
)
|
||||
|
||||
|
||||
class LinearFunc(FitFunc):
|
||||
"""The linear function that outputs f(x) = a * x + b."""
|
||||
|
||||
def __init__(self, list_of_points=None, params=None):
|
||||
super(LinearFunc, self).__init__(2, list_of_points, params)
|
||||
|
||||
def __call__(self, x):
|
||||
self.check_valid()
|
||||
return self._params[0] * x + self._params[1]
|
||||
|
||||
def _getitem(self, x, weights):
|
||||
return weights[0] * x + weights[1]
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}({a} * x + {b})".format(
|
||||
name=self.__class__.__name__,
|
||||
a=self._params[0],
|
||||
b=self._params[1],
|
||||
)
|
||||
|
||||
|
||||
class QuadraticFunc(FitFunc):
|
||||
"""The quadratic function that outputs f(x) = a * x^2 + b * x + c."""
|
||||
|
||||
def __init__(self, list_of_points=None, params=None):
|
||||
super(QuadraticFunc, self).__init__(3, list_of_points, params)
|
||||
|
||||
def __call__(self, x):
|
||||
self.check_valid()
|
||||
return self._params[0] * x * x + self._params[1] * x + self._params[2]
|
||||
|
||||
def _getitem(self, x, weights):
|
||||
return weights[0] * x * x + weights[1] * x + weights[2]
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}({a} * x^2 + {b} * x + {c})".format(
|
||||
name=self.__class__.__name__,
|
||||
a=self._params[0],
|
||||
b=self._params[1],
|
||||
c=self._params[2],
|
||||
)
|
||||
|
||||
|
||||
class CubicFunc(FitFunc):
|
||||
"""The cubic function that outputs f(x) = a * x^3 + b * x^2 + c * x + d."""
|
||||
|
||||
def __init__(self, list_of_points=None):
|
||||
super(CubicFunc, self).__init__(4, list_of_points)
|
||||
|
||||
def __call__(self, x):
|
||||
self.check_valid()
|
||||
return (
|
||||
self._params[0] * x ** 3
|
||||
+ self._params[1] * x ** 2
|
||||
+ self._params[2] * x
|
||||
+ self._params[3]
|
||||
)
|
||||
|
||||
def _getitem(self, x, weights):
|
||||
return weights[0] * x ** 3 + weights[1] * x ** 2 + weights[2] * x + weights[3]
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}({a} * x^3 + {b} * x^2 + {c} * x + {d})".format(
|
||||
name=self.__class__.__name__,
|
||||
a=self._params[0],
|
||||
b=self._params[1],
|
||||
c=self._params[2],
|
||||
d=self._params[3],
|
||||
)
|
||||
|
||||
|
||||
class QuarticFunc(FitFunc):
|
||||
"""The quartic function that outputs f(x) = a * x^4 + b * x^3 + c * x^2 + d * x + e."""
|
||||
|
||||
def __init__(self, list_of_points=None):
|
||||
super(QuarticFunc, self).__init__(5, list_of_points)
|
||||
|
||||
def __call__(self, x):
|
||||
self.check_valid()
|
||||
return (
|
||||
self._params[0] * x ** 4
|
||||
+ self._params[1] * x ** 3
|
||||
+ self._params[2] * x ** 2
|
||||
+ self._params[3] * x
|
||||
+ self._params[4]
|
||||
)
|
||||
|
||||
def _getitem(self, x, weights):
|
||||
return (
|
||||
weights[0] * x ** 4
|
||||
+ weights[1] * x ** 3
|
||||
+ weights[2] * x ** 2
|
||||
+ weights[3] * x
|
||||
+ weights[4]
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}({a} * x^4 + {b} * x^3 + {c} * x^2 + {d} * x + {e})".format(
|
||||
name=self.__class__.__name__,
|
||||
a=self._params[0],
|
||||
b=self._params[1],
|
||||
c=self._params[2],
|
||||
d=self._params[3],
|
||||
e=self._params[3],
|
||||
)
|
||||
@@ -1,8 +0,0 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 #
|
||||
#####################################################
|
||||
from .math_base_funcs import LinearFunc, QuadraticFunc, CubicFunc, QuarticFunc
|
||||
from .math_dynamic_funcs import DynamicLinearFunc
|
||||
from .math_dynamic_funcs import DynamicQuadraticFunc
|
||||
from .math_adv_funcs import ConstantFunc
|
||||
from .math_adv_funcs import ComposedSinFunc
|
||||
@@ -1,93 +0,0 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
|
||||
#####################################################
|
||||
import math
|
||||
import abc
|
||||
import copy
|
||||
import numpy as np
|
||||
from typing import Optional
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
|
||||
from .math_base_funcs import FitFunc
|
||||
|
||||
|
||||
class DynamicFunc(FitFunc):
|
||||
"""The dynamic quadratic function, where each param is a function."""
|
||||
|
||||
def __init__(self, freedom: int, params=None):
|
||||
super(DynamicFunc, self).__init__(freedom, None, params)
|
||||
self._timestamp = None
|
||||
|
||||
def __call__(self, x, timestamp=None):
|
||||
raise NotImplementedError
|
||||
|
||||
def _getitem(self, x, weights):
|
||||
raise NotImplementedError
|
||||
|
||||
def set_timestamp(self, timestamp):
|
||||
self._timestamp = timestamp
|
||||
|
||||
def noise_call(self, x, timestamp=None, std=0.1):
|
||||
clean_y = self.__call__(x, timestamp)
|
||||
if isinstance(clean_y, np.ndarray):
|
||||
noise_y = clean_y + np.random.normal(scale=std, size=clean_y.shape)
|
||||
else:
|
||||
raise ValueError("Unkonwn type: {:}".format(type(clean_y)))
|
||||
return noise_y
|
||||
|
||||
|
||||
class DynamicLinearFunc(DynamicFunc):
|
||||
"""The dynamic linear function that outputs f(x) = a * x + b.
|
||||
The a and b is a function of timestamp.
|
||||
"""
|
||||
|
||||
def __init__(self, params=None):
|
||||
super(DynamicLinearFunc, self).__init__(3, params)
|
||||
|
||||
def __call__(self, x, timestamp=None):
|
||||
self.check_valid()
|
||||
if timestamp is None:
|
||||
timestamp = self._timestamp
|
||||
a = self._params[0](timestamp)
|
||||
b = self._params[1](timestamp)
|
||||
convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x
|
||||
a, b = convert_fn(a), convert_fn(b)
|
||||
return a * x + b
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}({a} * x + {b}, timestamp={timestamp})".format(
|
||||
name=self.__class__.__name__,
|
||||
a=self._params[0],
|
||||
b=self._params[1],
|
||||
timestamp=self._timestamp,
|
||||
)
|
||||
|
||||
|
||||
class DynamicQuadraticFunc(DynamicFunc):
|
||||
"""The dynamic quadratic function that outputs f(x) = a * x^2 + b * x + c.
|
||||
The a, b, and c is a function of timestamp.
|
||||
"""
|
||||
|
||||
def __init__(self, params=None):
|
||||
super(DynamicQuadraticFunc, self).__init__(3, params)
|
||||
|
||||
def __call__(self, x, timestamp=None):
|
||||
self.check_valid()
|
||||
if timestamp is None:
|
||||
timestamp = self._timestamp
|
||||
a = self._params[0](timestamp)
|
||||
b = self._params[1](timestamp)
|
||||
c = self._params[2](timestamp)
|
||||
convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x
|
||||
a, b, c = convert_fn(a), convert_fn(b), convert_fn(c)
|
||||
return a * x * x + b * x + c
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}({a} * x^2 + {b} * x + {c}, timestamp={timestamp})".format(
|
||||
name=self.__class__.__name__,
|
||||
a=self._params[0],
|
||||
b=self._params[1],
|
||||
c=self._params[2],
|
||||
timestamp=self._timestamp,
|
||||
)
|
||||
@@ -1,58 +0,0 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 #
|
||||
#####################################################
|
||||
from .synthetic_utils import TimeStamp
|
||||
from .synthetic_env import EnvSampler
|
||||
from .synthetic_env import SyntheticDEnv
|
||||
from .math_core import LinearFunc
|
||||
from .math_core import DynamicLinearFunc
|
||||
from .math_core import DynamicQuadraticFunc
|
||||
from .math_core import ConstantFunc, ComposedSinFunc
|
||||
|
||||
|
||||
__all__ = ["TimeStamp", "SyntheticDEnv", "get_synthetic_env"]
|
||||
|
||||
|
||||
def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None, version="v1"):
|
||||
if version == "v1":
|
||||
mean_generator = ConstantFunc(0)
|
||||
std_generator = ConstantFunc(1)
|
||||
elif version == "v2":
|
||||
mean_generator = ComposedSinFunc()
|
||||
std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=1.5)
|
||||
else:
|
||||
raise ValueError("Unknown version: {:}".format(version))
|
||||
dynamic_env = SyntheticDEnv(
|
||||
[mean_generator],
|
||||
[[std_generator]],
|
||||
num_per_task=num_per_task,
|
||||
timestamp_config=dict(
|
||||
min_timestamp=-0.5, max_timestamp=1.5, num=total_timestamp, mode=mode
|
||||
),
|
||||
)
|
||||
if version == "v1":
|
||||
function = DynamicLinearFunc()
|
||||
function_param = dict()
|
||||
function_param[0] = ComposedSinFunc(
|
||||
amplitude_scale=ConstantFunc(3.0),
|
||||
num_sin_phase=9,
|
||||
sin_speed_use_power=False,
|
||||
)
|
||||
function_param[1] = ConstantFunc(constant=0.9)
|
||||
elif version == "v2":
|
||||
function = DynamicQuadraticFunc()
|
||||
function_param = dict()
|
||||
function_param[0] = ComposedSinFunc(
|
||||
num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0
|
||||
)
|
||||
function_param[1] = ConstantFunc(constant=0.9)
|
||||
function_param[2] = ComposedSinFunc(
|
||||
num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unknown version: {:}".format(version))
|
||||
|
||||
function.set(function_param)
|
||||
# dynamic_env.set_oracle_map(copy.deepcopy(function))
|
||||
dynamic_env.set_oracle_map(function)
|
||||
return dynamic_env
|
||||
@@ -1,180 +0,0 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
|
||||
#####################################################
|
||||
import math
|
||||
import random
|
||||
import numpy as np
|
||||
from typing import List, Optional, Dict
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
|
||||
from .synthetic_utils import TimeStamp
|
||||
|
||||
|
||||
def is_list_tuple(x):
|
||||
return isinstance(x, (tuple, list))
|
||||
|
||||
|
||||
def zip_sequence(sequence):
|
||||
def _combine(*alist):
|
||||
if is_list_tuple(alist[0]):
|
||||
return [_combine(*xlist) for xlist in zip(*alist)]
|
||||
else:
|
||||
return torch.cat(alist, dim=0)
|
||||
|
||||
def unsqueeze(a):
|
||||
if is_list_tuple(a):
|
||||
return [unsqueeze(x) for x in a]
|
||||
else:
|
||||
return a.unsqueeze(dim=0)
|
||||
|
||||
with torch.no_grad():
|
||||
sequence = [unsqueeze(a) for a in sequence]
|
||||
return _combine(*sequence)
|
||||
|
||||
|
||||
class SyntheticDEnv(data.Dataset):
|
||||
"""The synethtic dynamic environment."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mean_functors: List[data.Dataset],
|
||||
cov_functors: List[List[data.Dataset]],
|
||||
num_per_task: int = 5000,
|
||||
timestamp_config: Optional[Dict] = None,
|
||||
mode: Optional[str] = None,
|
||||
timestamp_noise_scale: float = 0.3,
|
||||
):
|
||||
self._ndim = len(mean_functors)
|
||||
assert self._ndim == len(
|
||||
cov_functors
|
||||
), "length does not match {:} vs. {:}".format(self._ndim, len(cov_functors))
|
||||
for cov_functor in cov_functors:
|
||||
assert self._ndim == len(
|
||||
cov_functor
|
||||
), "length does not match {:} vs. {:}".format(self._ndim, len(cov_functor))
|
||||
self._num_per_task = num_per_task
|
||||
if timestamp_config is None:
|
||||
timestamp_config = dict(mode=mode)
|
||||
elif "mode" not in timestamp_config:
|
||||
timestamp_config["mode"] = mode
|
||||
|
||||
self._timestamp_generator = TimeStamp(**timestamp_config)
|
||||
self._timestamp_noise_scale = timestamp_noise_scale
|
||||
|
||||
self._mean_functors = mean_functors
|
||||
self._cov_functors = cov_functors
|
||||
|
||||
self._oracle_map = None
|
||||
self._seq_length = None
|
||||
|
||||
@property
|
||||
def min_timestamp(self):
|
||||
return self._timestamp_generator.min_timestamp
|
||||
|
||||
@property
|
||||
def max_timestamp(self):
|
||||
return self._timestamp_generator.max_timestamp
|
||||
|
||||
@property
|
||||
def timestamp_interval(self):
|
||||
return self._timestamp_generator.interval
|
||||
|
||||
def random_timestamp(self):
|
||||
return (
|
||||
random.random() * (self.max_timestamp - self.min_timestamp)
|
||||
+ self.min_timestamp
|
||||
)
|
||||
|
||||
def reset_max_seq_length(self, seq_length):
|
||||
self._seq_length = seq_length
|
||||
|
||||
def get_timestamp(self, index):
|
||||
if index is None:
|
||||
timestamps = []
|
||||
for index in range(len(self._timestamp_generator)):
|
||||
timestamps.append(self._timestamp_generator[index][1])
|
||||
return tuple(timestamps)
|
||||
else:
|
||||
index, timestamp = self._timestamp_generator[index]
|
||||
return timestamp
|
||||
|
||||
def set_oracle_map(self, functor):
|
||||
self._oracle_map = functor
|
||||
|
||||
def __iter__(self):
|
||||
self._iter_num = 0
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self._iter_num >= len(self):
|
||||
raise StopIteration
|
||||
self._iter_num += 1
|
||||
return self.__getitem__(self._iter_num - 1)
|
||||
|
||||
def __getitem__(self, index):
|
||||
assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self))
|
||||
index, timestamp = self._timestamp_generator[index]
|
||||
if self._seq_length is None:
|
||||
return self.__call__(timestamp)
|
||||
else:
|
||||
noise = (
|
||||
random.random() * self.timestamp_interval * self._timestamp_noise_scale
|
||||
)
|
||||
timestamps = [
|
||||
timestamp + i * self.timestamp_interval + noise
|
||||
for i in range(self._seq_length)
|
||||
]
|
||||
xdata = [self.__call__(timestamp) for timestamp in timestamps]
|
||||
return zip_sequence(xdata)
|
||||
|
||||
def __call__(self, timestamp):
|
||||
mean_list = [functor(timestamp) for functor in self._mean_functors]
|
||||
cov_matrix = [
|
||||
[abs(cov_gen(timestamp)) for cov_gen in cov_functor]
|
||||
for cov_functor in self._cov_functors
|
||||
]
|
||||
|
||||
dataset = np.random.multivariate_normal(
|
||||
mean_list, cov_matrix, size=self._num_per_task
|
||||
)
|
||||
if self._oracle_map is None:
|
||||
return torch.Tensor([timestamp]), torch.Tensor(dataset)
|
||||
else:
|
||||
targets = self._oracle_map.noise_call(dataset, timestamp)
|
||||
return torch.Tensor([timestamp]), (
|
||||
torch.Tensor(dataset),
|
||||
torch.Tensor(targets),
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._timestamp_generator)
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}({cur_num:}/{total} elements, ndim={ndim}, num_per_task={num_per_task}, range=[{xrange_min:.5f}~{xrange_max:.5f}], mode={mode})".format(
|
||||
name=self.__class__.__name__,
|
||||
cur_num=len(self),
|
||||
total=len(self._timestamp_generator),
|
||||
ndim=self._ndim,
|
||||
num_per_task=self._num_per_task,
|
||||
xrange_min=self.min_timestamp,
|
||||
xrange_max=self.max_timestamp,
|
||||
mode=self._timestamp_generator.mode,
|
||||
)
|
||||
|
||||
|
||||
class EnvSampler:
|
||||
def __init__(self, env, batch, enlarge):
|
||||
indexes = list(range(len(env)))
|
||||
self._indexes = indexes * enlarge
|
||||
self._batch = batch
|
||||
self._iterations = len(self._indexes) // self._batch
|
||||
|
||||
def __iter__(self):
|
||||
random.shuffle(self._indexes)
|
||||
for it in range(self._iterations):
|
||||
indexes = self._indexes[it * self._batch : (it + 1) * self._batch]
|
||||
yield indexes
|
||||
|
||||
def __len__(self):
|
||||
return self._iterations
|
||||
@@ -1,72 +0,0 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
|
||||
#####################################################
|
||||
import copy
|
||||
|
||||
from .math_dynamic_funcs import DynamicLinearFunc, DynamicQuadraticFunc
|
||||
from .math_adv_funcs import ConstantFunc, ComposedSinFunc
|
||||
from .synthetic_env import SyntheticDEnv
|
||||
|
||||
|
||||
def create_example(timestamp_config=None, num_per_task=5000, indicator="v1"):
|
||||
if indicator == "v1":
|
||||
return create_example_v1(timestamp_config, num_per_task)
|
||||
elif indicator == "v2":
|
||||
return create_example_v2(timestamp_config, num_per_task)
|
||||
else:
|
||||
raise ValueError("Unkonwn indicator: {:}".format(indicator))
|
||||
|
||||
|
||||
def create_example_v1(
|
||||
timestamp_config=None,
|
||||
num_per_task=5000,
|
||||
):
|
||||
mean_generator = ComposedSinFunc()
|
||||
std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=0.5)
|
||||
|
||||
dynamic_env = SyntheticDEnv(
|
||||
[mean_generator],
|
||||
[[std_generator]],
|
||||
num_per_task=num_per_task,
|
||||
timestamp_config=timestamp_config,
|
||||
)
|
||||
|
||||
function = DynamicQuadraticFunc()
|
||||
function_param = dict()
|
||||
function_param[0] = ComposedSinFunc(
|
||||
num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0
|
||||
)
|
||||
function_param[1] = ConstantFunc(constant=0.9)
|
||||
function_param[2] = ComposedSinFunc(
|
||||
num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9
|
||||
)
|
||||
function.set(function_param)
|
||||
|
||||
dynamic_env.set_oracle_map(copy.deepcopy(function))
|
||||
return dynamic_env, function
|
||||
|
||||
|
||||
def create_example_v2(
|
||||
timestamp_config=None,
|
||||
num_per_task=5000,
|
||||
):
|
||||
mean_generator = ConstantFunc(0)
|
||||
std_generator = ConstantFunc(1)
|
||||
|
||||
dynamic_env = SyntheticDEnv(
|
||||
[mean_generator],
|
||||
[[std_generator]],
|
||||
num_per_task=num_per_task,
|
||||
timestamp_config=timestamp_config,
|
||||
)
|
||||
|
||||
function = DynamicLinearFunc()
|
||||
function_param = dict()
|
||||
function_param[0] = ComposedSinFunc(
|
||||
amplitude_scale=ConstantFunc(1.0), period_phase_shift=ConstantFunc(1.0)
|
||||
)
|
||||
function_param[1] = ConstantFunc(constant=0.9)
|
||||
function.set(function_param)
|
||||
|
||||
dynamic_env.set_oracle_map(copy.deepcopy(function))
|
||||
return dynamic_env, function
|
||||
@@ -1,93 +0,0 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
|
||||
#####################################################
|
||||
import math
|
||||
import abc
|
||||
import numpy as np
|
||||
from typing import Optional
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
|
||||
|
||||
class UnifiedSplit:
|
||||
"""A class to unify the split strategy."""
|
||||
|
||||
def __init__(self, total_num, mode):
|
||||
# Training Set 60%
|
||||
num_of_train = int(total_num * 0.6)
|
||||
# Validation Set 20%
|
||||
num_of_valid = int(total_num * 0.2)
|
||||
# Test Set 20%
|
||||
num_of_set = total_num - num_of_train - num_of_valid
|
||||
all_indexes = list(range(total_num))
|
||||
if mode is None:
|
||||
self._indexes = all_indexes
|
||||
elif mode.lower() in ("train", "training"):
|
||||
self._indexes = all_indexes[:num_of_train]
|
||||
elif mode.lower() in ("valid", "validation"):
|
||||
self._indexes = all_indexes[num_of_train : num_of_train + num_of_valid]
|
||||
elif mode.lower() in ("test", "testing"):
|
||||
self._indexes = all_indexes[num_of_train + num_of_valid :]
|
||||
else:
|
||||
raise ValueError("Unkonwn mode of {:}".format(mode))
|
||||
self._all_indexes = all_indexes
|
||||
self._mode = mode
|
||||
|
||||
@property
|
||||
def mode(self):
|
||||
return self._mode
|
||||
|
||||
|
||||
class TimeStamp(UnifiedSplit, data.Dataset):
|
||||
"""The timestamp dataset."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
min_timestamp: float = 0.0,
|
||||
max_timestamp: float = 1.0,
|
||||
num: int = 100,
|
||||
mode: Optional[str] = None,
|
||||
):
|
||||
self._min_timestamp = min_timestamp
|
||||
self._max_timestamp = max_timestamp
|
||||
self._interval = (max_timestamp - min_timestamp) / (float(num) - 1)
|
||||
self._total_num = num
|
||||
UnifiedSplit.__init__(self, self._total_num, mode)
|
||||
|
||||
@property
|
||||
def min_timestamp(self):
|
||||
return self._min_timestamp + self._interval * min(self._indexes)
|
||||
|
||||
@property
|
||||
def max_timestamp(self):
|
||||
return self._min_timestamp + self._interval * max(self._indexes)
|
||||
|
||||
@property
|
||||
def interval(self):
|
||||
return self._interval
|
||||
|
||||
def __iter__(self):
|
||||
self._iter_num = 0
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self._iter_num >= len(self):
|
||||
raise StopIteration
|
||||
self._iter_num += 1
|
||||
return self.__getitem__(self._iter_num - 1)
|
||||
|
||||
def __getitem__(self, index):
|
||||
assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self))
|
||||
index = self._indexes[index]
|
||||
timestamp = self._min_timestamp + self._interval * index
|
||||
return index, timestamp
|
||||
|
||||
def __len__(self):
|
||||
return len(self._indexes)
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}({cur_num:}/{total} elements)".format(
|
||||
name=self.__class__.__name__,
|
||||
cur_num=len(self),
|
||||
total=self._total_num,
|
||||
)
|
||||
@@ -1,24 +0,0 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import os
|
||||
|
||||
|
||||
def test_imagenet_data(imagenet):
|
||||
total_length = len(imagenet)
|
||||
assert (
|
||||
total_length == 1281166 or total_length == 50000
|
||||
), "The length of ImageNet is wrong : {}".format(total_length)
|
||||
map_id = {}
|
||||
for index in range(total_length):
|
||||
path, target = imagenet.imgs[index]
|
||||
folder, image_name = os.path.split(path)
|
||||
_, folder = os.path.split(folder)
|
||||
if folder not in map_id:
|
||||
map_id[folder] = target
|
||||
else:
|
||||
assert map_id[folder] == target, "Class : {} is not {}".format(
|
||||
folder, target
|
||||
)
|
||||
assert image_name.find(folder) == 0, "{} is wrong.".format(path)
|
||||
print("Check ImageNet Dataset OK")
|
||||
@@ -1,16 +0,0 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
# every package does not rely on pytorch or tensorflow
|
||||
# I tried to list all dependency here: os, sys, time, numpy, (possibly) matplotlib
|
||||
##################################################
|
||||
from .logger import Logger, PrintLogger
|
||||
from .meter import AverageMeter
|
||||
from .time_utils import (
|
||||
time_for_file,
|
||||
time_string,
|
||||
time_string_short,
|
||||
time_print,
|
||||
convert_secs2time,
|
||||
)
|
||||
from .pickle_wrap import pickle_save, pickle_load
|
||||
@@ -1,173 +0,0 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
from pathlib import Path
|
||||
import importlib, warnings
|
||||
import os, sys, time, numpy as np
|
||||
|
||||
if sys.version_info.major == 2: # Python 2.x
|
||||
from StringIO import StringIO as BIO
|
||||
else: # Python 3.x
|
||||
from io import BytesIO as BIO
|
||||
|
||||
if importlib.util.find_spec("tensorflow"):
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class PrintLogger(object):
|
||||
def __init__(self):
|
||||
"""Create a summary writer logging to log_dir."""
|
||||
self.name = "PrintLogger"
|
||||
|
||||
def log(self, string):
|
||||
print(string)
|
||||
|
||||
def close(self):
|
||||
print("-" * 30 + " close printer " + "-" * 30)
|
||||
|
||||
|
||||
class Logger(object):
|
||||
def __init__(self, log_dir, seed, create_model_dir=True, use_tf=False):
|
||||
"""Create a summary writer logging to log_dir."""
|
||||
self.seed = int(seed)
|
||||
self.log_dir = Path(log_dir)
|
||||
self.model_dir = Path(log_dir) / "checkpoint"
|
||||
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||
if create_model_dir:
|
||||
self.model_dir.mkdir(parents=True, exist_ok=True)
|
||||
# self.meta_dir.mkdir(mode=0o775, parents=True, exist_ok=True)
|
||||
|
||||
self.use_tf = bool(use_tf)
|
||||
self.tensorboard_dir = self.log_dir / (
|
||||
"tensorboard-{:}".format(time.strftime("%d-%h", time.gmtime(time.time())))
|
||||
)
|
||||
# self.tensorboard_dir = self.log_dir / ('tensorboard-{:}'.format(time.strftime( '%d-%h-at-%H:%M:%S', time.gmtime(time.time()) )))
|
||||
self.logger_path = self.log_dir / "seed-{:}-T-{:}.log".format(
|
||||
self.seed, time.strftime("%d-%h-at-%H-%M-%S", time.gmtime(time.time()))
|
||||
)
|
||||
self.logger_file = open(self.logger_path, "w")
|
||||
|
||||
if self.use_tf:
|
||||
self.tensorboard_dir.mkdir(mode=0o775, parents=True, exist_ok=True)
|
||||
self.writer = tf.summary.FileWriter(str(self.tensorboard_dir))
|
||||
else:
|
||||
self.writer = None
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(dir={log_dir}, use-tf={use_tf}, writer={writer})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
def path(self, mode):
|
||||
valids = ("model", "best", "info", "log", None)
|
||||
if mode is None:
|
||||
return self.log_dir
|
||||
elif mode == "model":
|
||||
return self.model_dir / "seed-{:}-basic.pth".format(self.seed)
|
||||
elif mode == "best":
|
||||
return self.model_dir / "seed-{:}-best.pth".format(self.seed)
|
||||
elif mode == "info":
|
||||
return self.log_dir / "seed-{:}-last-info.pth".format(self.seed)
|
||||
elif mode == "log":
|
||||
return self.log_dir
|
||||
else:
|
||||
raise TypeError("Unknow mode = {:}, valid modes = {:}".format(mode, valids))
|
||||
|
||||
def extract_log(self):
|
||||
return self.logger_file
|
||||
|
||||
def close(self):
|
||||
self.logger_file.close()
|
||||
if self.writer is not None:
|
||||
self.writer.close()
|
||||
|
||||
def log(self, string, save=True, stdout=False):
|
||||
if stdout:
|
||||
sys.stdout.write(string)
|
||||
sys.stdout.flush()
|
||||
else:
|
||||
print(string)
|
||||
if save:
|
||||
self.logger_file.write("{:}\n".format(string))
|
||||
self.logger_file.flush()
|
||||
|
||||
def scalar_summary(self, tags, values, step):
|
||||
"""Log a scalar variable."""
|
||||
if not self.use_tf:
|
||||
warnings.warn("Do set use-tensorflow installed but call scalar_summary")
|
||||
else:
|
||||
assert isinstance(tags, list) == isinstance(
|
||||
values, list
|
||||
), "Type : {:} vs {:}".format(type(tags), type(values))
|
||||
if not isinstance(tags, list):
|
||||
tags, values = [tags], [values]
|
||||
for tag, value in zip(tags, values):
|
||||
summary = tf.Summary(
|
||||
value=[tf.Summary.Value(tag=tag, simple_value=value)]
|
||||
)
|
||||
self.writer.add_summary(summary, step)
|
||||
self.writer.flush()
|
||||
|
||||
def image_summary(self, tag, images, step):
|
||||
"""Log a list of images."""
|
||||
import scipy
|
||||
|
||||
if not self.use_tf:
|
||||
warnings.warn("Do set use-tensorflow installed but call scalar_summary")
|
||||
return
|
||||
|
||||
img_summaries = []
|
||||
for i, img in enumerate(images):
|
||||
# Write the image to a string
|
||||
try:
|
||||
s = StringIO()
|
||||
except:
|
||||
s = BytesIO()
|
||||
scipy.misc.toimage(img).save(s, format="png")
|
||||
|
||||
# Create an Image object
|
||||
img_sum = tf.Summary.Image(
|
||||
encoded_image_string=s.getvalue(),
|
||||
height=img.shape[0],
|
||||
width=img.shape[1],
|
||||
)
|
||||
# Create a Summary value
|
||||
img_summaries.append(
|
||||
tf.Summary.Value(tag="{}/{}".format(tag, i), image=img_sum)
|
||||
)
|
||||
|
||||
# Create and write Summary
|
||||
summary = tf.Summary(value=img_summaries)
|
||||
self.writer.add_summary(summary, step)
|
||||
self.writer.flush()
|
||||
|
||||
def histo_summary(self, tag, values, step, bins=1000):
|
||||
"""Log a histogram of the tensor of values."""
|
||||
if not self.use_tf:
|
||||
raise ValueError("Do not have tensorflow")
|
||||
import tensorflow as tf
|
||||
|
||||
# Create a histogram using numpy
|
||||
counts, bin_edges = np.histogram(values, bins=bins)
|
||||
|
||||
# Fill the fields of the histogram proto
|
||||
hist = tf.HistogramProto()
|
||||
hist.min = float(np.min(values))
|
||||
hist.max = float(np.max(values))
|
||||
hist.num = int(np.prod(values.shape))
|
||||
hist.sum = float(np.sum(values))
|
||||
hist.sum_squares = float(np.sum(values ** 2))
|
||||
|
||||
# Drop the start of the first bin
|
||||
bin_edges = bin_edges[1:]
|
||||
|
||||
# Add bin edges and counts
|
||||
for edge in bin_edges:
|
||||
hist.bucket_limit.append(edge)
|
||||
for c in counts:
|
||||
hist.bucket.append(c)
|
||||
|
||||
# Create and write Summary
|
||||
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])
|
||||
self.writer.add_summary(summary, step)
|
||||
self.writer.flush()
|
||||
@@ -1,120 +0,0 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0.0
|
||||
self.avg = 0.0
|
||||
self.sum = 0.0
|
||||
self.count = 0.0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(val={val}, avg={avg}, count={count})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
|
||||
class RecorderMeter(object):
|
||||
"""Computes and stores the minimum loss value and its epoch index"""
|
||||
|
||||
def __init__(self, total_epoch):
|
||||
self.reset(total_epoch)
|
||||
|
||||
def reset(self, total_epoch):
|
||||
assert total_epoch > 0, "total_epoch should be greater than 0 vs {:}".format(
|
||||
total_epoch
|
||||
)
|
||||
self.total_epoch = total_epoch
|
||||
self.current_epoch = 0
|
||||
self.epoch_losses = np.zeros(
|
||||
(self.total_epoch, 2), dtype=np.float32
|
||||
) # [epoch, train/val]
|
||||
self.epoch_losses = self.epoch_losses - 1
|
||||
self.epoch_accuracy = np.zeros(
|
||||
(self.total_epoch, 2), dtype=np.float32
|
||||
) # [epoch, train/val]
|
||||
self.epoch_accuracy = self.epoch_accuracy
|
||||
|
||||
def update(self, idx, train_loss, train_acc, val_loss, val_acc):
|
||||
assert (
|
||||
idx >= 0 and idx < self.total_epoch
|
||||
), "total_epoch : {} , but update with the {} index".format(
|
||||
self.total_epoch, idx
|
||||
)
|
||||
self.epoch_losses[idx, 0] = train_loss
|
||||
self.epoch_losses[idx, 1] = val_loss
|
||||
self.epoch_accuracy[idx, 0] = train_acc
|
||||
self.epoch_accuracy[idx, 1] = val_acc
|
||||
self.current_epoch = idx + 1
|
||||
return self.max_accuracy(False) == self.epoch_accuracy[idx, 1]
|
||||
|
||||
def max_accuracy(self, istrain):
|
||||
if self.current_epoch <= 0:
|
||||
return 0
|
||||
if istrain:
|
||||
return self.epoch_accuracy[: self.current_epoch, 0].max()
|
||||
else:
|
||||
return self.epoch_accuracy[: self.current_epoch, 1].max()
|
||||
|
||||
def plot_curve(self, save_path):
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use("agg")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
title = "the accuracy/loss curve of train/val"
|
||||
dpi = 100
|
||||
width, height = 1600, 1000
|
||||
legend_fontsize = 10
|
||||
figsize = width / float(dpi), height / float(dpi)
|
||||
|
||||
fig = plt.figure(figsize=figsize)
|
||||
x_axis = np.array([i for i in range(self.total_epoch)]) # epochs
|
||||
y_axis = np.zeros(self.total_epoch)
|
||||
|
||||
plt.xlim(0, self.total_epoch)
|
||||
plt.ylim(0, 100)
|
||||
interval_y = 5
|
||||
interval_x = 5
|
||||
plt.xticks(np.arange(0, self.total_epoch + interval_x, interval_x))
|
||||
plt.yticks(np.arange(0, 100 + interval_y, interval_y))
|
||||
plt.grid()
|
||||
plt.title(title, fontsize=20)
|
||||
plt.xlabel("the training epoch", fontsize=16)
|
||||
plt.ylabel("accuracy", fontsize=16)
|
||||
|
||||
y_axis[:] = self.epoch_accuracy[:, 0]
|
||||
plt.plot(x_axis, y_axis, color="g", linestyle="-", label="train-accuracy", lw=2)
|
||||
plt.legend(loc=4, fontsize=legend_fontsize)
|
||||
|
||||
y_axis[:] = self.epoch_accuracy[:, 1]
|
||||
plt.plot(x_axis, y_axis, color="y", linestyle="-", label="valid-accuracy", lw=2)
|
||||
plt.legend(loc=4, fontsize=legend_fontsize)
|
||||
|
||||
y_axis[:] = self.epoch_losses[:, 0]
|
||||
plt.plot(
|
||||
x_axis, y_axis * 50, color="g", linestyle=":", label="train-loss-x50", lw=2
|
||||
)
|
||||
plt.legend(loc=4, fontsize=legend_fontsize)
|
||||
|
||||
y_axis[:] = self.epoch_losses[:, 1]
|
||||
plt.plot(
|
||||
x_axis, y_axis * 50, color="y", linestyle=":", label="valid-loss-x50", lw=2
|
||||
)
|
||||
plt.legend(loc=4, fontsize=legend_fontsize)
|
||||
|
||||
if save_path is not None:
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches="tight")
|
||||
print("---- save figure {} into {}".format(title, save_path))
|
||||
plt.close(fig)
|
||||
@@ -1,21 +0,0 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def pickle_save(obj, path):
|
||||
file_path = Path(path)
|
||||
file_dir = file_path.parent
|
||||
file_dir.mkdir(parents=True, exist_ok=True)
|
||||
with file_path.open("wb") as f:
|
||||
pickle.dump(obj, f)
|
||||
|
||||
|
||||
def pickle_load(path):
|
||||
if not Path(path).exists():
|
||||
raise ValueError("{:} does not exists".format(path))
|
||||
with Path(path).open("rb") as f:
|
||||
data = pickle.load(f)
|
||||
return data
|
||||
@@ -1,49 +0,0 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import time, sys
|
||||
import numpy as np
|
||||
|
||||
|
||||
def time_for_file():
|
||||
ISOTIMEFORMAT = "%d-%h-at-%H-%M-%S"
|
||||
return "{:}".format(time.strftime(ISOTIMEFORMAT, time.gmtime(time.time())))
|
||||
|
||||
|
||||
def time_string():
|
||||
ISOTIMEFORMAT = "%Y-%m-%d %X"
|
||||
string = "[{:}]".format(time.strftime(ISOTIMEFORMAT, time.gmtime(time.time())))
|
||||
return string
|
||||
|
||||
|
||||
def time_string_short():
|
||||
ISOTIMEFORMAT = "%Y%m%d"
|
||||
string = "{:}".format(time.strftime(ISOTIMEFORMAT, time.gmtime(time.time())))
|
||||
return string
|
||||
|
||||
|
||||
def time_print(string, is_print=True):
|
||||
if is_print:
|
||||
print("{} : {}".format(time_string(), string))
|
||||
|
||||
|
||||
def convert_secs2time(epoch_time, return_str=False):
|
||||
need_hour = int(epoch_time / 3600)
|
||||
need_mins = int((epoch_time - 3600 * need_hour) / 60)
|
||||
need_secs = int(epoch_time - 3600 * need_hour - 60 * need_mins)
|
||||
if return_str:
|
||||
str = "[{:02d}:{:02d}:{:02d}]".format(need_hour, need_mins, need_secs)
|
||||
return str
|
||||
else:
|
||||
return need_hour, need_mins, need_secs
|
||||
|
||||
|
||||
def print_log(print_string, log):
|
||||
# if isinstance(log, Logger): log.log('{:}'.format(print_string))
|
||||
if hasattr(log, "log"):
|
||||
log.log("{:}".format(print_string))
|
||||
else:
|
||||
print("{:}".format(print_string))
|
||||
if log is not None:
|
||||
log.write("{:}\n".format(print_string))
|
||||
log.flush()
|
||||
@@ -1,117 +0,0 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import math, torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .initialization import initialize_resnet
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
def __init__(self, nChannels, growthRate):
|
||||
super(Bottleneck, self).__init__()
|
||||
interChannels = 4 * growthRate
|
||||
self.bn1 = nn.BatchNorm2d(nChannels)
|
||||
self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(interChannels)
|
||||
self.conv2 = nn.Conv2d(
|
||||
interChannels, growthRate, kernel_size=3, padding=1, bias=False
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1(F.relu(self.bn1(x)))
|
||||
out = self.conv2(F.relu(self.bn2(out)))
|
||||
out = torch.cat((x, out), 1)
|
||||
return out
|
||||
|
||||
|
||||
class SingleLayer(nn.Module):
|
||||
def __init__(self, nChannels, growthRate):
|
||||
super(SingleLayer, self).__init__()
|
||||
self.bn1 = nn.BatchNorm2d(nChannels)
|
||||
self.conv1 = nn.Conv2d(
|
||||
nChannels, growthRate, kernel_size=3, padding=1, bias=False
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1(F.relu(self.bn1(x)))
|
||||
out = torch.cat((x, out), 1)
|
||||
return out
|
||||
|
||||
|
||||
class Transition(nn.Module):
|
||||
def __init__(self, nChannels, nOutChannels):
|
||||
super(Transition, self).__init__()
|
||||
self.bn1 = nn.BatchNorm2d(nChannels)
|
||||
self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1(F.relu(self.bn1(x)))
|
||||
out = F.avg_pool2d(out, 2)
|
||||
return out
|
||||
|
||||
|
||||
class DenseNet(nn.Module):
|
||||
def __init__(self, growthRate, depth, reduction, nClasses, bottleneck):
|
||||
super(DenseNet, self).__init__()
|
||||
|
||||
if bottleneck:
|
||||
nDenseBlocks = int((depth - 4) / 6)
|
||||
else:
|
||||
nDenseBlocks = int((depth - 4) / 3)
|
||||
|
||||
self.message = "CifarDenseNet : block : {:}, depth : {:}, reduction : {:}, growth-rate = {:}, class = {:}".format(
|
||||
"bottleneck" if bottleneck else "basic",
|
||||
depth,
|
||||
reduction,
|
||||
growthRate,
|
||||
nClasses,
|
||||
)
|
||||
|
||||
nChannels = 2 * growthRate
|
||||
self.conv1 = nn.Conv2d(3, nChannels, kernel_size=3, padding=1, bias=False)
|
||||
|
||||
self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck)
|
||||
nChannels += nDenseBlocks * growthRate
|
||||
nOutChannels = int(math.floor(nChannels * reduction))
|
||||
self.trans1 = Transition(nChannels, nOutChannels)
|
||||
|
||||
nChannels = nOutChannels
|
||||
self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck)
|
||||
nChannels += nDenseBlocks * growthRate
|
||||
nOutChannels = int(math.floor(nChannels * reduction))
|
||||
self.trans2 = Transition(nChannels, nOutChannels)
|
||||
|
||||
nChannels = nOutChannels
|
||||
self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck)
|
||||
nChannels += nDenseBlocks * growthRate
|
||||
|
||||
self.act = nn.Sequential(
|
||||
nn.BatchNorm2d(nChannels), nn.ReLU(inplace=True), nn.AvgPool2d(8)
|
||||
)
|
||||
self.fc = nn.Linear(nChannels, nClasses)
|
||||
|
||||
self.apply(initialize_resnet)
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck):
|
||||
layers = []
|
||||
for i in range(int(nDenseBlocks)):
|
||||
if bottleneck:
|
||||
layers.append(Bottleneck(nChannels, growthRate))
|
||||
else:
|
||||
layers.append(SingleLayer(nChannels, growthRate))
|
||||
nChannels += growthRate
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, inputs):
|
||||
out = self.conv1(inputs)
|
||||
out = self.trans1(self.dense1(out))
|
||||
out = self.trans2(self.dense2(out))
|
||||
out = self.dense3(out)
|
||||
features = self.act(out)
|
||||
features = features.view(features.size(0), -1)
|
||||
out = self.fc(features)
|
||||
return features, out
|
||||
@@ -1,180 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .initialization import initialize_resnet
|
||||
from .SharedUtils import additive_func
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, nIn, nOut, stride):
|
||||
super(Downsample, self).__init__()
|
||||
assert stride == 2 and nOut == 2 * nIn, "stride:{} IO:{},{}".format(
|
||||
stride, nIn, nOut
|
||||
)
|
||||
self.in_dim = nIn
|
||||
self.out_dim = nOut
|
||||
self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
|
||||
self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=1, padding=0, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.avg(x)
|
||||
out = self.conv(x)
|
||||
return out
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
def __init__(self, nIn, nOut, kernel, stride, padding, bias, relu):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, bias=bias
|
||||
)
|
||||
self.bn = nn.BatchNorm2d(nOut)
|
||||
if relu:
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
else:
|
||||
self.relu = None
|
||||
self.out_dim = nOut
|
||||
self.num_conv = 1
|
||||
|
||||
def forward(self, x):
|
||||
conv = self.conv(x)
|
||||
bn = self.bn(conv)
|
||||
if self.relu:
|
||||
return self.relu(bn)
|
||||
else:
|
||||
return bn
|
||||
|
||||
|
||||
class ResNetBasicblock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride):
|
||||
super(ResNetBasicblock, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
self.conv_a = ConvBNReLU(inplanes, planes, 3, stride, 1, False, True)
|
||||
self.conv_b = ConvBNReLU(planes, planes, 3, 1, 1, False, False)
|
||||
if stride == 2:
|
||||
self.downsample = Downsample(inplanes, planes, stride)
|
||||
elif inplanes != planes:
|
||||
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, False)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.out_dim = planes
|
||||
self.num_conv = 2
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
basicblock = self.conv_a(inputs)
|
||||
basicblock = self.conv_b(basicblock)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = additive_func(residual, basicblock)
|
||||
return F.relu(out, inplace=True)
|
||||
|
||||
|
||||
class ResNetBottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride):
|
||||
super(ResNetBottleneck, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
self.conv_1x1 = ConvBNReLU(inplanes, planes, 1, 1, 0, False, True)
|
||||
self.conv_3x3 = ConvBNReLU(planes, planes, 3, stride, 1, False, True)
|
||||
self.conv_1x4 = ConvBNReLU(
|
||||
planes, planes * self.expansion, 1, 1, 0, False, False
|
||||
)
|
||||
if stride == 2:
|
||||
self.downsample = Downsample(inplanes, planes * self.expansion, stride)
|
||||
elif inplanes != planes * self.expansion:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes, planes * self.expansion, 1, 1, 0, False, False
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.out_dim = planes * self.expansion
|
||||
self.num_conv = 3
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
bottleneck = self.conv_1x1(inputs)
|
||||
bottleneck = self.conv_3x3(bottleneck)
|
||||
bottleneck = self.conv_1x4(bottleneck)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = additive_func(residual, bottleneck)
|
||||
return F.relu(out, inplace=True)
|
||||
|
||||
|
||||
class CifarResNet(nn.Module):
|
||||
def __init__(self, block_name, depth, num_classes, zero_init_residual):
|
||||
super(CifarResNet, self).__init__()
|
||||
|
||||
# Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
|
||||
if block_name == "ResNetBasicblock":
|
||||
block = ResNetBasicblock
|
||||
assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110"
|
||||
layer_blocks = (depth - 2) // 6
|
||||
elif block_name == "ResNetBottleneck":
|
||||
block = ResNetBottleneck
|
||||
assert (depth - 2) % 9 == 0, "depth should be one of 164"
|
||||
layer_blocks = (depth - 2) // 9
|
||||
else:
|
||||
raise ValueError("invalid block : {:}".format(block_name))
|
||||
|
||||
self.message = "CifarResNet : Block : {:}, Depth : {:}, Layers for each block : {:}".format(
|
||||
block_name, depth, layer_blocks
|
||||
)
|
||||
self.num_classes = num_classes
|
||||
self.channels = [16]
|
||||
self.layers = nn.ModuleList([ConvBNReLU(3, 16, 3, 1, 1, False, True)])
|
||||
for stage in range(3):
|
||||
for iL in range(layer_blocks):
|
||||
iC = self.channels[-1]
|
||||
planes = 16 * (2 ** stage)
|
||||
stride = 2 if stage > 0 and iL == 0 else 1
|
||||
module = block(iC, planes, stride)
|
||||
self.channels.append(module.out_dim)
|
||||
self.layers.append(module)
|
||||
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(
|
||||
stage,
|
||||
iL,
|
||||
layer_blocks,
|
||||
len(self.layers) - 1,
|
||||
iC,
|
||||
module.out_dim,
|
||||
stride,
|
||||
)
|
||||
|
||||
self.avgpool = nn.AvgPool2d(8)
|
||||
self.classifier = nn.Linear(module.out_dim, num_classes)
|
||||
assert (
|
||||
sum(x.num_conv for x in self.layers) + 1 == depth
|
||||
), "invalid depth check {:} vs {:}".format(
|
||||
sum(x.num_conv for x in self.layers) + 1, depth
|
||||
)
|
||||
|
||||
self.apply(initialize_resnet)
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, ResNetBasicblock):
|
||||
nn.init.constant_(m.conv_b.bn.weight, 0)
|
||||
elif isinstance(m, ResNetBottleneck):
|
||||
nn.init.constant_(m.conv_1x4.bn.weight, 0)
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def forward(self, inputs):
|
||||
x = inputs
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
features = self.avgpool(x)
|
||||
features = features.view(features.size(0), -1)
|
||||
logits = self.classifier(features)
|
||||
return features, logits
|
||||
@@ -1,115 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .initialization import initialize_resnet
|
||||
|
||||
|
||||
class WideBasicblock(nn.Module):
|
||||
def __init__(self, inplanes, planes, stride, dropout=False):
|
||||
super(WideBasicblock, self).__init__()
|
||||
|
||||
self.bn_a = nn.BatchNorm2d(inplanes)
|
||||
self.conv_a = nn.Conv2d(
|
||||
inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False
|
||||
)
|
||||
|
||||
self.bn_b = nn.BatchNorm2d(planes)
|
||||
if dropout:
|
||||
self.dropout = nn.Dropout2d(p=0.5, inplace=True)
|
||||
else:
|
||||
self.dropout = None
|
||||
self.conv_b = nn.Conv2d(
|
||||
planes, planes, kernel_size=3, stride=1, padding=1, bias=False
|
||||
)
|
||||
|
||||
if inplanes != planes:
|
||||
self.downsample = nn.Conv2d(
|
||||
inplanes, planes, kernel_size=1, stride=stride, padding=0, bias=False
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
basicblock = self.bn_a(x)
|
||||
basicblock = F.relu(basicblock)
|
||||
basicblock = self.conv_a(basicblock)
|
||||
|
||||
basicblock = self.bn_b(basicblock)
|
||||
basicblock = F.relu(basicblock)
|
||||
if self.dropout is not None:
|
||||
basicblock = self.dropout(basicblock)
|
||||
basicblock = self.conv_b(basicblock)
|
||||
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
|
||||
return x + basicblock
|
||||
|
||||
|
||||
class CifarWideResNet(nn.Module):
|
||||
"""
|
||||
ResNet optimized for the Cifar dataset, as specified in
|
||||
https://arxiv.org/abs/1512.03385.pdf
|
||||
"""
|
||||
|
||||
def __init__(self, depth, widen_factor, num_classes, dropout):
|
||||
super(CifarWideResNet, self).__init__()
|
||||
|
||||
# Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
|
||||
assert (depth - 4) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110"
|
||||
layer_blocks = (depth - 4) // 6
|
||||
print(
|
||||
"CifarPreResNet : Depth : {} , Layers for each block : {}".format(
|
||||
depth, layer_blocks
|
||||
)
|
||||
)
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.dropout = dropout
|
||||
self.conv_3x3 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
|
||||
self.message = "Wide ResNet : depth={:}, widen_factor={:}, class={:}".format(
|
||||
depth, widen_factor, num_classes
|
||||
)
|
||||
self.inplanes = 16
|
||||
self.stage_1 = self._make_layer(
|
||||
WideBasicblock, 16 * widen_factor, layer_blocks, 1
|
||||
)
|
||||
self.stage_2 = self._make_layer(
|
||||
WideBasicblock, 32 * widen_factor, layer_blocks, 2
|
||||
)
|
||||
self.stage_3 = self._make_layer(
|
||||
WideBasicblock, 64 * widen_factor, layer_blocks, 2
|
||||
)
|
||||
self.lastact = nn.Sequential(
|
||||
nn.BatchNorm2d(64 * widen_factor), nn.ReLU(inplace=True)
|
||||
)
|
||||
self.avgpool = nn.AvgPool2d(8)
|
||||
self.classifier = nn.Linear(64 * widen_factor, num_classes)
|
||||
|
||||
self.apply(initialize_resnet)
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride):
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, self.dropout))
|
||||
self.inplanes = planes
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes, 1, self.dropout))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_3x3(x)
|
||||
x = self.stage_1(x)
|
||||
x = self.stage_2(x)
|
||||
x = self.stage_3(x)
|
||||
x = self.lastact(x)
|
||||
x = self.avgpool(x)
|
||||
features = x.view(x.size(0), -1)
|
||||
outs = self.classifier(features)
|
||||
return features, outs
|
||||
@@ -1,117 +0,0 @@
|
||||
# MobileNetV2: Inverted Residuals and Linear Bottlenecks, CVPR 2018
|
||||
from torch import nn
|
||||
from .initialization import initialize_resnet
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
padding = (kernel_size - 1) // 2
|
||||
self.conv = nn.Conv2d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
groups=groups,
|
||||
bias=False,
|
||||
)
|
||||
self.bn = nn.BatchNorm2d(out_planes)
|
||||
self.relu = nn.ReLU6(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv(x)
|
||||
out = self.bn(out)
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
def __init__(self, inp, oup, stride, expand_ratio):
|
||||
super(InvertedResidual, self).__init__()
|
||||
self.stride = stride
|
||||
assert stride in [1, 2]
|
||||
|
||||
hidden_dim = int(round(inp * expand_ratio))
|
||||
self.use_res_connect = self.stride == 1 and inp == oup
|
||||
|
||||
layers = []
|
||||
if expand_ratio != 1:
|
||||
# pw
|
||||
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
|
||||
layers.extend(
|
||||
[
|
||||
# dw
|
||||
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
]
|
||||
)
|
||||
self.conv = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_res_connect:
|
||||
return x + self.conv(x)
|
||||
else:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class MobileNetV2(nn.Module):
|
||||
def __init__(
|
||||
self, num_classes, width_mult, input_channel, last_channel, block_name, dropout
|
||||
):
|
||||
super(MobileNetV2, self).__init__()
|
||||
if block_name == "InvertedResidual":
|
||||
block = InvertedResidual
|
||||
else:
|
||||
raise ValueError("invalid block name : {:}".format(block_name))
|
||||
inverted_residual_setting = [
|
||||
# t, c, n, s
|
||||
[1, 16, 1, 1],
|
||||
[6, 24, 2, 2],
|
||||
[6, 32, 3, 2],
|
||||
[6, 64, 4, 2],
|
||||
[6, 96, 3, 1],
|
||||
[6, 160, 3, 2],
|
||||
[6, 320, 1, 1],
|
||||
]
|
||||
|
||||
# building first layer
|
||||
input_channel = int(input_channel * width_mult)
|
||||
self.last_channel = int(last_channel * max(1.0, width_mult))
|
||||
features = [ConvBNReLU(3, input_channel, stride=2)]
|
||||
# building inverted residual blocks
|
||||
for t, c, n, s in inverted_residual_setting:
|
||||
output_channel = int(c * width_mult)
|
||||
for i in range(n):
|
||||
stride = s if i == 0 else 1
|
||||
features.append(
|
||||
block(input_channel, output_channel, stride, expand_ratio=t)
|
||||
)
|
||||
input_channel = output_channel
|
||||
# building last several layers
|
||||
features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
|
||||
# make it nn.Sequential
|
||||
self.features = nn.Sequential(*features)
|
||||
|
||||
# building classifier
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(self.last_channel, num_classes),
|
||||
)
|
||||
self.message = "MobileNetV2 : width_mult={:}, in-C={:}, last-C={:}, block={:}, dropout={:}".format(
|
||||
width_mult, input_channel, last_channel, block_name, dropout
|
||||
)
|
||||
|
||||
# weight initialization
|
||||
self.apply(initialize_resnet)
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def forward(self, inputs):
|
||||
features = self.features(inputs)
|
||||
vectors = features.mean([2, 3])
|
||||
predicts = self.classifier(vectors)
|
||||
return features, predicts
|
||||
@@ -1,217 +0,0 @@
|
||||
# Deep Residual Learning for Image Recognition, CVPR 2016
|
||||
import torch.nn as nn
|
||||
from .initialization import initialize_resnet
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1, groups=1):
|
||||
return nn.Conv2d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
groups=groups,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(
|
||||
self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64
|
||||
):
|
||||
super(BasicBlock, self).__init__()
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError("BasicBlock only supports groups=1 and base_width=64")
|
||||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(
|
||||
self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64
|
||||
):
|
||||
super(Bottleneck, self).__init__()
|
||||
width = int(planes * (base_width / 64.0)) * groups
|
||||
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv1x1(inplanes, width)
|
||||
self.bn1 = nn.BatchNorm2d(width)
|
||||
self.conv2 = conv3x3(width, width, stride, groups)
|
||||
self.bn2 = nn.BatchNorm2d(width)
|
||||
self.conv3 = conv1x1(width, planes * self.expansion)
|
||||
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
block_name,
|
||||
layers,
|
||||
deep_stem,
|
||||
num_classes,
|
||||
zero_init_residual,
|
||||
groups,
|
||||
width_per_group,
|
||||
):
|
||||
super(ResNet, self).__init__()
|
||||
|
||||
# planes = [int(width_per_group * groups * 2 ** i) for i in range(4)]
|
||||
if block_name == "BasicBlock":
|
||||
block = BasicBlock
|
||||
elif block_name == "Bottleneck":
|
||||
block = Bottleneck
|
||||
else:
|
||||
raise ValueError("invalid block-name : {:}".format(block_name))
|
||||
|
||||
if not deep_stem:
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
else:
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
self.inplanes = 64
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(
|
||||
block, 64, layers[0], stride=1, groups=groups, base_width=width_per_group
|
||||
)
|
||||
self.layer2 = self._make_layer(
|
||||
block, 128, layers[1], stride=2, groups=groups, base_width=width_per_group
|
||||
)
|
||||
self.layer3 = self._make_layer(
|
||||
block, 256, layers[2], stride=2, groups=groups, base_width=width_per_group
|
||||
)
|
||||
self.layer4 = self._make_layer(
|
||||
block, 512, layers[3], stride=2, groups=groups, base_width=width_per_group
|
||||
)
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||
self.message = (
|
||||
"block = {:}, layers = {:}, deep_stem = {:}, num_classes = {:}".format(
|
||||
block, layers, deep_stem, num_classes
|
||||
)
|
||||
)
|
||||
|
||||
self.apply(initialize_resnet)
|
||||
|
||||
# Zero-initialize the last BN in each residual branch,
|
||||
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
||||
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
nn.init.constant_(m.bn3.weight, 0)
|
||||
elif isinstance(m, BasicBlock):
|
||||
nn.init.constant_(m.bn2.weight, 0)
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride, groups, base_width):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
if stride == 2:
|
||||
downsample = nn.Sequential(
|
||||
nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
|
||||
conv1x1(self.inplanes, planes * block.expansion, 1),
|
||||
nn.BatchNorm2d(planes * block.expansion),
|
||||
)
|
||||
elif stride == 1:
|
||||
downsample = nn.Sequential(
|
||||
conv1x1(self.inplanes, planes * block.expansion, stride),
|
||||
nn.BatchNorm2d(planes * block.expansion),
|
||||
)
|
||||
else:
|
||||
raise ValueError("invalid stride [{:}] for downsample".format(stride))
|
||||
|
||||
layers = []
|
||||
layers.append(
|
||||
block(self.inplanes, planes, stride, downsample, groups, base_width)
|
||||
)
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes, 1, None, groups, base_width))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
features = self.avgpool(x)
|
||||
features = features.view(features.size(0), -1)
|
||||
logits = self.fc(features)
|
||||
|
||||
return features, logits
|
||||
@@ -1,37 +0,0 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def additive_func(A, B):
|
||||
assert A.dim() == B.dim() and A.size(0) == B.size(0), "{:} vs {:}".format(
|
||||
A.size(), B.size()
|
||||
)
|
||||
C = min(A.size(1), B.size(1))
|
||||
if A.size(1) == B.size(1):
|
||||
return A + B
|
||||
elif A.size(1) < B.size(1):
|
||||
out = B.clone()
|
||||
out[:, :C] += A
|
||||
return out
|
||||
else:
|
||||
out = A.clone()
|
||||
out[:, :C] += B
|
||||
return out
|
||||
|
||||
|
||||
def change_key(key, value):
|
||||
def func(m):
|
||||
if hasattr(m, key):
|
||||
setattr(m, key, value)
|
||||
|
||||
return func
|
||||
|
||||
|
||||
def parse_channel_info(xstring):
|
||||
blocks = xstring.split(" ")
|
||||
blocks = [x.split("-") for x in blocks]
|
||||
blocks = [[int(_) for _ in x] for x in blocks]
|
||||
return blocks
|
||||
@@ -1,326 +0,0 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
from os import path as osp
|
||||
from typing import List, Text
|
||||
import torch
|
||||
|
||||
__all__ = [
|
||||
"change_key",
|
||||
"get_cell_based_tiny_net",
|
||||
"get_search_spaces",
|
||||
"get_cifar_models",
|
||||
"get_imagenet_models",
|
||||
"obtain_model",
|
||||
"obtain_search_model",
|
||||
"load_net_from_checkpoint",
|
||||
"CellStructure",
|
||||
"CellArchitectures",
|
||||
]
|
||||
|
||||
# useful modules
|
||||
from config_utils import dict2config
|
||||
from models.SharedUtils import change_key
|
||||
from models.cell_searchs import CellStructure, CellArchitectures
|
||||
|
||||
|
||||
# Cell-based NAS Models
|
||||
def get_cell_based_tiny_net(config):
|
||||
if isinstance(config, dict):
|
||||
config = dict2config(config, None) # to support the argument being a dict
|
||||
super_type = getattr(config, "super_type", "basic")
|
||||
group_names = ["DARTS-V1", "DARTS-V2", "GDAS", "SETN", "ENAS", "RANDOM", "generic"]
|
||||
if super_type == "basic" and config.name in group_names:
|
||||
from .cell_searchs import nas201_super_nets as nas_super_nets
|
||||
|
||||
try:
|
||||
return nas_super_nets[config.name](
|
||||
config.C,
|
||||
config.N,
|
||||
config.max_nodes,
|
||||
config.num_classes,
|
||||
config.space,
|
||||
config.affine,
|
||||
config.track_running_stats,
|
||||
)
|
||||
except:
|
||||
return nas_super_nets[config.name](
|
||||
config.C, config.N, config.max_nodes, config.num_classes, config.space
|
||||
)
|
||||
elif super_type == "search-shape":
|
||||
from .shape_searchs import GenericNAS301Model
|
||||
|
||||
genotype = CellStructure.str2structure(config.genotype)
|
||||
return GenericNAS301Model(
|
||||
config.candidate_Cs,
|
||||
config.max_num_Cs,
|
||||
genotype,
|
||||
config.num_classes,
|
||||
config.affine,
|
||||
config.track_running_stats,
|
||||
)
|
||||
elif super_type == "nasnet-super":
|
||||
from .cell_searchs import nasnet_super_nets as nas_super_nets
|
||||
|
||||
return nas_super_nets[config.name](
|
||||
config.C,
|
||||
config.N,
|
||||
config.steps,
|
||||
config.multiplier,
|
||||
config.stem_multiplier,
|
||||
config.num_classes,
|
||||
config.space,
|
||||
config.affine,
|
||||
config.track_running_stats,
|
||||
)
|
||||
elif config.name == "infer.tiny":
|
||||
from .cell_infers import TinyNetwork
|
||||
|
||||
if hasattr(config, "genotype"):
|
||||
genotype = config.genotype
|
||||
elif hasattr(config, "arch_str"):
|
||||
genotype = CellStructure.str2structure(config.arch_str)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Can not find genotype from this config : {:}".format(config)
|
||||
)
|
||||
return TinyNetwork(config.C, config.N, genotype, config.num_classes)
|
||||
elif config.name == "infer.shape.tiny":
|
||||
from .shape_infers import DynamicShapeTinyNet
|
||||
|
||||
if isinstance(config.channels, str):
|
||||
channels = tuple([int(x) for x in config.channels.split(":")])
|
||||
else:
|
||||
channels = config.channels
|
||||
genotype = CellStructure.str2structure(config.genotype)
|
||||
return DynamicShapeTinyNet(channels, genotype, config.num_classes)
|
||||
elif config.name == "infer.nasnet-cifar":
|
||||
from .cell_infers import NASNetonCIFAR
|
||||
|
||||
raise NotImplementedError
|
||||
else:
|
||||
raise ValueError("invalid network name : {:}".format(config.name))
|
||||
|
||||
|
||||
# obtain the search space, i.e., a dict mapping the operation name into a python-function for this op
|
||||
def get_search_spaces(xtype, name) -> List[Text]:
|
||||
if xtype == "cell" or xtype == "tss": # The topology search space.
|
||||
from .cell_operations import SearchSpaceNames
|
||||
|
||||
assert name in SearchSpaceNames, "invalid name [{:}] in {:}".format(
|
||||
name, SearchSpaceNames.keys()
|
||||
)
|
||||
return SearchSpaceNames[name]
|
||||
elif xtype == "sss": # The size search space.
|
||||
if name in ["nats-bench", "nats-bench-size"]:
|
||||
return {"candidates": [8, 16, 24, 32, 40, 48, 56, 64], "numbers": 5}
|
||||
else:
|
||||
raise ValueError("Invalid name : {:}".format(name))
|
||||
else:
|
||||
raise ValueError("invalid search-space type is {:}".format(xtype))
|
||||
|
||||
|
||||
def get_cifar_models(config, extra_path=None):
|
||||
super_type = getattr(config, "super_type", "basic")
|
||||
if super_type == "basic":
|
||||
from .CifarResNet import CifarResNet
|
||||
from .CifarDenseNet import DenseNet
|
||||
from .CifarWideResNet import CifarWideResNet
|
||||
|
||||
if config.arch == "resnet":
|
||||
return CifarResNet(
|
||||
config.module, config.depth, config.class_num, config.zero_init_residual
|
||||
)
|
||||
elif config.arch == "densenet":
|
||||
return DenseNet(
|
||||
config.growthRate,
|
||||
config.depth,
|
||||
config.reduction,
|
||||
config.class_num,
|
||||
config.bottleneck,
|
||||
)
|
||||
elif config.arch == "wideresnet":
|
||||
return CifarWideResNet(
|
||||
config.depth, config.wide_factor, config.class_num, config.dropout
|
||||
)
|
||||
else:
|
||||
raise ValueError("invalid module type : {:}".format(config.arch))
|
||||
elif super_type.startswith("infer"):
|
||||
from .shape_infers import InferWidthCifarResNet
|
||||
from .shape_infers import InferDepthCifarResNet
|
||||
from .shape_infers import InferCifarResNet
|
||||
from .cell_infers import NASNetonCIFAR
|
||||
|
||||
assert len(super_type.split("-")) == 2, "invalid super_type : {:}".format(
|
||||
super_type
|
||||
)
|
||||
infer_mode = super_type.split("-")[1]
|
||||
if infer_mode == "width":
|
||||
return InferWidthCifarResNet(
|
||||
config.module,
|
||||
config.depth,
|
||||
config.xchannels,
|
||||
config.class_num,
|
||||
config.zero_init_residual,
|
||||
)
|
||||
elif infer_mode == "depth":
|
||||
return InferDepthCifarResNet(
|
||||
config.module,
|
||||
config.depth,
|
||||
config.xblocks,
|
||||
config.class_num,
|
||||
config.zero_init_residual,
|
||||
)
|
||||
elif infer_mode == "shape":
|
||||
return InferCifarResNet(
|
||||
config.module,
|
||||
config.depth,
|
||||
config.xblocks,
|
||||
config.xchannels,
|
||||
config.class_num,
|
||||
config.zero_init_residual,
|
||||
)
|
||||
elif infer_mode == "nasnet.cifar":
|
||||
genotype = config.genotype
|
||||
if extra_path is not None: # reload genotype by extra_path
|
||||
if not osp.isfile(extra_path):
|
||||
raise ValueError("invalid extra_path : {:}".format(extra_path))
|
||||
xdata = torch.load(extra_path)
|
||||
current_epoch = xdata["epoch"]
|
||||
genotype = xdata["genotypes"][current_epoch - 1]
|
||||
C = config.C if hasattr(config, "C") else config.ichannel
|
||||
N = config.N if hasattr(config, "N") else config.layers
|
||||
return NASNetonCIFAR(
|
||||
C, N, config.stem_multi, config.class_num, genotype, config.auxiliary
|
||||
)
|
||||
else:
|
||||
raise ValueError("invalid infer-mode : {:}".format(infer_mode))
|
||||
else:
|
||||
raise ValueError("invalid super-type : {:}".format(super_type))
|
||||
|
||||
|
||||
def get_imagenet_models(config):
|
||||
super_type = getattr(config, "super_type", "basic")
|
||||
if super_type == "basic":
|
||||
from .ImageNet_ResNet import ResNet
|
||||
from .ImageNet_MobileNetV2 import MobileNetV2
|
||||
|
||||
if config.arch == "resnet":
|
||||
return ResNet(
|
||||
config.block_name,
|
||||
config.layers,
|
||||
config.deep_stem,
|
||||
config.class_num,
|
||||
config.zero_init_residual,
|
||||
config.groups,
|
||||
config.width_per_group,
|
||||
)
|
||||
elif config.arch == "mobilenet_v2":
|
||||
return MobileNetV2(
|
||||
config.class_num,
|
||||
config.width_multi,
|
||||
config.input_channel,
|
||||
config.last_channel,
|
||||
"InvertedResidual",
|
||||
config.dropout,
|
||||
)
|
||||
else:
|
||||
raise ValueError("invalid arch : {:}".format(config.arch))
|
||||
elif super_type.startswith("infer"): # NAS searched architecture
|
||||
assert len(super_type.split("-")) == 2, "invalid super_type : {:}".format(
|
||||
super_type
|
||||
)
|
||||
infer_mode = super_type.split("-")[1]
|
||||
if infer_mode == "shape":
|
||||
from .shape_infers import InferImagenetResNet
|
||||
from .shape_infers import InferMobileNetV2
|
||||
|
||||
if config.arch == "resnet":
|
||||
return InferImagenetResNet(
|
||||
config.block_name,
|
||||
config.layers,
|
||||
config.xblocks,
|
||||
config.xchannels,
|
||||
config.deep_stem,
|
||||
config.class_num,
|
||||
config.zero_init_residual,
|
||||
)
|
||||
elif config.arch == "MobileNetV2":
|
||||
return InferMobileNetV2(
|
||||
config.class_num, config.xchannels, config.xblocks, config.dropout
|
||||
)
|
||||
else:
|
||||
raise ValueError("invalid arch-mode : {:}".format(config.arch))
|
||||
else:
|
||||
raise ValueError("invalid infer-mode : {:}".format(infer_mode))
|
||||
else:
|
||||
raise ValueError("invalid super-type : {:}".format(super_type))
|
||||
|
||||
|
||||
# Try to obtain the network by config.
|
||||
def obtain_model(config, extra_path=None):
|
||||
if config.dataset == "cifar":
|
||||
return get_cifar_models(config, extra_path)
|
||||
elif config.dataset == "imagenet":
|
||||
return get_imagenet_models(config)
|
||||
else:
|
||||
raise ValueError("invalid dataset in the model config : {:}".format(config))
|
||||
|
||||
|
||||
def obtain_search_model(config):
|
||||
if config.dataset == "cifar":
|
||||
if config.arch == "resnet":
|
||||
from .shape_searchs import SearchWidthCifarResNet
|
||||
from .shape_searchs import SearchDepthCifarResNet
|
||||
from .shape_searchs import SearchShapeCifarResNet
|
||||
|
||||
if config.search_mode == "width":
|
||||
return SearchWidthCifarResNet(
|
||||
config.module, config.depth, config.class_num
|
||||
)
|
||||
elif config.search_mode == "depth":
|
||||
return SearchDepthCifarResNet(
|
||||
config.module, config.depth, config.class_num
|
||||
)
|
||||
elif config.search_mode == "shape":
|
||||
return SearchShapeCifarResNet(
|
||||
config.module, config.depth, config.class_num
|
||||
)
|
||||
else:
|
||||
raise ValueError("invalid search mode : {:}".format(config.search_mode))
|
||||
elif config.arch == "simres":
|
||||
from .shape_searchs import SearchWidthSimResNet
|
||||
|
||||
if config.search_mode == "width":
|
||||
return SearchWidthSimResNet(config.depth, config.class_num)
|
||||
else:
|
||||
raise ValueError("invalid search mode : {:}".format(config.search_mode))
|
||||
else:
|
||||
raise ValueError(
|
||||
"invalid arch : {:} for dataset [{:}]".format(
|
||||
config.arch, config.dataset
|
||||
)
|
||||
)
|
||||
elif config.dataset == "imagenet":
|
||||
from .shape_searchs import SearchShapeImagenetResNet
|
||||
|
||||
assert config.search_mode == "shape", "invalid search-mode : {:}".format(
|
||||
config.search_mode
|
||||
)
|
||||
if config.arch == "resnet":
|
||||
return SearchShapeImagenetResNet(
|
||||
config.block_name, config.layers, config.deep_stem, config.class_num
|
||||
)
|
||||
else:
|
||||
raise ValueError("invalid model config : {:}".format(config))
|
||||
else:
|
||||
raise ValueError("invalid dataset in the model config : {:}".format(config))
|
||||
|
||||
|
||||
def load_net_from_checkpoint(checkpoint):
|
||||
assert osp.isfile(checkpoint), "checkpoint {:} does not exist".format(checkpoint)
|
||||
checkpoint = torch.load(checkpoint)
|
||||
model_config = dict2config(checkpoint["model-config"], None)
|
||||
model = obtain_model(model_config)
|
||||
model.load_state_dict(checkpoint["base-model"])
|
||||
return model
|
||||
@@ -1,5 +0,0 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
from .tiny_network import TinyNetwork
|
||||
from .nasnet_cifar import NASNetonCIFAR
|
||||
@@ -1,155 +0,0 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
|
||||
from models.cell_operations import OPS
|
||||
|
||||
|
||||
# Cell for NAS-Bench-201
|
||||
class InferCell(nn.Module):
|
||||
def __init__(
|
||||
self, genotype, C_in, C_out, stride, affine=True, track_running_stats=True
|
||||
):
|
||||
super(InferCell, self).__init__()
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
self.node_IN = []
|
||||
self.node_IX = []
|
||||
self.genotype = deepcopy(genotype)
|
||||
for i in range(1, len(genotype)):
|
||||
node_info = genotype[i - 1]
|
||||
cur_index = []
|
||||
cur_innod = []
|
||||
for (op_name, op_in) in node_info:
|
||||
if op_in == 0:
|
||||
layer = OPS[op_name](
|
||||
C_in, C_out, stride, affine, track_running_stats
|
||||
)
|
||||
else:
|
||||
layer = OPS[op_name](C_out, C_out, 1, affine, track_running_stats)
|
||||
cur_index.append(len(self.layers))
|
||||
cur_innod.append(op_in)
|
||||
self.layers.append(layer)
|
||||
self.node_IX.append(cur_index)
|
||||
self.node_IN.append(cur_innod)
|
||||
self.nodes = len(genotype)
|
||||
self.in_dim = C_in
|
||||
self.out_dim = C_out
|
||||
|
||||
def extra_repr(self):
|
||||
string = "info :: nodes={nodes}, inC={in_dim}, outC={out_dim}".format(
|
||||
**self.__dict__
|
||||
)
|
||||
laystr = []
|
||||
for i, (node_layers, node_innods) in enumerate(zip(self.node_IX, self.node_IN)):
|
||||
y = [
|
||||
"I{:}-L{:}".format(_ii, _il)
|
||||
for _il, _ii in zip(node_layers, node_innods)
|
||||
]
|
||||
x = "{:}<-({:})".format(i + 1, ",".join(y))
|
||||
laystr.append(x)
|
||||
return (
|
||||
string
|
||||
+ ", [{:}]".format(" | ".join(laystr))
|
||||
+ ", {:}".format(self.genotype.tostr())
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
nodes = [inputs]
|
||||
for i, (node_layers, node_innods) in enumerate(zip(self.node_IX, self.node_IN)):
|
||||
node_feature = sum(
|
||||
self.layers[_il](nodes[_ii])
|
||||
for _il, _ii in zip(node_layers, node_innods)
|
||||
)
|
||||
nodes.append(node_feature)
|
||||
return nodes[-1]
|
||||
|
||||
|
||||
# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018
|
||||
class NASNetInferCell(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
genotype,
|
||||
C_prev_prev,
|
||||
C_prev,
|
||||
C,
|
||||
reduction,
|
||||
reduction_prev,
|
||||
affine,
|
||||
track_running_stats,
|
||||
):
|
||||
super(NASNetInferCell, self).__init__()
|
||||
self.reduction = reduction
|
||||
if reduction_prev:
|
||||
self.preprocess0 = OPS["skip_connect"](
|
||||
C_prev_prev, C, 2, affine, track_running_stats
|
||||
)
|
||||
else:
|
||||
self.preprocess0 = OPS["nor_conv_1x1"](
|
||||
C_prev_prev, C, 1, affine, track_running_stats
|
||||
)
|
||||
self.preprocess1 = OPS["nor_conv_1x1"](
|
||||
C_prev, C, 1, affine, track_running_stats
|
||||
)
|
||||
|
||||
if not reduction:
|
||||
nodes, concats = genotype["normal"], genotype["normal_concat"]
|
||||
else:
|
||||
nodes, concats = genotype["reduce"], genotype["reduce_concat"]
|
||||
self._multiplier = len(concats)
|
||||
self._concats = concats
|
||||
self._steps = len(nodes)
|
||||
self._nodes = nodes
|
||||
self.edges = nn.ModuleDict()
|
||||
for i, node in enumerate(nodes):
|
||||
for in_node in node:
|
||||
name, j = in_node[0], in_node[1]
|
||||
stride = 2 if reduction and j < 2 else 1
|
||||
node_str = "{:}<-{:}".format(i + 2, j)
|
||||
self.edges[node_str] = OPS[name](
|
||||
C, C, stride, affine, track_running_stats
|
||||
)
|
||||
|
||||
# [TODO] to support drop_prob in this function..
|
||||
def forward(self, s0, s1, unused_drop_prob):
|
||||
s0 = self.preprocess0(s0)
|
||||
s1 = self.preprocess1(s1)
|
||||
|
||||
states = [s0, s1]
|
||||
for i, node in enumerate(self._nodes):
|
||||
clist = []
|
||||
for in_node in node:
|
||||
name, j = in_node[0], in_node[1]
|
||||
node_str = "{:}<-{:}".format(i + 2, j)
|
||||
op = self.edges[node_str]
|
||||
clist.append(op(states[j]))
|
||||
states.append(sum(clist))
|
||||
return torch.cat([states[x] for x in self._concats], dim=1)
|
||||
|
||||
|
||||
class AuxiliaryHeadCIFAR(nn.Module):
|
||||
def __init__(self, C, num_classes):
|
||||
"""assuming input size 8x8"""
|
||||
super(AuxiliaryHeadCIFAR, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.AvgPool2d(
|
||||
5, stride=3, padding=0, count_include_pad=False
|
||||
), # image size = 2 x 2
|
||||
nn.Conv2d(C, 128, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(128, 768, 2, bias=False),
|
||||
nn.BatchNorm2d(768),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
self.classifier = nn.Linear(768, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.classifier(x.view(x.size(0), -1))
|
||||
return x
|
||||
@@ -1,117 +0,0 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from .cells import NASNetInferCell as InferCell, AuxiliaryHeadCIFAR
|
||||
|
||||
|
||||
# The macro structure is based on NASNet
|
||||
class NASNetonCIFAR(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
C,
|
||||
N,
|
||||
stem_multiplier,
|
||||
num_classes,
|
||||
genotype,
|
||||
auxiliary,
|
||||
affine=True,
|
||||
track_running_stats=True,
|
||||
):
|
||||
super(NASNetonCIFAR, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C * stem_multiplier, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C * stem_multiplier),
|
||||
)
|
||||
|
||||
# config for each layer
|
||||
layer_channels = (
|
||||
[C] * N + [C * 2] + [C * 2] * (N - 1) + [C * 4] + [C * 4] * (N - 1)
|
||||
)
|
||||
layer_reductions = (
|
||||
[False] * N + [True] + [False] * (N - 1) + [True] + [False] * (N - 1)
|
||||
)
|
||||
|
||||
C_prev_prev, C_prev, C_curr, reduction_prev = (
|
||||
C * stem_multiplier,
|
||||
C * stem_multiplier,
|
||||
C,
|
||||
False,
|
||||
)
|
||||
self.auxiliary_index = None
|
||||
self.auxiliary_head = None
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(
|
||||
zip(layer_channels, layer_reductions)
|
||||
):
|
||||
cell = InferCell(
|
||||
genotype,
|
||||
C_prev_prev,
|
||||
C_prev,
|
||||
C_curr,
|
||||
reduction,
|
||||
reduction_prev,
|
||||
affine,
|
||||
track_running_stats,
|
||||
)
|
||||
self.cells.append(cell)
|
||||
C_prev_prev, C_prev, reduction_prev = (
|
||||
C_prev,
|
||||
cell._multiplier * C_curr,
|
||||
reduction,
|
||||
)
|
||||
if reduction and C_curr == C * 4 and auxiliary:
|
||||
self.auxiliary_head = AuxiliaryHeadCIFAR(C_prev, num_classes)
|
||||
self.auxiliary_index = index
|
||||
self._Layer = len(self.cells)
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self.drop_path_prob = -1
|
||||
|
||||
def update_drop_path(self, drop_path_prob):
|
||||
self.drop_path_prob = drop_path_prob
|
||||
|
||||
def auxiliary_param(self):
|
||||
if self.auxiliary_head is None:
|
||||
return []
|
||||
else:
|
||||
return list(self.auxiliary_head.parameters())
|
||||
|
||||
def get_message(self):
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += "\n {:02d}/{:02d} :: {:}".format(
|
||||
i, len(self.cells), cell.extra_repr()
|
||||
)
|
||||
return string
|
||||
|
||||
def extra_repr(self):
|
||||
return "{name}(C={_C}, N={_layerN}, L={_Layer})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
stem_feature, logits_aux = self.stem(inputs), None
|
||||
cell_results = [stem_feature, stem_feature]
|
||||
for i, cell in enumerate(self.cells):
|
||||
cell_feature = cell(cell_results[-2], cell_results[-1], self.drop_path_prob)
|
||||
cell_results.append(cell_feature)
|
||||
if (
|
||||
self.auxiliary_index is not None
|
||||
and i == self.auxiliary_index
|
||||
and self.training
|
||||
):
|
||||
logits_aux = self.auxiliary_head(cell_results[-1])
|
||||
out = self.lastact(cell_results[-1])
|
||||
out = self.global_pooling(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
if logits_aux is None:
|
||||
return out, logits
|
||||
else:
|
||||
return out, [logits, logits_aux]
|
||||
@@ -1,63 +0,0 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import torch.nn as nn
|
||||
from ..cell_operations import ResNetBasicblock
|
||||
from .cells import InferCell
|
||||
|
||||
|
||||
# The macro structure for architectures in NAS-Bench-201
|
||||
class TinyNetwork(nn.Module):
|
||||
def __init__(self, C, N, genotype, num_classes):
|
||||
super(TinyNetwork, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C)
|
||||
)
|
||||
|
||||
layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
|
||||
C_prev = C
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(
|
||||
zip(layer_channels, layer_reductions)
|
||||
):
|
||||
if reduction:
|
||||
cell = ResNetBasicblock(C_prev, C_curr, 2, True)
|
||||
else:
|
||||
cell = InferCell(genotype, C_prev, C_curr, 1)
|
||||
self.cells.append(cell)
|
||||
C_prev = cell.out_dim
|
||||
self._Layer = len(self.cells)
|
||||
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
|
||||
def get_message(self):
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += "\n {:02d}/{:02d} :: {:}".format(
|
||||
i, len(self.cells), cell.extra_repr()
|
||||
)
|
||||
return string
|
||||
|
||||
def extra_repr(self):
|
||||
return "{name}(C={_C}, N={_layerN}, L={_Layer})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
feature = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
feature = cell(feature)
|
||||
|
||||
out = self.lastact(feature)
|
||||
out = self.global_pooling(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
return out, logits
|
||||
@@ -1,553 +0,0 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
__all__ = ["OPS", "RAW_OP_CLASSES", "ResNetBasicblock", "SearchSpaceNames"]
|
||||
|
||||
OPS = {
|
||||
"none": lambda C_in, C_out, stride, affine, track_running_stats: Zero(
|
||||
C_in, C_out, stride
|
||||
),
|
||||
"avg_pool_3x3": lambda C_in, C_out, stride, affine, track_running_stats: POOLING(
|
||||
C_in, C_out, stride, "avg", affine, track_running_stats
|
||||
),
|
||||
"max_pool_3x3": lambda C_in, C_out, stride, affine, track_running_stats: POOLING(
|
||||
C_in, C_out, stride, "max", affine, track_running_stats
|
||||
),
|
||||
"nor_conv_7x7": lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(
|
||||
C_in,
|
||||
C_out,
|
||||
(7, 7),
|
||||
(stride, stride),
|
||||
(3, 3),
|
||||
(1, 1),
|
||||
affine,
|
||||
track_running_stats,
|
||||
),
|
||||
"nor_conv_3x3": lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(
|
||||
C_in,
|
||||
C_out,
|
||||
(3, 3),
|
||||
(stride, stride),
|
||||
(1, 1),
|
||||
(1, 1),
|
||||
affine,
|
||||
track_running_stats,
|
||||
),
|
||||
"nor_conv_1x1": lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(
|
||||
C_in,
|
||||
C_out,
|
||||
(1, 1),
|
||||
(stride, stride),
|
||||
(0, 0),
|
||||
(1, 1),
|
||||
affine,
|
||||
track_running_stats,
|
||||
),
|
||||
"dua_sepc_3x3": lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(
|
||||
C_in,
|
||||
C_out,
|
||||
(3, 3),
|
||||
(stride, stride),
|
||||
(1, 1),
|
||||
(1, 1),
|
||||
affine,
|
||||
track_running_stats,
|
||||
),
|
||||
"dua_sepc_5x5": lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(
|
||||
C_in,
|
||||
C_out,
|
||||
(5, 5),
|
||||
(stride, stride),
|
||||
(2, 2),
|
||||
(1, 1),
|
||||
affine,
|
||||
track_running_stats,
|
||||
),
|
||||
"dil_sepc_3x3": lambda C_in, C_out, stride, affine, track_running_stats: SepConv(
|
||||
C_in,
|
||||
C_out,
|
||||
(3, 3),
|
||||
(stride, stride),
|
||||
(2, 2),
|
||||
(2, 2),
|
||||
affine,
|
||||
track_running_stats,
|
||||
),
|
||||
"dil_sepc_5x5": lambda C_in, C_out, stride, affine, track_running_stats: SepConv(
|
||||
C_in,
|
||||
C_out,
|
||||
(5, 5),
|
||||
(stride, stride),
|
||||
(4, 4),
|
||||
(2, 2),
|
||||
affine,
|
||||
track_running_stats,
|
||||
),
|
||||
"skip_connect": lambda C_in, C_out, stride, affine, track_running_stats: Identity()
|
||||
if stride == 1 and C_in == C_out
|
||||
else FactorizedReduce(C_in, C_out, stride, affine, track_running_stats),
|
||||
}
|
||||
|
||||
CONNECT_NAS_BENCHMARK = ["none", "skip_connect", "nor_conv_3x3"]
|
||||
NAS_BENCH_201 = ["none", "skip_connect", "nor_conv_1x1", "nor_conv_3x3", "avg_pool_3x3"]
|
||||
DARTS_SPACE = [
|
||||
"none",
|
||||
"skip_connect",
|
||||
"dua_sepc_3x3",
|
||||
"dua_sepc_5x5",
|
||||
"dil_sepc_3x3",
|
||||
"dil_sepc_5x5",
|
||||
"avg_pool_3x3",
|
||||
"max_pool_3x3",
|
||||
]
|
||||
|
||||
SearchSpaceNames = {
|
||||
"connect-nas": CONNECT_NAS_BENCHMARK,
|
||||
"nats-bench": NAS_BENCH_201,
|
||||
"nas-bench-201": NAS_BENCH_201,
|
||||
"darts": DARTS_SPACE,
|
||||
}
|
||||
|
||||
|
||||
class ReLUConvBN(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
C_in,
|
||||
C_out,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
affine,
|
||||
track_running_stats=True,
|
||||
):
|
||||
super(ReLUConvBN, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(
|
||||
C_in,
|
||||
C_out,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
bias=not affine,
|
||||
),
|
||||
nn.BatchNorm2d(
|
||||
C_out, affine=affine, track_running_stats=track_running_stats
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class SepConv(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
C_in,
|
||||
C_out,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
affine,
|
||||
track_running_stats=True,
|
||||
):
|
||||
super(SepConv, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(
|
||||
C_in,
|
||||
C_in,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=C_in,
|
||||
bias=False,
|
||||
),
|
||||
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=not affine),
|
||||
nn.BatchNorm2d(
|
||||
C_out, affine=affine, track_running_stats=track_running_stats
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class DualSepConv(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
C_in,
|
||||
C_out,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
affine,
|
||||
track_running_stats=True,
|
||||
):
|
||||
super(DualSepConv, self).__init__()
|
||||
self.op_a = SepConv(
|
||||
C_in,
|
||||
C_in,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
affine,
|
||||
track_running_stats,
|
||||
)
|
||||
self.op_b = SepConv(
|
||||
C_in, C_out, kernel_size, 1, padding, dilation, affine, track_running_stats
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.op_a(x)
|
||||
x = self.op_b(x)
|
||||
return x
|
||||
|
||||
|
||||
class ResNetBasicblock(nn.Module):
|
||||
def __init__(self, inplanes, planes, stride, affine=True, track_running_stats=True):
|
||||
super(ResNetBasicblock, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
self.conv_a = ReLUConvBN(
|
||||
inplanes, planes, 3, stride, 1, 1, affine, track_running_stats
|
||||
)
|
||||
self.conv_b = ReLUConvBN(
|
||||
planes, planes, 3, 1, 1, 1, affine, track_running_stats
|
||||
)
|
||||
if stride == 2:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
|
||||
nn.Conv2d(
|
||||
inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False
|
||||
),
|
||||
)
|
||||
elif inplanes != planes:
|
||||
self.downsample = ReLUConvBN(
|
||||
inplanes, planes, 1, 1, 0, 1, affine, track_running_stats
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.in_dim = inplanes
|
||||
self.out_dim = planes
|
||||
self.stride = stride
|
||||
self.num_conv = 2
|
||||
|
||||
def extra_repr(self):
|
||||
string = "{name}(inC={in_dim}, outC={out_dim}, stride={stride})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
return string
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
basicblock = self.conv_a(inputs)
|
||||
basicblock = self.conv_b(basicblock)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
return residual + basicblock
|
||||
|
||||
|
||||
class POOLING(nn.Module):
|
||||
def __init__(
|
||||
self, C_in, C_out, stride, mode, affine=True, track_running_stats=True
|
||||
):
|
||||
super(POOLING, self).__init__()
|
||||
if C_in == C_out:
|
||||
self.preprocess = None
|
||||
else:
|
||||
self.preprocess = ReLUConvBN(
|
||||
C_in, C_out, 1, 1, 0, 1, affine, track_running_stats
|
||||
)
|
||||
if mode == "avg":
|
||||
self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False)
|
||||
elif mode == "max":
|
||||
self.op = nn.MaxPool2d(3, stride=stride, padding=1)
|
||||
else:
|
||||
raise ValueError("Invalid mode={:} in POOLING".format(mode))
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.preprocess:
|
||||
x = self.preprocess(inputs)
|
||||
else:
|
||||
x = inputs
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
def __init__(self):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class Zero(nn.Module):
|
||||
def __init__(self, C_in, C_out, stride):
|
||||
super(Zero, self).__init__()
|
||||
self.C_in = C_in
|
||||
self.C_out = C_out
|
||||
self.stride = stride
|
||||
self.is_zero = True
|
||||
|
||||
def forward(self, x):
|
||||
if self.C_in == self.C_out:
|
||||
if self.stride == 1:
|
||||
return x.mul(0.0)
|
||||
else:
|
||||
return x[:, :, :: self.stride, :: self.stride].mul(0.0)
|
||||
else:
|
||||
shape = list(x.shape)
|
||||
shape[1] = self.C_out
|
||||
zeros = x.new_zeros(shape, dtype=x.dtype, device=x.device)
|
||||
return zeros
|
||||
|
||||
def extra_repr(self):
|
||||
return "C_in={C_in}, C_out={C_out}, stride={stride}".format(**self.__dict__)
|
||||
|
||||
|
||||
class FactorizedReduce(nn.Module):
|
||||
def __init__(self, C_in, C_out, stride, affine, track_running_stats):
|
||||
super(FactorizedReduce, self).__init__()
|
||||
self.stride = stride
|
||||
self.C_in = C_in
|
||||
self.C_out = C_out
|
||||
self.relu = nn.ReLU(inplace=False)
|
||||
if stride == 2:
|
||||
# assert C_out % 2 == 0, 'C_out : {:}'.format(C_out)
|
||||
C_outs = [C_out // 2, C_out - C_out // 2]
|
||||
self.convs = nn.ModuleList()
|
||||
for i in range(2):
|
||||
self.convs.append(
|
||||
nn.Conv2d(
|
||||
C_in, C_outs[i], 1, stride=stride, padding=0, bias=not affine
|
||||
)
|
||||
)
|
||||
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
|
||||
elif stride == 1:
|
||||
self.conv = nn.Conv2d(
|
||||
C_in, C_out, 1, stride=stride, padding=0, bias=not affine
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid stride : {:}".format(stride))
|
||||
self.bn = nn.BatchNorm2d(
|
||||
C_out, affine=affine, track_running_stats=track_running_stats
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if self.stride == 2:
|
||||
x = self.relu(x)
|
||||
y = self.pad(x)
|
||||
out = torch.cat([self.convs[0](x), self.convs[1](y[:, :, 1:, 1:])], dim=1)
|
||||
else:
|
||||
out = self.conv(x)
|
||||
out = self.bn(out)
|
||||
return out
|
||||
|
||||
def extra_repr(self):
|
||||
return "C_in={C_in}, C_out={C_out}, stride={stride}".format(**self.__dict__)
|
||||
|
||||
|
||||
# Auto-ReID: Searching for a Part-Aware ConvNet for Person Re-Identification, ICCV 2019
|
||||
class PartAwareOp(nn.Module):
|
||||
def __init__(self, C_in, C_out, stride, part=4):
|
||||
super().__init__()
|
||||
self.part = 4
|
||||
self.hidden = C_in // 3
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.local_conv_list = nn.ModuleList()
|
||||
for i in range(self.part):
|
||||
self.local_conv_list.append(
|
||||
nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(C_in, self.hidden, 1),
|
||||
nn.BatchNorm2d(self.hidden, affine=True),
|
||||
)
|
||||
)
|
||||
self.W_K = nn.Linear(self.hidden, self.hidden)
|
||||
self.W_Q = nn.Linear(self.hidden, self.hidden)
|
||||
|
||||
if stride == 2:
|
||||
self.last = FactorizedReduce(C_in + self.hidden, C_out, 2)
|
||||
elif stride == 1:
|
||||
self.last = FactorizedReduce(C_in + self.hidden, C_out, 1)
|
||||
else:
|
||||
raise ValueError("Invalid Stride : {:}".format(stride))
|
||||
|
||||
def forward(self, x):
|
||||
batch, C, H, W = x.size()
|
||||
assert H >= self.part, "input size too small : {:} vs {:}".format(
|
||||
x.shape, self.part
|
||||
)
|
||||
IHs = [0]
|
||||
for i in range(self.part):
|
||||
IHs.append(min(H, int((i + 1) * (float(H) / self.part))))
|
||||
local_feat_list = []
|
||||
for i in range(self.part):
|
||||
feature = x[:, :, IHs[i] : IHs[i + 1], :]
|
||||
xfeax = self.avg_pool(feature)
|
||||
xfea = self.local_conv_list[i](xfeax)
|
||||
local_feat_list.append(xfea)
|
||||
part_feature = torch.cat(local_feat_list, dim=2).view(batch, -1, self.part)
|
||||
part_feature = part_feature.transpose(1, 2).contiguous()
|
||||
part_K = self.W_K(part_feature)
|
||||
part_Q = self.W_Q(part_feature).transpose(1, 2).contiguous()
|
||||
weight_att = torch.bmm(part_K, part_Q)
|
||||
attention = torch.softmax(weight_att, dim=2)
|
||||
aggreateF = torch.bmm(attention, part_feature).transpose(1, 2).contiguous()
|
||||
features = []
|
||||
for i in range(self.part):
|
||||
feature = aggreateF[:, :, i : i + 1].expand(
|
||||
batch, self.hidden, IHs[i + 1] - IHs[i]
|
||||
)
|
||||
feature = feature.view(batch, self.hidden, IHs[i + 1] - IHs[i], 1)
|
||||
features.append(feature)
|
||||
features = torch.cat(features, dim=2).expand(batch, self.hidden, H, W)
|
||||
final_fea = torch.cat((x, features), dim=1)
|
||||
outputs = self.last(final_fea)
|
||||
return outputs
|
||||
|
||||
|
||||
def drop_path(x, drop_prob):
|
||||
if drop_prob > 0.0:
|
||||
keep_prob = 1.0 - drop_prob
|
||||
mask = x.new_zeros(x.size(0), 1, 1, 1)
|
||||
mask = mask.bernoulli_(keep_prob)
|
||||
x = torch.div(x, keep_prob)
|
||||
x.mul_(mask)
|
||||
return x
|
||||
|
||||
|
||||
# Searching for A Robust Neural Architecture in Four GPU Hours
|
||||
class GDAS_Reduction_Cell(nn.Module):
|
||||
def __init__(
|
||||
self, C_prev_prev, C_prev, C, reduction_prev, affine, track_running_stats
|
||||
):
|
||||
super(GDAS_Reduction_Cell, self).__init__()
|
||||
if reduction_prev:
|
||||
self.preprocess0 = FactorizedReduce(
|
||||
C_prev_prev, C, 2, affine, track_running_stats
|
||||
)
|
||||
else:
|
||||
self.preprocess0 = ReLUConvBN(
|
||||
C_prev_prev, C, 1, 1, 0, 1, affine, track_running_stats
|
||||
)
|
||||
self.preprocess1 = ReLUConvBN(
|
||||
C_prev, C, 1, 1, 0, 1, affine, track_running_stats
|
||||
)
|
||||
|
||||
self.reduction = True
|
||||
self.ops1 = nn.ModuleList(
|
||||
[
|
||||
nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(
|
||||
C,
|
||||
C,
|
||||
(1, 3),
|
||||
stride=(1, 2),
|
||||
padding=(0, 1),
|
||||
groups=8,
|
||||
bias=not affine,
|
||||
),
|
||||
nn.Conv2d(
|
||||
C,
|
||||
C,
|
||||
(3, 1),
|
||||
stride=(2, 1),
|
||||
padding=(1, 0),
|
||||
groups=8,
|
||||
bias=not affine,
|
||||
),
|
||||
nn.BatchNorm2d(
|
||||
C, affine=affine, track_running_stats=track_running_stats
|
||||
),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C, C, 1, stride=1, padding=0, bias=not affine),
|
||||
nn.BatchNorm2d(
|
||||
C, affine=affine, track_running_stats=track_running_stats
|
||||
),
|
||||
),
|
||||
nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(
|
||||
C,
|
||||
C,
|
||||
(1, 3),
|
||||
stride=(1, 2),
|
||||
padding=(0, 1),
|
||||
groups=8,
|
||||
bias=not affine,
|
||||
),
|
||||
nn.Conv2d(
|
||||
C,
|
||||
C,
|
||||
(3, 1),
|
||||
stride=(2, 1),
|
||||
padding=(1, 0),
|
||||
groups=8,
|
||||
bias=not affine,
|
||||
),
|
||||
nn.BatchNorm2d(
|
||||
C, affine=affine, track_running_stats=track_running_stats
|
||||
),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C, C, 1, stride=1, padding=0, bias=not affine),
|
||||
nn.BatchNorm2d(
|
||||
C, affine=affine, track_running_stats=track_running_stats
|
||||
),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
self.ops2 = nn.ModuleList(
|
||||
[
|
||||
nn.Sequential(
|
||||
nn.MaxPool2d(3, stride=2, padding=1),
|
||||
nn.BatchNorm2d(
|
||||
C, affine=affine, track_running_stats=track_running_stats
|
||||
),
|
||||
),
|
||||
nn.Sequential(
|
||||
nn.MaxPool2d(3, stride=2, padding=1),
|
||||
nn.BatchNorm2d(
|
||||
C, affine=affine, track_running_stats=track_running_stats
|
||||
),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
@property
|
||||
def multiplier(self):
|
||||
return 4
|
||||
|
||||
def forward(self, s0, s1, drop_prob=-1):
|
||||
s0 = self.preprocess0(s0)
|
||||
s1 = self.preprocess1(s1)
|
||||
|
||||
X0 = self.ops1[0](s0)
|
||||
X1 = self.ops1[1](s1)
|
||||
if self.training and drop_prob > 0.0:
|
||||
X0, X1 = drop_path(X0, drop_prob), drop_path(X1, drop_prob)
|
||||
|
||||
# X2 = self.ops2[0] (X0+X1)
|
||||
X2 = self.ops2[0](s0)
|
||||
X3 = self.ops2[1](s1)
|
||||
if self.training and drop_prob > 0.0:
|
||||
X2, X3 = drop_path(X2, drop_prob), drop_path(X3, drop_prob)
|
||||
return torch.cat([X0, X1, X2, X3], dim=1)
|
||||
|
||||
|
||||
# To manage the useful classes in this file.
|
||||
RAW_OP_CLASSES = {"gdas_reduction": GDAS_Reduction_Cell}
|
||||
@@ -1,33 +0,0 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
# The macro structure is defined in NAS-Bench-201
|
||||
from .search_model_darts import TinyNetworkDarts
|
||||
from .search_model_gdas import TinyNetworkGDAS
|
||||
from .search_model_setn import TinyNetworkSETN
|
||||
from .search_model_enas import TinyNetworkENAS
|
||||
from .search_model_random import TinyNetworkRANDOM
|
||||
from .generic_model import GenericNAS201Model
|
||||
from .genotypes import Structure as CellStructure, architectures as CellArchitectures
|
||||
|
||||
# NASNet-based macro structure
|
||||
from .search_model_gdas_nasnet import NASNetworkGDAS
|
||||
from .search_model_gdas_frc_nasnet import NASNetworkGDAS_FRC
|
||||
from .search_model_darts_nasnet import NASNetworkDARTS
|
||||
|
||||
|
||||
nas201_super_nets = {
|
||||
"DARTS-V1": TinyNetworkDarts,
|
||||
"DARTS-V2": TinyNetworkDarts,
|
||||
"GDAS": TinyNetworkGDAS,
|
||||
"SETN": TinyNetworkSETN,
|
||||
"ENAS": TinyNetworkENAS,
|
||||
"RANDOM": TinyNetworkRANDOM,
|
||||
"generic": GenericNAS201Model,
|
||||
}
|
||||
|
||||
nasnet_super_nets = {
|
||||
"GDAS": NASNetworkGDAS,
|
||||
"GDAS_FRC": NASNetworkGDAS_FRC,
|
||||
"DARTS": NASNetworkDARTS,
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import torch
|
||||
from search_model_enas_utils import Controller
|
||||
|
||||
|
||||
def main():
|
||||
controller = Controller(6, 4)
|
||||
predictions = controller()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,362 +0,0 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.07 #
|
||||
#####################################################
|
||||
import torch, random
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from typing import Text
|
||||
from torch.distributions.categorical import Categorical
|
||||
|
||||
from ..cell_operations import ResNetBasicblock, drop_path
|
||||
from .search_cells import NAS201SearchCell as SearchCell
|
||||
from .genotypes import Structure
|
||||
|
||||
|
||||
class Controller(nn.Module):
|
||||
# we refer to https://github.com/TDeVries/enas_pytorch/blob/master/models/controller.py
|
||||
def __init__(
|
||||
self,
|
||||
edge2index,
|
||||
op_names,
|
||||
max_nodes,
|
||||
lstm_size=32,
|
||||
lstm_num_layers=2,
|
||||
tanh_constant=2.5,
|
||||
temperature=5.0,
|
||||
):
|
||||
super(Controller, self).__init__()
|
||||
# assign the attributes
|
||||
self.max_nodes = max_nodes
|
||||
self.num_edge = len(edge2index)
|
||||
self.edge2index = edge2index
|
||||
self.num_ops = len(op_names)
|
||||
self.op_names = op_names
|
||||
self.lstm_size = lstm_size
|
||||
self.lstm_N = lstm_num_layers
|
||||
self.tanh_constant = tanh_constant
|
||||
self.temperature = temperature
|
||||
# create parameters
|
||||
self.register_parameter(
|
||||
"input_vars", nn.Parameter(torch.Tensor(1, 1, lstm_size))
|
||||
)
|
||||
self.w_lstm = nn.LSTM(
|
||||
input_size=self.lstm_size,
|
||||
hidden_size=self.lstm_size,
|
||||
num_layers=self.lstm_N,
|
||||
)
|
||||
self.w_embd = nn.Embedding(self.num_ops, self.lstm_size)
|
||||
self.w_pred = nn.Linear(self.lstm_size, self.num_ops)
|
||||
|
||||
nn.init.uniform_(self.input_vars, -0.1, 0.1)
|
||||
nn.init.uniform_(self.w_lstm.weight_hh_l0, -0.1, 0.1)
|
||||
nn.init.uniform_(self.w_lstm.weight_ih_l0, -0.1, 0.1)
|
||||
nn.init.uniform_(self.w_embd.weight, -0.1, 0.1)
|
||||
nn.init.uniform_(self.w_pred.weight, -0.1, 0.1)
|
||||
|
||||
def convert_structure(self, _arch):
|
||||
genotypes = []
|
||||
for i in range(1, self.max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
op_index = _arch[self.edge2index[node_str]]
|
||||
op_name = self.op_names[op_index]
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append(tuple(xlist))
|
||||
return Structure(genotypes)
|
||||
|
||||
def forward(self):
|
||||
|
||||
inputs, h0 = self.input_vars, None
|
||||
log_probs, entropys, sampled_arch = [], [], []
|
||||
for iedge in range(self.num_edge):
|
||||
outputs, h0 = self.w_lstm(inputs, h0)
|
||||
|
||||
logits = self.w_pred(outputs)
|
||||
logits = logits / self.temperature
|
||||
logits = self.tanh_constant * torch.tanh(logits)
|
||||
# distribution
|
||||
op_distribution = Categorical(logits=logits)
|
||||
op_index = op_distribution.sample()
|
||||
sampled_arch.append(op_index.item())
|
||||
|
||||
op_log_prob = op_distribution.log_prob(op_index)
|
||||
log_probs.append(op_log_prob.view(-1))
|
||||
op_entropy = op_distribution.entropy()
|
||||
entropys.append(op_entropy.view(-1))
|
||||
|
||||
# obtain the input embedding for the next step
|
||||
inputs = self.w_embd(op_index)
|
||||
return (
|
||||
torch.sum(torch.cat(log_probs)),
|
||||
torch.sum(torch.cat(entropys)),
|
||||
self.convert_structure(sampled_arch),
|
||||
)
|
||||
|
||||
|
||||
class GenericNAS201Model(nn.Module):
|
||||
def __init__(
|
||||
self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats
|
||||
):
|
||||
super(GenericNAS201Model, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
self._max_nodes = max_nodes
|
||||
self._stem = nn.Sequential(
|
||||
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C)
|
||||
)
|
||||
layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
C_prev, num_edge, edge2index = C, None, None
|
||||
self._cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(
|
||||
zip(layer_channels, layer_reductions)
|
||||
):
|
||||
if reduction:
|
||||
cell = ResNetBasicblock(C_prev, C_curr, 2)
|
||||
else:
|
||||
cell = SearchCell(
|
||||
C_prev,
|
||||
C_curr,
|
||||
1,
|
||||
max_nodes,
|
||||
search_space,
|
||||
affine,
|
||||
track_running_stats,
|
||||
)
|
||||
if num_edge is None:
|
||||
num_edge, edge2index = cell.num_edges, cell.edge2index
|
||||
else:
|
||||
assert (
|
||||
num_edge == cell.num_edges and edge2index == cell.edge2index
|
||||
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
|
||||
self._cells.append(cell)
|
||||
C_prev = cell.out_dim
|
||||
self._op_names = deepcopy(search_space)
|
||||
self._Layer = len(self._cells)
|
||||
self.edge2index = edge2index
|
||||
self.lastact = nn.Sequential(
|
||||
nn.BatchNorm2d(
|
||||
C_prev, affine=affine, track_running_stats=track_running_stats
|
||||
),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self._num_edge = num_edge
|
||||
# algorithm related
|
||||
self.arch_parameters = nn.Parameter(
|
||||
1e-3 * torch.randn(num_edge, len(search_space))
|
||||
)
|
||||
self._mode = None
|
||||
self.dynamic_cell = None
|
||||
self._tau = None
|
||||
self._algo = None
|
||||
self._drop_path = None
|
||||
self.verbose = False
|
||||
|
||||
def set_algo(self, algo: Text):
|
||||
# used for searching
|
||||
assert self._algo is None, "This functioin can only be called once."
|
||||
self._algo = algo
|
||||
if algo == "enas":
|
||||
self.controller = Controller(
|
||||
self.edge2index, self._op_names, self._max_nodes
|
||||
)
|
||||
else:
|
||||
self.arch_parameters = nn.Parameter(
|
||||
1e-3 * torch.randn(self._num_edge, len(self._op_names))
|
||||
)
|
||||
if algo == "gdas":
|
||||
self._tau = 10
|
||||
|
||||
def set_cal_mode(self, mode, dynamic_cell=None):
|
||||
assert mode in ["gdas", "enas", "urs", "joint", "select", "dynamic"]
|
||||
self._mode = mode
|
||||
if mode == "dynamic":
|
||||
self.dynamic_cell = deepcopy(dynamic_cell)
|
||||
else:
|
||||
self.dynamic_cell = None
|
||||
|
||||
def set_drop_path(self, progress, drop_path_rate):
|
||||
if drop_path_rate is None:
|
||||
self._drop_path = None
|
||||
elif progress is None:
|
||||
self._drop_path = drop_path_rate
|
||||
else:
|
||||
self._drop_path = progress * drop_path_rate
|
||||
|
||||
@property
|
||||
def mode(self):
|
||||
return self._mode
|
||||
|
||||
@property
|
||||
def drop_path(self):
|
||||
return self._drop_path
|
||||
|
||||
@property
|
||||
def weights(self):
|
||||
xlist = list(self._stem.parameters())
|
||||
xlist += list(self._cells.parameters())
|
||||
xlist += list(self.lastact.parameters())
|
||||
xlist += list(self.global_pooling.parameters())
|
||||
xlist += list(self.classifier.parameters())
|
||||
return xlist
|
||||
|
||||
def set_tau(self, tau):
|
||||
self._tau = tau
|
||||
|
||||
@property
|
||||
def tau(self):
|
||||
return self._tau
|
||||
|
||||
@property
|
||||
def alphas(self):
|
||||
if self._algo == "enas":
|
||||
return list(self.controller.parameters())
|
||||
else:
|
||||
return [self.arch_parameters]
|
||||
|
||||
@property
|
||||
def message(self):
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self._cells):
|
||||
string += "\n {:02d}/{:02d} :: {:}".format(
|
||||
i, len(self._cells), cell.extra_repr()
|
||||
)
|
||||
return string
|
||||
|
||||
def show_alphas(self):
|
||||
with torch.no_grad():
|
||||
if self._algo == "enas":
|
||||
return "w_pred :\n{:}".format(self.controller.w_pred.weight)
|
||||
else:
|
||||
return "arch-parameters :\n{:}".format(
|
||||
nn.functional.softmax(self.arch_parameters, dim=-1).cpu()
|
||||
)
|
||||
|
||||
def extra_repr(self):
|
||||
return "{name}(C={_C}, Max-Nodes={_max_nodes}, N={_layerN}, L={_Layer}, alg={_algo})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
@property
|
||||
def genotype(self):
|
||||
genotypes = []
|
||||
for i in range(1, self._max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
with torch.no_grad():
|
||||
weights = self.arch_parameters[self.edge2index[node_str]]
|
||||
op_name = self._op_names[weights.argmax().item()]
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append(tuple(xlist))
|
||||
return Structure(genotypes)
|
||||
|
||||
def dync_genotype(self, use_random=False):
|
||||
genotypes = []
|
||||
with torch.no_grad():
|
||||
alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1)
|
||||
for i in range(1, self._max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
if use_random:
|
||||
op_name = random.choice(self._op_names)
|
||||
else:
|
||||
weights = alphas_cpu[self.edge2index[node_str]]
|
||||
op_index = torch.multinomial(weights, 1).item()
|
||||
op_name = self._op_names[op_index]
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append(tuple(xlist))
|
||||
return Structure(genotypes)
|
||||
|
||||
def get_log_prob(self, arch):
|
||||
with torch.no_grad():
|
||||
logits = nn.functional.log_softmax(self.arch_parameters, dim=-1)
|
||||
select_logits = []
|
||||
for i, node_info in enumerate(arch.nodes):
|
||||
for op, xin in node_info:
|
||||
node_str = "{:}<-{:}".format(i + 1, xin)
|
||||
op_index = self._op_names.index(op)
|
||||
select_logits.append(logits[self.edge2index[node_str], op_index])
|
||||
return sum(select_logits).item()
|
||||
|
||||
def return_topK(self, K, use_random=False):
|
||||
archs = Structure.gen_all(self._op_names, self._max_nodes, False)
|
||||
pairs = [(self.get_log_prob(arch), arch) for arch in archs]
|
||||
if K < 0 or K >= len(archs):
|
||||
K = len(archs)
|
||||
if use_random:
|
||||
return random.sample(archs, K)
|
||||
else:
|
||||
sorted_pairs = sorted(pairs, key=lambda x: -x[0])
|
||||
return_pairs = [sorted_pairs[_][1] for _ in range(K)]
|
||||
return return_pairs
|
||||
|
||||
def normalize_archp(self):
|
||||
if self.mode == "gdas":
|
||||
while True:
|
||||
gumbels = -torch.empty_like(self.arch_parameters).exponential_().log()
|
||||
logits = (self.arch_parameters.log_softmax(dim=1) + gumbels) / self.tau
|
||||
probs = nn.functional.softmax(logits, dim=1)
|
||||
index = probs.max(-1, keepdim=True)[1]
|
||||
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
|
||||
hardwts = one_h - probs.detach() + probs
|
||||
if (
|
||||
(torch.isinf(gumbels).any())
|
||||
or (torch.isinf(probs).any())
|
||||
or (torch.isnan(probs).any())
|
||||
):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
with torch.no_grad():
|
||||
hardwts_cpu = hardwts.detach().cpu()
|
||||
return hardwts, hardwts_cpu, index, "GUMBEL"
|
||||
else:
|
||||
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
|
||||
index = alphas.max(-1, keepdim=True)[1]
|
||||
with torch.no_grad():
|
||||
alphas_cpu = alphas.detach().cpu()
|
||||
return alphas, alphas_cpu, index, "SOFTMAX"
|
||||
|
||||
def forward(self, inputs):
|
||||
alphas, alphas_cpu, index, verbose_str = self.normalize_archp()
|
||||
feature = self._stem(inputs)
|
||||
for i, cell in enumerate(self._cells):
|
||||
if isinstance(cell, SearchCell):
|
||||
if self.mode == "urs":
|
||||
feature = cell.forward_urs(feature)
|
||||
if self.verbose:
|
||||
verbose_str += "-forward_urs"
|
||||
elif self.mode == "select":
|
||||
feature = cell.forward_select(feature, alphas_cpu)
|
||||
if self.verbose:
|
||||
verbose_str += "-forward_select"
|
||||
elif self.mode == "joint":
|
||||
feature = cell.forward_joint(feature, alphas)
|
||||
if self.verbose:
|
||||
verbose_str += "-forward_joint"
|
||||
elif self.mode == "dynamic":
|
||||
feature = cell.forward_dynamic(feature, self.dynamic_cell)
|
||||
if self.verbose:
|
||||
verbose_str += "-forward_dynamic"
|
||||
elif self.mode == "gdas":
|
||||
feature = cell.forward_gdas(feature, alphas, index)
|
||||
if self.verbose:
|
||||
verbose_str += "-forward_gdas"
|
||||
else:
|
||||
raise ValueError("invalid mode={:}".format(self.mode))
|
||||
else:
|
||||
feature = cell(feature)
|
||||
if self.drop_path is not None:
|
||||
feature = drop_path(feature, self.drop_path)
|
||||
if self.verbose and random.random() < 0.001:
|
||||
print(verbose_str)
|
||||
out = self.lastact(feature)
|
||||
out = self.global_pooling(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
return out, logits
|
||||
@@ -1,274 +0,0 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
def get_combination(space, num):
|
||||
combs = []
|
||||
for i in range(num):
|
||||
if i == 0:
|
||||
for func in space:
|
||||
combs.append([(func, i)])
|
||||
else:
|
||||
new_combs = []
|
||||
for string in combs:
|
||||
for func in space:
|
||||
xstring = string + [(func, i)]
|
||||
new_combs.append(xstring)
|
||||
combs = new_combs
|
||||
return combs
|
||||
|
||||
|
||||
class Structure:
|
||||
def __init__(self, genotype):
|
||||
assert isinstance(genotype, list) or isinstance(
|
||||
genotype, tuple
|
||||
), "invalid class of genotype : {:}".format(type(genotype))
|
||||
self.node_num = len(genotype) + 1
|
||||
self.nodes = []
|
||||
self.node_N = []
|
||||
for idx, node_info in enumerate(genotype):
|
||||
assert isinstance(node_info, list) or isinstance(
|
||||
node_info, tuple
|
||||
), "invalid class of node_info : {:}".format(type(node_info))
|
||||
assert len(node_info) >= 1, "invalid length : {:}".format(len(node_info))
|
||||
for node_in in node_info:
|
||||
assert isinstance(node_in, list) or isinstance(
|
||||
node_in, tuple
|
||||
), "invalid class of in-node : {:}".format(type(node_in))
|
||||
assert (
|
||||
len(node_in) == 2 and node_in[1] <= idx
|
||||
), "invalid in-node : {:}".format(node_in)
|
||||
self.node_N.append(len(node_info))
|
||||
self.nodes.append(tuple(deepcopy(node_info)))
|
||||
|
||||
def tolist(self, remove_str):
|
||||
# convert this class to the list, if remove_str is 'none', then remove the 'none' operation.
|
||||
# note that we re-order the input node in this function
|
||||
# return the-genotype-list and success [if unsuccess, it is not a connectivity]
|
||||
genotypes = []
|
||||
for node_info in self.nodes:
|
||||
node_info = list(node_info)
|
||||
node_info = sorted(node_info, key=lambda x: (x[1], x[0]))
|
||||
node_info = tuple(filter(lambda x: x[0] != remove_str, node_info))
|
||||
if len(node_info) == 0:
|
||||
return None, False
|
||||
genotypes.append(node_info)
|
||||
return genotypes, True
|
||||
|
||||
def node(self, index):
|
||||
assert index > 0 and index <= len(self), "invalid index={:} < {:}".format(
|
||||
index, len(self)
|
||||
)
|
||||
return self.nodes[index]
|
||||
|
||||
def tostr(self):
|
||||
strings = []
|
||||
for node_info in self.nodes:
|
||||
string = "|".join([x[0] + "~{:}".format(x[1]) for x in node_info])
|
||||
string = "|{:}|".format(string)
|
||||
strings.append(string)
|
||||
return "+".join(strings)
|
||||
|
||||
def check_valid(self):
|
||||
nodes = {0: True}
|
||||
for i, node_info in enumerate(self.nodes):
|
||||
sums = []
|
||||
for op, xin in node_info:
|
||||
if op == "none" or nodes[xin] is False:
|
||||
x = False
|
||||
else:
|
||||
x = True
|
||||
sums.append(x)
|
||||
nodes[i + 1] = sum(sums) > 0
|
||||
return nodes[len(self.nodes)]
|
||||
|
||||
def to_unique_str(self, consider_zero=False):
|
||||
# this is used to identify the isomorphic cell, which rerquires the prior knowledge of operation
|
||||
# two operations are special, i.e., none and skip_connect
|
||||
nodes = {0: "0"}
|
||||
for i_node, node_info in enumerate(self.nodes):
|
||||
cur_node = []
|
||||
for op, xin in node_info:
|
||||
if consider_zero is None:
|
||||
x = "(" + nodes[xin] + ")" + "@{:}".format(op)
|
||||
elif consider_zero:
|
||||
if op == "none" or nodes[xin] == "#":
|
||||
x = "#" # zero
|
||||
elif op == "skip_connect":
|
||||
x = nodes[xin]
|
||||
else:
|
||||
x = "(" + nodes[xin] + ")" + "@{:}".format(op)
|
||||
else:
|
||||
if op == "skip_connect":
|
||||
x = nodes[xin]
|
||||
else:
|
||||
x = "(" + nodes[xin] + ")" + "@{:}".format(op)
|
||||
cur_node.append(x)
|
||||
nodes[i_node + 1] = "+".join(sorted(cur_node))
|
||||
return nodes[len(self.nodes)]
|
||||
|
||||
def check_valid_op(self, op_names):
|
||||
for node_info in self.nodes:
|
||||
for inode_edge in node_info:
|
||||
# assert inode_edge[0] in op_names, 'invalid op-name : {:}'.format(inode_edge[0])
|
||||
if inode_edge[0] not in op_names:
|
||||
return False
|
||||
return True
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}({node_num} nodes with {node_info})".format(
|
||||
name=self.__class__.__name__, node_info=self.tostr(), **self.__dict__
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.nodes) + 1
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.nodes[index]
|
||||
|
||||
@staticmethod
|
||||
def str2structure(xstr):
|
||||
if isinstance(xstr, Structure):
|
||||
return xstr
|
||||
assert isinstance(xstr, str), "must take string (not {:}) as input".format(
|
||||
type(xstr)
|
||||
)
|
||||
nodestrs = xstr.split("+")
|
||||
genotypes = []
|
||||
for i, node_str in enumerate(nodestrs):
|
||||
inputs = list(filter(lambda x: x != "", node_str.split("|")))
|
||||
for xinput in inputs:
|
||||
assert len(xinput.split("~")) == 2, "invalid input length : {:}".format(
|
||||
xinput
|
||||
)
|
||||
inputs = (xi.split("~") for xi in inputs)
|
||||
input_infos = tuple((op, int(IDX)) for (op, IDX) in inputs)
|
||||
genotypes.append(input_infos)
|
||||
return Structure(genotypes)
|
||||
|
||||
@staticmethod
|
||||
def str2fullstructure(xstr, default_name="none"):
|
||||
assert isinstance(xstr, str), "must take string (not {:}) as input".format(
|
||||
type(xstr)
|
||||
)
|
||||
nodestrs = xstr.split("+")
|
||||
genotypes = []
|
||||
for i, node_str in enumerate(nodestrs):
|
||||
inputs = list(filter(lambda x: x != "", node_str.split("|")))
|
||||
for xinput in inputs:
|
||||
assert len(xinput.split("~")) == 2, "invalid input length : {:}".format(
|
||||
xinput
|
||||
)
|
||||
inputs = (xi.split("~") for xi in inputs)
|
||||
input_infos = list((op, int(IDX)) for (op, IDX) in inputs)
|
||||
all_in_nodes = list(x[1] for x in input_infos)
|
||||
for j in range(i):
|
||||
if j not in all_in_nodes:
|
||||
input_infos.append((default_name, j))
|
||||
node_info = sorted(input_infos, key=lambda x: (x[1], x[0]))
|
||||
genotypes.append(tuple(node_info))
|
||||
return Structure(genotypes)
|
||||
|
||||
@staticmethod
|
||||
def gen_all(search_space, num, return_ori):
|
||||
assert isinstance(search_space, list) or isinstance(
|
||||
search_space, tuple
|
||||
), "invalid class of search-space : {:}".format(type(search_space))
|
||||
assert (
|
||||
num >= 2
|
||||
), "There should be at least two nodes in a neural cell instead of {:}".format(
|
||||
num
|
||||
)
|
||||
all_archs = get_combination(search_space, 1)
|
||||
for i, arch in enumerate(all_archs):
|
||||
all_archs[i] = [tuple(arch)]
|
||||
|
||||
for inode in range(2, num):
|
||||
cur_nodes = get_combination(search_space, inode)
|
||||
new_all_archs = []
|
||||
for previous_arch in all_archs:
|
||||
for cur_node in cur_nodes:
|
||||
new_all_archs.append(previous_arch + [tuple(cur_node)])
|
||||
all_archs = new_all_archs
|
||||
if return_ori:
|
||||
return all_archs
|
||||
else:
|
||||
return [Structure(x) for x in all_archs]
|
||||
|
||||
|
||||
ResNet_CODE = Structure(
|
||||
[
|
||||
(("nor_conv_3x3", 0),), # node-1
|
||||
(("nor_conv_3x3", 1),), # node-2
|
||||
(("skip_connect", 0), ("skip_connect", 2)),
|
||||
] # node-3
|
||||
)
|
||||
|
||||
AllConv3x3_CODE = Structure(
|
||||
[
|
||||
(("nor_conv_3x3", 0),), # node-1
|
||||
(("nor_conv_3x3", 0), ("nor_conv_3x3", 1)), # node-2
|
||||
(("nor_conv_3x3", 0), ("nor_conv_3x3", 1), ("nor_conv_3x3", 2)),
|
||||
] # node-3
|
||||
)
|
||||
|
||||
AllFull_CODE = Structure(
|
||||
[
|
||||
(
|
||||
("skip_connect", 0),
|
||||
("nor_conv_1x1", 0),
|
||||
("nor_conv_3x3", 0),
|
||||
("avg_pool_3x3", 0),
|
||||
), # node-1
|
||||
(
|
||||
("skip_connect", 0),
|
||||
("nor_conv_1x1", 0),
|
||||
("nor_conv_3x3", 0),
|
||||
("avg_pool_3x3", 0),
|
||||
("skip_connect", 1),
|
||||
("nor_conv_1x1", 1),
|
||||
("nor_conv_3x3", 1),
|
||||
("avg_pool_3x3", 1),
|
||||
), # node-2
|
||||
(
|
||||
("skip_connect", 0),
|
||||
("nor_conv_1x1", 0),
|
||||
("nor_conv_3x3", 0),
|
||||
("avg_pool_3x3", 0),
|
||||
("skip_connect", 1),
|
||||
("nor_conv_1x1", 1),
|
||||
("nor_conv_3x3", 1),
|
||||
("avg_pool_3x3", 1),
|
||||
("skip_connect", 2),
|
||||
("nor_conv_1x1", 2),
|
||||
("nor_conv_3x3", 2),
|
||||
("avg_pool_3x3", 2),
|
||||
),
|
||||
] # node-3
|
||||
)
|
||||
|
||||
AllConv1x1_CODE = Structure(
|
||||
[
|
||||
(("nor_conv_1x1", 0),), # node-1
|
||||
(("nor_conv_1x1", 0), ("nor_conv_1x1", 1)), # node-2
|
||||
(("nor_conv_1x1", 0), ("nor_conv_1x1", 1), ("nor_conv_1x1", 2)),
|
||||
] # node-3
|
||||
)
|
||||
|
||||
AllIdentity_CODE = Structure(
|
||||
[
|
||||
(("skip_connect", 0),), # node-1
|
||||
(("skip_connect", 0), ("skip_connect", 1)), # node-2
|
||||
(("skip_connect", 0), ("skip_connect", 1), ("skip_connect", 2)),
|
||||
] # node-3
|
||||
)
|
||||
|
||||
architectures = {
|
||||
"resnet": ResNet_CODE,
|
||||
"all_c3x3": AllConv3x3_CODE,
|
||||
"all_c1x1": AllConv1x1_CODE,
|
||||
"all_idnt": AllIdentity_CODE,
|
||||
"all_full": AllFull_CODE,
|
||||
}
|
||||
@@ -1,251 +0,0 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import math, random, torch
|
||||
import warnings
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from copy import deepcopy
|
||||
from ..cell_operations import OPS
|
||||
|
||||
|
||||
# This module is used for NAS-Bench-201, represents a small search space with a complete DAG
|
||||
class NAS201SearchCell(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
C_in,
|
||||
C_out,
|
||||
stride,
|
||||
max_nodes,
|
||||
op_names,
|
||||
affine=False,
|
||||
track_running_stats=True,
|
||||
):
|
||||
super(NAS201SearchCell, self).__init__()
|
||||
|
||||
self.op_names = deepcopy(op_names)
|
||||
self.edges = nn.ModuleDict()
|
||||
self.max_nodes = max_nodes
|
||||
self.in_dim = C_in
|
||||
self.out_dim = C_out
|
||||
for i in range(1, max_nodes):
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
if j == 0:
|
||||
xlists = [
|
||||
OPS[op_name](C_in, C_out, stride, affine, track_running_stats)
|
||||
for op_name in op_names
|
||||
]
|
||||
else:
|
||||
xlists = [
|
||||
OPS[op_name](C_in, C_out, 1, affine, track_running_stats)
|
||||
for op_name in op_names
|
||||
]
|
||||
self.edges[node_str] = nn.ModuleList(xlists)
|
||||
self.edge_keys = sorted(list(self.edges.keys()))
|
||||
self.edge2index = {key: i for i, key in enumerate(self.edge_keys)}
|
||||
self.num_edges = len(self.edges)
|
||||
|
||||
def extra_repr(self):
|
||||
string = "info :: {max_nodes} nodes, inC={in_dim}, outC={out_dim}".format(
|
||||
**self.__dict__
|
||||
)
|
||||
return string
|
||||
|
||||
def forward(self, inputs, weightss):
|
||||
nodes = [inputs]
|
||||
for i in range(1, self.max_nodes):
|
||||
inter_nodes = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
weights = weightss[self.edge2index[node_str]]
|
||||
inter_nodes.append(
|
||||
sum(
|
||||
layer(nodes[j]) * w
|
||||
for layer, w in zip(self.edges[node_str], weights)
|
||||
)
|
||||
)
|
||||
nodes.append(sum(inter_nodes))
|
||||
return nodes[-1]
|
||||
|
||||
# GDAS
|
||||
def forward_gdas(self, inputs, hardwts, index):
|
||||
nodes = [inputs]
|
||||
for i in range(1, self.max_nodes):
|
||||
inter_nodes = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
weights = hardwts[self.edge2index[node_str]]
|
||||
argmaxs = index[self.edge2index[node_str]].item()
|
||||
weigsum = sum(
|
||||
weights[_ie] * edge(nodes[j]) if _ie == argmaxs else weights[_ie]
|
||||
for _ie, edge in enumerate(self.edges[node_str])
|
||||
)
|
||||
inter_nodes.append(weigsum)
|
||||
nodes.append(sum(inter_nodes))
|
||||
return nodes[-1]
|
||||
|
||||
# joint
|
||||
def forward_joint(self, inputs, weightss):
|
||||
nodes = [inputs]
|
||||
for i in range(1, self.max_nodes):
|
||||
inter_nodes = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
weights = weightss[self.edge2index[node_str]]
|
||||
# aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) / weights.numel()
|
||||
aggregation = sum(
|
||||
layer(nodes[j]) * w
|
||||
for layer, w in zip(self.edges[node_str], weights)
|
||||
)
|
||||
inter_nodes.append(aggregation)
|
||||
nodes.append(sum(inter_nodes))
|
||||
return nodes[-1]
|
||||
|
||||
# uniform random sampling per iteration, SETN
|
||||
def forward_urs(self, inputs):
|
||||
nodes = [inputs]
|
||||
for i in range(1, self.max_nodes):
|
||||
while True: # to avoid select zero for all ops
|
||||
sops, has_non_zero = [], False
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
candidates = self.edges[node_str]
|
||||
select_op = random.choice(candidates)
|
||||
sops.append(select_op)
|
||||
if not hasattr(select_op, "is_zero") or select_op.is_zero is False:
|
||||
has_non_zero = True
|
||||
if has_non_zero:
|
||||
break
|
||||
inter_nodes = []
|
||||
for j, select_op in enumerate(sops):
|
||||
inter_nodes.append(select_op(nodes[j]))
|
||||
nodes.append(sum(inter_nodes))
|
||||
return nodes[-1]
|
||||
|
||||
# select the argmax
|
||||
def forward_select(self, inputs, weightss):
|
||||
nodes = [inputs]
|
||||
for i in range(1, self.max_nodes):
|
||||
inter_nodes = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
weights = weightss[self.edge2index[node_str]]
|
||||
inter_nodes.append(
|
||||
self.edges[node_str][weights.argmax().item()](nodes[j])
|
||||
)
|
||||
# inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) )
|
||||
nodes.append(sum(inter_nodes))
|
||||
return nodes[-1]
|
||||
|
||||
# forward with a specific structure
|
||||
def forward_dynamic(self, inputs, structure):
|
||||
nodes = [inputs]
|
||||
for i in range(1, self.max_nodes):
|
||||
cur_op_node = structure.nodes[i - 1]
|
||||
inter_nodes = []
|
||||
for op_name, j in cur_op_node:
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
op_index = self.op_names.index(op_name)
|
||||
inter_nodes.append(self.edges[node_str][op_index](nodes[j]))
|
||||
nodes.append(sum(inter_nodes))
|
||||
return nodes[-1]
|
||||
|
||||
|
||||
class MixedOp(nn.Module):
|
||||
def __init__(self, space, C, stride, affine, track_running_stats):
|
||||
super(MixedOp, self).__init__()
|
||||
self._ops = nn.ModuleList()
|
||||
for primitive in space:
|
||||
op = OPS[primitive](C, C, stride, affine, track_running_stats)
|
||||
self._ops.append(op)
|
||||
|
||||
def forward_gdas(self, x, weights, index):
|
||||
return self._ops[index](x) * weights[index]
|
||||
|
||||
def forward_darts(self, x, weights):
|
||||
return sum(w * op(x) for w, op in zip(weights, self._ops))
|
||||
|
||||
|
||||
# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018
|
||||
class NASNetSearchCell(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
space,
|
||||
steps,
|
||||
multiplier,
|
||||
C_prev_prev,
|
||||
C_prev,
|
||||
C,
|
||||
reduction,
|
||||
reduction_prev,
|
||||
affine,
|
||||
track_running_stats,
|
||||
):
|
||||
super(NASNetSearchCell, self).__init__()
|
||||
self.reduction = reduction
|
||||
self.op_names = deepcopy(space)
|
||||
if reduction_prev:
|
||||
self.preprocess0 = OPS["skip_connect"](
|
||||
C_prev_prev, C, 2, affine, track_running_stats
|
||||
)
|
||||
else:
|
||||
self.preprocess0 = OPS["nor_conv_1x1"](
|
||||
C_prev_prev, C, 1, affine, track_running_stats
|
||||
)
|
||||
self.preprocess1 = OPS["nor_conv_1x1"](
|
||||
C_prev, C, 1, affine, track_running_stats
|
||||
)
|
||||
self._steps = steps
|
||||
self._multiplier = multiplier
|
||||
|
||||
self._ops = nn.ModuleList()
|
||||
self.edges = nn.ModuleDict()
|
||||
for i in range(self._steps):
|
||||
for j in range(2 + i):
|
||||
node_str = "{:}<-{:}".format(
|
||||
i, j
|
||||
) # indicate the edge from node-(j) to node-(i+2)
|
||||
stride = 2 if reduction and j < 2 else 1
|
||||
op = MixedOp(space, C, stride, affine, track_running_stats)
|
||||
self.edges[node_str] = op
|
||||
self.edge_keys = sorted(list(self.edges.keys()))
|
||||
self.edge2index = {key: i for i, key in enumerate(self.edge_keys)}
|
||||
self.num_edges = len(self.edges)
|
||||
|
||||
@property
|
||||
def multiplier(self):
|
||||
return self._multiplier
|
||||
|
||||
def forward_gdas(self, s0, s1, weightss, indexs):
|
||||
s0 = self.preprocess0(s0)
|
||||
s1 = self.preprocess1(s1)
|
||||
|
||||
states = [s0, s1]
|
||||
for i in range(self._steps):
|
||||
clist = []
|
||||
for j, h in enumerate(states):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
op = self.edges[node_str]
|
||||
weights = weightss[self.edge2index[node_str]]
|
||||
index = indexs[self.edge2index[node_str]].item()
|
||||
clist.append(op.forward_gdas(h, weights, index))
|
||||
states.append(sum(clist))
|
||||
|
||||
return torch.cat(states[-self._multiplier :], dim=1)
|
||||
|
||||
def forward_darts(self, s0, s1, weightss):
|
||||
s0 = self.preprocess0(s0)
|
||||
s1 = self.preprocess1(s1)
|
||||
|
||||
states = [s0, s1]
|
||||
for i in range(self._steps):
|
||||
clist = []
|
||||
for j, h in enumerate(states):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
op = self.edges[node_str]
|
||||
weights = weightss[self.edge2index[node_str]]
|
||||
clist.append(op.forward_darts(h, weights))
|
||||
states.append(sum(clist))
|
||||
|
||||
return torch.cat(states[-self._multiplier :], dim=1)
|
||||
@@ -1,122 +0,0 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
########################################################
|
||||
# DARTS: Differentiable Architecture Search, ICLR 2019 #
|
||||
########################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from ..cell_operations import ResNetBasicblock
|
||||
from .search_cells import NAS201SearchCell as SearchCell
|
||||
from .genotypes import Structure
|
||||
|
||||
|
||||
class TinyNetworkDarts(nn.Module):
|
||||
def __init__(
|
||||
self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats
|
||||
):
|
||||
super(TinyNetworkDarts, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
self.max_nodes = max_nodes
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C)
|
||||
)
|
||||
|
||||
layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
|
||||
C_prev, num_edge, edge2index = C, None, None
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(
|
||||
zip(layer_channels, layer_reductions)
|
||||
):
|
||||
if reduction:
|
||||
cell = ResNetBasicblock(C_prev, C_curr, 2)
|
||||
else:
|
||||
cell = SearchCell(
|
||||
C_prev,
|
||||
C_curr,
|
||||
1,
|
||||
max_nodes,
|
||||
search_space,
|
||||
affine,
|
||||
track_running_stats,
|
||||
)
|
||||
if num_edge is None:
|
||||
num_edge, edge2index = cell.num_edges, cell.edge2index
|
||||
else:
|
||||
assert (
|
||||
num_edge == cell.num_edges and edge2index == cell.edge2index
|
||||
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
|
||||
self.cells.append(cell)
|
||||
C_prev = cell.out_dim
|
||||
self.op_names = deepcopy(search_space)
|
||||
self._Layer = len(self.cells)
|
||||
self.edge2index = edge2index
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self.arch_parameters = nn.Parameter(
|
||||
1e-3 * torch.randn(num_edge, len(search_space))
|
||||
)
|
||||
|
||||
def get_weights(self):
|
||||
xlist = list(self.stem.parameters()) + list(self.cells.parameters())
|
||||
xlist += list(self.lastact.parameters()) + list(
|
||||
self.global_pooling.parameters()
|
||||
)
|
||||
xlist += list(self.classifier.parameters())
|
||||
return xlist
|
||||
|
||||
def get_alphas(self):
|
||||
return [self.arch_parameters]
|
||||
|
||||
def show_alphas(self):
|
||||
with torch.no_grad():
|
||||
return "arch-parameters :\n{:}".format(
|
||||
nn.functional.softmax(self.arch_parameters, dim=-1).cpu()
|
||||
)
|
||||
|
||||
def get_message(self):
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += "\n {:02d}/{:02d} :: {:}".format(
|
||||
i, len(self.cells), cell.extra_repr()
|
||||
)
|
||||
return string
|
||||
|
||||
def extra_repr(self):
|
||||
return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
def genotype(self):
|
||||
genotypes = []
|
||||
for i in range(1, self.max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
with torch.no_grad():
|
||||
weights = self.arch_parameters[self.edge2index[node_str]]
|
||||
op_name = self.op_names[weights.argmax().item()]
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append(tuple(xlist))
|
||||
return Structure(genotypes)
|
||||
|
||||
def forward(self, inputs):
|
||||
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
|
||||
|
||||
feature = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if isinstance(cell, SearchCell):
|
||||
feature = cell(feature, alphas)
|
||||
else:
|
||||
feature = cell(feature)
|
||||
|
||||
out = self.lastact(feature)
|
||||
out = self.global_pooling(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
return out, logits
|
||||
@@ -1,178 +0,0 @@
|
||||
####################
|
||||
# DARTS, ICLR 2019 #
|
||||
####################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from typing import List, Text, Dict
|
||||
from .search_cells import NASNetSearchCell as SearchCell
|
||||
|
||||
|
||||
# The macro structure is based on NASNet
|
||||
class NASNetworkDARTS(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
C: int,
|
||||
N: int,
|
||||
steps: int,
|
||||
multiplier: int,
|
||||
stem_multiplier: int,
|
||||
num_classes: int,
|
||||
search_space: List[Text],
|
||||
affine: bool,
|
||||
track_running_stats: bool,
|
||||
):
|
||||
super(NASNetworkDARTS, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
self._steps = steps
|
||||
self._multiplier = multiplier
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C * stem_multiplier, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C * stem_multiplier),
|
||||
)
|
||||
|
||||
# config for each layer
|
||||
layer_channels = (
|
||||
[C] * N + [C * 2] + [C * 2] * (N - 1) + [C * 4] + [C * 4] * (N - 1)
|
||||
)
|
||||
layer_reductions = (
|
||||
[False] * N + [True] + [False] * (N - 1) + [True] + [False] * (N - 1)
|
||||
)
|
||||
|
||||
num_edge, edge2index = None, None
|
||||
C_prev_prev, C_prev, C_curr, reduction_prev = (
|
||||
C * stem_multiplier,
|
||||
C * stem_multiplier,
|
||||
C,
|
||||
False,
|
||||
)
|
||||
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(
|
||||
zip(layer_channels, layer_reductions)
|
||||
):
|
||||
cell = SearchCell(
|
||||
search_space,
|
||||
steps,
|
||||
multiplier,
|
||||
C_prev_prev,
|
||||
C_prev,
|
||||
C_curr,
|
||||
reduction,
|
||||
reduction_prev,
|
||||
affine,
|
||||
track_running_stats,
|
||||
)
|
||||
if num_edge is None:
|
||||
num_edge, edge2index = cell.num_edges, cell.edge2index
|
||||
else:
|
||||
assert (
|
||||
num_edge == cell.num_edges and edge2index == cell.edge2index
|
||||
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
|
||||
self.cells.append(cell)
|
||||
C_prev_prev, C_prev, reduction_prev = C_prev, multiplier * C_curr, reduction
|
||||
self.op_names = deepcopy(search_space)
|
||||
self._Layer = len(self.cells)
|
||||
self.edge2index = edge2index
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self.arch_normal_parameters = nn.Parameter(
|
||||
1e-3 * torch.randn(num_edge, len(search_space))
|
||||
)
|
||||
self.arch_reduce_parameters = nn.Parameter(
|
||||
1e-3 * torch.randn(num_edge, len(search_space))
|
||||
)
|
||||
|
||||
def get_weights(self) -> List[torch.nn.Parameter]:
|
||||
xlist = list(self.stem.parameters()) + list(self.cells.parameters())
|
||||
xlist += list(self.lastact.parameters()) + list(
|
||||
self.global_pooling.parameters()
|
||||
)
|
||||
xlist += list(self.classifier.parameters())
|
||||
return xlist
|
||||
|
||||
def get_alphas(self) -> List[torch.nn.Parameter]:
|
||||
return [self.arch_normal_parameters, self.arch_reduce_parameters]
|
||||
|
||||
def show_alphas(self) -> Text:
|
||||
with torch.no_grad():
|
||||
A = "arch-normal-parameters :\n{:}".format(
|
||||
nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu()
|
||||
)
|
||||
B = "arch-reduce-parameters :\n{:}".format(
|
||||
nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu()
|
||||
)
|
||||
return "{:}\n{:}".format(A, B)
|
||||
|
||||
def get_message(self) -> Text:
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += "\n {:02d}/{:02d} :: {:}".format(
|
||||
i, len(self.cells), cell.extra_repr()
|
||||
)
|
||||
return string
|
||||
|
||||
def extra_repr(self) -> Text:
|
||||
return "{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
def genotype(self) -> Dict[Text, List]:
|
||||
def _parse(weights):
|
||||
gene = []
|
||||
for i in range(self._steps):
|
||||
edges = []
|
||||
for j in range(2 + i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
ws = weights[self.edge2index[node_str]]
|
||||
for k, op_name in enumerate(self.op_names):
|
||||
if op_name == "none":
|
||||
continue
|
||||
edges.append((op_name, j, ws[k]))
|
||||
# (TODO) xuanyidong:
|
||||
# Here the selected two edges might come from the same input node.
|
||||
# And this case could be a problem that two edges will collapse into a single one
|
||||
# due to our assumption -- at most one edge from an input node during evaluation.
|
||||
edges = sorted(edges, key=lambda x: -x[-1])
|
||||
selected_edges = edges[:2]
|
||||
gene.append(tuple(selected_edges))
|
||||
return gene
|
||||
|
||||
with torch.no_grad():
|
||||
gene_normal = _parse(
|
||||
torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy()
|
||||
)
|
||||
gene_reduce = _parse(
|
||||
torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy()
|
||||
)
|
||||
return {
|
||||
"normal": gene_normal,
|
||||
"normal_concat": list(
|
||||
range(2 + self._steps - self._multiplier, self._steps + 2)
|
||||
),
|
||||
"reduce": gene_reduce,
|
||||
"reduce_concat": list(
|
||||
range(2 + self._steps - self._multiplier, self._steps + 2)
|
||||
),
|
||||
}
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
normal_w = nn.functional.softmax(self.arch_normal_parameters, dim=1)
|
||||
reduce_w = nn.functional.softmax(self.arch_reduce_parameters, dim=1)
|
||||
|
||||
s0 = s1 = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if cell.reduction:
|
||||
ww = reduce_w
|
||||
else:
|
||||
ww = normal_w
|
||||
s0, s1 = s1, cell.forward_darts(s0, s1, ww)
|
||||
out = self.lastact(s1)
|
||||
out = self.global_pooling(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
return out, logits
|
||||
@@ -1,114 +0,0 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##########################################################################
|
||||
# Efficient Neural Architecture Search via Parameters Sharing, ICML 2018 #
|
||||
##########################################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from ..cell_operations import ResNetBasicblock
|
||||
from .search_cells import NAS201SearchCell as SearchCell
|
||||
from .genotypes import Structure
|
||||
from .search_model_enas_utils import Controller
|
||||
|
||||
|
||||
class TinyNetworkENAS(nn.Module):
|
||||
def __init__(
|
||||
self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats
|
||||
):
|
||||
super(TinyNetworkENAS, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
self.max_nodes = max_nodes
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C)
|
||||
)
|
||||
|
||||
layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
|
||||
C_prev, num_edge, edge2index = C, None, None
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(
|
||||
zip(layer_channels, layer_reductions)
|
||||
):
|
||||
if reduction:
|
||||
cell = ResNetBasicblock(C_prev, C_curr, 2)
|
||||
else:
|
||||
cell = SearchCell(
|
||||
C_prev,
|
||||
C_curr,
|
||||
1,
|
||||
max_nodes,
|
||||
search_space,
|
||||
affine,
|
||||
track_running_stats,
|
||||
)
|
||||
if num_edge is None:
|
||||
num_edge, edge2index = cell.num_edges, cell.edge2index
|
||||
else:
|
||||
assert (
|
||||
num_edge == cell.num_edges and edge2index == cell.edge2index
|
||||
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
|
||||
self.cells.append(cell)
|
||||
C_prev = cell.out_dim
|
||||
self.op_names = deepcopy(search_space)
|
||||
self._Layer = len(self.cells)
|
||||
self.edge2index = edge2index
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
# to maintain the sampled architecture
|
||||
self.sampled_arch = None
|
||||
|
||||
def update_arch(self, _arch):
|
||||
if _arch is None:
|
||||
self.sampled_arch = None
|
||||
elif isinstance(_arch, Structure):
|
||||
self.sampled_arch = _arch
|
||||
elif isinstance(_arch, (list, tuple)):
|
||||
genotypes = []
|
||||
for i in range(1, self.max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
op_index = _arch[self.edge2index[node_str]]
|
||||
op_name = self.op_names[op_index]
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append(tuple(xlist))
|
||||
self.sampled_arch = Structure(genotypes)
|
||||
else:
|
||||
raise ValueError("invalid type of input architecture : {:}".format(_arch))
|
||||
return self.sampled_arch
|
||||
|
||||
def create_controller(self):
|
||||
return Controller(len(self.edge2index), len(self.op_names))
|
||||
|
||||
def get_message(self):
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += "\n {:02d}/{:02d} :: {:}".format(
|
||||
i, len(self.cells), cell.extra_repr()
|
||||
)
|
||||
return string
|
||||
|
||||
def extra_repr(self):
|
||||
return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
feature = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if isinstance(cell, SearchCell):
|
||||
feature = cell.forward_dynamic(feature, self.sampled_arch)
|
||||
else:
|
||||
feature = cell(feature)
|
||||
|
||||
out = self.lastact(feature)
|
||||
out = self.global_pooling(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
return out, logits
|
||||
@@ -1,74 +0,0 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##########################################################################
|
||||
# Efficient Neural Architecture Search via Parameters Sharing, ICML 2018 #
|
||||
##########################################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributions.categorical import Categorical
|
||||
|
||||
|
||||
class Controller(nn.Module):
|
||||
# we refer to https://github.com/TDeVries/enas_pytorch/blob/master/models/controller.py
|
||||
def __init__(
|
||||
self,
|
||||
num_edge,
|
||||
num_ops,
|
||||
lstm_size=32,
|
||||
lstm_num_layers=2,
|
||||
tanh_constant=2.5,
|
||||
temperature=5.0,
|
||||
):
|
||||
super(Controller, self).__init__()
|
||||
# assign the attributes
|
||||
self.num_edge = num_edge
|
||||
self.num_ops = num_ops
|
||||
self.lstm_size = lstm_size
|
||||
self.lstm_N = lstm_num_layers
|
||||
self.tanh_constant = tanh_constant
|
||||
self.temperature = temperature
|
||||
# create parameters
|
||||
self.register_parameter(
|
||||
"input_vars", nn.Parameter(torch.Tensor(1, 1, lstm_size))
|
||||
)
|
||||
self.w_lstm = nn.LSTM(
|
||||
input_size=self.lstm_size,
|
||||
hidden_size=self.lstm_size,
|
||||
num_layers=self.lstm_N,
|
||||
)
|
||||
self.w_embd = nn.Embedding(self.num_ops, self.lstm_size)
|
||||
self.w_pred = nn.Linear(self.lstm_size, self.num_ops)
|
||||
|
||||
nn.init.uniform_(self.input_vars, -0.1, 0.1)
|
||||
nn.init.uniform_(self.w_lstm.weight_hh_l0, -0.1, 0.1)
|
||||
nn.init.uniform_(self.w_lstm.weight_ih_l0, -0.1, 0.1)
|
||||
nn.init.uniform_(self.w_embd.weight, -0.1, 0.1)
|
||||
nn.init.uniform_(self.w_pred.weight, -0.1, 0.1)
|
||||
|
||||
def forward(self):
|
||||
|
||||
inputs, h0 = self.input_vars, None
|
||||
log_probs, entropys, sampled_arch = [], [], []
|
||||
for iedge in range(self.num_edge):
|
||||
outputs, h0 = self.w_lstm(inputs, h0)
|
||||
|
||||
logits = self.w_pred(outputs)
|
||||
logits = logits / self.temperature
|
||||
logits = self.tanh_constant * torch.tanh(logits)
|
||||
# distribution
|
||||
op_distribution = Categorical(logits=logits)
|
||||
op_index = op_distribution.sample()
|
||||
sampled_arch.append(op_index.item())
|
||||
|
||||
op_log_prob = op_distribution.log_prob(op_index)
|
||||
log_probs.append(op_log_prob.view(-1))
|
||||
op_entropy = op_distribution.entropy()
|
||||
entropys.append(op_entropy.view(-1))
|
||||
|
||||
# obtain the input embedding for the next step
|
||||
inputs = self.w_embd(op_index)
|
||||
return (
|
||||
torch.sum(torch.cat(log_probs)),
|
||||
torch.sum(torch.cat(entropys)),
|
||||
sampled_arch,
|
||||
)
|
||||
@@ -1,142 +0,0 @@
|
||||
###########################################################################
|
||||
# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 #
|
||||
###########################################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from ..cell_operations import ResNetBasicblock
|
||||
from .search_cells import NAS201SearchCell as SearchCell
|
||||
from .genotypes import Structure
|
||||
|
||||
|
||||
class TinyNetworkGDAS(nn.Module):
|
||||
|
||||
# def __init__(self, C, N, max_nodes, num_classes, search_space, affine=False, track_running_stats=True):
|
||||
def __init__(
|
||||
self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats
|
||||
):
|
||||
super(TinyNetworkGDAS, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
self.max_nodes = max_nodes
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C)
|
||||
)
|
||||
|
||||
layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
|
||||
C_prev, num_edge, edge2index = C, None, None
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(
|
||||
zip(layer_channels, layer_reductions)
|
||||
):
|
||||
if reduction:
|
||||
cell = ResNetBasicblock(C_prev, C_curr, 2)
|
||||
else:
|
||||
cell = SearchCell(
|
||||
C_prev,
|
||||
C_curr,
|
||||
1,
|
||||
max_nodes,
|
||||
search_space,
|
||||
affine,
|
||||
track_running_stats,
|
||||
)
|
||||
if num_edge is None:
|
||||
num_edge, edge2index = cell.num_edges, cell.edge2index
|
||||
else:
|
||||
assert (
|
||||
num_edge == cell.num_edges and edge2index == cell.edge2index
|
||||
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
|
||||
self.cells.append(cell)
|
||||
C_prev = cell.out_dim
|
||||
self.op_names = deepcopy(search_space)
|
||||
self._Layer = len(self.cells)
|
||||
self.edge2index = edge2index
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self.arch_parameters = nn.Parameter(
|
||||
1e-3 * torch.randn(num_edge, len(search_space))
|
||||
)
|
||||
self.tau = 10
|
||||
|
||||
def get_weights(self):
|
||||
xlist = list(self.stem.parameters()) + list(self.cells.parameters())
|
||||
xlist += list(self.lastact.parameters()) + list(
|
||||
self.global_pooling.parameters()
|
||||
)
|
||||
xlist += list(self.classifier.parameters())
|
||||
return xlist
|
||||
|
||||
def set_tau(self, tau):
|
||||
self.tau = tau
|
||||
|
||||
def get_tau(self):
|
||||
return self.tau
|
||||
|
||||
def get_alphas(self):
|
||||
return [self.arch_parameters]
|
||||
|
||||
def show_alphas(self):
|
||||
with torch.no_grad():
|
||||
return "arch-parameters :\n{:}".format(
|
||||
nn.functional.softmax(self.arch_parameters, dim=-1).cpu()
|
||||
)
|
||||
|
||||
def get_message(self):
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += "\n {:02d}/{:02d} :: {:}".format(
|
||||
i, len(self.cells), cell.extra_repr()
|
||||
)
|
||||
return string
|
||||
|
||||
def extra_repr(self):
|
||||
return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
def genotype(self):
|
||||
genotypes = []
|
||||
for i in range(1, self.max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
with torch.no_grad():
|
||||
weights = self.arch_parameters[self.edge2index[node_str]]
|
||||
op_name = self.op_names[weights.argmax().item()]
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append(tuple(xlist))
|
||||
return Structure(genotypes)
|
||||
|
||||
def forward(self, inputs):
|
||||
while True:
|
||||
gumbels = -torch.empty_like(self.arch_parameters).exponential_().log()
|
||||
logits = (self.arch_parameters.log_softmax(dim=1) + gumbels) / self.tau
|
||||
probs = nn.functional.softmax(logits, dim=1)
|
||||
index = probs.max(-1, keepdim=True)[1]
|
||||
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
|
||||
hardwts = one_h - probs.detach() + probs
|
||||
if (
|
||||
(torch.isinf(gumbels).any())
|
||||
or (torch.isinf(probs).any())
|
||||
or (torch.isnan(probs).any())
|
||||
):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
feature = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if isinstance(cell, SearchCell):
|
||||
feature = cell.forward_gdas(feature, hardwts, index)
|
||||
else:
|
||||
feature = cell(feature)
|
||||
out = self.lastact(feature)
|
||||
out = self.global_pooling(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
return out, logits
|
||||
@@ -1,199 +0,0 @@
|
||||
###########################################################################
|
||||
# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 #
|
||||
###########################################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from models.cell_searchs.search_cells import NASNetSearchCell as SearchCell
|
||||
from models.cell_operations import RAW_OP_CLASSES
|
||||
|
||||
|
||||
# The macro structure is based on NASNet
|
||||
class NASNetworkGDAS_FRC(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
C,
|
||||
N,
|
||||
steps,
|
||||
multiplier,
|
||||
stem_multiplier,
|
||||
num_classes,
|
||||
search_space,
|
||||
affine,
|
||||
track_running_stats,
|
||||
):
|
||||
super(NASNetworkGDAS_FRC, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
self._steps = steps
|
||||
self._multiplier = multiplier
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C * stem_multiplier, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C * stem_multiplier),
|
||||
)
|
||||
|
||||
# config for each layer
|
||||
layer_channels = (
|
||||
[C] * N + [C * 2] + [C * 2] * (N - 1) + [C * 4] + [C * 4] * (N - 1)
|
||||
)
|
||||
layer_reductions = (
|
||||
[False] * N + [True] + [False] * (N - 1) + [True] + [False] * (N - 1)
|
||||
)
|
||||
|
||||
num_edge, edge2index = None, None
|
||||
C_prev_prev, C_prev, C_curr, reduction_prev = (
|
||||
C * stem_multiplier,
|
||||
C * stem_multiplier,
|
||||
C,
|
||||
False,
|
||||
)
|
||||
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(
|
||||
zip(layer_channels, layer_reductions)
|
||||
):
|
||||
if reduction:
|
||||
cell = RAW_OP_CLASSES["gdas_reduction"](
|
||||
C_prev_prev,
|
||||
C_prev,
|
||||
C_curr,
|
||||
reduction_prev,
|
||||
affine,
|
||||
track_running_stats,
|
||||
)
|
||||
else:
|
||||
cell = SearchCell(
|
||||
search_space,
|
||||
steps,
|
||||
multiplier,
|
||||
C_prev_prev,
|
||||
C_prev,
|
||||
C_curr,
|
||||
reduction,
|
||||
reduction_prev,
|
||||
affine,
|
||||
track_running_stats,
|
||||
)
|
||||
if num_edge is None:
|
||||
num_edge, edge2index = cell.num_edges, cell.edge2index
|
||||
else:
|
||||
assert (
|
||||
reduction
|
||||
or num_edge == cell.num_edges
|
||||
and edge2index == cell.edge2index
|
||||
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
|
||||
self.cells.append(cell)
|
||||
C_prev_prev, C_prev, reduction_prev = (
|
||||
C_prev,
|
||||
cell.multiplier * C_curr,
|
||||
reduction,
|
||||
)
|
||||
self.op_names = deepcopy(search_space)
|
||||
self._Layer = len(self.cells)
|
||||
self.edge2index = edge2index
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self.arch_parameters = nn.Parameter(
|
||||
1e-3 * torch.randn(num_edge, len(search_space))
|
||||
)
|
||||
self.tau = 10
|
||||
|
||||
def get_weights(self):
|
||||
xlist = list(self.stem.parameters()) + list(self.cells.parameters())
|
||||
xlist += list(self.lastact.parameters()) + list(
|
||||
self.global_pooling.parameters()
|
||||
)
|
||||
xlist += list(self.classifier.parameters())
|
||||
return xlist
|
||||
|
||||
def set_tau(self, tau):
|
||||
self.tau = tau
|
||||
|
||||
def get_tau(self):
|
||||
return self.tau
|
||||
|
||||
def get_alphas(self):
|
||||
return [self.arch_parameters]
|
||||
|
||||
def show_alphas(self):
|
||||
with torch.no_grad():
|
||||
A = "arch-normal-parameters :\n{:}".format(
|
||||
nn.functional.softmax(self.arch_parameters, dim=-1).cpu()
|
||||
)
|
||||
return "{:}".format(A)
|
||||
|
||||
def get_message(self):
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += "\n {:02d}/{:02d} :: {:}".format(
|
||||
i, len(self.cells), cell.extra_repr()
|
||||
)
|
||||
return string
|
||||
|
||||
def extra_repr(self):
|
||||
return "{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
def genotype(self):
|
||||
def _parse(weights):
|
||||
gene = []
|
||||
for i in range(self._steps):
|
||||
edges = []
|
||||
for j in range(2 + i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
ws = weights[self.edge2index[node_str]]
|
||||
for k, op_name in enumerate(self.op_names):
|
||||
if op_name == "none":
|
||||
continue
|
||||
edges.append((op_name, j, ws[k]))
|
||||
edges = sorted(edges, key=lambda x: -x[-1])
|
||||
selected_edges = edges[:2]
|
||||
gene.append(tuple(selected_edges))
|
||||
return gene
|
||||
|
||||
with torch.no_grad():
|
||||
gene_normal = _parse(
|
||||
torch.softmax(self.arch_parameters, dim=-1).cpu().numpy()
|
||||
)
|
||||
return {
|
||||
"normal": gene_normal,
|
||||
"normal_concat": list(
|
||||
range(2 + self._steps - self._multiplier, self._steps + 2)
|
||||
),
|
||||
}
|
||||
|
||||
def forward(self, inputs):
|
||||
def get_gumbel_prob(xins):
|
||||
while True:
|
||||
gumbels = -torch.empty_like(xins).exponential_().log()
|
||||
logits = (xins.log_softmax(dim=1) + gumbels) / self.tau
|
||||
probs = nn.functional.softmax(logits, dim=1)
|
||||
index = probs.max(-1, keepdim=True)[1]
|
||||
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
|
||||
hardwts = one_h - probs.detach() + probs
|
||||
if (
|
||||
(torch.isinf(gumbels).any())
|
||||
or (torch.isinf(probs).any())
|
||||
or (torch.isnan(probs).any())
|
||||
):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
return hardwts, index
|
||||
|
||||
hardwts, index = get_gumbel_prob(self.arch_parameters)
|
||||
|
||||
s0 = s1 = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if cell.reduction:
|
||||
s0, s1 = s1, cell(s0, s1)
|
||||
else:
|
||||
s0, s1 = s1, cell.forward_gdas(s0, s1, hardwts, index)
|
||||
out = self.lastact(s1)
|
||||
out = self.global_pooling(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
return out, logits
|
||||
@@ -1,197 +0,0 @@
|
||||
###########################################################################
|
||||
# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 #
|
||||
###########################################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from .search_cells import NASNetSearchCell as SearchCell
|
||||
|
||||
|
||||
# The macro structure is based on NASNet
|
||||
class NASNetworkGDAS(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
C,
|
||||
N,
|
||||
steps,
|
||||
multiplier,
|
||||
stem_multiplier,
|
||||
num_classes,
|
||||
search_space,
|
||||
affine,
|
||||
track_running_stats,
|
||||
):
|
||||
super(NASNetworkGDAS, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
self._steps = steps
|
||||
self._multiplier = multiplier
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C * stem_multiplier, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C * stem_multiplier),
|
||||
)
|
||||
|
||||
# config for each layer
|
||||
layer_channels = (
|
||||
[C] * N + [C * 2] + [C * 2] * (N - 1) + [C * 4] + [C * 4] * (N - 1)
|
||||
)
|
||||
layer_reductions = (
|
||||
[False] * N + [True] + [False] * (N - 1) + [True] + [False] * (N - 1)
|
||||
)
|
||||
|
||||
num_edge, edge2index = None, None
|
||||
C_prev_prev, C_prev, C_curr, reduction_prev = (
|
||||
C * stem_multiplier,
|
||||
C * stem_multiplier,
|
||||
C,
|
||||
False,
|
||||
)
|
||||
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(
|
||||
zip(layer_channels, layer_reductions)
|
||||
):
|
||||
cell = SearchCell(
|
||||
search_space,
|
||||
steps,
|
||||
multiplier,
|
||||
C_prev_prev,
|
||||
C_prev,
|
||||
C_curr,
|
||||
reduction,
|
||||
reduction_prev,
|
||||
affine,
|
||||
track_running_stats,
|
||||
)
|
||||
if num_edge is None:
|
||||
num_edge, edge2index = cell.num_edges, cell.edge2index
|
||||
else:
|
||||
assert (
|
||||
num_edge == cell.num_edges and edge2index == cell.edge2index
|
||||
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
|
||||
self.cells.append(cell)
|
||||
C_prev_prev, C_prev, reduction_prev = C_prev, multiplier * C_curr, reduction
|
||||
self.op_names = deepcopy(search_space)
|
||||
self._Layer = len(self.cells)
|
||||
self.edge2index = edge2index
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self.arch_normal_parameters = nn.Parameter(
|
||||
1e-3 * torch.randn(num_edge, len(search_space))
|
||||
)
|
||||
self.arch_reduce_parameters = nn.Parameter(
|
||||
1e-3 * torch.randn(num_edge, len(search_space))
|
||||
)
|
||||
self.tau = 10
|
||||
|
||||
def get_weights(self):
|
||||
xlist = list(self.stem.parameters()) + list(self.cells.parameters())
|
||||
xlist += list(self.lastact.parameters()) + list(
|
||||
self.global_pooling.parameters()
|
||||
)
|
||||
xlist += list(self.classifier.parameters())
|
||||
return xlist
|
||||
|
||||
def set_tau(self, tau):
|
||||
self.tau = tau
|
||||
|
||||
def get_tau(self):
|
||||
return self.tau
|
||||
|
||||
def get_alphas(self):
|
||||
return [self.arch_normal_parameters, self.arch_reduce_parameters]
|
||||
|
||||
def show_alphas(self):
|
||||
with torch.no_grad():
|
||||
A = "arch-normal-parameters :\n{:}".format(
|
||||
nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu()
|
||||
)
|
||||
B = "arch-reduce-parameters :\n{:}".format(
|
||||
nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu()
|
||||
)
|
||||
return "{:}\n{:}".format(A, B)
|
||||
|
||||
def get_message(self):
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += "\n {:02d}/{:02d} :: {:}".format(
|
||||
i, len(self.cells), cell.extra_repr()
|
||||
)
|
||||
return string
|
||||
|
||||
def extra_repr(self):
|
||||
return "{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
def genotype(self):
|
||||
def _parse(weights):
|
||||
gene = []
|
||||
for i in range(self._steps):
|
||||
edges = []
|
||||
for j in range(2 + i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
ws = weights[self.edge2index[node_str]]
|
||||
for k, op_name in enumerate(self.op_names):
|
||||
if op_name == "none":
|
||||
continue
|
||||
edges.append((op_name, j, ws[k]))
|
||||
edges = sorted(edges, key=lambda x: -x[-1])
|
||||
selected_edges = edges[:2]
|
||||
gene.append(tuple(selected_edges))
|
||||
return gene
|
||||
|
||||
with torch.no_grad():
|
||||
gene_normal = _parse(
|
||||
torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy()
|
||||
)
|
||||
gene_reduce = _parse(
|
||||
torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy()
|
||||
)
|
||||
return {
|
||||
"normal": gene_normal,
|
||||
"normal_concat": list(
|
||||
range(2 + self._steps - self._multiplier, self._steps + 2)
|
||||
),
|
||||
"reduce": gene_reduce,
|
||||
"reduce_concat": list(
|
||||
range(2 + self._steps - self._multiplier, self._steps + 2)
|
||||
),
|
||||
}
|
||||
|
||||
def forward(self, inputs):
|
||||
def get_gumbel_prob(xins):
|
||||
while True:
|
||||
gumbels = -torch.empty_like(xins).exponential_().log()
|
||||
logits = (xins.log_softmax(dim=1) + gumbels) / self.tau
|
||||
probs = nn.functional.softmax(logits, dim=1)
|
||||
index = probs.max(-1, keepdim=True)[1]
|
||||
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
|
||||
hardwts = one_h - probs.detach() + probs
|
||||
if (
|
||||
(torch.isinf(gumbels).any())
|
||||
or (torch.isinf(probs).any())
|
||||
or (torch.isnan(probs).any())
|
||||
):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
return hardwts, index
|
||||
|
||||
normal_hardwts, normal_index = get_gumbel_prob(self.arch_normal_parameters)
|
||||
reduce_hardwts, reduce_index = get_gumbel_prob(self.arch_reduce_parameters)
|
||||
|
||||
s0 = s1 = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if cell.reduction:
|
||||
hardwts, index = reduce_hardwts, reduce_index
|
||||
else:
|
||||
hardwts, index = normal_hardwts, normal_index
|
||||
s0, s1 = s1, cell.forward_gdas(s0, s1, hardwts, index)
|
||||
out = self.lastact(s1)
|
||||
out = self.global_pooling(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
return out, logits
|
||||
@@ -1,102 +0,0 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##############################################################################
|
||||
# Random Search and Reproducibility for Neural Architecture Search, UAI 2019 #
|
||||
##############################################################################
|
||||
import torch, random
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from ..cell_operations import ResNetBasicblock
|
||||
from .search_cells import NAS201SearchCell as SearchCell
|
||||
from .genotypes import Structure
|
||||
|
||||
|
||||
class TinyNetworkRANDOM(nn.Module):
|
||||
def __init__(
|
||||
self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats
|
||||
):
|
||||
super(TinyNetworkRANDOM, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
self.max_nodes = max_nodes
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C)
|
||||
)
|
||||
|
||||
layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
|
||||
C_prev, num_edge, edge2index = C, None, None
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(
|
||||
zip(layer_channels, layer_reductions)
|
||||
):
|
||||
if reduction:
|
||||
cell = ResNetBasicblock(C_prev, C_curr, 2)
|
||||
else:
|
||||
cell = SearchCell(
|
||||
C_prev,
|
||||
C_curr,
|
||||
1,
|
||||
max_nodes,
|
||||
search_space,
|
||||
affine,
|
||||
track_running_stats,
|
||||
)
|
||||
if num_edge is None:
|
||||
num_edge, edge2index = cell.num_edges, cell.edge2index
|
||||
else:
|
||||
assert (
|
||||
num_edge == cell.num_edges and edge2index == cell.edge2index
|
||||
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
|
||||
self.cells.append(cell)
|
||||
C_prev = cell.out_dim
|
||||
self.op_names = deepcopy(search_space)
|
||||
self._Layer = len(self.cells)
|
||||
self.edge2index = edge2index
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self.arch_cache = None
|
||||
|
||||
def get_message(self):
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += "\n {:02d}/{:02d} :: {:}".format(
|
||||
i, len(self.cells), cell.extra_repr()
|
||||
)
|
||||
return string
|
||||
|
||||
def extra_repr(self):
|
||||
return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
def random_genotype(self, set_cache):
|
||||
genotypes = []
|
||||
for i in range(1, self.max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
op_name = random.choice(self.op_names)
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append(tuple(xlist))
|
||||
arch = Structure(genotypes)
|
||||
if set_cache:
|
||||
self.arch_cache = arch
|
||||
return arch
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
feature = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if isinstance(cell, SearchCell):
|
||||
feature = cell.forward_dynamic(feature, self.arch_cache)
|
||||
else:
|
||||
feature = cell(feature)
|
||||
|
||||
out = self.lastact(feature)
|
||||
out = self.global_pooling(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
return out, logits
|
||||
@@ -1,178 +0,0 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
######################################################################################
|
||||
# One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 #
|
||||
######################################################################################
|
||||
import torch, random
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from ..cell_operations import ResNetBasicblock
|
||||
from .search_cells import NAS201SearchCell as SearchCell
|
||||
from .genotypes import Structure
|
||||
|
||||
|
||||
class TinyNetworkSETN(nn.Module):
|
||||
def __init__(
|
||||
self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats
|
||||
):
|
||||
super(TinyNetworkSETN, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
self.max_nodes = max_nodes
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C)
|
||||
)
|
||||
|
||||
layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
|
||||
C_prev, num_edge, edge2index = C, None, None
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(
|
||||
zip(layer_channels, layer_reductions)
|
||||
):
|
||||
if reduction:
|
||||
cell = ResNetBasicblock(C_prev, C_curr, 2)
|
||||
else:
|
||||
cell = SearchCell(
|
||||
C_prev,
|
||||
C_curr,
|
||||
1,
|
||||
max_nodes,
|
||||
search_space,
|
||||
affine,
|
||||
track_running_stats,
|
||||
)
|
||||
if num_edge is None:
|
||||
num_edge, edge2index = cell.num_edges, cell.edge2index
|
||||
else:
|
||||
assert (
|
||||
num_edge == cell.num_edges and edge2index == cell.edge2index
|
||||
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
|
||||
self.cells.append(cell)
|
||||
C_prev = cell.out_dim
|
||||
self.op_names = deepcopy(search_space)
|
||||
self._Layer = len(self.cells)
|
||||
self.edge2index = edge2index
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self.arch_parameters = nn.Parameter(
|
||||
1e-3 * torch.randn(num_edge, len(search_space))
|
||||
)
|
||||
self.mode = "urs"
|
||||
self.dynamic_cell = None
|
||||
|
||||
def set_cal_mode(self, mode, dynamic_cell=None):
|
||||
assert mode in ["urs", "joint", "select", "dynamic"]
|
||||
self.mode = mode
|
||||
if mode == "dynamic":
|
||||
self.dynamic_cell = deepcopy(dynamic_cell)
|
||||
else:
|
||||
self.dynamic_cell = None
|
||||
|
||||
def get_cal_mode(self):
|
||||
return self.mode
|
||||
|
||||
def get_weights(self):
|
||||
xlist = list(self.stem.parameters()) + list(self.cells.parameters())
|
||||
xlist += list(self.lastact.parameters()) + list(
|
||||
self.global_pooling.parameters()
|
||||
)
|
||||
xlist += list(self.classifier.parameters())
|
||||
return xlist
|
||||
|
||||
def get_alphas(self):
|
||||
return [self.arch_parameters]
|
||||
|
||||
def get_message(self):
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += "\n {:02d}/{:02d} :: {:}".format(
|
||||
i, len(self.cells), cell.extra_repr()
|
||||
)
|
||||
return string
|
||||
|
||||
def extra_repr(self):
|
||||
return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
def genotype(self):
|
||||
genotypes = []
|
||||
for i in range(1, self.max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
with torch.no_grad():
|
||||
weights = self.arch_parameters[self.edge2index[node_str]]
|
||||
op_name = self.op_names[weights.argmax().item()]
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append(tuple(xlist))
|
||||
return Structure(genotypes)
|
||||
|
||||
def dync_genotype(self, use_random=False):
|
||||
genotypes = []
|
||||
with torch.no_grad():
|
||||
alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1)
|
||||
for i in range(1, self.max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
if use_random:
|
||||
op_name = random.choice(self.op_names)
|
||||
else:
|
||||
weights = alphas_cpu[self.edge2index[node_str]]
|
||||
op_index = torch.multinomial(weights, 1).item()
|
||||
op_name = self.op_names[op_index]
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append(tuple(xlist))
|
||||
return Structure(genotypes)
|
||||
|
||||
def get_log_prob(self, arch):
|
||||
with torch.no_grad():
|
||||
logits = nn.functional.log_softmax(self.arch_parameters, dim=-1)
|
||||
select_logits = []
|
||||
for i, node_info in enumerate(arch.nodes):
|
||||
for op, xin in node_info:
|
||||
node_str = "{:}<-{:}".format(i + 1, xin)
|
||||
op_index = self.op_names.index(op)
|
||||
select_logits.append(logits[self.edge2index[node_str], op_index])
|
||||
return sum(select_logits).item()
|
||||
|
||||
def return_topK(self, K):
|
||||
archs = Structure.gen_all(self.op_names, self.max_nodes, False)
|
||||
pairs = [(self.get_log_prob(arch), arch) for arch in archs]
|
||||
if K < 0 or K >= len(archs):
|
||||
K = len(archs)
|
||||
sorted_pairs = sorted(pairs, key=lambda x: -x[0])
|
||||
return_pairs = [sorted_pairs[_][1] for _ in range(K)]
|
||||
return return_pairs
|
||||
|
||||
def forward(self, inputs):
|
||||
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
|
||||
with torch.no_grad():
|
||||
alphas_cpu = alphas.detach().cpu()
|
||||
|
||||
feature = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if isinstance(cell, SearchCell):
|
||||
if self.mode == "urs":
|
||||
feature = cell.forward_urs(feature)
|
||||
elif self.mode == "select":
|
||||
feature = cell.forward_select(feature, alphas_cpu)
|
||||
elif self.mode == "joint":
|
||||
feature = cell.forward_joint(feature, alphas)
|
||||
elif self.mode == "dynamic":
|
||||
feature = cell.forward_dynamic(feature, self.dynamic_cell)
|
||||
else:
|
||||
raise ValueError("invalid mode={:}".format(self.mode))
|
||||
else:
|
||||
feature = cell(feature)
|
||||
|
||||
out = self.lastact(feature)
|
||||
out = self.global_pooling(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
return out, logits
|
||||
@@ -1,205 +0,0 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
######################################################################################
|
||||
# One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 #
|
||||
######################################################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from typing import List, Text, Dict
|
||||
from .search_cells import NASNetSearchCell as SearchCell
|
||||
|
||||
|
||||
# The macro structure is based on NASNet
|
||||
class NASNetworkSETN(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
C: int,
|
||||
N: int,
|
||||
steps: int,
|
||||
multiplier: int,
|
||||
stem_multiplier: int,
|
||||
num_classes: int,
|
||||
search_space: List[Text],
|
||||
affine: bool,
|
||||
track_running_stats: bool,
|
||||
):
|
||||
super(NASNetworkSETN, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
self._steps = steps
|
||||
self._multiplier = multiplier
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C * stem_multiplier, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C * stem_multiplier),
|
||||
)
|
||||
|
||||
# config for each layer
|
||||
layer_channels = (
|
||||
[C] * N + [C * 2] + [C * 2] * (N - 1) + [C * 4] + [C * 4] * (N - 1)
|
||||
)
|
||||
layer_reductions = (
|
||||
[False] * N + [True] + [False] * (N - 1) + [True] + [False] * (N - 1)
|
||||
)
|
||||
|
||||
num_edge, edge2index = None, None
|
||||
C_prev_prev, C_prev, C_curr, reduction_prev = (
|
||||
C * stem_multiplier,
|
||||
C * stem_multiplier,
|
||||
C,
|
||||
False,
|
||||
)
|
||||
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(
|
||||
zip(layer_channels, layer_reductions)
|
||||
):
|
||||
cell = SearchCell(
|
||||
search_space,
|
||||
steps,
|
||||
multiplier,
|
||||
C_prev_prev,
|
||||
C_prev,
|
||||
C_curr,
|
||||
reduction,
|
||||
reduction_prev,
|
||||
affine,
|
||||
track_running_stats,
|
||||
)
|
||||
if num_edge is None:
|
||||
num_edge, edge2index = cell.num_edges, cell.edge2index
|
||||
else:
|
||||
assert (
|
||||
num_edge == cell.num_edges and edge2index == cell.edge2index
|
||||
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
|
||||
self.cells.append(cell)
|
||||
C_prev_prev, C_prev, reduction_prev = C_prev, multiplier * C_curr, reduction
|
||||
self.op_names = deepcopy(search_space)
|
||||
self._Layer = len(self.cells)
|
||||
self.edge2index = edge2index
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self.arch_normal_parameters = nn.Parameter(
|
||||
1e-3 * torch.randn(num_edge, len(search_space))
|
||||
)
|
||||
self.arch_reduce_parameters = nn.Parameter(
|
||||
1e-3 * torch.randn(num_edge, len(search_space))
|
||||
)
|
||||
self.mode = "urs"
|
||||
self.dynamic_cell = None
|
||||
|
||||
def set_cal_mode(self, mode, dynamic_cell=None):
|
||||
assert mode in ["urs", "joint", "select", "dynamic"]
|
||||
self.mode = mode
|
||||
if mode == "dynamic":
|
||||
self.dynamic_cell = deepcopy(dynamic_cell)
|
||||
else:
|
||||
self.dynamic_cell = None
|
||||
|
||||
def get_weights(self):
|
||||
xlist = list(self.stem.parameters()) + list(self.cells.parameters())
|
||||
xlist += list(self.lastact.parameters()) + list(
|
||||
self.global_pooling.parameters()
|
||||
)
|
||||
xlist += list(self.classifier.parameters())
|
||||
return xlist
|
||||
|
||||
def get_alphas(self):
|
||||
return [self.arch_normal_parameters, self.arch_reduce_parameters]
|
||||
|
||||
def show_alphas(self):
|
||||
with torch.no_grad():
|
||||
A = "arch-normal-parameters :\n{:}".format(
|
||||
nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu()
|
||||
)
|
||||
B = "arch-reduce-parameters :\n{:}".format(
|
||||
nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu()
|
||||
)
|
||||
return "{:}\n{:}".format(A, B)
|
||||
|
||||
def get_message(self):
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += "\n {:02d}/{:02d} :: {:}".format(
|
||||
i, len(self.cells), cell.extra_repr()
|
||||
)
|
||||
return string
|
||||
|
||||
def extra_repr(self):
|
||||
return "{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
def dync_genotype(self, use_random=False):
|
||||
genotypes = []
|
||||
with torch.no_grad():
|
||||
alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1)
|
||||
for i in range(1, self.max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
if use_random:
|
||||
op_name = random.choice(self.op_names)
|
||||
else:
|
||||
weights = alphas_cpu[self.edge2index[node_str]]
|
||||
op_index = torch.multinomial(weights, 1).item()
|
||||
op_name = self.op_names[op_index]
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append(tuple(xlist))
|
||||
return Structure(genotypes)
|
||||
|
||||
def genotype(self):
|
||||
def _parse(weights):
|
||||
gene = []
|
||||
for i in range(self._steps):
|
||||
edges = []
|
||||
for j in range(2 + i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
ws = weights[self.edge2index[node_str]]
|
||||
for k, op_name in enumerate(self.op_names):
|
||||
if op_name == "none":
|
||||
continue
|
||||
edges.append((op_name, j, ws[k]))
|
||||
edges = sorted(edges, key=lambda x: -x[-1])
|
||||
selected_edges = edges[:2]
|
||||
gene.append(tuple(selected_edges))
|
||||
return gene
|
||||
|
||||
with torch.no_grad():
|
||||
gene_normal = _parse(
|
||||
torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy()
|
||||
)
|
||||
gene_reduce = _parse(
|
||||
torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy()
|
||||
)
|
||||
return {
|
||||
"normal": gene_normal,
|
||||
"normal_concat": list(
|
||||
range(2 + self._steps - self._multiplier, self._steps + 2)
|
||||
),
|
||||
"reduce": gene_reduce,
|
||||
"reduce_concat": list(
|
||||
range(2 + self._steps - self._multiplier, self._steps + 2)
|
||||
),
|
||||
}
|
||||
|
||||
def forward(self, inputs):
|
||||
normal_hardwts = nn.functional.softmax(self.arch_normal_parameters, dim=-1)
|
||||
reduce_hardwts = nn.functional.softmax(self.arch_reduce_parameters, dim=-1)
|
||||
|
||||
s0 = s1 = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
# [TODO]
|
||||
raise NotImplementedError
|
||||
if cell.reduction:
|
||||
hardwts, index = reduce_hardwts, reduce_index
|
||||
else:
|
||||
hardwts, index = normal_hardwts, normal_index
|
||||
s0, s1 = s1, cell.forward_gdas(s0, s1, hardwts, index)
|
||||
out = self.lastact(s1)
|
||||
out = self.global_pooling(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
return out, logits
|
||||
@@ -1,74 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def copy_conv(module, init):
|
||||
assert isinstance(module, nn.Conv2d), "invalid module : {:}".format(module)
|
||||
assert isinstance(init, nn.Conv2d), "invalid module : {:}".format(init)
|
||||
new_i, new_o = module.in_channels, module.out_channels
|
||||
module.weight.copy_(init.weight.detach()[:new_o, :new_i])
|
||||
if module.bias is not None:
|
||||
module.bias.copy_(init.bias.detach()[:new_o])
|
||||
|
||||
|
||||
def copy_bn(module, init):
|
||||
assert isinstance(module, nn.BatchNorm2d), "invalid module : {:}".format(module)
|
||||
assert isinstance(init, nn.BatchNorm2d), "invalid module : {:}".format(init)
|
||||
num_features = module.num_features
|
||||
if module.weight is not None:
|
||||
module.weight.copy_(init.weight.detach()[:num_features])
|
||||
if module.bias is not None:
|
||||
module.bias.copy_(init.bias.detach()[:num_features])
|
||||
if module.running_mean is not None:
|
||||
module.running_mean.copy_(init.running_mean.detach()[:num_features])
|
||||
if module.running_var is not None:
|
||||
module.running_var.copy_(init.running_var.detach()[:num_features])
|
||||
|
||||
|
||||
def copy_fc(module, init):
|
||||
assert isinstance(module, nn.Linear), "invalid module : {:}".format(module)
|
||||
assert isinstance(init, nn.Linear), "invalid module : {:}".format(init)
|
||||
new_i, new_o = module.in_features, module.out_features
|
||||
module.weight.copy_(init.weight.detach()[:new_o, :new_i])
|
||||
if module.bias is not None:
|
||||
module.bias.copy_(init.bias.detach()[:new_o])
|
||||
|
||||
|
||||
def copy_base(module, init):
|
||||
assert type(module).__name__ in [
|
||||
"ConvBNReLU",
|
||||
"Downsample",
|
||||
], "invalid module : {:}".format(module)
|
||||
assert type(init).__name__ in [
|
||||
"ConvBNReLU",
|
||||
"Downsample",
|
||||
], "invalid module : {:}".format(init)
|
||||
if module.conv is not None:
|
||||
copy_conv(module.conv, init.conv)
|
||||
if module.bn is not None:
|
||||
copy_bn(module.bn, init.bn)
|
||||
|
||||
|
||||
def copy_basic(module, init):
|
||||
copy_base(module.conv_a, init.conv_a)
|
||||
copy_base(module.conv_b, init.conv_b)
|
||||
if module.downsample is not None:
|
||||
if init.downsample is not None:
|
||||
copy_base(module.downsample, init.downsample)
|
||||
# else:
|
||||
# import pdb; pdb.set_trace()
|
||||
|
||||
|
||||
def init_from_model(network, init_model):
|
||||
with torch.no_grad():
|
||||
copy_fc(network.classifier, init_model.classifier)
|
||||
for base, target in zip(init_model.layers, network.layers):
|
||||
assert (
|
||||
type(base).__name__ == type(target).__name__
|
||||
), "invalid type : {:} vs {:}".format(base, target)
|
||||
if type(base).__name__ == "ConvBNReLU":
|
||||
copy_base(target, base)
|
||||
elif type(base).__name__ == "ResNetBasicblock":
|
||||
copy_basic(target, base)
|
||||
else:
|
||||
raise ValueError("unknown type name : {:}".format(type(base).__name__))
|
||||
@@ -1,16 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def initialize_resnet(m):
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
@@ -1,286 +0,0 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from ..initialization import initialize_resnet
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
def __init__(
|
||||
self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu
|
||||
):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
if has_avg:
|
||||
self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
|
||||
else:
|
||||
self.avg = None
|
||||
self.conv = nn.Conv2d(
|
||||
nIn,
|
||||
nOut,
|
||||
kernel_size=kernel,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=bias,
|
||||
)
|
||||
if has_bn:
|
||||
self.bn = nn.BatchNorm2d(nOut)
|
||||
else:
|
||||
self.bn = None
|
||||
if has_relu:
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
else:
|
||||
self.relu = None
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.avg:
|
||||
out = self.avg(inputs)
|
||||
else:
|
||||
out = inputs
|
||||
conv = self.conv(out)
|
||||
if self.bn:
|
||||
out = self.bn(conv)
|
||||
else:
|
||||
out = conv
|
||||
if self.relu:
|
||||
out = self.relu(out)
|
||||
else:
|
||||
out = out
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNetBasicblock(nn.Module):
|
||||
num_conv = 2
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, iCs, stride):
|
||||
super(ResNetBasicblock, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
assert isinstance(iCs, tuple) or isinstance(
|
||||
iCs, list
|
||||
), "invalid type of iCs : {:}".format(iCs)
|
||||
assert len(iCs) == 3, "invalid lengths of iCs : {:}".format(iCs)
|
||||
|
||||
self.conv_a = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[1],
|
||||
3,
|
||||
stride,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
self.conv_b = ConvBNReLU(
|
||||
iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False
|
||||
)
|
||||
residual_in = iCs[0]
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[2],
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=True,
|
||||
has_bn=False,
|
||||
has_relu=False,
|
||||
)
|
||||
residual_in = iCs[2]
|
||||
elif iCs[0] != iCs[2]:
|
||||
self.downsample = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[2],
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
# self.out_dim = max(residual_in, iCs[2])
|
||||
self.out_dim = iCs[2]
|
||||
|
||||
def forward(self, inputs):
|
||||
basicblock = self.conv_a(inputs)
|
||||
basicblock = self.conv_b(basicblock)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = residual + basicblock
|
||||
return F.relu(out, inplace=True)
|
||||
|
||||
|
||||
class ResNetBottleneck(nn.Module):
|
||||
expansion = 4
|
||||
num_conv = 3
|
||||
|
||||
def __init__(self, iCs, stride):
|
||||
super(ResNetBottleneck, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
assert isinstance(iCs, tuple) or isinstance(
|
||||
iCs, list
|
||||
), "invalid type of iCs : {:}".format(iCs)
|
||||
assert len(iCs) == 4, "invalid lengths of iCs : {:}".format(iCs)
|
||||
self.conv_1x1 = ConvBNReLU(
|
||||
iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True
|
||||
)
|
||||
self.conv_3x3 = ConvBNReLU(
|
||||
iCs[1],
|
||||
iCs[2],
|
||||
3,
|
||||
stride,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
self.conv_1x4 = ConvBNReLU(
|
||||
iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False
|
||||
)
|
||||
residual_in = iCs[0]
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[3],
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=True,
|
||||
has_bn=False,
|
||||
has_relu=False,
|
||||
)
|
||||
residual_in = iCs[3]
|
||||
elif iCs[0] != iCs[3]:
|
||||
self.downsample = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[3],
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=False,
|
||||
has_relu=False,
|
||||
)
|
||||
residual_in = iCs[3]
|
||||
else:
|
||||
self.downsample = None
|
||||
# self.out_dim = max(residual_in, iCs[3])
|
||||
self.out_dim = iCs[3]
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
bottleneck = self.conv_1x1(inputs)
|
||||
bottleneck = self.conv_3x3(bottleneck)
|
||||
bottleneck = self.conv_1x4(bottleneck)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = residual + bottleneck
|
||||
return F.relu(out, inplace=True)
|
||||
|
||||
|
||||
class InferCifarResNet(nn.Module):
|
||||
def __init__(
|
||||
self, block_name, depth, xblocks, xchannels, num_classes, zero_init_residual
|
||||
):
|
||||
super(InferCifarResNet, self).__init__()
|
||||
|
||||
# Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
|
||||
if block_name == "ResNetBasicblock":
|
||||
block = ResNetBasicblock
|
||||
assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110"
|
||||
layer_blocks = (depth - 2) // 6
|
||||
elif block_name == "ResNetBottleneck":
|
||||
block = ResNetBottleneck
|
||||
assert (depth - 2) % 9 == 0, "depth should be one of 164"
|
||||
layer_blocks = (depth - 2) // 9
|
||||
else:
|
||||
raise ValueError("invalid block : {:}".format(block_name))
|
||||
assert len(xblocks) == 3, "invalid xblocks : {:}".format(xblocks)
|
||||
|
||||
self.message = (
|
||||
"InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}".format(
|
||||
depth, layer_blocks
|
||||
)
|
||||
)
|
||||
self.num_classes = num_classes
|
||||
self.xchannels = xchannels
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
ConvBNReLU(
|
||||
xchannels[0],
|
||||
xchannels[1],
|
||||
3,
|
||||
1,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
]
|
||||
)
|
||||
last_channel_idx = 1
|
||||
for stage in range(3):
|
||||
for iL in range(layer_blocks):
|
||||
num_conv = block.num_conv
|
||||
iCs = self.xchannels[last_channel_idx : last_channel_idx + num_conv + 1]
|
||||
stride = 2 if stage > 0 and iL == 0 else 1
|
||||
module = block(iCs, stride)
|
||||
last_channel_idx += num_conv
|
||||
self.xchannels[last_channel_idx] = module.out_dim
|
||||
self.layers.append(module)
|
||||
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(
|
||||
stage,
|
||||
iL,
|
||||
layer_blocks,
|
||||
len(self.layers) - 1,
|
||||
iCs,
|
||||
module.out_dim,
|
||||
stride,
|
||||
)
|
||||
if iL + 1 == xblocks[stage]: # reach the maximum depth
|
||||
out_channel = module.out_dim
|
||||
for iiL in range(iL + 1, layer_blocks):
|
||||
last_channel_idx += num_conv
|
||||
self.xchannels[last_channel_idx] = module.out_dim
|
||||
break
|
||||
|
||||
self.avgpool = nn.AvgPool2d(8)
|
||||
self.classifier = nn.Linear(self.xchannels[-1], num_classes)
|
||||
|
||||
self.apply(initialize_resnet)
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, ResNetBasicblock):
|
||||
nn.init.constant_(m.conv_b.bn.weight, 0)
|
||||
elif isinstance(m, ResNetBottleneck):
|
||||
nn.init.constant_(m.conv_1x4.bn.weight, 0)
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def forward(self, inputs):
|
||||
x = inputs
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
features = self.avgpool(x)
|
||||
features = features.view(features.size(0), -1)
|
||||
logits = self.classifier(features)
|
||||
return features, logits
|
||||
@@ -1,263 +0,0 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from ..initialization import initialize_resnet
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
def __init__(
|
||||
self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu
|
||||
):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
if has_avg:
|
||||
self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
|
||||
else:
|
||||
self.avg = None
|
||||
self.conv = nn.Conv2d(
|
||||
nIn,
|
||||
nOut,
|
||||
kernel_size=kernel,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=bias,
|
||||
)
|
||||
if has_bn:
|
||||
self.bn = nn.BatchNorm2d(nOut)
|
||||
else:
|
||||
self.bn = None
|
||||
if has_relu:
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
else:
|
||||
self.relu = None
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.avg:
|
||||
out = self.avg(inputs)
|
||||
else:
|
||||
out = inputs
|
||||
conv = self.conv(out)
|
||||
if self.bn:
|
||||
out = self.bn(conv)
|
||||
else:
|
||||
out = conv
|
||||
if self.relu:
|
||||
out = self.relu(out)
|
||||
else:
|
||||
out = out
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNetBasicblock(nn.Module):
|
||||
num_conv = 2
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride):
|
||||
super(ResNetBasicblock, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
|
||||
self.conv_a = ConvBNReLU(
|
||||
inplanes,
|
||||
planes,
|
||||
3,
|
||||
stride,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
self.conv_b = ConvBNReLU(
|
||||
planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False
|
||||
)
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=True,
|
||||
has_bn=False,
|
||||
has_relu=False,
|
||||
)
|
||||
elif inplanes != planes:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.out_dim = planes
|
||||
|
||||
def forward(self, inputs):
|
||||
basicblock = self.conv_a(inputs)
|
||||
basicblock = self.conv_b(basicblock)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = residual + basicblock
|
||||
return F.relu(out, inplace=True)
|
||||
|
||||
|
||||
class ResNetBottleneck(nn.Module):
|
||||
expansion = 4
|
||||
num_conv = 3
|
||||
|
||||
def __init__(self, inplanes, planes, stride):
|
||||
super(ResNetBottleneck, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
self.conv_1x1 = ConvBNReLU(
|
||||
inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True
|
||||
)
|
||||
self.conv_3x3 = ConvBNReLU(
|
||||
planes,
|
||||
planes,
|
||||
3,
|
||||
stride,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
self.conv_1x4 = ConvBNReLU(
|
||||
planes,
|
||||
planes * self.expansion,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes * self.expansion,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=True,
|
||||
has_bn=False,
|
||||
has_relu=False,
|
||||
)
|
||||
elif inplanes != planes * self.expansion:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes * self.expansion,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=False,
|
||||
has_relu=False,
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.out_dim = planes * self.expansion
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
bottleneck = self.conv_1x1(inputs)
|
||||
bottleneck = self.conv_3x3(bottleneck)
|
||||
bottleneck = self.conv_1x4(bottleneck)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = residual + bottleneck
|
||||
return F.relu(out, inplace=True)
|
||||
|
||||
|
||||
class InferDepthCifarResNet(nn.Module):
|
||||
def __init__(self, block_name, depth, xblocks, num_classes, zero_init_residual):
|
||||
super(InferDepthCifarResNet, self).__init__()
|
||||
|
||||
# Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
|
||||
if block_name == "ResNetBasicblock":
|
||||
block = ResNetBasicblock
|
||||
assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110"
|
||||
layer_blocks = (depth - 2) // 6
|
||||
elif block_name == "ResNetBottleneck":
|
||||
block = ResNetBottleneck
|
||||
assert (depth - 2) % 9 == 0, "depth should be one of 164"
|
||||
layer_blocks = (depth - 2) // 9
|
||||
else:
|
||||
raise ValueError("invalid block : {:}".format(block_name))
|
||||
assert len(xblocks) == 3, "invalid xblocks : {:}".format(xblocks)
|
||||
|
||||
self.message = (
|
||||
"InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}".format(
|
||||
depth, layer_blocks
|
||||
)
|
||||
)
|
||||
self.num_classes = num_classes
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
ConvBNReLU(
|
||||
3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True
|
||||
)
|
||||
]
|
||||
)
|
||||
self.channels = [16]
|
||||
for stage in range(3):
|
||||
for iL in range(layer_blocks):
|
||||
iC = self.channels[-1]
|
||||
planes = 16 * (2 ** stage)
|
||||
stride = 2 if stage > 0 and iL == 0 else 1
|
||||
module = block(iC, planes, stride)
|
||||
self.channels.append(module.out_dim)
|
||||
self.layers.append(module)
|
||||
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:}, oC={:3d}, stride={:}".format(
|
||||
stage,
|
||||
iL,
|
||||
layer_blocks,
|
||||
len(self.layers) - 1,
|
||||
planes,
|
||||
module.out_dim,
|
||||
stride,
|
||||
)
|
||||
if iL + 1 == xblocks[stage]: # reach the maximum depth
|
||||
break
|
||||
|
||||
self.avgpool = nn.AvgPool2d(8)
|
||||
self.classifier = nn.Linear(self.channels[-1], num_classes)
|
||||
|
||||
self.apply(initialize_resnet)
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, ResNetBasicblock):
|
||||
nn.init.constant_(m.conv_b.bn.weight, 0)
|
||||
elif isinstance(m, ResNetBottleneck):
|
||||
nn.init.constant_(m.conv_1x4.bn.weight, 0)
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def forward(self, inputs):
|
||||
x = inputs
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
features = self.avgpool(x)
|
||||
features = features.view(features.size(0), -1)
|
||||
logits = self.classifier(features)
|
||||
return features, logits
|
||||
@@ -1,277 +0,0 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from ..initialization import initialize_resnet
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
def __init__(
|
||||
self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu
|
||||
):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
if has_avg:
|
||||
self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
|
||||
else:
|
||||
self.avg = None
|
||||
self.conv = nn.Conv2d(
|
||||
nIn,
|
||||
nOut,
|
||||
kernel_size=kernel,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=bias,
|
||||
)
|
||||
if has_bn:
|
||||
self.bn = nn.BatchNorm2d(nOut)
|
||||
else:
|
||||
self.bn = None
|
||||
if has_relu:
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
else:
|
||||
self.relu = None
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.avg:
|
||||
out = self.avg(inputs)
|
||||
else:
|
||||
out = inputs
|
||||
conv = self.conv(out)
|
||||
if self.bn:
|
||||
out = self.bn(conv)
|
||||
else:
|
||||
out = conv
|
||||
if self.relu:
|
||||
out = self.relu(out)
|
||||
else:
|
||||
out = out
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNetBasicblock(nn.Module):
|
||||
num_conv = 2
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, iCs, stride):
|
||||
super(ResNetBasicblock, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
assert isinstance(iCs, tuple) or isinstance(
|
||||
iCs, list
|
||||
), "invalid type of iCs : {:}".format(iCs)
|
||||
assert len(iCs) == 3, "invalid lengths of iCs : {:}".format(iCs)
|
||||
|
||||
self.conv_a = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[1],
|
||||
3,
|
||||
stride,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
self.conv_b = ConvBNReLU(
|
||||
iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False
|
||||
)
|
||||
residual_in = iCs[0]
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[2],
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=True,
|
||||
has_bn=False,
|
||||
has_relu=False,
|
||||
)
|
||||
residual_in = iCs[2]
|
||||
elif iCs[0] != iCs[2]:
|
||||
self.downsample = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[2],
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
# self.out_dim = max(residual_in, iCs[2])
|
||||
self.out_dim = iCs[2]
|
||||
|
||||
def forward(self, inputs):
|
||||
basicblock = self.conv_a(inputs)
|
||||
basicblock = self.conv_b(basicblock)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = residual + basicblock
|
||||
return F.relu(out, inplace=True)
|
||||
|
||||
|
||||
class ResNetBottleneck(nn.Module):
|
||||
expansion = 4
|
||||
num_conv = 3
|
||||
|
||||
def __init__(self, iCs, stride):
|
||||
super(ResNetBottleneck, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
assert isinstance(iCs, tuple) or isinstance(
|
||||
iCs, list
|
||||
), "invalid type of iCs : {:}".format(iCs)
|
||||
assert len(iCs) == 4, "invalid lengths of iCs : {:}".format(iCs)
|
||||
self.conv_1x1 = ConvBNReLU(
|
||||
iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True
|
||||
)
|
||||
self.conv_3x3 = ConvBNReLU(
|
||||
iCs[1],
|
||||
iCs[2],
|
||||
3,
|
||||
stride,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
self.conv_1x4 = ConvBNReLU(
|
||||
iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False
|
||||
)
|
||||
residual_in = iCs[0]
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[3],
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=True,
|
||||
has_bn=False,
|
||||
has_relu=False,
|
||||
)
|
||||
residual_in = iCs[3]
|
||||
elif iCs[0] != iCs[3]:
|
||||
self.downsample = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[3],
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=False,
|
||||
has_relu=False,
|
||||
)
|
||||
residual_in = iCs[3]
|
||||
else:
|
||||
self.downsample = None
|
||||
# self.out_dim = max(residual_in, iCs[3])
|
||||
self.out_dim = iCs[3]
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
bottleneck = self.conv_1x1(inputs)
|
||||
bottleneck = self.conv_3x3(bottleneck)
|
||||
bottleneck = self.conv_1x4(bottleneck)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = residual + bottleneck
|
||||
return F.relu(out, inplace=True)
|
||||
|
||||
|
||||
class InferWidthCifarResNet(nn.Module):
|
||||
def __init__(self, block_name, depth, xchannels, num_classes, zero_init_residual):
|
||||
super(InferWidthCifarResNet, self).__init__()
|
||||
|
||||
# Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
|
||||
if block_name == "ResNetBasicblock":
|
||||
block = ResNetBasicblock
|
||||
assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110"
|
||||
layer_blocks = (depth - 2) // 6
|
||||
elif block_name == "ResNetBottleneck":
|
||||
block = ResNetBottleneck
|
||||
assert (depth - 2) % 9 == 0, "depth should be one of 164"
|
||||
layer_blocks = (depth - 2) // 9
|
||||
else:
|
||||
raise ValueError("invalid block : {:}".format(block_name))
|
||||
|
||||
self.message = (
|
||||
"InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}".format(
|
||||
depth, layer_blocks
|
||||
)
|
||||
)
|
||||
self.num_classes = num_classes
|
||||
self.xchannels = xchannels
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
ConvBNReLU(
|
||||
xchannels[0],
|
||||
xchannels[1],
|
||||
3,
|
||||
1,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
]
|
||||
)
|
||||
last_channel_idx = 1
|
||||
for stage in range(3):
|
||||
for iL in range(layer_blocks):
|
||||
num_conv = block.num_conv
|
||||
iCs = self.xchannels[last_channel_idx : last_channel_idx + num_conv + 1]
|
||||
stride = 2 if stage > 0 and iL == 0 else 1
|
||||
module = block(iCs, stride)
|
||||
last_channel_idx += num_conv
|
||||
self.xchannels[last_channel_idx] = module.out_dim
|
||||
self.layers.append(module)
|
||||
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(
|
||||
stage,
|
||||
iL,
|
||||
layer_blocks,
|
||||
len(self.layers) - 1,
|
||||
iCs,
|
||||
module.out_dim,
|
||||
stride,
|
||||
)
|
||||
|
||||
self.avgpool = nn.AvgPool2d(8)
|
||||
self.classifier = nn.Linear(self.xchannels[-1], num_classes)
|
||||
|
||||
self.apply(initialize_resnet)
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, ResNetBasicblock):
|
||||
nn.init.constant_(m.conv_b.bn.weight, 0)
|
||||
elif isinstance(m, ResNetBottleneck):
|
||||
nn.init.constant_(m.conv_1x4.bn.weight, 0)
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def forward(self, inputs):
|
||||
x = inputs
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
features = self.avgpool(x)
|
||||
features = features.view(features.size(0), -1)
|
||||
logits = self.classifier(features)
|
||||
return features, logits
|
||||
@@ -1,324 +0,0 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from ..initialization import initialize_resnet
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
|
||||
num_conv = 1
|
||||
|
||||
def __init__(
|
||||
self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu
|
||||
):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
if has_avg:
|
||||
self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
|
||||
else:
|
||||
self.avg = None
|
||||
self.conv = nn.Conv2d(
|
||||
nIn,
|
||||
nOut,
|
||||
kernel_size=kernel,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=bias,
|
||||
)
|
||||
if has_bn:
|
||||
self.bn = nn.BatchNorm2d(nOut)
|
||||
else:
|
||||
self.bn = None
|
||||
if has_relu:
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
else:
|
||||
self.relu = None
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.avg:
|
||||
out = self.avg(inputs)
|
||||
else:
|
||||
out = inputs
|
||||
conv = self.conv(out)
|
||||
if self.bn:
|
||||
out = self.bn(conv)
|
||||
else:
|
||||
out = conv
|
||||
if self.relu:
|
||||
out = self.relu(out)
|
||||
else:
|
||||
out = out
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNetBasicblock(nn.Module):
|
||||
num_conv = 2
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, iCs, stride):
|
||||
super(ResNetBasicblock, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
assert isinstance(iCs, tuple) or isinstance(
|
||||
iCs, list
|
||||
), "invalid type of iCs : {:}".format(iCs)
|
||||
assert len(iCs) == 3, "invalid lengths of iCs : {:}".format(iCs)
|
||||
|
||||
self.conv_a = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[1],
|
||||
3,
|
||||
stride,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
self.conv_b = ConvBNReLU(
|
||||
iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False
|
||||
)
|
||||
residual_in = iCs[0]
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[2],
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=True,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
residual_in = iCs[2]
|
||||
elif iCs[0] != iCs[2]:
|
||||
self.downsample = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[2],
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
# self.out_dim = max(residual_in, iCs[2])
|
||||
self.out_dim = iCs[2]
|
||||
|
||||
def forward(self, inputs):
|
||||
basicblock = self.conv_a(inputs)
|
||||
basicblock = self.conv_b(basicblock)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = residual + basicblock
|
||||
return F.relu(out, inplace=True)
|
||||
|
||||
|
||||
class ResNetBottleneck(nn.Module):
|
||||
expansion = 4
|
||||
num_conv = 3
|
||||
|
||||
def __init__(self, iCs, stride):
|
||||
super(ResNetBottleneck, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
assert isinstance(iCs, tuple) or isinstance(
|
||||
iCs, list
|
||||
), "invalid type of iCs : {:}".format(iCs)
|
||||
assert len(iCs) == 4, "invalid lengths of iCs : {:}".format(iCs)
|
||||
self.conv_1x1 = ConvBNReLU(
|
||||
iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True
|
||||
)
|
||||
self.conv_3x3 = ConvBNReLU(
|
||||
iCs[1],
|
||||
iCs[2],
|
||||
3,
|
||||
stride,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
self.conv_1x4 = ConvBNReLU(
|
||||
iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False
|
||||
)
|
||||
residual_in = iCs[0]
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[3],
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=True,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
residual_in = iCs[3]
|
||||
elif iCs[0] != iCs[3]:
|
||||
self.downsample = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[3],
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
residual_in = iCs[3]
|
||||
else:
|
||||
self.downsample = None
|
||||
# self.out_dim = max(residual_in, iCs[3])
|
||||
self.out_dim = iCs[3]
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
bottleneck = self.conv_1x1(inputs)
|
||||
bottleneck = self.conv_3x3(bottleneck)
|
||||
bottleneck = self.conv_1x4(bottleneck)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = residual + bottleneck
|
||||
return F.relu(out, inplace=True)
|
||||
|
||||
|
||||
class InferImagenetResNet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
block_name,
|
||||
layers,
|
||||
xblocks,
|
||||
xchannels,
|
||||
deep_stem,
|
||||
num_classes,
|
||||
zero_init_residual,
|
||||
):
|
||||
super(InferImagenetResNet, self).__init__()
|
||||
|
||||
# Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
|
||||
if block_name == "BasicBlock":
|
||||
block = ResNetBasicblock
|
||||
elif block_name == "Bottleneck":
|
||||
block = ResNetBottleneck
|
||||
else:
|
||||
raise ValueError("invalid block : {:}".format(block_name))
|
||||
assert len(xblocks) == len(
|
||||
layers
|
||||
), "invalid layers : {:} vs xblocks : {:}".format(layers, xblocks)
|
||||
|
||||
self.message = "InferImagenetResNet : Depth : {:} -> {:}, Layers for each block : {:}".format(
|
||||
sum(layers) * block.num_conv, sum(xblocks) * block.num_conv, xblocks
|
||||
)
|
||||
self.num_classes = num_classes
|
||||
self.xchannels = xchannels
|
||||
if not deep_stem:
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
ConvBNReLU(
|
||||
xchannels[0],
|
||||
xchannels[1],
|
||||
7,
|
||||
2,
|
||||
3,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
]
|
||||
)
|
||||
last_channel_idx = 1
|
||||
else:
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
ConvBNReLU(
|
||||
xchannels[0],
|
||||
xchannels[1],
|
||||
3,
|
||||
2,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
),
|
||||
ConvBNReLU(
|
||||
xchannels[1],
|
||||
xchannels[2],
|
||||
3,
|
||||
1,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
last_channel_idx = 2
|
||||
self.layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
|
||||
for stage, layer_blocks in enumerate(layers):
|
||||
for iL in range(layer_blocks):
|
||||
num_conv = block.num_conv
|
||||
iCs = self.xchannels[last_channel_idx : last_channel_idx + num_conv + 1]
|
||||
stride = 2 if stage > 0 and iL == 0 else 1
|
||||
module = block(iCs, stride)
|
||||
last_channel_idx += num_conv
|
||||
self.xchannels[last_channel_idx] = module.out_dim
|
||||
self.layers.append(module)
|
||||
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(
|
||||
stage,
|
||||
iL,
|
||||
layer_blocks,
|
||||
len(self.layers) - 1,
|
||||
iCs,
|
||||
module.out_dim,
|
||||
stride,
|
||||
)
|
||||
if iL + 1 == xblocks[stage]: # reach the maximum depth
|
||||
out_channel = module.out_dim
|
||||
for iiL in range(iL + 1, layer_blocks):
|
||||
last_channel_idx += num_conv
|
||||
self.xchannels[last_channel_idx] = module.out_dim
|
||||
break
|
||||
assert last_channel_idx + 1 == len(self.xchannels), "{:} vs {:}".format(
|
||||
last_channel_idx, len(self.xchannels)
|
||||
)
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.classifier = nn.Linear(self.xchannels[-1], num_classes)
|
||||
|
||||
self.apply(initialize_resnet)
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, ResNetBasicblock):
|
||||
nn.init.constant_(m.conv_b.bn.weight, 0)
|
||||
elif isinstance(m, ResNetBottleneck):
|
||||
nn.init.constant_(m.conv_1x4.bn.weight, 0)
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def forward(self, inputs):
|
||||
x = inputs
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
features = self.avgpool(x)
|
||||
features = features.view(features.size(0), -1)
|
||||
logits = self.classifier(features)
|
||||
return features, logits
|
||||
@@ -1,174 +0,0 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
# MobileNetV2: Inverted Residuals and Linear Bottlenecks, CVPR 2018
|
||||
from torch import nn
|
||||
from ..initialization import initialize_resnet
|
||||
from ..SharedUtils import parse_channel_info
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size,
|
||||
stride,
|
||||
groups,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
padding = (kernel_size - 1) // 2
|
||||
self.conv = nn.Conv2d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
groups=groups,
|
||||
bias=False,
|
||||
)
|
||||
if has_bn:
|
||||
self.bn = nn.BatchNorm2d(out_planes)
|
||||
else:
|
||||
self.bn = None
|
||||
if has_relu:
|
||||
self.relu = nn.ReLU6(inplace=True)
|
||||
else:
|
||||
self.relu = None
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv(x)
|
||||
if self.bn:
|
||||
out = self.bn(out)
|
||||
if self.relu:
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
def __init__(self, channels, stride, expand_ratio, additive):
|
||||
super(InvertedResidual, self).__init__()
|
||||
self.stride = stride
|
||||
assert stride in [1, 2], "invalid stride : {:}".format(stride)
|
||||
assert len(channels) in [2, 3], "invalid channels : {:}".format(channels)
|
||||
|
||||
if len(channels) == 2:
|
||||
layers = []
|
||||
else:
|
||||
layers = [ConvBNReLU(channels[0], channels[1], 1, 1, 1)]
|
||||
layers.extend(
|
||||
[
|
||||
# dw
|
||||
ConvBNReLU(channels[-2], channels[-2], 3, stride, channels[-2]),
|
||||
# pw-linear
|
||||
ConvBNReLU(channels[-2], channels[-1], 1, 1, 1, True, False),
|
||||
]
|
||||
)
|
||||
self.conv = nn.Sequential(*layers)
|
||||
self.additive = additive
|
||||
if self.additive and channels[0] != channels[-1]:
|
||||
self.shortcut = ConvBNReLU(channels[0], channels[-1], 1, 1, 1, True, False)
|
||||
else:
|
||||
self.shortcut = None
|
||||
self.out_dim = channels[-1]
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv(x)
|
||||
# if self.additive: return additive_func(out, x)
|
||||
if self.shortcut:
|
||||
return out + self.shortcut(x)
|
||||
else:
|
||||
return out
|
||||
|
||||
|
||||
class InferMobileNetV2(nn.Module):
|
||||
def __init__(self, num_classes, xchannels, xblocks, dropout):
|
||||
super(InferMobileNetV2, self).__init__()
|
||||
block = InvertedResidual
|
||||
inverted_residual_setting = [
|
||||
# t, c, n, s
|
||||
[1, 16, 1, 1],
|
||||
[6, 24, 2, 2],
|
||||
[6, 32, 3, 2],
|
||||
[6, 64, 4, 2],
|
||||
[6, 96, 3, 1],
|
||||
[6, 160, 3, 2],
|
||||
[6, 320, 1, 1],
|
||||
]
|
||||
assert len(inverted_residual_setting) == len(
|
||||
xblocks
|
||||
), "invalid number of layers : {:} vs {:}".format(
|
||||
len(inverted_residual_setting), len(xblocks)
|
||||
)
|
||||
for block_num, ir_setting in zip(xblocks, inverted_residual_setting):
|
||||
assert block_num <= ir_setting[2], "{:} vs {:}".format(
|
||||
block_num, ir_setting
|
||||
)
|
||||
xchannels = parse_channel_info(xchannels)
|
||||
# for i, chs in enumerate(xchannels):
|
||||
# if i > 0: assert chs[0] == xchannels[i-1][-1], 'Layer[{:}] is invalid {:} vs {:}'.format(i, xchannels[i-1], chs)
|
||||
self.xchannels = xchannels
|
||||
self.message = "InferMobileNetV2 : xblocks={:}".format(xblocks)
|
||||
# building first layer
|
||||
features = [ConvBNReLU(xchannels[0][0], xchannels[0][1], 3, 2, 1)]
|
||||
last_channel_idx = 1
|
||||
|
||||
# building inverted residual blocks
|
||||
for stage, (t, c, n, s) in enumerate(inverted_residual_setting):
|
||||
for i in range(n):
|
||||
stride = s if i == 0 else 1
|
||||
additv = True if i > 0 else False
|
||||
module = block(self.xchannels[last_channel_idx], stride, t, additv)
|
||||
features.append(module)
|
||||
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, Cs={:}, stride={:}, expand={:}, original-C={:}".format(
|
||||
stage,
|
||||
i,
|
||||
n,
|
||||
len(features),
|
||||
self.xchannels[last_channel_idx],
|
||||
stride,
|
||||
t,
|
||||
c,
|
||||
)
|
||||
last_channel_idx += 1
|
||||
if i + 1 == xblocks[stage]:
|
||||
out_channel = module.out_dim
|
||||
for iiL in range(i + 1, n):
|
||||
last_channel_idx += 1
|
||||
self.xchannels[last_channel_idx][0] = module.out_dim
|
||||
break
|
||||
# building last several layers
|
||||
features.append(
|
||||
ConvBNReLU(
|
||||
self.xchannels[last_channel_idx][0],
|
||||
self.xchannels[last_channel_idx][1],
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
)
|
||||
)
|
||||
assert last_channel_idx + 2 == len(self.xchannels), "{:} vs {:}".format(
|
||||
last_channel_idx, len(self.xchannels)
|
||||
)
|
||||
# make it nn.Sequential
|
||||
self.features = nn.Sequential(*features)
|
||||
|
||||
# building classifier
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(self.xchannels[last_channel_idx][1], num_classes),
|
||||
)
|
||||
|
||||
# weight initialization
|
||||
self.apply(initialize_resnet)
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def forward(self, inputs):
|
||||
features = self.features(inputs)
|
||||
vectors = features.mean([2, 3])
|
||||
predicts = self.classifier(vectors)
|
||||
return features, predicts
|
||||
@@ -1,64 +0,0 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
from typing import List, Text, Any
|
||||
import torch.nn as nn
|
||||
from models.cell_operations import ResNetBasicblock
|
||||
from models.cell_infers.cells import InferCell
|
||||
|
||||
|
||||
class DynamicShapeTinyNet(nn.Module):
|
||||
def __init__(self, channels: List[int], genotype: Any, num_classes: int):
|
||||
super(DynamicShapeTinyNet, self).__init__()
|
||||
self._channels = channels
|
||||
if len(channels) % 3 != 2:
|
||||
raise ValueError("invalid number of layers : {:}".format(len(channels)))
|
||||
self._num_stage = N = len(channels) // 3
|
||||
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, channels[0], kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(channels[0]),
|
||||
)
|
||||
|
||||
# layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
|
||||
c_prev = channels[0]
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (c_curr, reduction) in enumerate(zip(channels, layer_reductions)):
|
||||
if reduction:
|
||||
cell = ResNetBasicblock(c_prev, c_curr, 2, True)
|
||||
else:
|
||||
cell = InferCell(genotype, c_prev, c_curr, 1)
|
||||
self.cells.append(cell)
|
||||
c_prev = cell.out_dim
|
||||
self._num_layer = len(self.cells)
|
||||
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(c_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(c_prev, num_classes)
|
||||
|
||||
def get_message(self) -> Text:
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += "\n {:02d}/{:02d} :: {:}".format(
|
||||
i, len(self.cells), cell.extra_repr()
|
||||
)
|
||||
return string
|
||||
|
||||
def extra_repr(self):
|
||||
return "{name}(C={_channels}, N={_num_stage}, L={_num_layer})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
feature = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
feature = cell(feature)
|
||||
|
||||
out = self.lastact(feature)
|
||||
out = self.global_pooling(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
return out, logits
|
||||
@@ -1,9 +0,0 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
from .InferCifarResNet_width import InferWidthCifarResNet
|
||||
from .InferImagenetResNet import InferImagenetResNet
|
||||
from .InferCifarResNet_depth import InferDepthCifarResNet
|
||||
from .InferCifarResNet import InferCifarResNet
|
||||
from .InferMobileNetV2 import InferMobileNetV2
|
||||
from .InferTinyCellNet import DynamicShapeTinyNet
|
||||
@@ -1,5 +0,0 @@
|
||||
def parse_channel_info(xstring):
|
||||
blocks = xstring.split(" ")
|
||||
blocks = [x.split("-") for x in blocks]
|
||||
blocks = [[int(_) for _ in x] for x in blocks]
|
||||
return blocks
|
||||
@@ -1,760 +0,0 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import math, torch
|
||||
from collections import OrderedDict
|
||||
from bisect import bisect_right
|
||||
import torch.nn as nn
|
||||
from ..initialization import initialize_resnet
|
||||
from ..SharedUtils import additive_func
|
||||
from .SoftSelect import select2withP, ChannelWiseInter
|
||||
from .SoftSelect import linear_forward
|
||||
from .SoftSelect import get_width_choices
|
||||
|
||||
|
||||
def get_depth_choices(nDepth, return_num):
|
||||
if nDepth == 2:
|
||||
choices = (1, 2)
|
||||
elif nDepth == 3:
|
||||
choices = (1, 2, 3)
|
||||
elif nDepth > 3:
|
||||
choices = list(range(1, nDepth + 1, 2))
|
||||
if choices[-1] < nDepth:
|
||||
choices.append(nDepth)
|
||||
else:
|
||||
raise ValueError("invalid nDepth : {:}".format(nDepth))
|
||||
if return_num:
|
||||
return len(choices)
|
||||
else:
|
||||
return choices
|
||||
|
||||
|
||||
def conv_forward(inputs, conv, choices):
|
||||
iC = conv.in_channels
|
||||
fill_size = list(inputs.size())
|
||||
fill_size[1] = iC - fill_size[1]
|
||||
filled = torch.zeros(fill_size, device=inputs.device)
|
||||
xinputs = torch.cat((inputs, filled), dim=1)
|
||||
outputs = conv(xinputs)
|
||||
selecteds = [outputs[:, :oC] for oC in choices]
|
||||
return selecteds
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
num_conv = 1
|
||||
|
||||
def __init__(
|
||||
self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu
|
||||
):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
self.InShape = None
|
||||
self.OutShape = None
|
||||
self.choices = get_width_choices(nOut)
|
||||
self.register_buffer("choices_tensor", torch.Tensor(self.choices))
|
||||
|
||||
if has_avg:
|
||||
self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
|
||||
else:
|
||||
self.avg = None
|
||||
self.conv = nn.Conv2d(
|
||||
nIn,
|
||||
nOut,
|
||||
kernel_size=kernel,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=bias,
|
||||
)
|
||||
# if has_bn : self.bn = nn.BatchNorm2d(nOut)
|
||||
# else : self.bn = None
|
||||
self.has_bn = has_bn
|
||||
self.BNs = nn.ModuleList()
|
||||
for i, _out in enumerate(self.choices):
|
||||
self.BNs.append(nn.BatchNorm2d(_out))
|
||||
if has_relu:
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
else:
|
||||
self.relu = None
|
||||
self.in_dim = nIn
|
||||
self.out_dim = nOut
|
||||
self.search_mode = "basic"
|
||||
|
||||
def get_flops(self, channels, check_range=True, divide=1):
|
||||
iC, oC = channels
|
||||
if check_range:
|
||||
assert (
|
||||
iC <= self.conv.in_channels and oC <= self.conv.out_channels
|
||||
), "{:} vs {:} | {:} vs {:}".format(
|
||||
iC, self.conv.in_channels, oC, self.conv.out_channels
|
||||
)
|
||||
assert (
|
||||
isinstance(self.InShape, tuple) and len(self.InShape) == 2
|
||||
), "invalid in-shape : {:}".format(self.InShape)
|
||||
assert (
|
||||
isinstance(self.OutShape, tuple) and len(self.OutShape) == 2
|
||||
), "invalid out-shape : {:}".format(self.OutShape)
|
||||
# conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups
|
||||
conv_per_position_flops = (
|
||||
self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups
|
||||
)
|
||||
all_positions = self.OutShape[0] * self.OutShape[1]
|
||||
flops = (conv_per_position_flops * all_positions / divide) * iC * oC
|
||||
if self.conv.bias is not None:
|
||||
flops += all_positions / divide
|
||||
return flops
|
||||
|
||||
def get_range(self):
|
||||
return [self.choices]
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.search_mode == "basic":
|
||||
return self.basic_forward(inputs)
|
||||
elif self.search_mode == "search":
|
||||
return self.search_forward(inputs)
|
||||
else:
|
||||
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
|
||||
|
||||
def search_forward(self, tuple_inputs):
|
||||
assert (
|
||||
isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5
|
||||
), "invalid type input : {:}".format(type(tuple_inputs))
|
||||
inputs, expected_inC, probability, index, prob = tuple_inputs
|
||||
index, prob = torch.squeeze(index).tolist(), torch.squeeze(prob)
|
||||
probability = torch.squeeze(probability)
|
||||
assert len(index) == 2, "invalid length : {:}".format(index)
|
||||
# compute expected flop
|
||||
# coordinates = torch.arange(self.x_range[0], self.x_range[1]+1).type_as(probability)
|
||||
expected_outC = (self.choices_tensor * probability).sum()
|
||||
expected_flop = self.get_flops([expected_inC, expected_outC], False, 1e6)
|
||||
if self.avg:
|
||||
out = self.avg(inputs)
|
||||
else:
|
||||
out = inputs
|
||||
# convolutional layer
|
||||
out_convs = conv_forward(out, self.conv, [self.choices[i] for i in index])
|
||||
out_bns = [self.BNs[idx](out_conv) for idx, out_conv in zip(index, out_convs)]
|
||||
# merge
|
||||
out_channel = max([x.size(1) for x in out_bns])
|
||||
outA = ChannelWiseInter(out_bns[0], out_channel)
|
||||
outB = ChannelWiseInter(out_bns[1], out_channel)
|
||||
out = outA * prob[0] + outB * prob[1]
|
||||
# out = additive_func(out_bns[0]*prob[0], out_bns[1]*prob[1])
|
||||
|
||||
if self.relu:
|
||||
out = self.relu(out)
|
||||
else:
|
||||
out = out
|
||||
return out, expected_outC, expected_flop
|
||||
|
||||
def basic_forward(self, inputs):
|
||||
if self.avg:
|
||||
out = self.avg(inputs)
|
||||
else:
|
||||
out = inputs
|
||||
conv = self.conv(out)
|
||||
if self.has_bn:
|
||||
out = self.BNs[-1](conv)
|
||||
else:
|
||||
out = conv
|
||||
if self.relu:
|
||||
out = self.relu(out)
|
||||
else:
|
||||
out = out
|
||||
if self.InShape is None:
|
||||
self.InShape = (inputs.size(-2), inputs.size(-1))
|
||||
self.OutShape = (out.size(-2), out.size(-1))
|
||||
return out
|
||||
|
||||
|
||||
class ResNetBasicblock(nn.Module):
|
||||
expansion = 1
|
||||
num_conv = 2
|
||||
|
||||
def __init__(self, inplanes, planes, stride):
|
||||
super(ResNetBasicblock, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
self.conv_a = ConvBNReLU(
|
||||
inplanes,
|
||||
planes,
|
||||
3,
|
||||
stride,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
self.conv_b = ConvBNReLU(
|
||||
planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False
|
||||
)
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=True,
|
||||
has_bn=False,
|
||||
has_relu=False,
|
||||
)
|
||||
elif inplanes != planes:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.out_dim = planes
|
||||
self.search_mode = "basic"
|
||||
|
||||
def get_range(self):
|
||||
return self.conv_a.get_range() + self.conv_b.get_range()
|
||||
|
||||
def get_flops(self, channels):
|
||||
assert len(channels) == 3, "invalid channels : {:}".format(channels)
|
||||
flop_A = self.conv_a.get_flops([channels[0], channels[1]])
|
||||
flop_B = self.conv_b.get_flops([channels[1], channels[2]])
|
||||
if hasattr(self.downsample, "get_flops"):
|
||||
flop_C = self.downsample.get_flops([channels[0], channels[-1]])
|
||||
else:
|
||||
flop_C = 0
|
||||
if (
|
||||
channels[0] != channels[-1] and self.downsample is None
|
||||
): # this short-cut will be added during the infer-train
|
||||
flop_C = (
|
||||
channels[0]
|
||||
* channels[-1]
|
||||
* self.conv_b.OutShape[0]
|
||||
* self.conv_b.OutShape[1]
|
||||
)
|
||||
return flop_A + flop_B + flop_C
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.search_mode == "basic":
|
||||
return self.basic_forward(inputs)
|
||||
elif self.search_mode == "search":
|
||||
return self.search_forward(inputs)
|
||||
else:
|
||||
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
|
||||
|
||||
def search_forward(self, tuple_inputs):
|
||||
assert (
|
||||
isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5
|
||||
), "invalid type input : {:}".format(type(tuple_inputs))
|
||||
inputs, expected_inC, probability, indexes, probs = tuple_inputs
|
||||
assert indexes.size(0) == 2 and probs.size(0) == 2 and probability.size(0) == 2
|
||||
out_a, expected_inC_a, expected_flop_a = self.conv_a(
|
||||
(inputs, expected_inC, probability[0], indexes[0], probs[0])
|
||||
)
|
||||
out_b, expected_inC_b, expected_flop_b = self.conv_b(
|
||||
(out_a, expected_inC_a, probability[1], indexes[1], probs[1])
|
||||
)
|
||||
if self.downsample is not None:
|
||||
residual, _, expected_flop_c = self.downsample(
|
||||
(inputs, expected_inC, probability[1], indexes[1], probs[1])
|
||||
)
|
||||
else:
|
||||
residual, expected_flop_c = inputs, 0
|
||||
out = additive_func(residual, out_b)
|
||||
return (
|
||||
nn.functional.relu(out, inplace=True),
|
||||
expected_inC_b,
|
||||
sum([expected_flop_a, expected_flop_b, expected_flop_c]),
|
||||
)
|
||||
|
||||
def basic_forward(self, inputs):
|
||||
basicblock = self.conv_a(inputs)
|
||||
basicblock = self.conv_b(basicblock)
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = additive_func(residual, basicblock)
|
||||
return nn.functional.relu(out, inplace=True)
|
||||
|
||||
|
||||
class ResNetBottleneck(nn.Module):
|
||||
expansion = 4
|
||||
num_conv = 3
|
||||
|
||||
def __init__(self, inplanes, planes, stride):
|
||||
super(ResNetBottleneck, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
self.conv_1x1 = ConvBNReLU(
|
||||
inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True
|
||||
)
|
||||
self.conv_3x3 = ConvBNReLU(
|
||||
planes,
|
||||
planes,
|
||||
3,
|
||||
stride,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
self.conv_1x4 = ConvBNReLU(
|
||||
planes,
|
||||
planes * self.expansion,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes * self.expansion,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=True,
|
||||
has_bn=False,
|
||||
has_relu=False,
|
||||
)
|
||||
elif inplanes != planes * self.expansion:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes * self.expansion,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.out_dim = planes * self.expansion
|
||||
self.search_mode = "basic"
|
||||
|
||||
def get_range(self):
|
||||
return (
|
||||
self.conv_1x1.get_range()
|
||||
+ self.conv_3x3.get_range()
|
||||
+ self.conv_1x4.get_range()
|
||||
)
|
||||
|
||||
def get_flops(self, channels):
|
||||
assert len(channels) == 4, "invalid channels : {:}".format(channels)
|
||||
flop_A = self.conv_1x1.get_flops([channels[0], channels[1]])
|
||||
flop_B = self.conv_3x3.get_flops([channels[1], channels[2]])
|
||||
flop_C = self.conv_1x4.get_flops([channels[2], channels[3]])
|
||||
if hasattr(self.downsample, "get_flops"):
|
||||
flop_D = self.downsample.get_flops([channels[0], channels[-1]])
|
||||
else:
|
||||
flop_D = 0
|
||||
if (
|
||||
channels[0] != channels[-1] and self.downsample is None
|
||||
): # this short-cut will be added during the infer-train
|
||||
flop_D = (
|
||||
channels[0]
|
||||
* channels[-1]
|
||||
* self.conv_1x4.OutShape[0]
|
||||
* self.conv_1x4.OutShape[1]
|
||||
)
|
||||
return flop_A + flop_B + flop_C + flop_D
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.search_mode == "basic":
|
||||
return self.basic_forward(inputs)
|
||||
elif self.search_mode == "search":
|
||||
return self.search_forward(inputs)
|
||||
else:
|
||||
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
|
||||
|
||||
def basic_forward(self, inputs):
|
||||
bottleneck = self.conv_1x1(inputs)
|
||||
bottleneck = self.conv_3x3(bottleneck)
|
||||
bottleneck = self.conv_1x4(bottleneck)
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = additive_func(residual, bottleneck)
|
||||
return nn.functional.relu(out, inplace=True)
|
||||
|
||||
def search_forward(self, tuple_inputs):
|
||||
assert (
|
||||
isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5
|
||||
), "invalid type input : {:}".format(type(tuple_inputs))
|
||||
inputs, expected_inC, probability, indexes, probs = tuple_inputs
|
||||
assert indexes.size(0) == 3 and probs.size(0) == 3 and probability.size(0) == 3
|
||||
out_1x1, expected_inC_1x1, expected_flop_1x1 = self.conv_1x1(
|
||||
(inputs, expected_inC, probability[0], indexes[0], probs[0])
|
||||
)
|
||||
out_3x3, expected_inC_3x3, expected_flop_3x3 = self.conv_3x3(
|
||||
(out_1x1, expected_inC_1x1, probability[1], indexes[1], probs[1])
|
||||
)
|
||||
out_1x4, expected_inC_1x4, expected_flop_1x4 = self.conv_1x4(
|
||||
(out_3x3, expected_inC_3x3, probability[2], indexes[2], probs[2])
|
||||
)
|
||||
if self.downsample is not None:
|
||||
residual, _, expected_flop_c = self.downsample(
|
||||
(inputs, expected_inC, probability[2], indexes[2], probs[2])
|
||||
)
|
||||
else:
|
||||
residual, expected_flop_c = inputs, 0
|
||||
out = additive_func(residual, out_1x4)
|
||||
return (
|
||||
nn.functional.relu(out, inplace=True),
|
||||
expected_inC_1x4,
|
||||
sum(
|
||||
[
|
||||
expected_flop_1x1,
|
||||
expected_flop_3x3,
|
||||
expected_flop_1x4,
|
||||
expected_flop_c,
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class SearchShapeCifarResNet(nn.Module):
|
||||
def __init__(self, block_name, depth, num_classes):
|
||||
super(SearchShapeCifarResNet, self).__init__()
|
||||
|
||||
# Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
|
||||
if block_name == "ResNetBasicblock":
|
||||
block = ResNetBasicblock
|
||||
assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110"
|
||||
layer_blocks = (depth - 2) // 6
|
||||
elif block_name == "ResNetBottleneck":
|
||||
block = ResNetBottleneck
|
||||
assert (depth - 2) % 9 == 0, "depth should be one of 164"
|
||||
layer_blocks = (depth - 2) // 9
|
||||
else:
|
||||
raise ValueError("invalid block : {:}".format(block_name))
|
||||
|
||||
self.message = (
|
||||
"SearchShapeCifarResNet : Depth : {:} , Layers for each block : {:}".format(
|
||||
depth, layer_blocks
|
||||
)
|
||||
)
|
||||
self.num_classes = num_classes
|
||||
self.channels = [16]
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
ConvBNReLU(
|
||||
3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True
|
||||
)
|
||||
]
|
||||
)
|
||||
self.InShape = None
|
||||
self.depth_info = OrderedDict()
|
||||
self.depth_at_i = OrderedDict()
|
||||
for stage in range(3):
|
||||
cur_block_choices = get_depth_choices(layer_blocks, False)
|
||||
assert (
|
||||
cur_block_choices[-1] == layer_blocks
|
||||
), "stage={:}, {:} vs {:}".format(stage, cur_block_choices, layer_blocks)
|
||||
self.message += (
|
||||
"\nstage={:} ::: depth-block-choices={:} for {:} blocks.".format(
|
||||
stage, cur_block_choices, layer_blocks
|
||||
)
|
||||
)
|
||||
block_choices, xstart = [], len(self.layers)
|
||||
for iL in range(layer_blocks):
|
||||
iC = self.channels[-1]
|
||||
planes = 16 * (2 ** stage)
|
||||
stride = 2 if stage > 0 and iL == 0 else 1
|
||||
module = block(iC, planes, stride)
|
||||
self.channels.append(module.out_dim)
|
||||
self.layers.append(module)
|
||||
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(
|
||||
stage,
|
||||
iL,
|
||||
layer_blocks,
|
||||
len(self.layers) - 1,
|
||||
iC,
|
||||
module.out_dim,
|
||||
stride,
|
||||
)
|
||||
# added for depth
|
||||
layer_index = len(self.layers) - 1
|
||||
if iL + 1 in cur_block_choices:
|
||||
block_choices.append(layer_index)
|
||||
if iL + 1 == layer_blocks:
|
||||
self.depth_info[layer_index] = {
|
||||
"choices": block_choices,
|
||||
"stage": stage,
|
||||
"xstart": xstart,
|
||||
}
|
||||
self.depth_info_list = []
|
||||
for xend, info in self.depth_info.items():
|
||||
self.depth_info_list.append((xend, info))
|
||||
xstart, xstage = info["xstart"], info["stage"]
|
||||
for ilayer in range(xstart, xend + 1):
|
||||
idx = bisect_right(info["choices"], ilayer - 1)
|
||||
self.depth_at_i[ilayer] = (xstage, idx)
|
||||
|
||||
self.avgpool = nn.AvgPool2d(8)
|
||||
self.classifier = nn.Linear(module.out_dim, num_classes)
|
||||
self.InShape = None
|
||||
self.tau = -1
|
||||
self.search_mode = "basic"
|
||||
# assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth)
|
||||
|
||||
# parameters for width
|
||||
self.Ranges = []
|
||||
self.layer2indexRange = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
start_index = len(self.Ranges)
|
||||
self.Ranges += layer.get_range()
|
||||
self.layer2indexRange.append((start_index, len(self.Ranges)))
|
||||
assert len(self.Ranges) + 1 == depth, "invalid depth check {:} vs {:}".format(
|
||||
len(self.Ranges) + 1, depth
|
||||
)
|
||||
|
||||
self.register_parameter(
|
||||
"width_attentions",
|
||||
nn.Parameter(torch.Tensor(len(self.Ranges), get_width_choices(None))),
|
||||
)
|
||||
self.register_parameter(
|
||||
"depth_attentions",
|
||||
nn.Parameter(torch.Tensor(3, get_depth_choices(layer_blocks, True))),
|
||||
)
|
||||
nn.init.normal_(self.width_attentions, 0, 0.01)
|
||||
nn.init.normal_(self.depth_attentions, 0, 0.01)
|
||||
self.apply(initialize_resnet)
|
||||
|
||||
def arch_parameters(self, LR=None):
|
||||
if LR is None:
|
||||
return [self.width_attentions, self.depth_attentions]
|
||||
else:
|
||||
return [
|
||||
{"params": self.width_attentions, "lr": LR},
|
||||
{"params": self.depth_attentions, "lr": LR},
|
||||
]
|
||||
|
||||
def base_parameters(self):
|
||||
return (
|
||||
list(self.layers.parameters())
|
||||
+ list(self.avgpool.parameters())
|
||||
+ list(self.classifier.parameters())
|
||||
)
|
||||
|
||||
def get_flop(self, mode, config_dict, extra_info):
|
||||
if config_dict is not None:
|
||||
config_dict = config_dict.copy()
|
||||
# select channels
|
||||
channels = [3]
|
||||
for i, weight in enumerate(self.width_attentions):
|
||||
if mode == "genotype":
|
||||
with torch.no_grad():
|
||||
probe = nn.functional.softmax(weight, dim=0)
|
||||
C = self.Ranges[i][torch.argmax(probe).item()]
|
||||
elif mode == "max":
|
||||
C = self.Ranges[i][-1]
|
||||
elif mode == "fix":
|
||||
C = int(math.sqrt(extra_info) * self.Ranges[i][-1])
|
||||
elif mode == "random":
|
||||
assert isinstance(extra_info, float), "invalid extra_info : {:}".format(
|
||||
extra_info
|
||||
)
|
||||
with torch.no_grad():
|
||||
prob = nn.functional.softmax(weight, dim=0)
|
||||
approximate_C = int(math.sqrt(extra_info) * self.Ranges[i][-1])
|
||||
for j in range(prob.size(0)):
|
||||
prob[j] = 1 / (
|
||||
abs(j - (approximate_C - self.Ranges[i][j])) + 0.2
|
||||
)
|
||||
C = self.Ranges[i][torch.multinomial(prob, 1, False).item()]
|
||||
else:
|
||||
raise ValueError("invalid mode : {:}".format(mode))
|
||||
channels.append(C)
|
||||
# select depth
|
||||
if mode == "genotype":
|
||||
with torch.no_grad():
|
||||
depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
|
||||
choices = torch.argmax(depth_probs, dim=1).cpu().tolist()
|
||||
elif mode == "max" or mode == "fix":
|
||||
choices = [depth_probs.size(1) - 1 for _ in range(depth_probs.size(0))]
|
||||
elif mode == "random":
|
||||
with torch.no_grad():
|
||||
depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
|
||||
choices = torch.multinomial(depth_probs, 1, False).cpu().tolist()
|
||||
else:
|
||||
raise ValueError("invalid mode : {:}".format(mode))
|
||||
selected_layers = []
|
||||
for choice, xvalue in zip(choices, self.depth_info_list):
|
||||
xtemp = xvalue[1]["choices"][choice] - xvalue[1]["xstart"] + 1
|
||||
selected_layers.append(xtemp)
|
||||
flop = 0
|
||||
for i, layer in enumerate(self.layers):
|
||||
s, e = self.layer2indexRange[i]
|
||||
xchl = tuple(channels[s : e + 1])
|
||||
if i in self.depth_at_i:
|
||||
xstagei, xatti = self.depth_at_i[i]
|
||||
if xatti <= choices[xstagei]: # leave this depth
|
||||
flop += layer.get_flops(xchl)
|
||||
else:
|
||||
flop += 0 # do not use this layer
|
||||
else:
|
||||
flop += layer.get_flops(xchl)
|
||||
# the last fc layer
|
||||
flop += channels[-1] * self.classifier.out_features
|
||||
if config_dict is None:
|
||||
return flop / 1e6
|
||||
else:
|
||||
config_dict["xchannels"] = channels
|
||||
config_dict["xblocks"] = selected_layers
|
||||
config_dict["super_type"] = "infer-shape"
|
||||
config_dict["estimated_FLOP"] = flop / 1e6
|
||||
return flop / 1e6, config_dict
|
||||
|
||||
def get_arch_info(self):
|
||||
string = (
|
||||
"for depth and width, there are {:} + {:} attention probabilities.".format(
|
||||
len(self.depth_attentions), len(self.width_attentions)
|
||||
)
|
||||
)
|
||||
string += "\n{:}".format(self.depth_info)
|
||||
discrepancy = []
|
||||
with torch.no_grad():
|
||||
for i, att in enumerate(self.depth_attentions):
|
||||
prob = nn.functional.softmax(att, dim=0)
|
||||
prob = prob.cpu()
|
||||
selc = prob.argmax().item()
|
||||
prob = prob.tolist()
|
||||
prob = ["{:.3f}".format(x) for x in prob]
|
||||
xstring = "{:03d}/{:03d}-th : {:}".format(
|
||||
i, len(self.depth_attentions), " ".join(prob)
|
||||
)
|
||||
logt = ["{:.4f}".format(x) for x in att.cpu().tolist()]
|
||||
xstring += " || {:17s}".format(" ".join(logt))
|
||||
prob = sorted([float(x) for x in prob])
|
||||
disc = prob[-1] - prob[-2]
|
||||
xstring += " || discrepancy={:.2f} || select={:}/{:}".format(
|
||||
disc, selc, len(prob)
|
||||
)
|
||||
discrepancy.append(disc)
|
||||
string += "\n{:}".format(xstring)
|
||||
string += "\n-----------------------------------------------"
|
||||
for i, att in enumerate(self.width_attentions):
|
||||
prob = nn.functional.softmax(att, dim=0)
|
||||
prob = prob.cpu()
|
||||
selc = prob.argmax().item()
|
||||
prob = prob.tolist()
|
||||
prob = ["{:.3f}".format(x) for x in prob]
|
||||
xstring = "{:03d}/{:03d}-th : {:}".format(
|
||||
i, len(self.width_attentions), " ".join(prob)
|
||||
)
|
||||
logt = ["{:.3f}".format(x) for x in att.cpu().tolist()]
|
||||
xstring += " || {:52s}".format(" ".join(logt))
|
||||
prob = sorted([float(x) for x in prob])
|
||||
disc = prob[-1] - prob[-2]
|
||||
xstring += " || dis={:.2f} || select={:}/{:}".format(
|
||||
disc, selc, len(prob)
|
||||
)
|
||||
discrepancy.append(disc)
|
||||
string += "\n{:}".format(xstring)
|
||||
return string, discrepancy
|
||||
|
||||
def set_tau(self, tau_max, tau_min, epoch_ratio):
|
||||
assert (
|
||||
epoch_ratio >= 0 and epoch_ratio <= 1
|
||||
), "invalid epoch-ratio : {:}".format(epoch_ratio)
|
||||
tau = tau_min + (tau_max - tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2
|
||||
self.tau = tau
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.search_mode == "basic":
|
||||
return self.basic_forward(inputs)
|
||||
elif self.search_mode == "search":
|
||||
return self.search_forward(inputs)
|
||||
else:
|
||||
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
|
||||
|
||||
def search_forward(self, inputs):
|
||||
flop_width_probs = nn.functional.softmax(self.width_attentions, dim=1)
|
||||
flop_depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
|
||||
flop_depth_probs = torch.flip(
|
||||
torch.cumsum(torch.flip(flop_depth_probs, [1]), 1), [1]
|
||||
)
|
||||
selected_widths, selected_width_probs = select2withP(
|
||||
self.width_attentions, self.tau
|
||||
)
|
||||
selected_depth_probs = select2withP(self.depth_attentions, self.tau, True)
|
||||
with torch.no_grad():
|
||||
selected_widths = selected_widths.cpu()
|
||||
|
||||
x, last_channel_idx, expected_inC, flops = inputs, 0, 3, []
|
||||
feature_maps = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
selected_w_index = selected_widths[
|
||||
last_channel_idx : last_channel_idx + layer.num_conv
|
||||
]
|
||||
selected_w_probs = selected_width_probs[
|
||||
last_channel_idx : last_channel_idx + layer.num_conv
|
||||
]
|
||||
layer_prob = flop_width_probs[
|
||||
last_channel_idx : last_channel_idx + layer.num_conv
|
||||
]
|
||||
x, expected_inC, expected_flop = layer(
|
||||
(x, expected_inC, layer_prob, selected_w_index, selected_w_probs)
|
||||
)
|
||||
feature_maps.append(x)
|
||||
last_channel_idx += layer.num_conv
|
||||
if i in self.depth_info: # aggregate the information
|
||||
choices = self.depth_info[i]["choices"]
|
||||
xstagei = self.depth_info[i]["stage"]
|
||||
# print ('iL={:}, choices={:}, stage={:}, probs={:}'.format(i, choices, xstagei, selected_depth_probs[xstagei].cpu().tolist()))
|
||||
# for A, W in zip(choices, selected_depth_probs[xstagei]):
|
||||
# print('Size = {:}, W = {:}'.format(feature_maps[A].size(), W))
|
||||
possible_tensors = []
|
||||
max_C = max(feature_maps[A].size(1) for A in choices)
|
||||
for tempi, A in enumerate(choices):
|
||||
xtensor = ChannelWiseInter(feature_maps[A], max_C)
|
||||
# drop_ratio = 1-(tempi+1.0)/len(choices)
|
||||
# xtensor = drop_path(xtensor, drop_ratio)
|
||||
possible_tensors.append(xtensor)
|
||||
weighted_sum = sum(
|
||||
xtensor * W
|
||||
for xtensor, W in zip(
|
||||
possible_tensors, selected_depth_probs[xstagei]
|
||||
)
|
||||
)
|
||||
x = weighted_sum
|
||||
|
||||
if i in self.depth_at_i:
|
||||
xstagei, xatti = self.depth_at_i[i]
|
||||
x_expected_flop = flop_depth_probs[xstagei, xatti] * expected_flop
|
||||
else:
|
||||
x_expected_flop = expected_flop
|
||||
flops.append(x_expected_flop)
|
||||
flops.append(expected_inC * (self.classifier.out_features * 1.0 / 1e6))
|
||||
features = self.avgpool(x)
|
||||
features = features.view(features.size(0), -1)
|
||||
logits = linear_forward(features, self.classifier)
|
||||
return logits, torch.stack([sum(flops)])
|
||||
|
||||
def basic_forward(self, inputs):
|
||||
if self.InShape is None:
|
||||
self.InShape = (inputs.size(-2), inputs.size(-1))
|
||||
x = inputs
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
features = self.avgpool(x)
|
||||
features = features.view(features.size(0), -1)
|
||||
logits = self.classifier(features)
|
||||
return features, logits
|
||||
@@ -1,515 +0,0 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import math, torch
|
||||
from collections import OrderedDict
|
||||
from bisect import bisect_right
|
||||
import torch.nn as nn
|
||||
from ..initialization import initialize_resnet
|
||||
from ..SharedUtils import additive_func
|
||||
from .SoftSelect import select2withP, ChannelWiseInter
|
||||
from .SoftSelect import linear_forward
|
||||
from .SoftSelect import get_width_choices
|
||||
|
||||
|
||||
def get_depth_choices(nDepth, return_num):
|
||||
if nDepth == 2:
|
||||
choices = (1, 2)
|
||||
elif nDepth == 3:
|
||||
choices = (1, 2, 3)
|
||||
elif nDepth > 3:
|
||||
choices = list(range(1, nDepth + 1, 2))
|
||||
if choices[-1] < nDepth:
|
||||
choices.append(nDepth)
|
||||
else:
|
||||
raise ValueError("invalid nDepth : {:}".format(nDepth))
|
||||
if return_num:
|
||||
return len(choices)
|
||||
else:
|
||||
return choices
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
num_conv = 1
|
||||
|
||||
def __init__(
|
||||
self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu
|
||||
):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
self.InShape = None
|
||||
self.OutShape = None
|
||||
self.choices = get_width_choices(nOut)
|
||||
self.register_buffer("choices_tensor", torch.Tensor(self.choices))
|
||||
|
||||
if has_avg:
|
||||
self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
|
||||
else:
|
||||
self.avg = None
|
||||
self.conv = nn.Conv2d(
|
||||
nIn,
|
||||
nOut,
|
||||
kernel_size=kernel,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=bias,
|
||||
)
|
||||
if has_bn:
|
||||
self.bn = nn.BatchNorm2d(nOut)
|
||||
else:
|
||||
self.bn = None
|
||||
if has_relu:
|
||||
self.relu = nn.ReLU(inplace=False)
|
||||
else:
|
||||
self.relu = None
|
||||
self.in_dim = nIn
|
||||
self.out_dim = nOut
|
||||
|
||||
def get_flops(self, divide=1):
|
||||
iC, oC = self.in_dim, self.out_dim
|
||||
assert (
|
||||
iC <= self.conv.in_channels and oC <= self.conv.out_channels
|
||||
), "{:} vs {:} | {:} vs {:}".format(
|
||||
iC, self.conv.in_channels, oC, self.conv.out_channels
|
||||
)
|
||||
assert (
|
||||
isinstance(self.InShape, tuple) and len(self.InShape) == 2
|
||||
), "invalid in-shape : {:}".format(self.InShape)
|
||||
assert (
|
||||
isinstance(self.OutShape, tuple) and len(self.OutShape) == 2
|
||||
), "invalid out-shape : {:}".format(self.OutShape)
|
||||
# conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups
|
||||
conv_per_position_flops = (
|
||||
self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups
|
||||
)
|
||||
all_positions = self.OutShape[0] * self.OutShape[1]
|
||||
flops = (conv_per_position_flops * all_positions / divide) * iC * oC
|
||||
if self.conv.bias is not None:
|
||||
flops += all_positions / divide
|
||||
return flops
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.avg:
|
||||
out = self.avg(inputs)
|
||||
else:
|
||||
out = inputs
|
||||
conv = self.conv(out)
|
||||
if self.bn:
|
||||
out = self.bn(conv)
|
||||
else:
|
||||
out = conv
|
||||
if self.relu:
|
||||
out = self.relu(out)
|
||||
else:
|
||||
out = out
|
||||
if self.InShape is None:
|
||||
self.InShape = (inputs.size(-2), inputs.size(-1))
|
||||
self.OutShape = (out.size(-2), out.size(-1))
|
||||
return out
|
||||
|
||||
|
||||
class ResNetBasicblock(nn.Module):
|
||||
expansion = 1
|
||||
num_conv = 2
|
||||
|
||||
def __init__(self, inplanes, planes, stride):
|
||||
super(ResNetBasicblock, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
self.conv_a = ConvBNReLU(
|
||||
inplanes,
|
||||
planes,
|
||||
3,
|
||||
stride,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
self.conv_b = ConvBNReLU(
|
||||
planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False
|
||||
)
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=True,
|
||||
has_bn=False,
|
||||
has_relu=False,
|
||||
)
|
||||
elif inplanes != planes:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.out_dim = planes
|
||||
self.search_mode = "basic"
|
||||
|
||||
def get_flops(self, divide=1):
|
||||
flop_A = self.conv_a.get_flops(divide)
|
||||
flop_B = self.conv_b.get_flops(divide)
|
||||
if hasattr(self.downsample, "get_flops"):
|
||||
flop_C = self.downsample.get_flops(divide)
|
||||
else:
|
||||
flop_C = 0
|
||||
return flop_A + flop_B + flop_C
|
||||
|
||||
def forward(self, inputs):
|
||||
basicblock = self.conv_a(inputs)
|
||||
basicblock = self.conv_b(basicblock)
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = additive_func(residual, basicblock)
|
||||
return nn.functional.relu(out, inplace=True)
|
||||
|
||||
|
||||
class ResNetBottleneck(nn.Module):
|
||||
expansion = 4
|
||||
num_conv = 3
|
||||
|
||||
def __init__(self, inplanes, planes, stride):
|
||||
super(ResNetBottleneck, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
self.conv_1x1 = ConvBNReLU(
|
||||
inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True
|
||||
)
|
||||
self.conv_3x3 = ConvBNReLU(
|
||||
planes,
|
||||
planes,
|
||||
3,
|
||||
stride,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
self.conv_1x4 = ConvBNReLU(
|
||||
planes,
|
||||
planes * self.expansion,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes * self.expansion,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=True,
|
||||
has_bn=False,
|
||||
has_relu=False,
|
||||
)
|
||||
elif inplanes != planes * self.expansion:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes * self.expansion,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.out_dim = planes * self.expansion
|
||||
self.search_mode = "basic"
|
||||
|
||||
def get_range(self):
|
||||
return (
|
||||
self.conv_1x1.get_range()
|
||||
+ self.conv_3x3.get_range()
|
||||
+ self.conv_1x4.get_range()
|
||||
)
|
||||
|
||||
def get_flops(self, divide):
|
||||
flop_A = self.conv_1x1.get_flops(divide)
|
||||
flop_B = self.conv_3x3.get_flops(divide)
|
||||
flop_C = self.conv_1x4.get_flops(divide)
|
||||
if hasattr(self.downsample, "get_flops"):
|
||||
flop_D = self.downsample.get_flops(divide)
|
||||
else:
|
||||
flop_D = 0
|
||||
return flop_A + flop_B + flop_C + flop_D
|
||||
|
||||
def forward(self, inputs):
|
||||
bottleneck = self.conv_1x1(inputs)
|
||||
bottleneck = self.conv_3x3(bottleneck)
|
||||
bottleneck = self.conv_1x4(bottleneck)
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = additive_func(residual, bottleneck)
|
||||
return nn.functional.relu(out, inplace=True)
|
||||
|
||||
|
||||
class SearchDepthCifarResNet(nn.Module):
|
||||
def __init__(self, block_name, depth, num_classes):
|
||||
super(SearchDepthCifarResNet, self).__init__()
|
||||
|
||||
# Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
|
||||
if block_name == "ResNetBasicblock":
|
||||
block = ResNetBasicblock
|
||||
assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110"
|
||||
layer_blocks = (depth - 2) // 6
|
||||
elif block_name == "ResNetBottleneck":
|
||||
block = ResNetBottleneck
|
||||
assert (depth - 2) % 9 == 0, "depth should be one of 164"
|
||||
layer_blocks = (depth - 2) // 9
|
||||
else:
|
||||
raise ValueError("invalid block : {:}".format(block_name))
|
||||
|
||||
self.message = (
|
||||
"SearchShapeCifarResNet : Depth : {:} , Layers for each block : {:}".format(
|
||||
depth, layer_blocks
|
||||
)
|
||||
)
|
||||
self.num_classes = num_classes
|
||||
self.channels = [16]
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
ConvBNReLU(
|
||||
3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True
|
||||
)
|
||||
]
|
||||
)
|
||||
self.InShape = None
|
||||
self.depth_info = OrderedDict()
|
||||
self.depth_at_i = OrderedDict()
|
||||
for stage in range(3):
|
||||
cur_block_choices = get_depth_choices(layer_blocks, False)
|
||||
assert (
|
||||
cur_block_choices[-1] == layer_blocks
|
||||
), "stage={:}, {:} vs {:}".format(stage, cur_block_choices, layer_blocks)
|
||||
self.message += (
|
||||
"\nstage={:} ::: depth-block-choices={:} for {:} blocks.".format(
|
||||
stage, cur_block_choices, layer_blocks
|
||||
)
|
||||
)
|
||||
block_choices, xstart = [], len(self.layers)
|
||||
for iL in range(layer_blocks):
|
||||
iC = self.channels[-1]
|
||||
planes = 16 * (2 ** stage)
|
||||
stride = 2 if stage > 0 and iL == 0 else 1
|
||||
module = block(iC, planes, stride)
|
||||
self.channels.append(module.out_dim)
|
||||
self.layers.append(module)
|
||||
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(
|
||||
stage,
|
||||
iL,
|
||||
layer_blocks,
|
||||
len(self.layers) - 1,
|
||||
iC,
|
||||
module.out_dim,
|
||||
stride,
|
||||
)
|
||||
# added for depth
|
||||
layer_index = len(self.layers) - 1
|
||||
if iL + 1 in cur_block_choices:
|
||||
block_choices.append(layer_index)
|
||||
if iL + 1 == layer_blocks:
|
||||
self.depth_info[layer_index] = {
|
||||
"choices": block_choices,
|
||||
"stage": stage,
|
||||
"xstart": xstart,
|
||||
}
|
||||
self.depth_info_list = []
|
||||
for xend, info in self.depth_info.items():
|
||||
self.depth_info_list.append((xend, info))
|
||||
xstart, xstage = info["xstart"], info["stage"]
|
||||
for ilayer in range(xstart, xend + 1):
|
||||
idx = bisect_right(info["choices"], ilayer - 1)
|
||||
self.depth_at_i[ilayer] = (xstage, idx)
|
||||
|
||||
self.avgpool = nn.AvgPool2d(8)
|
||||
self.classifier = nn.Linear(module.out_dim, num_classes)
|
||||
self.InShape = None
|
||||
self.tau = -1
|
||||
self.search_mode = "basic"
|
||||
# assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth)
|
||||
|
||||
self.register_parameter(
|
||||
"depth_attentions",
|
||||
nn.Parameter(torch.Tensor(3, get_depth_choices(layer_blocks, True))),
|
||||
)
|
||||
nn.init.normal_(self.depth_attentions, 0, 0.01)
|
||||
self.apply(initialize_resnet)
|
||||
|
||||
def arch_parameters(self):
|
||||
return [self.depth_attentions]
|
||||
|
||||
def base_parameters(self):
|
||||
return (
|
||||
list(self.layers.parameters())
|
||||
+ list(self.avgpool.parameters())
|
||||
+ list(self.classifier.parameters())
|
||||
)
|
||||
|
||||
def get_flop(self, mode, config_dict, extra_info):
|
||||
if config_dict is not None:
|
||||
config_dict = config_dict.copy()
|
||||
# select depth
|
||||
if mode == "genotype":
|
||||
with torch.no_grad():
|
||||
depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
|
||||
choices = torch.argmax(depth_probs, dim=1).cpu().tolist()
|
||||
elif mode == "max":
|
||||
choices = [depth_probs.size(1) - 1 for _ in range(depth_probs.size(0))]
|
||||
elif mode == "random":
|
||||
with torch.no_grad():
|
||||
depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
|
||||
choices = torch.multinomial(depth_probs, 1, False).cpu().tolist()
|
||||
else:
|
||||
raise ValueError("invalid mode : {:}".format(mode))
|
||||
selected_layers = []
|
||||
for choice, xvalue in zip(choices, self.depth_info_list):
|
||||
xtemp = xvalue[1]["choices"][choice] - xvalue[1]["xstart"] + 1
|
||||
selected_layers.append(xtemp)
|
||||
flop = 0
|
||||
for i, layer in enumerate(self.layers):
|
||||
if i in self.depth_at_i:
|
||||
xstagei, xatti = self.depth_at_i[i]
|
||||
if xatti <= choices[xstagei]: # leave this depth
|
||||
flop += layer.get_flops()
|
||||
else:
|
||||
flop += 0 # do not use this layer
|
||||
else:
|
||||
flop += layer.get_flops()
|
||||
# the last fc layer
|
||||
flop += self.classifier.in_features * self.classifier.out_features
|
||||
if config_dict is None:
|
||||
return flop / 1e6
|
||||
else:
|
||||
config_dict["xblocks"] = selected_layers
|
||||
config_dict["super_type"] = "infer-depth"
|
||||
config_dict["estimated_FLOP"] = flop / 1e6
|
||||
return flop / 1e6, config_dict
|
||||
|
||||
def get_arch_info(self):
|
||||
string = "for depth, there are {:} attention probabilities.".format(
|
||||
len(self.depth_attentions)
|
||||
)
|
||||
string += "\n{:}".format(self.depth_info)
|
||||
discrepancy = []
|
||||
with torch.no_grad():
|
||||
for i, att in enumerate(self.depth_attentions):
|
||||
prob = nn.functional.softmax(att, dim=0)
|
||||
prob = prob.cpu()
|
||||
selc = prob.argmax().item()
|
||||
prob = prob.tolist()
|
||||
prob = ["{:.3f}".format(x) for x in prob]
|
||||
xstring = "{:03d}/{:03d}-th : {:}".format(
|
||||
i, len(self.depth_attentions), " ".join(prob)
|
||||
)
|
||||
logt = ["{:.4f}".format(x) for x in att.cpu().tolist()]
|
||||
xstring += " || {:17s}".format(" ".join(logt))
|
||||
prob = sorted([float(x) for x in prob])
|
||||
disc = prob[-1] - prob[-2]
|
||||
xstring += " || discrepancy={:.2f} || select={:}/{:}".format(
|
||||
disc, selc, len(prob)
|
||||
)
|
||||
discrepancy.append(disc)
|
||||
string += "\n{:}".format(xstring)
|
||||
return string, discrepancy
|
||||
|
||||
def set_tau(self, tau_max, tau_min, epoch_ratio):
|
||||
assert (
|
||||
epoch_ratio >= 0 and epoch_ratio <= 1
|
||||
), "invalid epoch-ratio : {:}".format(epoch_ratio)
|
||||
tau = tau_min + (tau_max - tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2
|
||||
self.tau = tau
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.search_mode == "basic":
|
||||
return self.basic_forward(inputs)
|
||||
elif self.search_mode == "search":
|
||||
return self.search_forward(inputs)
|
||||
else:
|
||||
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
|
||||
|
||||
def search_forward(self, inputs):
|
||||
flop_depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
|
||||
flop_depth_probs = torch.flip(
|
||||
torch.cumsum(torch.flip(flop_depth_probs, [1]), 1), [1]
|
||||
)
|
||||
selected_depth_probs = select2withP(self.depth_attentions, self.tau, True)
|
||||
|
||||
x, flops = inputs, []
|
||||
feature_maps = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
layer_i = layer(x)
|
||||
feature_maps.append(layer_i)
|
||||
if i in self.depth_info: # aggregate the information
|
||||
choices = self.depth_info[i]["choices"]
|
||||
xstagei = self.depth_info[i]["stage"]
|
||||
possible_tensors = []
|
||||
for tempi, A in enumerate(choices):
|
||||
xtensor = feature_maps[A]
|
||||
possible_tensors.append(xtensor)
|
||||
weighted_sum = sum(
|
||||
xtensor * W
|
||||
for xtensor, W in zip(
|
||||
possible_tensors, selected_depth_probs[xstagei]
|
||||
)
|
||||
)
|
||||
x = weighted_sum
|
||||
else:
|
||||
x = layer_i
|
||||
|
||||
if i in self.depth_at_i:
|
||||
xstagei, xatti = self.depth_at_i[i]
|
||||
# print ('layer-{:03d}, stage={:}, att={:}, prob={:}, flop={:}'.format(i, xstagei, xatti, flop_depth_probs[xstagei, xatti].item(), layer.get_flops(1e6)))
|
||||
x_expected_flop = flop_depth_probs[xstagei, xatti] * layer.get_flops(
|
||||
1e6
|
||||
)
|
||||
else:
|
||||
x_expected_flop = layer.get_flops(1e6)
|
||||
flops.append(x_expected_flop)
|
||||
flops.append(
|
||||
(self.classifier.in_features * self.classifier.out_features * 1.0 / 1e6)
|
||||
)
|
||||
|
||||
features = self.avgpool(x)
|
||||
features = features.view(features.size(0), -1)
|
||||
logits = linear_forward(features, self.classifier)
|
||||
return logits, torch.stack([sum(flops)])
|
||||
|
||||
def basic_forward(self, inputs):
|
||||
if self.InShape is None:
|
||||
self.InShape = (inputs.size(-2), inputs.size(-1))
|
||||
x = inputs
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
features = self.avgpool(x)
|
||||
features = features.view(features.size(0), -1)
|
||||
logits = self.classifier(features)
|
||||
return features, logits
|
||||
@@ -1,619 +0,0 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import math, torch
|
||||
import torch.nn as nn
|
||||
from ..initialization import initialize_resnet
|
||||
from ..SharedUtils import additive_func
|
||||
from .SoftSelect import select2withP, ChannelWiseInter
|
||||
from .SoftSelect import linear_forward
|
||||
from .SoftSelect import get_width_choices as get_choices
|
||||
|
||||
|
||||
def conv_forward(inputs, conv, choices):
|
||||
iC = conv.in_channels
|
||||
fill_size = list(inputs.size())
|
||||
fill_size[1] = iC - fill_size[1]
|
||||
filled = torch.zeros(fill_size, device=inputs.device)
|
||||
xinputs = torch.cat((inputs, filled), dim=1)
|
||||
outputs = conv(xinputs)
|
||||
selecteds = [outputs[:, :oC] for oC in choices]
|
||||
return selecteds
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
num_conv = 1
|
||||
|
||||
def __init__(
|
||||
self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu
|
||||
):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
self.InShape = None
|
||||
self.OutShape = None
|
||||
self.choices = get_choices(nOut)
|
||||
self.register_buffer("choices_tensor", torch.Tensor(self.choices))
|
||||
|
||||
if has_avg:
|
||||
self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
|
||||
else:
|
||||
self.avg = None
|
||||
self.conv = nn.Conv2d(
|
||||
nIn,
|
||||
nOut,
|
||||
kernel_size=kernel,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=bias,
|
||||
)
|
||||
# if has_bn : self.bn = nn.BatchNorm2d(nOut)
|
||||
# else : self.bn = None
|
||||
self.has_bn = has_bn
|
||||
self.BNs = nn.ModuleList()
|
||||
for i, _out in enumerate(self.choices):
|
||||
self.BNs.append(nn.BatchNorm2d(_out))
|
||||
if has_relu:
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
else:
|
||||
self.relu = None
|
||||
self.in_dim = nIn
|
||||
self.out_dim = nOut
|
||||
self.search_mode = "basic"
|
||||
|
||||
def get_flops(self, channels, check_range=True, divide=1):
|
||||
iC, oC = channels
|
||||
if check_range:
|
||||
assert (
|
||||
iC <= self.conv.in_channels and oC <= self.conv.out_channels
|
||||
), "{:} vs {:} | {:} vs {:}".format(
|
||||
iC, self.conv.in_channels, oC, self.conv.out_channels
|
||||
)
|
||||
assert (
|
||||
isinstance(self.InShape, tuple) and len(self.InShape) == 2
|
||||
), "invalid in-shape : {:}".format(self.InShape)
|
||||
assert (
|
||||
isinstance(self.OutShape, tuple) and len(self.OutShape) == 2
|
||||
), "invalid out-shape : {:}".format(self.OutShape)
|
||||
# conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups
|
||||
conv_per_position_flops = (
|
||||
self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups
|
||||
)
|
||||
all_positions = self.OutShape[0] * self.OutShape[1]
|
||||
flops = (conv_per_position_flops * all_positions / divide) * iC * oC
|
||||
if self.conv.bias is not None:
|
||||
flops += all_positions / divide
|
||||
return flops
|
||||
|
||||
def get_range(self):
|
||||
return [self.choices]
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.search_mode == "basic":
|
||||
return self.basic_forward(inputs)
|
||||
elif self.search_mode == "search":
|
||||
return self.search_forward(inputs)
|
||||
else:
|
||||
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
|
||||
|
||||
def search_forward(self, tuple_inputs):
|
||||
assert (
|
||||
isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5
|
||||
), "invalid type input : {:}".format(type(tuple_inputs))
|
||||
inputs, expected_inC, probability, index, prob = tuple_inputs
|
||||
index, prob = torch.squeeze(index).tolist(), torch.squeeze(prob)
|
||||
probability = torch.squeeze(probability)
|
||||
assert len(index) == 2, "invalid length : {:}".format(index)
|
||||
# compute expected flop
|
||||
# coordinates = torch.arange(self.x_range[0], self.x_range[1]+1).type_as(probability)
|
||||
expected_outC = (self.choices_tensor * probability).sum()
|
||||
expected_flop = self.get_flops([expected_inC, expected_outC], False, 1e6)
|
||||
if self.avg:
|
||||
out = self.avg(inputs)
|
||||
else:
|
||||
out = inputs
|
||||
# convolutional layer
|
||||
out_convs = conv_forward(out, self.conv, [self.choices[i] for i in index])
|
||||
out_bns = [self.BNs[idx](out_conv) for idx, out_conv in zip(index, out_convs)]
|
||||
# merge
|
||||
out_channel = max([x.size(1) for x in out_bns])
|
||||
outA = ChannelWiseInter(out_bns[0], out_channel)
|
||||
outB = ChannelWiseInter(out_bns[1], out_channel)
|
||||
out = outA * prob[0] + outB * prob[1]
|
||||
# out = additive_func(out_bns[0]*prob[0], out_bns[1]*prob[1])
|
||||
|
||||
if self.relu:
|
||||
out = self.relu(out)
|
||||
else:
|
||||
out = out
|
||||
return out, expected_outC, expected_flop
|
||||
|
||||
def basic_forward(self, inputs):
|
||||
if self.avg:
|
||||
out = self.avg(inputs)
|
||||
else:
|
||||
out = inputs
|
||||
conv = self.conv(out)
|
||||
if self.has_bn:
|
||||
out = self.BNs[-1](conv)
|
||||
else:
|
||||
out = conv
|
||||
if self.relu:
|
||||
out = self.relu(out)
|
||||
else:
|
||||
out = out
|
||||
if self.InShape is None:
|
||||
self.InShape = (inputs.size(-2), inputs.size(-1))
|
||||
self.OutShape = (out.size(-2), out.size(-1))
|
||||
return out
|
||||
|
||||
|
||||
class ResNetBasicblock(nn.Module):
|
||||
expansion = 1
|
||||
num_conv = 2
|
||||
|
||||
def __init__(self, inplanes, planes, stride):
|
||||
super(ResNetBasicblock, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
self.conv_a = ConvBNReLU(
|
||||
inplanes,
|
||||
planes,
|
||||
3,
|
||||
stride,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
self.conv_b = ConvBNReLU(
|
||||
planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False
|
||||
)
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=True,
|
||||
has_bn=False,
|
||||
has_relu=False,
|
||||
)
|
||||
elif inplanes != planes:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.out_dim = planes
|
||||
self.search_mode = "basic"
|
||||
|
||||
def get_range(self):
|
||||
return self.conv_a.get_range() + self.conv_b.get_range()
|
||||
|
||||
def get_flops(self, channels):
|
||||
assert len(channels) == 3, "invalid channels : {:}".format(channels)
|
||||
flop_A = self.conv_a.get_flops([channels[0], channels[1]])
|
||||
flop_B = self.conv_b.get_flops([channels[1], channels[2]])
|
||||
if hasattr(self.downsample, "get_flops"):
|
||||
flop_C = self.downsample.get_flops([channels[0], channels[-1]])
|
||||
else:
|
||||
flop_C = 0
|
||||
if (
|
||||
channels[0] != channels[-1] and self.downsample is None
|
||||
): # this short-cut will be added during the infer-train
|
||||
flop_C = (
|
||||
channels[0]
|
||||
* channels[-1]
|
||||
* self.conv_b.OutShape[0]
|
||||
* self.conv_b.OutShape[1]
|
||||
)
|
||||
return flop_A + flop_B + flop_C
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.search_mode == "basic":
|
||||
return self.basic_forward(inputs)
|
||||
elif self.search_mode == "search":
|
||||
return self.search_forward(inputs)
|
||||
else:
|
||||
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
|
||||
|
||||
def search_forward(self, tuple_inputs):
|
||||
assert (
|
||||
isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5
|
||||
), "invalid type input : {:}".format(type(tuple_inputs))
|
||||
inputs, expected_inC, probability, indexes, probs = tuple_inputs
|
||||
assert indexes.size(0) == 2 and probs.size(0) == 2 and probability.size(0) == 2
|
||||
out_a, expected_inC_a, expected_flop_a = self.conv_a(
|
||||
(inputs, expected_inC, probability[0], indexes[0], probs[0])
|
||||
)
|
||||
out_b, expected_inC_b, expected_flop_b = self.conv_b(
|
||||
(out_a, expected_inC_a, probability[1], indexes[1], probs[1])
|
||||
)
|
||||
if self.downsample is not None:
|
||||
residual, _, expected_flop_c = self.downsample(
|
||||
(inputs, expected_inC, probability[1], indexes[1], probs[1])
|
||||
)
|
||||
else:
|
||||
residual, expected_flop_c = inputs, 0
|
||||
out = additive_func(residual, out_b)
|
||||
return (
|
||||
nn.functional.relu(out, inplace=True),
|
||||
expected_inC_b,
|
||||
sum([expected_flop_a, expected_flop_b, expected_flop_c]),
|
||||
)
|
||||
|
||||
def basic_forward(self, inputs):
|
||||
basicblock = self.conv_a(inputs)
|
||||
basicblock = self.conv_b(basicblock)
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = additive_func(residual, basicblock)
|
||||
return nn.functional.relu(out, inplace=True)
|
||||
|
||||
|
||||
class ResNetBottleneck(nn.Module):
|
||||
expansion = 4
|
||||
num_conv = 3
|
||||
|
||||
def __init__(self, inplanes, planes, stride):
|
||||
super(ResNetBottleneck, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
self.conv_1x1 = ConvBNReLU(
|
||||
inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True
|
||||
)
|
||||
self.conv_3x3 = ConvBNReLU(
|
||||
planes,
|
||||
planes,
|
||||
3,
|
||||
stride,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
self.conv_1x4 = ConvBNReLU(
|
||||
planes,
|
||||
planes * self.expansion,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes * self.expansion,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=True,
|
||||
has_bn=False,
|
||||
has_relu=False,
|
||||
)
|
||||
elif inplanes != planes * self.expansion:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes * self.expansion,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.out_dim = planes * self.expansion
|
||||
self.search_mode = "basic"
|
||||
|
||||
def get_range(self):
|
||||
return (
|
||||
self.conv_1x1.get_range()
|
||||
+ self.conv_3x3.get_range()
|
||||
+ self.conv_1x4.get_range()
|
||||
)
|
||||
|
||||
def get_flops(self, channels):
|
||||
assert len(channels) == 4, "invalid channels : {:}".format(channels)
|
||||
flop_A = self.conv_1x1.get_flops([channels[0], channels[1]])
|
||||
flop_B = self.conv_3x3.get_flops([channels[1], channels[2]])
|
||||
flop_C = self.conv_1x4.get_flops([channels[2], channels[3]])
|
||||
if hasattr(self.downsample, "get_flops"):
|
||||
flop_D = self.downsample.get_flops([channels[0], channels[-1]])
|
||||
else:
|
||||
flop_D = 0
|
||||
if (
|
||||
channels[0] != channels[-1] and self.downsample is None
|
||||
): # this short-cut will be added during the infer-train
|
||||
flop_D = (
|
||||
channels[0]
|
||||
* channels[-1]
|
||||
* self.conv_1x4.OutShape[0]
|
||||
* self.conv_1x4.OutShape[1]
|
||||
)
|
||||
return flop_A + flop_B + flop_C + flop_D
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.search_mode == "basic":
|
||||
return self.basic_forward(inputs)
|
||||
elif self.search_mode == "search":
|
||||
return self.search_forward(inputs)
|
||||
else:
|
||||
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
|
||||
|
||||
def basic_forward(self, inputs):
|
||||
bottleneck = self.conv_1x1(inputs)
|
||||
bottleneck = self.conv_3x3(bottleneck)
|
||||
bottleneck = self.conv_1x4(bottleneck)
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = additive_func(residual, bottleneck)
|
||||
return nn.functional.relu(out, inplace=True)
|
||||
|
||||
def search_forward(self, tuple_inputs):
|
||||
assert (
|
||||
isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5
|
||||
), "invalid type input : {:}".format(type(tuple_inputs))
|
||||
inputs, expected_inC, probability, indexes, probs = tuple_inputs
|
||||
assert indexes.size(0) == 3 and probs.size(0) == 3 and probability.size(0) == 3
|
||||
out_1x1, expected_inC_1x1, expected_flop_1x1 = self.conv_1x1(
|
||||
(inputs, expected_inC, probability[0], indexes[0], probs[0])
|
||||
)
|
||||
out_3x3, expected_inC_3x3, expected_flop_3x3 = self.conv_3x3(
|
||||
(out_1x1, expected_inC_1x1, probability[1], indexes[1], probs[1])
|
||||
)
|
||||
out_1x4, expected_inC_1x4, expected_flop_1x4 = self.conv_1x4(
|
||||
(out_3x3, expected_inC_3x3, probability[2], indexes[2], probs[2])
|
||||
)
|
||||
if self.downsample is not None:
|
||||
residual, _, expected_flop_c = self.downsample(
|
||||
(inputs, expected_inC, probability[2], indexes[2], probs[2])
|
||||
)
|
||||
else:
|
||||
residual, expected_flop_c = inputs, 0
|
||||
out = additive_func(residual, out_1x4)
|
||||
return (
|
||||
nn.functional.relu(out, inplace=True),
|
||||
expected_inC_1x4,
|
||||
sum(
|
||||
[
|
||||
expected_flop_1x1,
|
||||
expected_flop_3x3,
|
||||
expected_flop_1x4,
|
||||
expected_flop_c,
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class SearchWidthCifarResNet(nn.Module):
|
||||
def __init__(self, block_name, depth, num_classes):
|
||||
super(SearchWidthCifarResNet, self).__init__()
|
||||
|
||||
# Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
|
||||
if block_name == "ResNetBasicblock":
|
||||
block = ResNetBasicblock
|
||||
assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110"
|
||||
layer_blocks = (depth - 2) // 6
|
||||
elif block_name == "ResNetBottleneck":
|
||||
block = ResNetBottleneck
|
||||
assert (depth - 2) % 9 == 0, "depth should be one of 164"
|
||||
layer_blocks = (depth - 2) // 9
|
||||
else:
|
||||
raise ValueError("invalid block : {:}".format(block_name))
|
||||
|
||||
self.message = (
|
||||
"SearchWidthCifarResNet : Depth : {:} , Layers for each block : {:}".format(
|
||||
depth, layer_blocks
|
||||
)
|
||||
)
|
||||
self.num_classes = num_classes
|
||||
self.channels = [16]
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
ConvBNReLU(
|
||||
3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True
|
||||
)
|
||||
]
|
||||
)
|
||||
self.InShape = None
|
||||
for stage in range(3):
|
||||
for iL in range(layer_blocks):
|
||||
iC = self.channels[-1]
|
||||
planes = 16 * (2 ** stage)
|
||||
stride = 2 if stage > 0 and iL == 0 else 1
|
||||
module = block(iC, planes, stride)
|
||||
self.channels.append(module.out_dim)
|
||||
self.layers.append(module)
|
||||
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(
|
||||
stage,
|
||||
iL,
|
||||
layer_blocks,
|
||||
len(self.layers) - 1,
|
||||
iC,
|
||||
module.out_dim,
|
||||
stride,
|
||||
)
|
||||
|
||||
self.avgpool = nn.AvgPool2d(8)
|
||||
self.classifier = nn.Linear(module.out_dim, num_classes)
|
||||
self.InShape = None
|
||||
self.tau = -1
|
||||
self.search_mode = "basic"
|
||||
# assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth)
|
||||
|
||||
# parameters for width
|
||||
self.Ranges = []
|
||||
self.layer2indexRange = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
start_index = len(self.Ranges)
|
||||
self.Ranges += layer.get_range()
|
||||
self.layer2indexRange.append((start_index, len(self.Ranges)))
|
||||
assert len(self.Ranges) + 1 == depth, "invalid depth check {:} vs {:}".format(
|
||||
len(self.Ranges) + 1, depth
|
||||
)
|
||||
|
||||
self.register_parameter(
|
||||
"width_attentions",
|
||||
nn.Parameter(torch.Tensor(len(self.Ranges), get_choices(None))),
|
||||
)
|
||||
nn.init.normal_(self.width_attentions, 0, 0.01)
|
||||
self.apply(initialize_resnet)
|
||||
|
||||
def arch_parameters(self):
|
||||
return [self.width_attentions]
|
||||
|
||||
def base_parameters(self):
|
||||
return (
|
||||
list(self.layers.parameters())
|
||||
+ list(self.avgpool.parameters())
|
||||
+ list(self.classifier.parameters())
|
||||
)
|
||||
|
||||
def get_flop(self, mode, config_dict, extra_info):
|
||||
if config_dict is not None:
|
||||
config_dict = config_dict.copy()
|
||||
# weights = [F.softmax(x, dim=0) for x in self.width_attentions]
|
||||
channels = [3]
|
||||
for i, weight in enumerate(self.width_attentions):
|
||||
if mode == "genotype":
|
||||
with torch.no_grad():
|
||||
probe = nn.functional.softmax(weight, dim=0)
|
||||
C = self.Ranges[i][torch.argmax(probe).item()]
|
||||
elif mode == "max":
|
||||
C = self.Ranges[i][-1]
|
||||
elif mode == "fix":
|
||||
C = int(math.sqrt(extra_info) * self.Ranges[i][-1])
|
||||
elif mode == "random":
|
||||
assert isinstance(extra_info, float), "invalid extra_info : {:}".format(
|
||||
extra_info
|
||||
)
|
||||
with torch.no_grad():
|
||||
prob = nn.functional.softmax(weight, dim=0)
|
||||
approximate_C = int(math.sqrt(extra_info) * self.Ranges[i][-1])
|
||||
for j in range(prob.size(0)):
|
||||
prob[j] = 1 / (
|
||||
abs(j - (approximate_C - self.Ranges[i][j])) + 0.2
|
||||
)
|
||||
C = self.Ranges[i][torch.multinomial(prob, 1, False).item()]
|
||||
else:
|
||||
raise ValueError("invalid mode : {:}".format(mode))
|
||||
channels.append(C)
|
||||
flop = 0
|
||||
for i, layer in enumerate(self.layers):
|
||||
s, e = self.layer2indexRange[i]
|
||||
xchl = tuple(channels[s : e + 1])
|
||||
flop += layer.get_flops(xchl)
|
||||
# the last fc layer
|
||||
flop += channels[-1] * self.classifier.out_features
|
||||
if config_dict is None:
|
||||
return flop / 1e6
|
||||
else:
|
||||
config_dict["xchannels"] = channels
|
||||
config_dict["super_type"] = "infer-width"
|
||||
config_dict["estimated_FLOP"] = flop / 1e6
|
||||
return flop / 1e6, config_dict
|
||||
|
||||
def get_arch_info(self):
|
||||
string = "for width, there are {:} attention probabilities.".format(
|
||||
len(self.width_attentions)
|
||||
)
|
||||
discrepancy = []
|
||||
with torch.no_grad():
|
||||
for i, att in enumerate(self.width_attentions):
|
||||
prob = nn.functional.softmax(att, dim=0)
|
||||
prob = prob.cpu()
|
||||
selc = prob.argmax().item()
|
||||
prob = prob.tolist()
|
||||
prob = ["{:.3f}".format(x) for x in prob]
|
||||
xstring = "{:03d}/{:03d}-th : {:}".format(
|
||||
i, len(self.width_attentions), " ".join(prob)
|
||||
)
|
||||
logt = ["{:.3f}".format(x) for x in att.cpu().tolist()]
|
||||
xstring += " || {:52s}".format(" ".join(logt))
|
||||
prob = sorted([float(x) for x in prob])
|
||||
disc = prob[-1] - prob[-2]
|
||||
xstring += " || dis={:.2f} || select={:}/{:}".format(
|
||||
disc, selc, len(prob)
|
||||
)
|
||||
discrepancy.append(disc)
|
||||
string += "\n{:}".format(xstring)
|
||||
return string, discrepancy
|
||||
|
||||
def set_tau(self, tau_max, tau_min, epoch_ratio):
|
||||
assert (
|
||||
epoch_ratio >= 0 and epoch_ratio <= 1
|
||||
), "invalid epoch-ratio : {:}".format(epoch_ratio)
|
||||
tau = tau_min + (tau_max - tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2
|
||||
self.tau = tau
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.search_mode == "basic":
|
||||
return self.basic_forward(inputs)
|
||||
elif self.search_mode == "search":
|
||||
return self.search_forward(inputs)
|
||||
else:
|
||||
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
|
||||
|
||||
def search_forward(self, inputs):
|
||||
flop_probs = nn.functional.softmax(self.width_attentions, dim=1)
|
||||
selected_widths, selected_probs = select2withP(self.width_attentions, self.tau)
|
||||
with torch.no_grad():
|
||||
selected_widths = selected_widths.cpu()
|
||||
|
||||
x, last_channel_idx, expected_inC, flops = inputs, 0, 3, []
|
||||
for i, layer in enumerate(self.layers):
|
||||
selected_w_index = selected_widths[
|
||||
last_channel_idx : last_channel_idx + layer.num_conv
|
||||
]
|
||||
selected_w_probs = selected_probs[
|
||||
last_channel_idx : last_channel_idx + layer.num_conv
|
||||
]
|
||||
layer_prob = flop_probs[
|
||||
last_channel_idx : last_channel_idx + layer.num_conv
|
||||
]
|
||||
x, expected_inC, expected_flop = layer(
|
||||
(x, expected_inC, layer_prob, selected_w_index, selected_w_probs)
|
||||
)
|
||||
last_channel_idx += layer.num_conv
|
||||
flops.append(expected_flop)
|
||||
flops.append(expected_inC * (self.classifier.out_features * 1.0 / 1e6))
|
||||
features = self.avgpool(x)
|
||||
features = features.view(features.size(0), -1)
|
||||
logits = linear_forward(features, self.classifier)
|
||||
return logits, torch.stack([sum(flops)])
|
||||
|
||||
def basic_forward(self, inputs):
|
||||
if self.InShape is None:
|
||||
self.InShape = (inputs.size(-2), inputs.size(-1))
|
||||
x = inputs
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
features = self.avgpool(x)
|
||||
features = features.view(features.size(0), -1)
|
||||
logits = self.classifier(features)
|
||||
return features, logits
|
||||
@@ -1,766 +0,0 @@
|
||||
import math, torch
|
||||
from collections import OrderedDict
|
||||
from bisect import bisect_right
|
||||
import torch.nn as nn
|
||||
from ..initialization import initialize_resnet
|
||||
from ..SharedUtils import additive_func
|
||||
from .SoftSelect import select2withP, ChannelWiseInter
|
||||
from .SoftSelect import linear_forward
|
||||
from .SoftSelect import get_width_choices
|
||||
|
||||
|
||||
def get_depth_choices(layers):
|
||||
min_depth = min(layers)
|
||||
info = {"num": min_depth}
|
||||
for i, depth in enumerate(layers):
|
||||
choices = []
|
||||
for j in range(1, min_depth + 1):
|
||||
choices.append(int(float(depth) * j / min_depth))
|
||||
info[i] = choices
|
||||
return info
|
||||
|
||||
|
||||
def conv_forward(inputs, conv, choices):
|
||||
iC = conv.in_channels
|
||||
fill_size = list(inputs.size())
|
||||
fill_size[1] = iC - fill_size[1]
|
||||
filled = torch.zeros(fill_size, device=inputs.device)
|
||||
xinputs = torch.cat((inputs, filled), dim=1)
|
||||
outputs = conv(xinputs)
|
||||
selecteds = [outputs[:, :oC] for oC in choices]
|
||||
return selecteds
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
num_conv = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
nIn,
|
||||
nOut,
|
||||
kernel,
|
||||
stride,
|
||||
padding,
|
||||
bias,
|
||||
has_avg,
|
||||
has_bn,
|
||||
has_relu,
|
||||
last_max_pool=False,
|
||||
):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
self.InShape = None
|
||||
self.OutShape = None
|
||||
self.choices = get_width_choices(nOut)
|
||||
self.register_buffer("choices_tensor", torch.Tensor(self.choices))
|
||||
|
||||
if has_avg:
|
||||
self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
|
||||
else:
|
||||
self.avg = None
|
||||
self.conv = nn.Conv2d(
|
||||
nIn,
|
||||
nOut,
|
||||
kernel_size=kernel,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=bias,
|
||||
)
|
||||
# if has_bn : self.bn = nn.BatchNorm2d(nOut)
|
||||
# else : self.bn = None
|
||||
self.has_bn = has_bn
|
||||
self.BNs = nn.ModuleList()
|
||||
for i, _out in enumerate(self.choices):
|
||||
self.BNs.append(nn.BatchNorm2d(_out))
|
||||
if has_relu:
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
else:
|
||||
self.relu = None
|
||||
|
||||
if last_max_pool:
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
else:
|
||||
self.maxpool = None
|
||||
self.in_dim = nIn
|
||||
self.out_dim = nOut
|
||||
self.search_mode = "basic"
|
||||
|
||||
def get_flops(self, channels, check_range=True, divide=1):
|
||||
iC, oC = channels
|
||||
if check_range:
|
||||
assert (
|
||||
iC <= self.conv.in_channels and oC <= self.conv.out_channels
|
||||
), "{:} vs {:} | {:} vs {:}".format(
|
||||
iC, self.conv.in_channels, oC, self.conv.out_channels
|
||||
)
|
||||
assert (
|
||||
isinstance(self.InShape, tuple) and len(self.InShape) == 2
|
||||
), "invalid in-shape : {:}".format(self.InShape)
|
||||
assert (
|
||||
isinstance(self.OutShape, tuple) and len(self.OutShape) == 2
|
||||
), "invalid out-shape : {:}".format(self.OutShape)
|
||||
# conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups
|
||||
conv_per_position_flops = (
|
||||
self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups
|
||||
)
|
||||
all_positions = self.OutShape[0] * self.OutShape[1]
|
||||
flops = (conv_per_position_flops * all_positions / divide) * iC * oC
|
||||
if self.conv.bias is not None:
|
||||
flops += all_positions / divide
|
||||
return flops
|
||||
|
||||
def get_range(self):
|
||||
return [self.choices]
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.search_mode == "basic":
|
||||
return self.basic_forward(inputs)
|
||||
elif self.search_mode == "search":
|
||||
return self.search_forward(inputs)
|
||||
else:
|
||||
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
|
||||
|
||||
def search_forward(self, tuple_inputs):
|
||||
assert (
|
||||
isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5
|
||||
), "invalid type input : {:}".format(type(tuple_inputs))
|
||||
inputs, expected_inC, probability, index, prob = tuple_inputs
|
||||
index, prob = torch.squeeze(index).tolist(), torch.squeeze(prob)
|
||||
probability = torch.squeeze(probability)
|
||||
assert len(index) == 2, "invalid length : {:}".format(index)
|
||||
# compute expected flop
|
||||
# coordinates = torch.arange(self.x_range[0], self.x_range[1]+1).type_as(probability)
|
||||
expected_outC = (self.choices_tensor * probability).sum()
|
||||
expected_flop = self.get_flops([expected_inC, expected_outC], False, 1e6)
|
||||
if self.avg:
|
||||
out = self.avg(inputs)
|
||||
else:
|
||||
out = inputs
|
||||
# convolutional layer
|
||||
out_convs = conv_forward(out, self.conv, [self.choices[i] for i in index])
|
||||
out_bns = [self.BNs[idx](out_conv) for idx, out_conv in zip(index, out_convs)]
|
||||
# merge
|
||||
out_channel = max([x.size(1) for x in out_bns])
|
||||
outA = ChannelWiseInter(out_bns[0], out_channel)
|
||||
outB = ChannelWiseInter(out_bns[1], out_channel)
|
||||
out = outA * prob[0] + outB * prob[1]
|
||||
# out = additive_func(out_bns[0]*prob[0], out_bns[1]*prob[1])
|
||||
|
||||
if self.relu:
|
||||
out = self.relu(out)
|
||||
if self.maxpool:
|
||||
out = self.maxpool(out)
|
||||
return out, expected_outC, expected_flop
|
||||
|
||||
def basic_forward(self, inputs):
|
||||
if self.avg:
|
||||
out = self.avg(inputs)
|
||||
else:
|
||||
out = inputs
|
||||
conv = self.conv(out)
|
||||
if self.has_bn:
|
||||
out = self.BNs[-1](conv)
|
||||
else:
|
||||
out = conv
|
||||
if self.relu:
|
||||
out = self.relu(out)
|
||||
else:
|
||||
out = out
|
||||
if self.InShape is None:
|
||||
self.InShape = (inputs.size(-2), inputs.size(-1))
|
||||
self.OutShape = (out.size(-2), out.size(-1))
|
||||
if self.maxpool:
|
||||
out = self.maxpool(out)
|
||||
return out
|
||||
|
||||
|
||||
class ResNetBasicblock(nn.Module):
|
||||
expansion = 1
|
||||
num_conv = 2
|
||||
|
||||
def __init__(self, inplanes, planes, stride):
|
||||
super(ResNetBasicblock, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
self.conv_a = ConvBNReLU(
|
||||
inplanes,
|
||||
planes,
|
||||
3,
|
||||
stride,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
self.conv_b = ConvBNReLU(
|
||||
planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False
|
||||
)
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=True,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
elif inplanes != planes:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.out_dim = planes
|
||||
self.search_mode = "basic"
|
||||
|
||||
def get_range(self):
|
||||
return self.conv_a.get_range() + self.conv_b.get_range()
|
||||
|
||||
def get_flops(self, channels):
|
||||
assert len(channels) == 3, "invalid channels : {:}".format(channels)
|
||||
flop_A = self.conv_a.get_flops([channels[0], channels[1]])
|
||||
flop_B = self.conv_b.get_flops([channels[1], channels[2]])
|
||||
if hasattr(self.downsample, "get_flops"):
|
||||
flop_C = self.downsample.get_flops([channels[0], channels[-1]])
|
||||
else:
|
||||
flop_C = 0
|
||||
if (
|
||||
channels[0] != channels[-1] and self.downsample is None
|
||||
): # this short-cut will be added during the infer-train
|
||||
flop_C = (
|
||||
channels[0]
|
||||
* channels[-1]
|
||||
* self.conv_b.OutShape[0]
|
||||
* self.conv_b.OutShape[1]
|
||||
)
|
||||
return flop_A + flop_B + flop_C
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.search_mode == "basic":
|
||||
return self.basic_forward(inputs)
|
||||
elif self.search_mode == "search":
|
||||
return self.search_forward(inputs)
|
||||
else:
|
||||
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
|
||||
|
||||
def search_forward(self, tuple_inputs):
|
||||
assert (
|
||||
isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5
|
||||
), "invalid type input : {:}".format(type(tuple_inputs))
|
||||
inputs, expected_inC, probability, indexes, probs = tuple_inputs
|
||||
assert indexes.size(0) == 2 and probs.size(0) == 2 and probability.size(0) == 2
|
||||
# import pdb; pdb.set_trace()
|
||||
out_a, expected_inC_a, expected_flop_a = self.conv_a(
|
||||
(inputs, expected_inC, probability[0], indexes[0], probs[0])
|
||||
)
|
||||
out_b, expected_inC_b, expected_flop_b = self.conv_b(
|
||||
(out_a, expected_inC_a, probability[1], indexes[1], probs[1])
|
||||
)
|
||||
if self.downsample is not None:
|
||||
residual, _, expected_flop_c = self.downsample(
|
||||
(inputs, expected_inC, probability[1], indexes[1], probs[1])
|
||||
)
|
||||
else:
|
||||
residual, expected_flop_c = inputs, 0
|
||||
out = additive_func(residual, out_b)
|
||||
return (
|
||||
nn.functional.relu(out, inplace=True),
|
||||
expected_inC_b,
|
||||
sum([expected_flop_a, expected_flop_b, expected_flop_c]),
|
||||
)
|
||||
|
||||
def basic_forward(self, inputs):
|
||||
basicblock = self.conv_a(inputs)
|
||||
basicblock = self.conv_b(basicblock)
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = additive_func(residual, basicblock)
|
||||
return nn.functional.relu(out, inplace=True)
|
||||
|
||||
|
||||
class ResNetBottleneck(nn.Module):
|
||||
expansion = 4
|
||||
num_conv = 3
|
||||
|
||||
def __init__(self, inplanes, planes, stride):
|
||||
super(ResNetBottleneck, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
self.conv_1x1 = ConvBNReLU(
|
||||
inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True
|
||||
)
|
||||
self.conv_3x3 = ConvBNReLU(
|
||||
planes,
|
||||
planes,
|
||||
3,
|
||||
stride,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
self.conv_1x4 = ConvBNReLU(
|
||||
planes,
|
||||
planes * self.expansion,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes * self.expansion,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=True,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
elif inplanes != planes * self.expansion:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes * self.expansion,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.out_dim = planes * self.expansion
|
||||
self.search_mode = "basic"
|
||||
|
||||
def get_range(self):
|
||||
return (
|
||||
self.conv_1x1.get_range()
|
||||
+ self.conv_3x3.get_range()
|
||||
+ self.conv_1x4.get_range()
|
||||
)
|
||||
|
||||
def get_flops(self, channels):
|
||||
assert len(channels) == 4, "invalid channels : {:}".format(channels)
|
||||
flop_A = self.conv_1x1.get_flops([channels[0], channels[1]])
|
||||
flop_B = self.conv_3x3.get_flops([channels[1], channels[2]])
|
||||
flop_C = self.conv_1x4.get_flops([channels[2], channels[3]])
|
||||
if hasattr(self.downsample, "get_flops"):
|
||||
flop_D = self.downsample.get_flops([channels[0], channels[-1]])
|
||||
else:
|
||||
flop_D = 0
|
||||
if (
|
||||
channels[0] != channels[-1] and self.downsample is None
|
||||
): # this short-cut will be added during the infer-train
|
||||
flop_D = (
|
||||
channels[0]
|
||||
* channels[-1]
|
||||
* self.conv_1x4.OutShape[0]
|
||||
* self.conv_1x4.OutShape[1]
|
||||
)
|
||||
return flop_A + flop_B + flop_C + flop_D
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.search_mode == "basic":
|
||||
return self.basic_forward(inputs)
|
||||
elif self.search_mode == "search":
|
||||
return self.search_forward(inputs)
|
||||
else:
|
||||
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
|
||||
|
||||
def basic_forward(self, inputs):
|
||||
bottleneck = self.conv_1x1(inputs)
|
||||
bottleneck = self.conv_3x3(bottleneck)
|
||||
bottleneck = self.conv_1x4(bottleneck)
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = additive_func(residual, bottleneck)
|
||||
return nn.functional.relu(out, inplace=True)
|
||||
|
||||
def search_forward(self, tuple_inputs):
|
||||
assert (
|
||||
isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5
|
||||
), "invalid type input : {:}".format(type(tuple_inputs))
|
||||
inputs, expected_inC, probability, indexes, probs = tuple_inputs
|
||||
assert indexes.size(0) == 3 and probs.size(0) == 3 and probability.size(0) == 3
|
||||
out_1x1, expected_inC_1x1, expected_flop_1x1 = self.conv_1x1(
|
||||
(inputs, expected_inC, probability[0], indexes[0], probs[0])
|
||||
)
|
||||
out_3x3, expected_inC_3x3, expected_flop_3x3 = self.conv_3x3(
|
||||
(out_1x1, expected_inC_1x1, probability[1], indexes[1], probs[1])
|
||||
)
|
||||
out_1x4, expected_inC_1x4, expected_flop_1x4 = self.conv_1x4(
|
||||
(out_3x3, expected_inC_3x3, probability[2], indexes[2], probs[2])
|
||||
)
|
||||
if self.downsample is not None:
|
||||
residual, _, expected_flop_c = self.downsample(
|
||||
(inputs, expected_inC, probability[2], indexes[2], probs[2])
|
||||
)
|
||||
else:
|
||||
residual, expected_flop_c = inputs, 0
|
||||
out = additive_func(residual, out_1x4)
|
||||
return (
|
||||
nn.functional.relu(out, inplace=True),
|
||||
expected_inC_1x4,
|
||||
sum(
|
||||
[
|
||||
expected_flop_1x1,
|
||||
expected_flop_3x3,
|
||||
expected_flop_1x4,
|
||||
expected_flop_c,
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class SearchShapeImagenetResNet(nn.Module):
|
||||
def __init__(self, block_name, layers, deep_stem, num_classes):
|
||||
super(SearchShapeImagenetResNet, self).__init__()
|
||||
|
||||
# Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
|
||||
if block_name == "BasicBlock":
|
||||
block = ResNetBasicblock
|
||||
elif block_name == "Bottleneck":
|
||||
block = ResNetBottleneck
|
||||
else:
|
||||
raise ValueError("invalid block : {:}".format(block_name))
|
||||
|
||||
self.message = (
|
||||
"SearchShapeCifarResNet : Depth : {:} , Layers for each block : {:}".format(
|
||||
sum(layers) * block.num_conv, layers
|
||||
)
|
||||
)
|
||||
self.num_classes = num_classes
|
||||
if not deep_stem:
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
ConvBNReLU(
|
||||
3,
|
||||
64,
|
||||
7,
|
||||
2,
|
||||
3,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
last_max_pool=True,
|
||||
)
|
||||
]
|
||||
)
|
||||
self.channels = [64]
|
||||
else:
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
ConvBNReLU(
|
||||
3, 32, 3, 2, 1, False, has_avg=False, has_bn=True, has_relu=True
|
||||
),
|
||||
ConvBNReLU(
|
||||
32,
|
||||
64,
|
||||
3,
|
||||
1,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
last_max_pool=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
self.channels = [32, 64]
|
||||
|
||||
meta_depth_info = get_depth_choices(layers)
|
||||
self.InShape = None
|
||||
self.depth_info = OrderedDict()
|
||||
self.depth_at_i = OrderedDict()
|
||||
for stage, layer_blocks in enumerate(layers):
|
||||
cur_block_choices = meta_depth_info[stage]
|
||||
assert (
|
||||
cur_block_choices[-1] == layer_blocks
|
||||
), "stage={:}, {:} vs {:}".format(stage, cur_block_choices, layer_blocks)
|
||||
block_choices, xstart = [], len(self.layers)
|
||||
for iL in range(layer_blocks):
|
||||
iC = self.channels[-1]
|
||||
planes = 64 * (2 ** stage)
|
||||
stride = 2 if stage > 0 and iL == 0 else 1
|
||||
module = block(iC, planes, stride)
|
||||
self.channels.append(module.out_dim)
|
||||
self.layers.append(module)
|
||||
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(
|
||||
stage,
|
||||
iL,
|
||||
layer_blocks,
|
||||
len(self.layers) - 1,
|
||||
iC,
|
||||
module.out_dim,
|
||||
stride,
|
||||
)
|
||||
# added for depth
|
||||
layer_index = len(self.layers) - 1
|
||||
if iL + 1 in cur_block_choices:
|
||||
block_choices.append(layer_index)
|
||||
if iL + 1 == layer_blocks:
|
||||
self.depth_info[layer_index] = {
|
||||
"choices": block_choices,
|
||||
"stage": stage,
|
||||
"xstart": xstart,
|
||||
}
|
||||
self.depth_info_list = []
|
||||
for xend, info in self.depth_info.items():
|
||||
self.depth_info_list.append((xend, info))
|
||||
xstart, xstage = info["xstart"], info["stage"]
|
||||
for ilayer in range(xstart, xend + 1):
|
||||
idx = bisect_right(info["choices"], ilayer - 1)
|
||||
self.depth_at_i[ilayer] = (xstage, idx)
|
||||
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.classifier = nn.Linear(module.out_dim, num_classes)
|
||||
self.InShape = None
|
||||
self.tau = -1
|
||||
self.search_mode = "basic"
|
||||
# assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth)
|
||||
|
||||
# parameters for width
|
||||
self.Ranges = []
|
||||
self.layer2indexRange = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
start_index = len(self.Ranges)
|
||||
self.Ranges += layer.get_range()
|
||||
self.layer2indexRange.append((start_index, len(self.Ranges)))
|
||||
|
||||
self.register_parameter(
|
||||
"width_attentions",
|
||||
nn.Parameter(torch.Tensor(len(self.Ranges), get_width_choices(None))),
|
||||
)
|
||||
self.register_parameter(
|
||||
"depth_attentions",
|
||||
nn.Parameter(torch.Tensor(len(layers), meta_depth_info["num"])),
|
||||
)
|
||||
nn.init.normal_(self.width_attentions, 0, 0.01)
|
||||
nn.init.normal_(self.depth_attentions, 0, 0.01)
|
||||
self.apply(initialize_resnet)
|
||||
|
||||
def arch_parameters(self, LR=None):
|
||||
if LR is None:
|
||||
return [self.width_attentions, self.depth_attentions]
|
||||
else:
|
||||
return [
|
||||
{"params": self.width_attentions, "lr": LR},
|
||||
{"params": self.depth_attentions, "lr": LR},
|
||||
]
|
||||
|
||||
def base_parameters(self):
|
||||
return (
|
||||
list(self.layers.parameters())
|
||||
+ list(self.avgpool.parameters())
|
||||
+ list(self.classifier.parameters())
|
||||
)
|
||||
|
||||
def get_flop(self, mode, config_dict, extra_info):
|
||||
if config_dict is not None:
|
||||
config_dict = config_dict.copy()
|
||||
# select channels
|
||||
channels = [3]
|
||||
for i, weight in enumerate(self.width_attentions):
|
||||
if mode == "genotype":
|
||||
with torch.no_grad():
|
||||
probe = nn.functional.softmax(weight, dim=0)
|
||||
C = self.Ranges[i][torch.argmax(probe).item()]
|
||||
else:
|
||||
raise ValueError("invalid mode : {:}".format(mode))
|
||||
channels.append(C)
|
||||
# select depth
|
||||
if mode == "genotype":
|
||||
with torch.no_grad():
|
||||
depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
|
||||
choices = torch.argmax(depth_probs, dim=1).cpu().tolist()
|
||||
else:
|
||||
raise ValueError("invalid mode : {:}".format(mode))
|
||||
selected_layers = []
|
||||
for choice, xvalue in zip(choices, self.depth_info_list):
|
||||
xtemp = xvalue[1]["choices"][choice] - xvalue[1]["xstart"] + 1
|
||||
selected_layers.append(xtemp)
|
||||
flop = 0
|
||||
for i, layer in enumerate(self.layers):
|
||||
s, e = self.layer2indexRange[i]
|
||||
xchl = tuple(channels[s : e + 1])
|
||||
if i in self.depth_at_i:
|
||||
xstagei, xatti = self.depth_at_i[i]
|
||||
if xatti <= choices[xstagei]: # leave this depth
|
||||
flop += layer.get_flops(xchl)
|
||||
else:
|
||||
flop += 0 # do not use this layer
|
||||
else:
|
||||
flop += layer.get_flops(xchl)
|
||||
# the last fc layer
|
||||
flop += channels[-1] * self.classifier.out_features
|
||||
if config_dict is None:
|
||||
return flop / 1e6
|
||||
else:
|
||||
config_dict["xchannels"] = channels
|
||||
config_dict["xblocks"] = selected_layers
|
||||
config_dict["super_type"] = "infer-shape"
|
||||
config_dict["estimated_FLOP"] = flop / 1e6
|
||||
return flop / 1e6, config_dict
|
||||
|
||||
def get_arch_info(self):
|
||||
string = (
|
||||
"for depth and width, there are {:} + {:} attention probabilities.".format(
|
||||
len(self.depth_attentions), len(self.width_attentions)
|
||||
)
|
||||
)
|
||||
string += "\n{:}".format(self.depth_info)
|
||||
discrepancy = []
|
||||
with torch.no_grad():
|
||||
for i, att in enumerate(self.depth_attentions):
|
||||
prob = nn.functional.softmax(att, dim=0)
|
||||
prob = prob.cpu()
|
||||
selc = prob.argmax().item()
|
||||
prob = prob.tolist()
|
||||
prob = ["{:.3f}".format(x) for x in prob]
|
||||
xstring = "{:03d}/{:03d}-th : {:}".format(
|
||||
i, len(self.depth_attentions), " ".join(prob)
|
||||
)
|
||||
logt = ["{:.4f}".format(x) for x in att.cpu().tolist()]
|
||||
xstring += " || {:17s}".format(" ".join(logt))
|
||||
prob = sorted([float(x) for x in prob])
|
||||
disc = prob[-1] - prob[-2]
|
||||
xstring += " || discrepancy={:.2f} || select={:}/{:}".format(
|
||||
disc, selc, len(prob)
|
||||
)
|
||||
discrepancy.append(disc)
|
||||
string += "\n{:}".format(xstring)
|
||||
string += "\n-----------------------------------------------"
|
||||
for i, att in enumerate(self.width_attentions):
|
||||
prob = nn.functional.softmax(att, dim=0)
|
||||
prob = prob.cpu()
|
||||
selc = prob.argmax().item()
|
||||
prob = prob.tolist()
|
||||
prob = ["{:.3f}".format(x) for x in prob]
|
||||
xstring = "{:03d}/{:03d}-th : {:}".format(
|
||||
i, len(self.width_attentions), " ".join(prob)
|
||||
)
|
||||
logt = ["{:.3f}".format(x) for x in att.cpu().tolist()]
|
||||
xstring += " || {:52s}".format(" ".join(logt))
|
||||
prob = sorted([float(x) for x in prob])
|
||||
disc = prob[-1] - prob[-2]
|
||||
xstring += " || dis={:.2f} || select={:}/{:}".format(
|
||||
disc, selc, len(prob)
|
||||
)
|
||||
discrepancy.append(disc)
|
||||
string += "\n{:}".format(xstring)
|
||||
return string, discrepancy
|
||||
|
||||
def set_tau(self, tau_max, tau_min, epoch_ratio):
|
||||
assert (
|
||||
epoch_ratio >= 0 and epoch_ratio <= 1
|
||||
), "invalid epoch-ratio : {:}".format(epoch_ratio)
|
||||
tau = tau_min + (tau_max - tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2
|
||||
self.tau = tau
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.search_mode == "basic":
|
||||
return self.basic_forward(inputs)
|
||||
elif self.search_mode == "search":
|
||||
return self.search_forward(inputs)
|
||||
else:
|
||||
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
|
||||
|
||||
def search_forward(self, inputs):
|
||||
flop_width_probs = nn.functional.softmax(self.width_attentions, dim=1)
|
||||
flop_depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
|
||||
flop_depth_probs = torch.flip(
|
||||
torch.cumsum(torch.flip(flop_depth_probs, [1]), 1), [1]
|
||||
)
|
||||
selected_widths, selected_width_probs = select2withP(
|
||||
self.width_attentions, self.tau
|
||||
)
|
||||
selected_depth_probs = select2withP(self.depth_attentions, self.tau, True)
|
||||
with torch.no_grad():
|
||||
selected_widths = selected_widths.cpu()
|
||||
|
||||
x, last_channel_idx, expected_inC, flops = inputs, 0, 3, []
|
||||
feature_maps = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
selected_w_index = selected_widths[
|
||||
last_channel_idx : last_channel_idx + layer.num_conv
|
||||
]
|
||||
selected_w_probs = selected_width_probs[
|
||||
last_channel_idx : last_channel_idx + layer.num_conv
|
||||
]
|
||||
layer_prob = flop_width_probs[
|
||||
last_channel_idx : last_channel_idx + layer.num_conv
|
||||
]
|
||||
x, expected_inC, expected_flop = layer(
|
||||
(x, expected_inC, layer_prob, selected_w_index, selected_w_probs)
|
||||
)
|
||||
feature_maps.append(x)
|
||||
last_channel_idx += layer.num_conv
|
||||
if i in self.depth_info: # aggregate the information
|
||||
choices = self.depth_info[i]["choices"]
|
||||
xstagei = self.depth_info[i]["stage"]
|
||||
# print ('iL={:}, choices={:}, stage={:}, probs={:}'.format(i, choices, xstagei, selected_depth_probs[xstagei].cpu().tolist()))
|
||||
# for A, W in zip(choices, selected_depth_probs[xstagei]):
|
||||
# print('Size = {:}, W = {:}'.format(feature_maps[A].size(), W))
|
||||
possible_tensors = []
|
||||
max_C = max(feature_maps[A].size(1) for A in choices)
|
||||
for tempi, A in enumerate(choices):
|
||||
xtensor = ChannelWiseInter(feature_maps[A], max_C)
|
||||
possible_tensors.append(xtensor)
|
||||
weighted_sum = sum(
|
||||
xtensor * W
|
||||
for xtensor, W in zip(
|
||||
possible_tensors, selected_depth_probs[xstagei]
|
||||
)
|
||||
)
|
||||
x = weighted_sum
|
||||
|
||||
if i in self.depth_at_i:
|
||||
xstagei, xatti = self.depth_at_i[i]
|
||||
x_expected_flop = flop_depth_probs[xstagei, xatti] * expected_flop
|
||||
else:
|
||||
x_expected_flop = expected_flop
|
||||
flops.append(x_expected_flop)
|
||||
flops.append(expected_inC * (self.classifier.out_features * 1.0 / 1e6))
|
||||
features = self.avgpool(x)
|
||||
features = features.view(features.size(0), -1)
|
||||
logits = linear_forward(features, self.classifier)
|
||||
return logits, torch.stack([sum(flops)])
|
||||
|
||||
def basic_forward(self, inputs):
|
||||
if self.InShape is None:
|
||||
self.InShape = (inputs.size(-2), inputs.size(-1))
|
||||
x = inputs
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
features = self.avgpool(x)
|
||||
features = features.view(features.size(0), -1)
|
||||
logits = self.classifier(features)
|
||||
return features, logits
|
||||
@@ -1,466 +0,0 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import math, torch
|
||||
import torch.nn as nn
|
||||
from ..initialization import initialize_resnet
|
||||
from ..SharedUtils import additive_func
|
||||
from .SoftSelect import select2withP, ChannelWiseInter
|
||||
from .SoftSelect import linear_forward
|
||||
from .SoftSelect import get_width_choices as get_choices
|
||||
|
||||
|
||||
def conv_forward(inputs, conv, choices):
|
||||
iC = conv.in_channels
|
||||
fill_size = list(inputs.size())
|
||||
fill_size[1] = iC - fill_size[1]
|
||||
filled = torch.zeros(fill_size, device=inputs.device)
|
||||
xinputs = torch.cat((inputs, filled), dim=1)
|
||||
outputs = conv(xinputs)
|
||||
selecteds = [outputs[:, :oC] for oC in choices]
|
||||
return selecteds
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
num_conv = 1
|
||||
|
||||
def __init__(
|
||||
self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu
|
||||
):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
self.InShape = None
|
||||
self.OutShape = None
|
||||
self.choices = get_choices(nOut)
|
||||
self.register_buffer("choices_tensor", torch.Tensor(self.choices))
|
||||
|
||||
if has_avg:
|
||||
self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
|
||||
else:
|
||||
self.avg = None
|
||||
self.conv = nn.Conv2d(
|
||||
nIn,
|
||||
nOut,
|
||||
kernel_size=kernel,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=bias,
|
||||
)
|
||||
# if has_bn : self.bn = nn.BatchNorm2d(nOut)
|
||||
# else : self.bn = None
|
||||
self.has_bn = has_bn
|
||||
self.BNs = nn.ModuleList()
|
||||
for i, _out in enumerate(self.choices):
|
||||
self.BNs.append(nn.BatchNorm2d(_out))
|
||||
if has_relu:
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
else:
|
||||
self.relu = None
|
||||
self.in_dim = nIn
|
||||
self.out_dim = nOut
|
||||
self.search_mode = "basic"
|
||||
|
||||
def get_flops(self, channels, check_range=True, divide=1):
|
||||
iC, oC = channels
|
||||
if check_range:
|
||||
assert (
|
||||
iC <= self.conv.in_channels and oC <= self.conv.out_channels
|
||||
), "{:} vs {:} | {:} vs {:}".format(
|
||||
iC, self.conv.in_channels, oC, self.conv.out_channels
|
||||
)
|
||||
assert (
|
||||
isinstance(self.InShape, tuple) and len(self.InShape) == 2
|
||||
), "invalid in-shape : {:}".format(self.InShape)
|
||||
assert (
|
||||
isinstance(self.OutShape, tuple) and len(self.OutShape) == 2
|
||||
), "invalid out-shape : {:}".format(self.OutShape)
|
||||
# conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups
|
||||
conv_per_position_flops = (
|
||||
self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups
|
||||
)
|
||||
all_positions = self.OutShape[0] * self.OutShape[1]
|
||||
flops = (conv_per_position_flops * all_positions / divide) * iC * oC
|
||||
if self.conv.bias is not None:
|
||||
flops += all_positions / divide
|
||||
return flops
|
||||
|
||||
def get_range(self):
|
||||
return [self.choices]
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.search_mode == "basic":
|
||||
return self.basic_forward(inputs)
|
||||
elif self.search_mode == "search":
|
||||
return self.search_forward(inputs)
|
||||
else:
|
||||
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
|
||||
|
||||
def search_forward(self, tuple_inputs):
|
||||
assert (
|
||||
isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5
|
||||
), "invalid type input : {:}".format(type(tuple_inputs))
|
||||
inputs, expected_inC, probability, index, prob = tuple_inputs
|
||||
index, prob = torch.squeeze(index).tolist(), torch.squeeze(prob)
|
||||
probability = torch.squeeze(probability)
|
||||
assert len(index) == 2, "invalid length : {:}".format(index)
|
||||
# compute expected flop
|
||||
# coordinates = torch.arange(self.x_range[0], self.x_range[1]+1).type_as(probability)
|
||||
expected_outC = (self.choices_tensor * probability).sum()
|
||||
expected_flop = self.get_flops([expected_inC, expected_outC], False, 1e6)
|
||||
if self.avg:
|
||||
out = self.avg(inputs)
|
||||
else:
|
||||
out = inputs
|
||||
# convolutional layer
|
||||
out_convs = conv_forward(out, self.conv, [self.choices[i] for i in index])
|
||||
out_bns = [self.BNs[idx](out_conv) for idx, out_conv in zip(index, out_convs)]
|
||||
# merge
|
||||
out_channel = max([x.size(1) for x in out_bns])
|
||||
outA = ChannelWiseInter(out_bns[0], out_channel)
|
||||
outB = ChannelWiseInter(out_bns[1], out_channel)
|
||||
out = outA * prob[0] + outB * prob[1]
|
||||
# out = additive_func(out_bns[0]*prob[0], out_bns[1]*prob[1])
|
||||
|
||||
if self.relu:
|
||||
out = self.relu(out)
|
||||
else:
|
||||
out = out
|
||||
return out, expected_outC, expected_flop
|
||||
|
||||
def basic_forward(self, inputs):
|
||||
if self.avg:
|
||||
out = self.avg(inputs)
|
||||
else:
|
||||
out = inputs
|
||||
conv = self.conv(out)
|
||||
if self.has_bn:
|
||||
out = self.BNs[-1](conv)
|
||||
else:
|
||||
out = conv
|
||||
if self.relu:
|
||||
out = self.relu(out)
|
||||
else:
|
||||
out = out
|
||||
if self.InShape is None:
|
||||
self.InShape = (inputs.size(-2), inputs.size(-1))
|
||||
self.OutShape = (out.size(-2), out.size(-1))
|
||||
return out
|
||||
|
||||
|
||||
class SimBlock(nn.Module):
|
||||
expansion = 1
|
||||
num_conv = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride):
|
||||
super(SimBlock, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
self.conv = ConvBNReLU(
|
||||
inplanes,
|
||||
planes,
|
||||
3,
|
||||
stride,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=True,
|
||||
has_bn=False,
|
||||
has_relu=False,
|
||||
)
|
||||
elif inplanes != planes:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.out_dim = planes
|
||||
self.search_mode = "basic"
|
||||
|
||||
def get_range(self):
|
||||
return self.conv.get_range()
|
||||
|
||||
def get_flops(self, channels):
|
||||
assert len(channels) == 2, "invalid channels : {:}".format(channels)
|
||||
flop_A = self.conv.get_flops([channels[0], channels[1]])
|
||||
if hasattr(self.downsample, "get_flops"):
|
||||
flop_C = self.downsample.get_flops([channels[0], channels[-1]])
|
||||
else:
|
||||
flop_C = 0
|
||||
if (
|
||||
channels[0] != channels[-1] and self.downsample is None
|
||||
): # this short-cut will be added during the infer-train
|
||||
flop_C = (
|
||||
channels[0]
|
||||
* channels[-1]
|
||||
* self.conv.OutShape[0]
|
||||
* self.conv.OutShape[1]
|
||||
)
|
||||
return flop_A + flop_C
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.search_mode == "basic":
|
||||
return self.basic_forward(inputs)
|
||||
elif self.search_mode == "search":
|
||||
return self.search_forward(inputs)
|
||||
else:
|
||||
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
|
||||
|
||||
def search_forward(self, tuple_inputs):
|
||||
assert (
|
||||
isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5
|
||||
), "invalid type input : {:}".format(type(tuple_inputs))
|
||||
inputs, expected_inC, probability, indexes, probs = tuple_inputs
|
||||
assert (
|
||||
indexes.size(0) == 1 and probs.size(0) == 1 and probability.size(0) == 1
|
||||
), "invalid size : {:}, {:}, {:}".format(
|
||||
indexes.size(), probs.size(), probability.size()
|
||||
)
|
||||
out, expected_next_inC, expected_flop = self.conv(
|
||||
(inputs, expected_inC, probability[0], indexes[0], probs[0])
|
||||
)
|
||||
if self.downsample is not None:
|
||||
residual, _, expected_flop_c = self.downsample(
|
||||
(inputs, expected_inC, probability[-1], indexes[-1], probs[-1])
|
||||
)
|
||||
else:
|
||||
residual, expected_flop_c = inputs, 0
|
||||
out = additive_func(residual, out)
|
||||
return (
|
||||
nn.functional.relu(out, inplace=True),
|
||||
expected_next_inC,
|
||||
sum([expected_flop, expected_flop_c]),
|
||||
)
|
||||
|
||||
def basic_forward(self, inputs):
|
||||
basicblock = self.conv(inputs)
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = additive_func(residual, basicblock)
|
||||
return nn.functional.relu(out, inplace=True)
|
||||
|
||||
|
||||
class SearchWidthSimResNet(nn.Module):
|
||||
def __init__(self, depth, num_classes):
|
||||
super(SearchWidthSimResNet, self).__init__()
|
||||
|
||||
assert (
|
||||
depth - 2
|
||||
) % 3 == 0, "depth should be one of 5, 8, 11, 14, ... instead of {:}".format(
|
||||
depth
|
||||
)
|
||||
layer_blocks = (depth - 2) // 3
|
||||
self.message = (
|
||||
"SearchWidthSimResNet : Depth : {:} , Layers for each block : {:}".format(
|
||||
depth, layer_blocks
|
||||
)
|
||||
)
|
||||
self.num_classes = num_classes
|
||||
self.channels = [16]
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
ConvBNReLU(
|
||||
3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True
|
||||
)
|
||||
]
|
||||
)
|
||||
self.InShape = None
|
||||
for stage in range(3):
|
||||
for iL in range(layer_blocks):
|
||||
iC = self.channels[-1]
|
||||
planes = 16 * (2 ** stage)
|
||||
stride = 2 if stage > 0 and iL == 0 else 1
|
||||
module = SimBlock(iC, planes, stride)
|
||||
self.channels.append(module.out_dim)
|
||||
self.layers.append(module)
|
||||
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(
|
||||
stage,
|
||||
iL,
|
||||
layer_blocks,
|
||||
len(self.layers) - 1,
|
||||
iC,
|
||||
module.out_dim,
|
||||
stride,
|
||||
)
|
||||
|
||||
self.avgpool = nn.AvgPool2d(8)
|
||||
self.classifier = nn.Linear(module.out_dim, num_classes)
|
||||
self.InShape = None
|
||||
self.tau = -1
|
||||
self.search_mode = "basic"
|
||||
# assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth)
|
||||
|
||||
# parameters for width
|
||||
self.Ranges = []
|
||||
self.layer2indexRange = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
start_index = len(self.Ranges)
|
||||
self.Ranges += layer.get_range()
|
||||
self.layer2indexRange.append((start_index, len(self.Ranges)))
|
||||
assert len(self.Ranges) + 1 == depth, "invalid depth check {:} vs {:}".format(
|
||||
len(self.Ranges) + 1, depth
|
||||
)
|
||||
|
||||
self.register_parameter(
|
||||
"width_attentions",
|
||||
nn.Parameter(torch.Tensor(len(self.Ranges), get_choices(None))),
|
||||
)
|
||||
nn.init.normal_(self.width_attentions, 0, 0.01)
|
||||
self.apply(initialize_resnet)
|
||||
|
||||
def arch_parameters(self):
|
||||
return [self.width_attentions]
|
||||
|
||||
def base_parameters(self):
|
||||
return (
|
||||
list(self.layers.parameters())
|
||||
+ list(self.avgpool.parameters())
|
||||
+ list(self.classifier.parameters())
|
||||
)
|
||||
|
||||
def get_flop(self, mode, config_dict, extra_info):
|
||||
if config_dict is not None:
|
||||
config_dict = config_dict.copy()
|
||||
# weights = [F.softmax(x, dim=0) for x in self.width_attentions]
|
||||
channels = [3]
|
||||
for i, weight in enumerate(self.width_attentions):
|
||||
if mode == "genotype":
|
||||
with torch.no_grad():
|
||||
probe = nn.functional.softmax(weight, dim=0)
|
||||
C = self.Ranges[i][torch.argmax(probe).item()]
|
||||
elif mode == "max":
|
||||
C = self.Ranges[i][-1]
|
||||
elif mode == "fix":
|
||||
C = int(math.sqrt(extra_info) * self.Ranges[i][-1])
|
||||
elif mode == "random":
|
||||
assert isinstance(extra_info, float), "invalid extra_info : {:}".format(
|
||||
extra_info
|
||||
)
|
||||
with torch.no_grad():
|
||||
prob = nn.functional.softmax(weight, dim=0)
|
||||
approximate_C = int(math.sqrt(extra_info) * self.Ranges[i][-1])
|
||||
for j in range(prob.size(0)):
|
||||
prob[j] = 1 / (
|
||||
abs(j - (approximate_C - self.Ranges[i][j])) + 0.2
|
||||
)
|
||||
C = self.Ranges[i][torch.multinomial(prob, 1, False).item()]
|
||||
else:
|
||||
raise ValueError("invalid mode : {:}".format(mode))
|
||||
channels.append(C)
|
||||
flop = 0
|
||||
for i, layer in enumerate(self.layers):
|
||||
s, e = self.layer2indexRange[i]
|
||||
xchl = tuple(channels[s : e + 1])
|
||||
flop += layer.get_flops(xchl)
|
||||
# the last fc layer
|
||||
flop += channels[-1] * self.classifier.out_features
|
||||
if config_dict is None:
|
||||
return flop / 1e6
|
||||
else:
|
||||
config_dict["xchannels"] = channels
|
||||
config_dict["super_type"] = "infer-width"
|
||||
config_dict["estimated_FLOP"] = flop / 1e6
|
||||
return flop / 1e6, config_dict
|
||||
|
||||
def get_arch_info(self):
|
||||
string = "for width, there are {:} attention probabilities.".format(
|
||||
len(self.width_attentions)
|
||||
)
|
||||
discrepancy = []
|
||||
with torch.no_grad():
|
||||
for i, att in enumerate(self.width_attentions):
|
||||
prob = nn.functional.softmax(att, dim=0)
|
||||
prob = prob.cpu()
|
||||
selc = prob.argmax().item()
|
||||
prob = prob.tolist()
|
||||
prob = ["{:.3f}".format(x) for x in prob]
|
||||
xstring = "{:03d}/{:03d}-th : {:}".format(
|
||||
i, len(self.width_attentions), " ".join(prob)
|
||||
)
|
||||
logt = ["{:.3f}".format(x) for x in att.cpu().tolist()]
|
||||
xstring += " || {:52s}".format(" ".join(logt))
|
||||
prob = sorted([float(x) for x in prob])
|
||||
disc = prob[-1] - prob[-2]
|
||||
xstring += " || dis={:.2f} || select={:}/{:}".format(
|
||||
disc, selc, len(prob)
|
||||
)
|
||||
discrepancy.append(disc)
|
||||
string += "\n{:}".format(xstring)
|
||||
return string, discrepancy
|
||||
|
||||
def set_tau(self, tau_max, tau_min, epoch_ratio):
|
||||
assert (
|
||||
epoch_ratio >= 0 and epoch_ratio <= 1
|
||||
), "invalid epoch-ratio : {:}".format(epoch_ratio)
|
||||
tau = tau_min + (tau_max - tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2
|
||||
self.tau = tau
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.search_mode == "basic":
|
||||
return self.basic_forward(inputs)
|
||||
elif self.search_mode == "search":
|
||||
return self.search_forward(inputs)
|
||||
else:
|
||||
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
|
||||
|
||||
def search_forward(self, inputs):
|
||||
flop_probs = nn.functional.softmax(self.width_attentions, dim=1)
|
||||
selected_widths, selected_probs = select2withP(self.width_attentions, self.tau)
|
||||
with torch.no_grad():
|
||||
selected_widths = selected_widths.cpu()
|
||||
|
||||
x, last_channel_idx, expected_inC, flops = inputs, 0, 3, []
|
||||
for i, layer in enumerate(self.layers):
|
||||
selected_w_index = selected_widths[
|
||||
last_channel_idx : last_channel_idx + layer.num_conv
|
||||
]
|
||||
selected_w_probs = selected_probs[
|
||||
last_channel_idx : last_channel_idx + layer.num_conv
|
||||
]
|
||||
layer_prob = flop_probs[
|
||||
last_channel_idx : last_channel_idx + layer.num_conv
|
||||
]
|
||||
x, expected_inC, expected_flop = layer(
|
||||
(x, expected_inC, layer_prob, selected_w_index, selected_w_probs)
|
||||
)
|
||||
last_channel_idx += layer.num_conv
|
||||
flops.append(expected_flop)
|
||||
flops.append(expected_inC * (self.classifier.out_features * 1.0 / 1e6))
|
||||
features = self.avgpool(x)
|
||||
features = features.view(features.size(0), -1)
|
||||
logits = linear_forward(features, self.classifier)
|
||||
return logits, torch.stack([sum(flops)])
|
||||
|
||||
def basic_forward(self, inputs):
|
||||
if self.InShape is None:
|
||||
self.InShape = (inputs.size(-2), inputs.size(-1))
|
||||
x = inputs
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
features = self.avgpool(x)
|
||||
features = features.view(features.size(0), -1)
|
||||
logits = self.classifier(features)
|
||||
return features, logits
|
||||
@@ -1,128 +0,0 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import math, torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def select2withP(logits, tau, just_prob=False, num=2, eps=1e-7):
|
||||
if tau <= 0:
|
||||
new_logits = logits
|
||||
probs = nn.functional.softmax(new_logits, dim=1)
|
||||
else:
|
||||
while True: # a trick to avoid the gumbels bug
|
||||
gumbels = -torch.empty_like(logits).exponential_().log()
|
||||
new_logits = (logits.log_softmax(dim=1) + gumbels) / tau
|
||||
probs = nn.functional.softmax(new_logits, dim=1)
|
||||
if (
|
||||
(not torch.isinf(gumbels).any())
|
||||
and (not torch.isinf(probs).any())
|
||||
and (not torch.isnan(probs).any())
|
||||
):
|
||||
break
|
||||
|
||||
if just_prob:
|
||||
return probs
|
||||
|
||||
# with torch.no_grad(): # add eps for unexpected torch error
|
||||
# probs = nn.functional.softmax(new_logits, dim=1)
|
||||
# selected_index = torch.multinomial(probs + eps, 2, False)
|
||||
with torch.no_grad(): # add eps for unexpected torch error
|
||||
probs = probs.cpu()
|
||||
selected_index = torch.multinomial(probs + eps, num, False).to(logits.device)
|
||||
selected_logit = torch.gather(new_logits, 1, selected_index)
|
||||
selcted_probs = nn.functional.softmax(selected_logit, dim=1)
|
||||
return selected_index, selcted_probs
|
||||
|
||||
|
||||
def ChannelWiseInter(inputs, oC, mode="v2"):
|
||||
if mode == "v1":
|
||||
return ChannelWiseInterV1(inputs, oC)
|
||||
elif mode == "v2":
|
||||
return ChannelWiseInterV2(inputs, oC)
|
||||
else:
|
||||
raise ValueError("invalid mode : {:}".format(mode))
|
||||
|
||||
|
||||
def ChannelWiseInterV1(inputs, oC):
|
||||
assert inputs.dim() == 4, "invalid dimension : {:}".format(inputs.size())
|
||||
|
||||
def start_index(a, b, c):
|
||||
return int(math.floor(float(a * c) / b))
|
||||
|
||||
def end_index(a, b, c):
|
||||
return int(math.ceil(float((a + 1) * c) / b))
|
||||
|
||||
batch, iC, H, W = inputs.size()
|
||||
outputs = torch.zeros((batch, oC, H, W), dtype=inputs.dtype, device=inputs.device)
|
||||
if iC == oC:
|
||||
return inputs
|
||||
for ot in range(oC):
|
||||
istartT, iendT = start_index(ot, oC, iC), end_index(ot, oC, iC)
|
||||
values = inputs[:, istartT:iendT].mean(dim=1)
|
||||
outputs[:, ot, :, :] = values
|
||||
return outputs
|
||||
|
||||
|
||||
def ChannelWiseInterV2(inputs, oC):
|
||||
assert inputs.dim() == 4, "invalid dimension : {:}".format(inputs.size())
|
||||
batch, C, H, W = inputs.size()
|
||||
if C == oC:
|
||||
return inputs
|
||||
else:
|
||||
return nn.functional.adaptive_avg_pool3d(inputs, (oC, H, W))
|
||||
# inputs_5D = inputs.view(batch, 1, C, H, W)
|
||||
# otputs_5D = nn.functional.interpolate(inputs_5D, (oC,H,W), None, 'area', None)
|
||||
# otputs = otputs_5D.view(batch, oC, H, W)
|
||||
# otputs_5D = nn.functional.interpolate(inputs_5D, (oC,H,W), None, 'trilinear', False)
|
||||
# return otputs
|
||||
|
||||
|
||||
def linear_forward(inputs, linear):
|
||||
if linear is None:
|
||||
return inputs
|
||||
iC = inputs.size(1)
|
||||
weight = linear.weight[:, :iC]
|
||||
if linear.bias is None:
|
||||
bias = None
|
||||
else:
|
||||
bias = linear.bias
|
||||
return nn.functional.linear(inputs, weight, bias)
|
||||
|
||||
|
||||
def get_width_choices(nOut):
|
||||
xsrange = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
|
||||
if nOut is None:
|
||||
return len(xsrange)
|
||||
else:
|
||||
Xs = [int(nOut * i) for i in xsrange]
|
||||
# xs = [ int(nOut * i // 10) for i in range(2, 11)]
|
||||
# Xs = [x for i, x in enumerate(xs) if i+1 == len(xs) or xs[i+1] > x+1]
|
||||
Xs = sorted(list(set(Xs)))
|
||||
return tuple(Xs)
|
||||
|
||||
|
||||
def get_depth_choices(nDepth):
|
||||
if nDepth is None:
|
||||
return 3
|
||||
else:
|
||||
assert nDepth >= 3, "nDepth should be greater than 2 vs {:}".format(nDepth)
|
||||
if nDepth == 1:
|
||||
return (1, 1, 1)
|
||||
elif nDepth == 2:
|
||||
return (1, 1, 2)
|
||||
elif nDepth >= 3:
|
||||
return (nDepth // 3, nDepth * 2 // 3, nDepth)
|
||||
else:
|
||||
raise ValueError("invalid Depth : {:}".format(nDepth))
|
||||
|
||||
|
||||
def drop_path(x, drop_prob):
|
||||
if drop_prob > 0.0:
|
||||
keep_prob = 1.0 - drop_prob
|
||||
mask = x.new_zeros(x.size(0), 1, 1, 1)
|
||||
mask = mask.bernoulli_(keep_prob)
|
||||
x = x * (mask / keep_prob)
|
||||
# x.div_(keep_prob)
|
||||
# x.mul_(mask)
|
||||
return x
|
||||
@@ -1,9 +0,0 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
from .SearchCifarResNet_width import SearchWidthCifarResNet
|
||||
from .SearchCifarResNet_depth import SearchDepthCifarResNet
|
||||
from .SearchCifarResNet import SearchShapeCifarResNet
|
||||
from .SearchSimResNet_width import SearchWidthSimResNet
|
||||
from .SearchImagenetResNet import SearchShapeImagenetResNet
|
||||
from .generic_size_tiny_cell_model import GenericNAS301Model
|
||||
@@ -1,209 +0,0 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
# Here, we utilized three techniques to search for the number of channels:
|
||||
# - channel-wise interpolation from "Network Pruning via Transformable Architecture Search, NeurIPS 2019"
|
||||
# - masking + Gumbel-Softmax (mask_gumbel) from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020"
|
||||
# - masking + sampling (mask_rl) from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020"
|
||||
from typing import List, Text, Any
|
||||
import random, torch
|
||||
import torch.nn as nn
|
||||
|
||||
from models.cell_operations import ResNetBasicblock
|
||||
from models.cell_infers.cells import InferCell
|
||||
from models.shape_searchs.SoftSelect import select2withP, ChannelWiseInter
|
||||
|
||||
|
||||
class GenericNAS301Model(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
candidate_Cs: List[int],
|
||||
max_num_Cs: int,
|
||||
genotype: Any,
|
||||
num_classes: int,
|
||||
affine: bool,
|
||||
track_running_stats: bool,
|
||||
):
|
||||
super(GenericNAS301Model, self).__init__()
|
||||
self._max_num_Cs = max_num_Cs
|
||||
self._candidate_Cs = candidate_Cs
|
||||
if max_num_Cs % 3 != 2:
|
||||
raise ValueError("invalid number of layers : {:}".format(max_num_Cs))
|
||||
self._num_stage = N = max_num_Cs // 3
|
||||
self._max_C = max(candidate_Cs)
|
||||
|
||||
stem = nn.Sequential(
|
||||
nn.Conv2d(3, self._max_C, kernel_size=3, padding=1, bias=not affine),
|
||||
nn.BatchNorm2d(
|
||||
self._max_C, affine=affine, track_running_stats=track_running_stats
|
||||
),
|
||||
)
|
||||
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
|
||||
c_prev = self._max_C
|
||||
self._cells = nn.ModuleList()
|
||||
self._cells.append(stem)
|
||||
for index, reduction in enumerate(layer_reductions):
|
||||
if reduction:
|
||||
cell = ResNetBasicblock(c_prev, self._max_C, 2, True)
|
||||
else:
|
||||
cell = InferCell(
|
||||
genotype, c_prev, self._max_C, 1, affine, track_running_stats
|
||||
)
|
||||
self._cells.append(cell)
|
||||
c_prev = cell.out_dim
|
||||
self._num_layer = len(self._cells)
|
||||
|
||||
self.lastact = nn.Sequential(
|
||||
nn.BatchNorm2d(
|
||||
c_prev, affine=affine, track_running_stats=track_running_stats
|
||||
),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(c_prev, num_classes)
|
||||
# algorithm related
|
||||
self.register_buffer("_tau", torch.zeros(1))
|
||||
self._algo = None
|
||||
self._warmup_ratio = None
|
||||
|
||||
def set_algo(self, algo: Text):
|
||||
# used for searching
|
||||
assert self._algo is None, "This functioin can only be called once."
|
||||
assert algo in ["mask_gumbel", "mask_rl", "tas"], "invalid algo : {:}".format(
|
||||
algo
|
||||
)
|
||||
self._algo = algo
|
||||
self._arch_parameters = nn.Parameter(
|
||||
1e-3 * torch.randn(self._max_num_Cs, len(self._candidate_Cs))
|
||||
)
|
||||
# if algo == 'mask_gumbel' or algo == 'mask_rl':
|
||||
self.register_buffer(
|
||||
"_masks", torch.zeros(len(self._candidate_Cs), max(self._candidate_Cs))
|
||||
)
|
||||
for i in range(len(self._candidate_Cs)):
|
||||
self._masks.data[i, : self._candidate_Cs[i]] = 1
|
||||
|
||||
@property
|
||||
def tau(self):
|
||||
return self._tau
|
||||
|
||||
def set_tau(self, tau):
|
||||
self._tau.data[:] = tau
|
||||
|
||||
@property
|
||||
def warmup_ratio(self):
|
||||
return self._warmup_ratio
|
||||
|
||||
def set_warmup_ratio(self, ratio: float):
|
||||
self._warmup_ratio = ratio
|
||||
|
||||
@property
|
||||
def weights(self):
|
||||
xlist = list(self._cells.parameters())
|
||||
xlist += list(self.lastact.parameters())
|
||||
xlist += list(self.global_pooling.parameters())
|
||||
xlist += list(self.classifier.parameters())
|
||||
return xlist
|
||||
|
||||
@property
|
||||
def alphas(self):
|
||||
return [self._arch_parameters]
|
||||
|
||||
def show_alphas(self):
|
||||
with torch.no_grad():
|
||||
return "arch-parameters :\n{:}".format(
|
||||
nn.functional.softmax(self._arch_parameters, dim=-1).cpu()
|
||||
)
|
||||
|
||||
@property
|
||||
def random(self):
|
||||
cs = []
|
||||
for i in range(self._max_num_Cs):
|
||||
index = random.randint(0, len(self._candidate_Cs) - 1)
|
||||
cs.append(str(self._candidate_Cs[index]))
|
||||
return ":".join(cs)
|
||||
|
||||
@property
|
||||
def genotype(self):
|
||||
cs = []
|
||||
for i in range(self._max_num_Cs):
|
||||
with torch.no_grad():
|
||||
index = self._arch_parameters[i].argmax().item()
|
||||
cs.append(str(self._candidate_Cs[index]))
|
||||
return ":".join(cs)
|
||||
|
||||
def get_message(self) -> Text:
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self._cells):
|
||||
string += "\n {:02d}/{:02d} :: {:}".format(
|
||||
i, len(self._cells), cell.extra_repr()
|
||||
)
|
||||
return string
|
||||
|
||||
def extra_repr(self):
|
||||
return "{name}(candidates={_candidate_Cs}, num={_max_num_Cs}, N={_num_stage}, L={_num_layer})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
feature = inputs
|
||||
|
||||
log_probs = []
|
||||
for i, cell in enumerate(self._cells):
|
||||
feature = cell(feature)
|
||||
# apply different searching algorithms
|
||||
idx = max(0, i - 1)
|
||||
if self._warmup_ratio is not None:
|
||||
if random.random() < self._warmup_ratio:
|
||||
mask = self._masks[-1]
|
||||
else:
|
||||
mask = self._masks[random.randint(0, len(self._masks) - 1)]
|
||||
feature = feature * mask.view(1, -1, 1, 1)
|
||||
elif self._algo == "mask_gumbel":
|
||||
weights = nn.functional.gumbel_softmax(
|
||||
self._arch_parameters[idx : idx + 1], tau=self.tau, dim=-1
|
||||
)
|
||||
mask = torch.matmul(weights, self._masks).view(1, -1, 1, 1)
|
||||
feature = feature * mask
|
||||
elif self._algo == "tas":
|
||||
selected_cs, selected_probs = select2withP(
|
||||
self._arch_parameters[idx : idx + 1], self.tau, num=2
|
||||
)
|
||||
with torch.no_grad():
|
||||
i1, i2 = selected_cs.cpu().view(-1).tolist()
|
||||
c1, c2 = self._candidate_Cs[i1], self._candidate_Cs[i2]
|
||||
out_channel = max(c1, c2)
|
||||
out1 = ChannelWiseInter(feature[:, :c1], out_channel)
|
||||
out2 = ChannelWiseInter(feature[:, :c2], out_channel)
|
||||
out = out1 * selected_probs[0, 0] + out2 * selected_probs[0, 1]
|
||||
if feature.shape[1] == out.shape[1]:
|
||||
feature = out
|
||||
else:
|
||||
miss = torch.zeros(
|
||||
feature.shape[0],
|
||||
feature.shape[1] - out.shape[1],
|
||||
feature.shape[2],
|
||||
feature.shape[3],
|
||||
device=feature.device,
|
||||
)
|
||||
feature = torch.cat((out, miss), dim=1)
|
||||
elif self._algo == "mask_rl":
|
||||
prob = nn.functional.softmax(
|
||||
self._arch_parameters[idx : idx + 1], dim=-1
|
||||
)
|
||||
dist = torch.distributions.Categorical(prob)
|
||||
action = dist.sample()
|
||||
log_probs.append(dist.log_prob(action))
|
||||
mask = self._masks[action.item()].view(1, -1, 1, 1)
|
||||
feature = feature * mask
|
||||
else:
|
||||
raise ValueError("invalid algorithm : {:}".format(self._algo))
|
||||
|
||||
out = self.lastact(feature)
|
||||
out = self.global_pooling(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
return out, logits, log_probs
|
||||
@@ -1,20 +0,0 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from SoftSelect import ChannelWiseInter
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
tensors = torch.rand((16, 128, 7, 7))
|
||||
|
||||
for oc in range(200, 210):
|
||||
out_v1 = ChannelWiseInter(tensors, oc, "v1")
|
||||
out_v2 = ChannelWiseInter(tensors, oc, "v2")
|
||||
assert (out_v1 == out_v2).any().item() == 1
|
||||
for oc in range(48, 160):
|
||||
out_v1 = ChannelWiseInter(tensors, oc, "v1")
|
||||
out_v2 = ChannelWiseInter(tensors, oc, "v2")
|
||||
assert (out_v1 == out_v2).any().item() == 1
|
||||
@@ -1,67 +0,0 @@
|
||||
#######################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
|
||||
#######################################################
|
||||
# Use module in xlayers to construct different models #
|
||||
#######################################################
|
||||
from typing import List, Text, Dict, Any
|
||||
import torch
|
||||
|
||||
__all__ = ["get_model"]
|
||||
|
||||
|
||||
from xlayers.super_core import SuperSequential
|
||||
from xlayers.super_core import SuperLinear
|
||||
from xlayers.super_core import SuperDropout
|
||||
from xlayers.super_core import super_name2norm
|
||||
from xlayers.super_core import super_name2activation
|
||||
|
||||
|
||||
def get_model(config: Dict[Text, Any], **kwargs):
|
||||
model_type = config.get("model_type", "simple_mlp")
|
||||
if model_type == "simple_mlp":
|
||||
act_cls = super_name2activation[kwargs["act_cls"]]
|
||||
norm_cls = super_name2norm[kwargs["norm_cls"]]
|
||||
mean, std = kwargs.get("mean", None), kwargs.get("std", None)
|
||||
if "hidden_dim" in kwargs:
|
||||
hidden_dim1 = kwargs.get("hidden_dim")
|
||||
hidden_dim2 = kwargs.get("hidden_dim")
|
||||
else:
|
||||
hidden_dim1 = kwargs.get("hidden_dim1", 200)
|
||||
hidden_dim2 = kwargs.get("hidden_dim2", 100)
|
||||
model = SuperSequential(
|
||||
norm_cls(mean=mean, std=std),
|
||||
SuperLinear(kwargs["input_dim"], hidden_dim1),
|
||||
act_cls(),
|
||||
SuperLinear(hidden_dim1, hidden_dim2),
|
||||
act_cls(),
|
||||
SuperLinear(hidden_dim2, kwargs["output_dim"]),
|
||||
)
|
||||
elif model_type == "norm_mlp":
|
||||
act_cls = super_name2activation[kwargs["act_cls"]]
|
||||
norm_cls = super_name2norm[kwargs["norm_cls"]]
|
||||
sub_layers, last_dim = [], kwargs["input_dim"]
|
||||
for i, hidden_dim in enumerate(kwargs["hidden_dims"]):
|
||||
if last_dim > 1:
|
||||
sub_layers.append(norm_cls(last_dim, elementwise_affine=False))
|
||||
sub_layers.append(SuperLinear(last_dim, hidden_dim))
|
||||
sub_layers.append(act_cls())
|
||||
last_dim = hidden_dim
|
||||
sub_layers.append(SuperLinear(last_dim, kwargs["output_dim"]))
|
||||
model = SuperSequential(*sub_layers)
|
||||
elif model_type == "dual_norm_mlp":
|
||||
act_cls = super_name2activation[kwargs["act_cls"]]
|
||||
norm_cls = super_name2norm[kwargs["norm_cls"]]
|
||||
sub_layers, last_dim = [], kwargs["input_dim"]
|
||||
for i, hidden_dim in enumerate(kwargs["hidden_dims"]):
|
||||
if i > 0:
|
||||
sub_layers.append(norm_cls(last_dim, elementwise_affine=False))
|
||||
sub_layers.append(SuperLinear(last_dim, hidden_dim))
|
||||
sub_layers.append(SuperDropout(kwargs["dropout"]))
|
||||
sub_layers.append(SuperLinear(hidden_dim, hidden_dim))
|
||||
sub_layers.append(act_cls())
|
||||
last_dim = hidden_dim
|
||||
sub_layers.append(SuperLinear(last_dim, kwargs["output_dim"]))
|
||||
model = SuperSequential(*sub_layers)
|
||||
else:
|
||||
raise TypeError("Unkonwn model type: {:}".format(model_type))
|
||||
return model
|
||||
@@ -1,15 +0,0 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
|
||||
#####################################################################
|
||||
# This API will not be updated after 2020.09.16. #
|
||||
# Please use our new API in NATS-Bench, which is #
|
||||
# more efficient and contains info of more architecture candidates. #
|
||||
#####################################################################
|
||||
from .api_utils import ArchResults, ResultsCount
|
||||
from .api_201 import NASBench201API
|
||||
|
||||
# NAS_BENCH_201_API_VERSION="v1.1" # [2020.02.25]
|
||||
# NAS_BENCH_201_API_VERSION="v1.2" # [2020.03.09]
|
||||
# NAS_BENCH_201_API_VERSION="v1.3" # [2020.03.16]
|
||||
NAS_BENCH_201_API_VERSION="v2.0" # [2020.06.30]
|
||||
|
||||
@@ -1,274 +0,0 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
|
||||
############################################################################################
|
||||
# NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search, ICLR 2020 #
|
||||
############################################################################################
|
||||
# The history of benchmark files:
|
||||
# [2020.02.25] NAS-Bench-201-v1_0-e61699.pth : 6219 architectures are trained once, 1621 architectures are trained twice, 7785 architectures are trained three times. `LESS` only supports CIFAR10-VALID.
|
||||
# [2020.03.16] NAS-Bench-201-v1_1-096897.pth : 2225 architectures are trained once, 5439 archiitectures are trained twice, 7961 architectures are trained three times on all training sets. For the hyper-parameters with the total epochs of 12, each model is trained on CIFAR-10, CIFAR-100, ImageNet16-120 once, and is trained on CIFAR-10-VALID twice.
|
||||
#
|
||||
# I'm still actively enhancing our benchmark, while for the future benchmark file, please follow news from NATS-Bench (an extended version of NAS-Bench-201).
|
||||
#
|
||||
import os, copy, random, torch, numpy as np
|
||||
from pathlib import Path
|
||||
from typing import List, Text, Union, Dict, Optional
|
||||
from collections import OrderedDict, defaultdict
|
||||
|
||||
from .api_utils import ArchResults
|
||||
from .api_utils import NASBenchMetaAPI
|
||||
from .api_utils import remap_dataset_set_names
|
||||
|
||||
|
||||
ALL_BENCHMARK_FILES = ['NAS-Bench-201-v1_0-e61699.pth', 'NAS-Bench-201-v1_1-096897.pth']
|
||||
ALL_ARCHIVE_DIRS = ['NAS-Bench-201-v1_1-archive']
|
||||
|
||||
|
||||
def print_information(information, extra_info=None, show=False):
|
||||
dataset_names = information.get_dataset_names()
|
||||
strings = [information.arch_str, 'datasets : {:}, extra-info : {:}'.format(dataset_names, extra_info)]
|
||||
def metric2str(loss, acc):
|
||||
return 'loss = {:.3f}, top1 = {:.2f}%'.format(loss, acc)
|
||||
|
||||
for ida, dataset in enumerate(dataset_names):
|
||||
metric = information.get_compute_costs(dataset)
|
||||
flop, param, latency = metric['flops'], metric['params'], metric['latency']
|
||||
str1 = '{:14s} FLOP={:6.2f} M, Params={:.3f} MB, latency={:} ms.'.format(dataset, flop, param, '{:.2f}'.format(latency*1000) if latency is not None and latency > 0 else None)
|
||||
train_info = information.get_metrics(dataset, 'train')
|
||||
if dataset == 'cifar10-valid':
|
||||
valid_info = information.get_metrics(dataset, 'x-valid')
|
||||
str2 = '{:14s} train : [{:}], valid : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(valid_info['loss'], valid_info['accuracy']))
|
||||
elif dataset == 'cifar10':
|
||||
test__info = information.get_metrics(dataset, 'ori-test')
|
||||
str2 = '{:14s} train : [{:}], test : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(test__info['loss'], test__info['accuracy']))
|
||||
else:
|
||||
valid_info = information.get_metrics(dataset, 'x-valid')
|
||||
test__info = information.get_metrics(dataset, 'x-test')
|
||||
str2 = '{:14s} train : [{:}], valid : [{:}], test : [{:}]'.format(dataset, metric2str(train_info['loss'], train_info['accuracy']), metric2str(valid_info['loss'], valid_info['accuracy']), metric2str(test__info['loss'], test__info['accuracy']))
|
||||
strings += [str1, str2]
|
||||
if show: print('\n'.join(strings))
|
||||
return strings
|
||||
|
||||
|
||||
"""
|
||||
This is the class for the API of NAS-Bench-201.
|
||||
"""
|
||||
class NASBench201API(NASBenchMetaAPI):
|
||||
|
||||
""" The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """
|
||||
def __init__(self, file_path_or_dict: Optional[Union[Text, Dict]]=None,
|
||||
verbose: bool=True):
|
||||
self.filename = None
|
||||
self.reset_time()
|
||||
if file_path_or_dict is None:
|
||||
file_path_or_dict = os.path.join(os.environ['TORCH_HOME'], ALL_BENCHMARK_FILES[-1])
|
||||
print ('Try to use the default NAS-Bench-201 path from {:}.'.format(file_path_or_dict))
|
||||
if isinstance(file_path_or_dict, str) or isinstance(file_path_or_dict, Path):
|
||||
file_path_or_dict = str(file_path_or_dict)
|
||||
if verbose: print('try to create the NAS-Bench-201 api from {:}'.format(file_path_or_dict))
|
||||
assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict)
|
||||
self.filename = Path(file_path_or_dict).name
|
||||
file_path_or_dict = torch.load(file_path_or_dict, map_location='cpu')
|
||||
elif isinstance(file_path_or_dict, dict):
|
||||
file_path_or_dict = copy.deepcopy(file_path_or_dict)
|
||||
else: raise ValueError('invalid type : {:} not in [str, dict]'.format(type(file_path_or_dict)))
|
||||
assert isinstance(file_path_or_dict, dict), 'It should be a dict instead of {:}'.format(type(file_path_or_dict))
|
||||
self.verbose = verbose # [TODO] a flag indicating whether to print more logs
|
||||
keys = ('meta_archs', 'arch2infos', 'evaluated_indexes')
|
||||
for key in keys: assert key in file_path_or_dict, 'Can not find key[{:}] in the dict'.format(key)
|
||||
self.meta_archs = copy.deepcopy( file_path_or_dict['meta_archs'] )
|
||||
# This is a dict mapping each architecture to a dict, where the key is #epochs and the value is ArchResults
|
||||
self.arch2infos_dict = OrderedDict()
|
||||
self._avaliable_hps = set(['12', '200'])
|
||||
for xkey in sorted(list(file_path_or_dict['arch2infos'].keys())):
|
||||
all_info = file_path_or_dict['arch2infos'][xkey]
|
||||
hp2archres = OrderedDict()
|
||||
# self.arch2infos_less[xkey] = ArchResults.create_from_state_dict( all_info['less'] )
|
||||
# self.arch2infos_full[xkey] = ArchResults.create_from_state_dict( all_info['full'] )
|
||||
hp2archres['12'] = ArchResults.create_from_state_dict(all_info['less'])
|
||||
hp2archres['200'] = ArchResults.create_from_state_dict(all_info['full'])
|
||||
self.arch2infos_dict[xkey] = hp2archres
|
||||
self.evaluated_indexes = sorted(list(file_path_or_dict['evaluated_indexes']))
|
||||
self.archstr2index = {}
|
||||
for idx, arch in enumerate(self.meta_archs):
|
||||
assert arch not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch])
|
||||
self.archstr2index[ arch ] = idx
|
||||
|
||||
def reload(self, archive_root: Text = None, index: int = None):
|
||||
"""Overwrite all information of the 'index'-th architecture in the search space.
|
||||
It will load its data from 'archive_root'.
|
||||
"""
|
||||
if archive_root is None:
|
||||
archive_root = os.path.join(os.environ['TORCH_HOME'], ALL_ARCHIVE_DIRS[-1])
|
||||
assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root)
|
||||
if index is None:
|
||||
indexes = list(range(len(self)))
|
||||
else:
|
||||
indexes = [index]
|
||||
for idx in indexes:
|
||||
assert 0 <= idx < len(self.meta_archs), 'invalid index of {:}'.format(idx)
|
||||
xfile_path = os.path.join(archive_root, '{:06d}-FULL.pth'.format(idx))
|
||||
assert os.path.isfile(xfile_path), 'invalid data path : {:}'.format(xfile_path)
|
||||
xdata = torch.load(xfile_path, map_location='cpu')
|
||||
assert isinstance(xdata, dict) and 'full' in xdata and 'less' in xdata, 'invalid format of data in {:}'.format(xfile_path)
|
||||
hp2archres = OrderedDict()
|
||||
hp2archres['12'] = ArchResults.create_from_state_dict(xdata['less'])
|
||||
hp2archres['200'] = ArchResults.create_from_state_dict(xdata['full'])
|
||||
self.arch2infos_dict[idx] = hp2archres
|
||||
|
||||
def query_info_str_by_arch(self, arch, hp: Text='12'):
|
||||
""" This function is used to query the information of a specific architecture
|
||||
'arch' can be an architecture index or an architecture string
|
||||
When hp=12, the hyper-parameters used to train a model are in 'configs/nas-benchmark/hyper-opts/12E.config'
|
||||
When hp=200, the hyper-parameters used to train a model are in 'configs/nas-benchmark/hyper-opts/200E.config'
|
||||
The difference between these three configurations are the number of training epochs.
|
||||
"""
|
||||
if self.verbose:
|
||||
print('Call query_info_str_by_arch with arch={:} and hp={:}'.format(arch, hp))
|
||||
return self._query_info_str_by_arch(arch, hp, print_information)
|
||||
|
||||
# obtain the metric for the `index`-th architecture
|
||||
# `dataset` indicates the dataset:
|
||||
# 'cifar10-valid' : using the proposed train set of CIFAR-10 as the training set
|
||||
# 'cifar10' : using the proposed train+valid set of CIFAR-10 as the training set
|
||||
# 'cifar100' : using the proposed train set of CIFAR-100 as the training set
|
||||
# 'ImageNet16-120' : using the proposed train set of ImageNet-16-120 as the training set
|
||||
# `iepoch` indicates the index of training epochs from 0 to 11/199.
|
||||
# When iepoch=None, it will return the metric for the last training epoch
|
||||
# When iepoch=11, it will return the metric for the 11-th training epoch (starting from 0)
|
||||
# `use_12epochs_result` indicates different hyper-parameters for training
|
||||
# When use_12epochs_result=True, it trains the network with 12 epochs and the LR decayed from 0.1 to 0 within 12 epochs
|
||||
# When use_12epochs_result=False, it trains the network with 200 epochs and the LR decayed from 0.1 to 0 within 200 epochs
|
||||
# `is_random`
|
||||
# When is_random=True, the performance of a random architecture will be returned
|
||||
# When is_random=False, the performanceo of all trials will be averaged.
|
||||
def get_more_info(self, index, dataset, iepoch=None, hp='12', is_random=True):
|
||||
if self.verbose:
|
||||
print('Call the get_more_info function with index={:}, dataset={:}, iepoch={:}, hp={:}, and is_random={:}.'.format(index, dataset, iepoch, hp, is_random))
|
||||
index = self.query_index_by_arch(index) # To avoid the input is a string or an instance of a arch object
|
||||
if index not in self.arch2infos_dict:
|
||||
raise ValueError('Did not find {:} from arch2infos_dict.'.format(index))
|
||||
archresult = self.arch2infos_dict[index][str(hp)]
|
||||
# if randomly select one trial, select the seed at first
|
||||
if isinstance(is_random, bool) and is_random:
|
||||
seeds = archresult.get_dataset_seeds(dataset)
|
||||
is_random = random.choice(seeds)
|
||||
# collect the training information
|
||||
train_info = archresult.get_metrics(dataset, 'train', iepoch=iepoch, is_random=is_random)
|
||||
total = train_info['iepoch'] + 1
|
||||
xinfo = {'train-loss' : train_info['loss'],
|
||||
'train-accuracy': train_info['accuracy'],
|
||||
'train-per-time': train_info['all_time'] / total if train_info['all_time'] is not None else None,
|
||||
'train-all-time': train_info['all_time']}
|
||||
# collect the evaluation information
|
||||
if dataset == 'cifar10-valid':
|
||||
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
|
||||
try:
|
||||
test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
|
||||
except:
|
||||
test_info = None
|
||||
valtest_info = None
|
||||
else:
|
||||
try: # collect results on the proposed test set
|
||||
if dataset == 'cifar10':
|
||||
test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
|
||||
else:
|
||||
test_info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random)
|
||||
except:
|
||||
test_info = None
|
||||
try: # collect results on the proposed validation set
|
||||
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
|
||||
except:
|
||||
valid_info = None
|
||||
try:
|
||||
if dataset != 'cifar10':
|
||||
valtest_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
|
||||
else:
|
||||
valtest_info = None
|
||||
except:
|
||||
valtest_info = None
|
||||
if valid_info is not None:
|
||||
xinfo['valid-loss'] = valid_info['loss']
|
||||
xinfo['valid-accuracy'] = valid_info['accuracy']
|
||||
xinfo['valid-per-time'] = valid_info['all_time'] / total if valid_info['all_time'] is not None else None
|
||||
xinfo['valid-all-time'] = valid_info['all_time']
|
||||
if test_info is not None:
|
||||
xinfo['test-loss'] = test_info['loss']
|
||||
xinfo['test-accuracy'] = test_info['accuracy']
|
||||
xinfo['test-per-time'] = test_info['all_time'] / total if test_info['all_time'] is not None else None
|
||||
xinfo['test-all-time'] = test_info['all_time']
|
||||
if valtest_info is not None:
|
||||
xinfo['valtest-loss'] = valtest_info['loss']
|
||||
xinfo['valtest-accuracy'] = valtest_info['accuracy']
|
||||
xinfo['valtest-per-time'] = valtest_info['all_time'] / total if valtest_info['all_time'] is not None else None
|
||||
xinfo['valtest-all-time'] = valtest_info['all_time']
|
||||
return xinfo
|
||||
|
||||
def show(self, index: int = -1) -> None:
|
||||
"""This function will print the information of a specific (or all) architecture(s)."""
|
||||
self._show(index, print_information)
|
||||
|
||||
@staticmethod
|
||||
def str2lists(arch_str: Text) -> List[tuple]:
|
||||
"""
|
||||
This function shows how to read the string-based architecture encoding.
|
||||
It is the same as the `str2structure` func in `AutoDL-Projects/lib/models/cell_searchs/genotypes.py`
|
||||
|
||||
:param
|
||||
arch_str: the input is a string indicates the architecture topology, such as
|
||||
|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|
|
||||
:return: a list of tuple, contains multiple (op, input_node_index) pairs.
|
||||
|
||||
:usage
|
||||
arch = api.str2lists( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' )
|
||||
print ('there are {:} nodes in this arch'.format(len(arch)+1)) # arch is a list
|
||||
for i, node in enumerate(arch):
|
||||
print('the {:}-th node is the sum of these {:} nodes with op: {:}'.format(i+1, len(node), node))
|
||||
"""
|
||||
node_strs = arch_str.split('+')
|
||||
genotypes = []
|
||||
for i, node_str in enumerate(node_strs):
|
||||
inputs = list(filter(lambda x: x != '', node_str.split('|')))
|
||||
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
|
||||
inputs = ( xi.split('~') for xi in inputs )
|
||||
input_infos = tuple( (op, int(IDX)) for (op, IDX) in inputs)
|
||||
genotypes.append( input_infos )
|
||||
return genotypes
|
||||
|
||||
@staticmethod
|
||||
def str2matrix(arch_str: Text,
|
||||
search_space: List[Text] = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']) -> np.ndarray:
|
||||
"""
|
||||
This func shows how to convert the string-based architecture encoding to the encoding strategy in NAS-Bench-101.
|
||||
|
||||
:param
|
||||
arch_str: the input is a string indicates the architecture topology, such as
|
||||
|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|
|
||||
search_space: a list of operation string, the default list is the search space for NAS-Bench-201
|
||||
the default value should be be consistent with this line https://github.com/D-X-Y/AutoDL-Projects/blob/main/lib/models/cell_operations.py#L24
|
||||
:return
|
||||
the numpy matrix (2-D np.ndarray) representing the DAG of this architecture topology
|
||||
:usage
|
||||
matrix = api.str2matrix( '|nor_conv_1x1~0|+|none~0|none~1|+|none~0|none~1|skip_connect~2|' )
|
||||
This matrix is 4-by-4 matrix representing a cell with 4 nodes (only the lower left triangle is useful).
|
||||
[ [0, 0, 0, 0], # the first line represents the input (0-th) node
|
||||
[2, 0, 0, 0], # the second line represents the 1-st node, is calculated by 2-th-op( 0-th-node )
|
||||
[0, 0, 0, 0], # the third line represents the 2-nd node, is calculated by 0-th-op( 0-th-node ) + 0-th-op( 1-th-node )
|
||||
[0, 0, 1, 0] ] # the fourth line represents the 3-rd node, is calculated by 0-th-op( 0-th-node ) + 0-th-op( 1-th-node ) + 1-th-op( 2-th-node )
|
||||
In NAS-Bench-201 search space, 0-th-op is 'none', 1-th-op is 'skip_connect',
|
||||
2-th-op is 'nor_conv_1x1', 3-th-op is 'nor_conv_3x3', 4-th-op is 'avg_pool_3x3'.
|
||||
:(NOTE)
|
||||
If a node has two input-edges from the same node, this function does not work. One edge will be overlapped.
|
||||
"""
|
||||
node_strs = arch_str.split('+')
|
||||
num_nodes = len(node_strs) + 1
|
||||
matrix = np.zeros((num_nodes, num_nodes))
|
||||
for i, node_str in enumerate(node_strs):
|
||||
inputs = list(filter(lambda x: x != '', node_str.split('|')))
|
||||
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
|
||||
for xi in inputs:
|
||||
op, idx = xi.split('~')
|
||||
if op not in search_space: raise ValueError('this op ({:}) is not in {:}'.format(op, search_space))
|
||||
op_idx, node_idx = search_space.index(op), int(idx)
|
||||
matrix[i+1, node_idx] = op_idx
|
||||
return matrix
|
||||
|
||||
@@ -1,748 +0,0 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
|
||||
############################################################################################
|
||||
# NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search, ICLR 2020 #
|
||||
############################################################################################
|
||||
# In this Python file, we define NASBenchMetaAPI, the abstract class for benchmark APIs.
|
||||
# We also define the class ArchResults, which contains all information of a single architecture trained by one kind of hyper-parameters on three datasets.
|
||||
# We also define the class ResultsCount, which contains all information of a single trial for a single architecture.
|
||||
############################################################################################
|
||||
#
|
||||
import os, abc, copy, random, torch, numpy as np
|
||||
from pathlib import Path
|
||||
from typing import List, Text, Union, Dict, Optional
|
||||
from collections import OrderedDict, defaultdict
|
||||
|
||||
|
||||
def remap_dataset_set_names(dataset, metric_on_set, verbose=False):
|
||||
"""re-map the metric_on_set to internal keys"""
|
||||
if verbose:
|
||||
print('Call internal function _remap_dataset_set_names with dataset={:} and metric_on_set={:}'.format(dataset, metric_on_set))
|
||||
if dataset == 'cifar10' and metric_on_set == 'valid':
|
||||
dataset, metric_on_set = 'cifar10-valid', 'x-valid'
|
||||
elif dataset == 'cifar10' and metric_on_set == 'test':
|
||||
dataset, metric_on_set = 'cifar10', 'ori-test'
|
||||
elif dataset == 'cifar10' and metric_on_set == 'train':
|
||||
dataset, metric_on_set = 'cifar10', 'train'
|
||||
elif (dataset == 'cifar100' or dataset == 'ImageNet16-120') and metric_on_set == 'valid':
|
||||
metric_on_set = 'x-valid'
|
||||
elif (dataset == 'cifar100' or dataset == 'ImageNet16-120') and metric_on_set == 'test':
|
||||
metric_on_set = 'x-test'
|
||||
if verbose:
|
||||
print(' return dataset={:} and metric_on_set={:}'.format(dataset, metric_on_set))
|
||||
return dataset, metric_on_set
|
||||
|
||||
|
||||
class NASBenchMetaAPI(metaclass=abc.ABCMeta):
|
||||
|
||||
@abc.abstractmethod
|
||||
def __init__(self, file_path_or_dict: Optional[Union[Text, Dict]]=None, verbose: bool=True):
|
||||
"""The initialization function that takes the dataset file path (or a dict loaded from that path) as input."""
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
return copy.deepcopy(self.meta_archs[index])
|
||||
|
||||
def arch(self, index: int):
|
||||
"""Return the topology structure of the `index`-th architecture."""
|
||||
if self.verbose:
|
||||
print('Call the arch function with index={:}'.format(index))
|
||||
assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs))
|
||||
return copy.deepcopy(self.meta_archs[index])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.meta_archs)
|
||||
|
||||
def __repr__(self):
|
||||
return ('{name}({num}/{total} architectures, file={filename})'.format(name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs), filename=self.filename))
|
||||
|
||||
@property
|
||||
def avaliable_hps(self):
|
||||
return list(copy.deepcopy(self._avaliable_hps))
|
||||
|
||||
@property
|
||||
def used_time(self):
|
||||
return self._used_time
|
||||
|
||||
def reset_time(self):
|
||||
self._used_time = 0
|
||||
|
||||
def simulate_train_eval(self, arch, dataset, iepoch=None, hp='12', account_time=True):
|
||||
index = self.query_index_by_arch(arch)
|
||||
all_names = ('cifar10', 'cifar100', 'ImageNet16-120')
|
||||
assert dataset in all_names, 'Invalid dataset name : {:} vs {:}'.format(dataset, all_names)
|
||||
if dataset == 'cifar10':
|
||||
info = self.get_more_info(index, 'cifar10-valid', iepoch=iepoch, hp=hp, is_random=True)
|
||||
else:
|
||||
info = self.get_more_info(index, dataset, iepoch=iepoch, hp=hp, is_random=True)
|
||||
valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time']
|
||||
latency = self.get_latency(index, dataset)
|
||||
if account_time:
|
||||
self._used_time += time_cost
|
||||
return valid_acc, latency, time_cost, self._used_time
|
||||
|
||||
def random(self):
|
||||
"""Return a random index of all architectures."""
|
||||
return random.randint(0, len(self.meta_archs)-1)
|
||||
|
||||
def query_index_by_arch(self, arch):
|
||||
""" This function is used to query the index of an architecture in the search space.
|
||||
In the topology search space, the input arch can be an architecture string such as '|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|';
|
||||
or an instance that has the 'tostr' function that can generate the architecture string;
|
||||
or it is directly an architecture index, in this case, we will check whether it is valid or not.
|
||||
This function will return the index.
|
||||
If return -1, it means this architecture is not in the search space.
|
||||
Otherwise, it will return an int in [0, the-number-of-candidates-in-the-search-space).
|
||||
"""
|
||||
if self.verbose:
|
||||
print('Call query_index_by_arch with arch={:}'.format(arch))
|
||||
if isinstance(arch, int):
|
||||
if 0 <= arch < len(self):
|
||||
return arch
|
||||
else:
|
||||
raise ValueError('Invalid architecture index {:} vs [{:}, {:}].'.format(arch, 0, len(self)))
|
||||
elif isinstance(arch, str):
|
||||
if arch in self.archstr2index: arch_index = self.archstr2index[ arch ]
|
||||
else : arch_index = -1
|
||||
elif hasattr(arch, 'tostr'):
|
||||
if arch.tostr() in self.archstr2index: arch_index = self.archstr2index[ arch.tostr() ]
|
||||
else : arch_index = -1
|
||||
else: arch_index = -1
|
||||
return arch_index
|
||||
|
||||
def query_by_arch(self, arch, hp):
|
||||
# This is to make the current version be compatible with the old version.
|
||||
return self.query_info_str_by_arch(arch, hp)
|
||||
|
||||
@abc.abstractmethod
|
||||
def reload(self, archive_root: Text = None, index: int = None):
|
||||
"""Overwrite all information of the 'index'-th architecture in the search space, where the data will be loaded from 'archive_root'.
|
||||
If index is None, overwrite all ckps.
|
||||
"""
|
||||
|
||||
def clear_params(self, index: int, hp: Optional[Text]=None):
|
||||
"""Remove the architecture's weights to save memory.
|
||||
:arg
|
||||
index: the index of the target architecture
|
||||
hp: a flag to controll how to clear the parameters.
|
||||
-- None: clear all the weights in '01'/'12'/'90', which indicates the number of training epochs.
|
||||
-- '01' or '12' or '90': clear all the weights in arch2infos_dict[index][hp].
|
||||
"""
|
||||
if self.verbose:
|
||||
print('Call clear_params with index={:} and hp={:}'.format(index, hp))
|
||||
if hp is None:
|
||||
for key, result in self.arch2infos_dict[index].items():
|
||||
result.clear_params()
|
||||
else:
|
||||
if str(hp) not in self.arch2infos_dict[index]:
|
||||
raise ValueError('The {:}-th architecture only has hyper-parameters of {:} instead of {:}.'.format(index, list(self.arch2infos_dict[index].keys()), hp))
|
||||
self.arch2infos_dict[index][str(hp)].clear_params()
|
||||
|
||||
@abc.abstractmethod
|
||||
def query_info_str_by_arch(self, arch, hp: Text='12'):
|
||||
"""This function is used to query the information of a specific architecture."""
|
||||
|
||||
def _query_info_str_by_arch(self, arch, hp: Text='12', print_information=None):
|
||||
arch_index = self.query_index_by_arch(arch)
|
||||
if arch_index in self.arch2infos_dict:
|
||||
if hp not in self.arch2infos_dict[arch_index]:
|
||||
raise ValueError('The {:}-th architecture only has hyper-parameters of {:} instead of {:}.'.format(index, list(self.arch2infos_dict[arch_index].keys()), hp))
|
||||
info = self.arch2infos_dict[arch_index][hp]
|
||||
strings = print_information(info, 'arch-index={:}'.format(arch_index))
|
||||
return '\n'.join(strings)
|
||||
else:
|
||||
print ('Find this arch-index : {:}, but this arch is not evaluated.'.format(arch_index))
|
||||
return None
|
||||
|
||||
def query_meta_info_by_index(self, arch_index, hp: Text = '12'):
|
||||
"""Return the ArchResults for the 'arch_index'-th architecture. This function is similar to query_by_index."""
|
||||
if self.verbose:
|
||||
print('Call query_meta_info_by_index with arch_index={:}, hp={:}'.format(arch_index, hp))
|
||||
if arch_index in self.arch2infos_dict:
|
||||
if hp not in self.arch2infos_dict[arch_index]:
|
||||
raise ValueError('The {:}-th architecture only has hyper-parameters of {:} instead of {:}.'.format(arch_index, list(self.arch2infos_dict[arch_index].keys()), hp))
|
||||
info = self.arch2infos_dict[arch_index][hp]
|
||||
else:
|
||||
raise ValueError('arch_index [{:}] does not in arch2infos'.format(arch_index))
|
||||
return copy.deepcopy(info)
|
||||
|
||||
def query_by_index(self, arch_index: int, dataname: Union[None, Text] = None, hp: Text = '12'):
|
||||
""" This 'query_by_index' function is used to query information with the training of 01 epochs, 12 epochs, 90 epochs, or 200 epochs.
|
||||
------
|
||||
If hp=01, we train the model by 01 epochs (see config in configs/nas-benchmark/hyper-opts/01E.config)
|
||||
If hp=12, we train the model by 01 epochs (see config in configs/nas-benchmark/hyper-opts/12E.config)
|
||||
If hp=90, we train the model by 01 epochs (see config in configs/nas-benchmark/hyper-opts/90E.config)
|
||||
If hp=200, we train the model by 01 epochs (see config in configs/nas-benchmark/hyper-opts/200E.config)
|
||||
------
|
||||
If dataname is None, return the ArchResults
|
||||
else, return a dict with all trials on that dataset (the key is the seed)
|
||||
Options are 'cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120'.
|
||||
-- cifar10-valid : training the model on the CIFAR-10 training set.
|
||||
-- cifar10 : training the model on the CIFAR-10 training + validation set.
|
||||
-- cifar100 : training the model on the CIFAR-100 training set.
|
||||
-- ImageNet16-120 : training the model on the ImageNet16-120 training set.
|
||||
"""
|
||||
if self.verbose:
|
||||
print('Call query_by_index with arch_index={:}, dataname={:}, hp={:}'.format(arch_index, dataname, hp))
|
||||
info = self.query_meta_info_by_index(arch_index, hp)
|
||||
if dataname is None: return info
|
||||
else:
|
||||
if dataname not in info.get_dataset_names():
|
||||
raise ValueError('invalid dataset-name : {:} vs. {:}'.format(dataname, info.get_dataset_names()))
|
||||
return info.query(dataname)
|
||||
|
||||
def find_best(self, dataset, metric_on_set, FLOP_max=None, Param_max=None, hp: Text = '12'):
|
||||
"""Find the architecture with the highest accuracy based on some constraints."""
|
||||
if self.verbose:
|
||||
print('Call find_best with dataset={:}, metric_on_set={:}, hp={:} | with #FLOPs < {:} and #Params < {:}'.format(dataset, metric_on_set, hp, FLOP_max, Param_max))
|
||||
dataset, metric_on_set = remap_dataset_set_names(dataset, metric_on_set, self.verbose)
|
||||
best_index, highest_accuracy = -1, None
|
||||
for i, arch_index in enumerate(self.evaluated_indexes):
|
||||
arch_info = self.arch2infos_dict[arch_index][hp]
|
||||
info = arch_info.get_compute_costs(dataset) # the information of costs
|
||||
flop, param, latency = info['flops'], info['params'], info['latency']
|
||||
if FLOP_max is not None and flop > FLOP_max : continue
|
||||
if Param_max is not None and param > Param_max: continue
|
||||
xinfo = arch_info.get_metrics(dataset, metric_on_set) # the information of loss and accuracy
|
||||
loss, accuracy = xinfo['loss'], xinfo['accuracy']
|
||||
if best_index == -1:
|
||||
best_index, highest_accuracy = arch_index, accuracy
|
||||
elif highest_accuracy < accuracy:
|
||||
best_index, highest_accuracy = arch_index, accuracy
|
||||
if self.verbose:
|
||||
print(' the best architecture : [{:}] {:} with accuracy={:.3f}%'.format(best_index, self.arch(best_index), highest_accuracy))
|
||||
return best_index, highest_accuracy
|
||||
|
||||
def get_net_param(self, index, dataset, seed: Optional[int], hp: Text = '12'):
|
||||
"""
|
||||
This function is used to obtain the trained weights of the `index`-th architecture on `dataset` with the seed of `seed`
|
||||
Args [seed]:
|
||||
-- None : return a dict containing the trained weights of all trials, where each key is a seed and its corresponding value is the weights.
|
||||
-- a interger : return the weights of a specific trial, whose seed is this interger.
|
||||
Args [hp]:
|
||||
-- 01 : train the model by 01 epochs
|
||||
-- 12 : train the model by 12 epochs
|
||||
-- 90 : train the model by 90 epochs
|
||||
-- 200 : train the model by 200 epochs
|
||||
"""
|
||||
if self.verbose:
|
||||
print('Call the get_net_param function with index={:}, dataset={:}, seed={:}, hp={:}'.format(index, dataset, seed, hp))
|
||||
info = self.query_meta_info_by_index(index, hp)
|
||||
return info.get_net_param(dataset, seed)
|
||||
|
||||
def get_net_config(self, index: int, dataset: Text):
|
||||
"""
|
||||
This function is used to obtain the configuration for the `index`-th architecture on `dataset`.
|
||||
Args [dataset] (4 possible options):
|
||||
-- cifar10-valid : training the model on the CIFAR-10 training set.
|
||||
-- cifar10 : training the model on the CIFAR-10 training + validation set.
|
||||
-- cifar100 : training the model on the CIFAR-100 training set.
|
||||
-- ImageNet16-120 : training the model on the ImageNet16-120 training set.
|
||||
This function will return a dict.
|
||||
========= Some examlpes for using this function:
|
||||
config = api.get_net_config(128, 'cifar10')
|
||||
"""
|
||||
if self.verbose:
|
||||
print('Call the get_net_config function with index={:}, dataset={:}.'.format(index, dataset))
|
||||
if index in self.arch2infos_dict:
|
||||
info = self.arch2infos_dict[index]
|
||||
else:
|
||||
raise ValueError('The arch_index={:} is not in arch2infos_dict.'.format(arch_index))
|
||||
info = next(iter(info.values()))
|
||||
results = info.query(dataset, None)
|
||||
results = next(iter(results.values()))
|
||||
return results.get_config(None)
|
||||
|
||||
def get_cost_info(self, index: int, dataset: Text, hp: Text = '12') -> Dict[Text, float]:
|
||||
"""To obtain the cost metric for the `index`-th architecture on a dataset."""
|
||||
if self.verbose:
|
||||
print('Call the get_cost_info function with index={:}, dataset={:}, and hp={:}.'.format(index, dataset, hp))
|
||||
info = self.query_meta_info_by_index(index, hp)
|
||||
return info.get_compute_costs(dataset)
|
||||
|
||||
def get_latency(self, index: int, dataset: Text, hp: Text = '12') -> float:
|
||||
"""
|
||||
To obtain the latency of the network (by default it will return the latency with the batch size of 256).
|
||||
:param index: the index of the target architecture
|
||||
:param dataset: the dataset name (cifar10-valid, cifar10, cifar100, ImageNet16-120)
|
||||
:return: return a float value in seconds
|
||||
"""
|
||||
if self.verbose:
|
||||
print('Call the get_latency function with index={:}, dataset={:}, and hp={:}.'.format(index, dataset, hp))
|
||||
cost_dict = self.get_cost_info(index, dataset, hp)
|
||||
return cost_dict['latency']
|
||||
|
||||
@abc.abstractmethod
|
||||
def show(self, index=-1):
|
||||
"""This function will print the information of a specific (or all) architecture(s)."""
|
||||
|
||||
def _show(self, index=-1, print_information=None) -> None:
|
||||
"""
|
||||
This function will print the information of a specific (or all) architecture(s).
|
||||
|
||||
:param index: If the index < 0: it will loop for all architectures and print their information one by one.
|
||||
else: it will print the information of the 'index'-th architecture.
|
||||
:return: nothing
|
||||
"""
|
||||
if index < 0: # show all architectures
|
||||
print(self)
|
||||
for i, idx in enumerate(self.evaluated_indexes):
|
||||
print('\n' + '-' * 10 + ' The ({:5d}/{:5d}) {:06d}-th architecture! '.format(i, len(self.evaluated_indexes), idx) + '-'*10)
|
||||
print('arch : {:}'.format(self.meta_archs[idx]))
|
||||
for key, result in self.arch2infos_dict[index].items():
|
||||
strings = print_information(result)
|
||||
print('>' * 40 + ' {:03d} epochs '.format(result.get_total_epoch()) + '>' * 40)
|
||||
print('\n'.join(strings))
|
||||
print('<' * 40 + '------------' + '<' * 40)
|
||||
else:
|
||||
if 0 <= index < len(self.meta_archs):
|
||||
if index not in self.evaluated_indexes: print('The {:}-th architecture has not been evaluated or not saved.'.format(index))
|
||||
else:
|
||||
arch_info = self.arch2infos_dict[index]
|
||||
for key, result in self.arch2infos_dict[index].items():
|
||||
strings = print_information(result)
|
||||
print('>' * 40 + ' {:03d} epochs '.format(result.get_total_epoch()) + '>' * 40)
|
||||
print('\n'.join(strings))
|
||||
print('<' * 40 + '------------' + '<' * 40)
|
||||
else:
|
||||
print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs)))
|
||||
|
||||
def statistics(self, dataset: Text, hp: Union[Text, int]) -> Dict[int, int]:
|
||||
"""This function will count the number of total trials."""
|
||||
if self.verbose:
|
||||
print('Call the statistics function with dataset={:} and hp={:}.'.format(dataset, hp))
|
||||
valid_datasets = ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']
|
||||
if dataset not in valid_datasets:
|
||||
raise ValueError('{:} not in {:}'.format(dataset, valid_datasets))
|
||||
nums, hp = defaultdict(lambda: 0), str(hp)
|
||||
for index in range(len(self)):
|
||||
archInfo = self.arch2infos_dict[index][hp]
|
||||
dataset_seed = archInfo.dataset_seed
|
||||
if dataset not in dataset_seed:
|
||||
nums[0] += 1
|
||||
else:
|
||||
nums[len(dataset_seed[dataset])] += 1
|
||||
return dict(nums)
|
||||
|
||||
|
||||
class ArchResults(object):
|
||||
|
||||
def __init__(self, arch_index, arch_str):
|
||||
self.arch_index = int(arch_index)
|
||||
self.arch_str = copy.deepcopy(arch_str)
|
||||
self.all_results = dict()
|
||||
self.dataset_seed = dict()
|
||||
self.clear_net_done = False
|
||||
|
||||
def get_compute_costs(self, dataset):
|
||||
x_seeds = self.dataset_seed[dataset]
|
||||
results = [self.all_results[ (dataset, seed) ] for seed in x_seeds]
|
||||
|
||||
flops = [result.flop for result in results]
|
||||
params = [result.params for result in results]
|
||||
latencies = [result.get_latency() for result in results]
|
||||
latencies = [x for x in latencies if x > 0]
|
||||
mean_latency = np.mean(latencies) if len(latencies) > 0 else None
|
||||
time_infos = defaultdict(list)
|
||||
for result in results:
|
||||
time_info = result.get_times()
|
||||
for key, value in time_info.items(): time_infos[key].append( value )
|
||||
|
||||
info = {'flops' : np.mean(flops),
|
||||
'params' : np.mean(params),
|
||||
'latency': mean_latency}
|
||||
for key, value in time_infos.items():
|
||||
if len(value) > 0 and value[0] is not None:
|
||||
info[key] = np.mean(value)
|
||||
else: info[key] = None
|
||||
return info
|
||||
|
||||
def get_metrics(self, dataset, setname, iepoch=None, is_random=False):
|
||||
"""
|
||||
This `get_metrics` function is used to obtain obtain the loss, accuracy, etc information on a specific dataset.
|
||||
If not specify, each set refer to the proposed split in NAS-Bench-201 paper.
|
||||
If some args return None or raise error, then it is not avaliable.
|
||||
========================================
|
||||
Args [dataset] (4 possible options):
|
||||
-- cifar10-valid : training the model on the CIFAR-10 training set.
|
||||
-- cifar10 : training the model on the CIFAR-10 training + validation set.
|
||||
-- cifar100 : training the model on the CIFAR-100 training set.
|
||||
-- ImageNet16-120 : training the model on the ImageNet16-120 training set.
|
||||
Args [setname] (each dataset has different setnames):
|
||||
-- When dataset = cifar10-valid, you can use 'train', 'x-valid', 'ori-test'
|
||||
------ 'train' : the metric on the training set.
|
||||
------ 'x-valid' : the metric on the validation set.
|
||||
------ 'ori-test' : the metric on the test set.
|
||||
-- When dataset = cifar10, you can use 'train', 'ori-test'.
|
||||
------ 'train' : the metric on the training + validation set.
|
||||
------ 'ori-test' : the metric on the test set.
|
||||
-- When dataset = cifar100 or ImageNet16-120, you can use 'train', 'ori-test', 'x-valid', 'x-test'
|
||||
------ 'train' : the metric on the training set.
|
||||
------ 'x-valid' : the metric on the validation set.
|
||||
------ 'x-test' : the metric on the test set.
|
||||
------ 'ori-test' : the metric on the validation + test set.
|
||||
Args [iepoch] (None or an integer in [0, the-number-of-total-training-epochs)
|
||||
------ None : return the metric after the last training epoch.
|
||||
------ an integer i : return the metric after the i-th training epoch.
|
||||
Args [is_random]:
|
||||
------ True : return the metric of a randomly selected trial.
|
||||
------ False : return the averaged metric of all avaliable trials.
|
||||
------ an integer indicating the 'seed' value : return the metric of a specific trial (whose random seed is 'is_random').
|
||||
"""
|
||||
x_seeds = self.dataset_seed[dataset]
|
||||
results = [self.all_results[ (dataset, seed) ] for seed in x_seeds]
|
||||
infos = defaultdict(list)
|
||||
for result in results:
|
||||
if setname == 'train':
|
||||
info = result.get_train(iepoch)
|
||||
else:
|
||||
info = result.get_eval(setname, iepoch)
|
||||
for key, value in info.items(): infos[key].append( value )
|
||||
return_info = dict()
|
||||
if isinstance(is_random, bool) and is_random: # randomly select one
|
||||
index = random.randint(0, len(results)-1)
|
||||
for key, value in infos.items(): return_info[key] = value[index]
|
||||
elif isinstance(is_random, bool) and not is_random: # average
|
||||
for key, value in infos.items():
|
||||
if len(value) > 0 and value[0] is not None:
|
||||
return_info[key] = np.mean(value)
|
||||
else: return_info[key] = None
|
||||
elif isinstance(is_random, int): # specify the seed
|
||||
if is_random not in x_seeds: raise ValueError('can not find random seed ({:}) from {:}'.format(is_random, x_seeds))
|
||||
index = x_seeds.index(is_random)
|
||||
for key, value in infos.items(): return_info[key] = value[index]
|
||||
else:
|
||||
raise ValueError('invalid value for is_random: {:}'.format(is_random))
|
||||
return return_info
|
||||
|
||||
def show(self, is_print=False):
|
||||
return print_information(self, None, is_print)
|
||||
|
||||
def get_dataset_names(self):
|
||||
return list(self.dataset_seed.keys())
|
||||
|
||||
def get_dataset_seeds(self, dataset):
|
||||
return copy.deepcopy( self.dataset_seed[dataset] )
|
||||
|
||||
def get_net_param(self, dataset: Text, seed: Union[None, int] =None):
|
||||
"""
|
||||
This function will return the trained network's weights on the 'dataset'.
|
||||
:arg
|
||||
dataset: one of 'cifar10-valid', 'cifar10', 'cifar100', and 'ImageNet16-120'.
|
||||
seed: an integer indicates the seed value or None that indicates returing all trials.
|
||||
"""
|
||||
if seed is None:
|
||||
x_seeds = self.dataset_seed[dataset]
|
||||
return {seed: self.all_results[(dataset, seed)].get_net_param() for seed in x_seeds}
|
||||
else:
|
||||
xkey = (dataset, seed)
|
||||
if xkey in self.all_results:
|
||||
return self.all_results[xkey].get_net_param()
|
||||
else:
|
||||
raise ValueError('key={:} not in {:}'.format(xkey, list(self.all_results.keys())))
|
||||
|
||||
def reset_latency(self, dataset: Text, seed: Union[None, Text], latency: float) -> None:
|
||||
"""This function is used to reset the latency in all corresponding ResultsCount(s)."""
|
||||
if seed is None:
|
||||
for seed in self.dataset_seed[dataset]:
|
||||
self.all_results[(dataset, seed)].update_latency([latency])
|
||||
else:
|
||||
self.all_results[(dataset, seed)].update_latency([latency])
|
||||
|
||||
def reset_pseudo_train_times(self, dataset: Text, seed: Union[None, Text], estimated_per_epoch_time: float) -> None:
|
||||
"""This function is used to reset the train-times in all corresponding ResultsCount(s)."""
|
||||
if seed is None:
|
||||
for seed in self.dataset_seed[dataset]:
|
||||
self.all_results[(dataset, seed)].reset_pseudo_train_times(estimated_per_epoch_time)
|
||||
else:
|
||||
self.all_results[(dataset, seed)].reset_pseudo_train_times(estimated_per_epoch_time)
|
||||
|
||||
def reset_pseudo_eval_times(self, dataset: Text, seed: Union[None, Text], eval_name: Text, estimated_per_epoch_time: float) -> None:
|
||||
"""This function is used to reset the eval-times in all corresponding ResultsCount(s)."""
|
||||
if seed is None:
|
||||
for seed in self.dataset_seed[dataset]:
|
||||
self.all_results[(dataset, seed)].reset_pseudo_eval_times(eval_name, estimated_per_epoch_time)
|
||||
else:
|
||||
self.all_results[(dataset, seed)].reset_pseudo_eval_times(eval_name, estimated_per_epoch_time)
|
||||
|
||||
def get_latency(self, dataset: Text) -> float:
|
||||
"""Get the latency of a model on the target dataset. [Timestamp: 2020.03.09]"""
|
||||
latencies = []
|
||||
for seed in self.dataset_seed[dataset]:
|
||||
latency = self.all_results[(dataset, seed)].get_latency()
|
||||
if not isinstance(latency, float) or latency <= 0:
|
||||
raise ValueError('invalid latency of {:} with seed={:} : {:}'.format(dataset, seed, latency))
|
||||
latencies.append(latency)
|
||||
return sum(latencies) / len(latencies)
|
||||
|
||||
def get_total_epoch(self, dataset=None):
|
||||
"""Return the total number of training epochs."""
|
||||
if dataset is None:
|
||||
epochss = []
|
||||
for xdata, x_seeds in self.dataset_seed.items():
|
||||
epochss += [self.all_results[(xdata, seed)].get_total_epoch() for seed in x_seeds]
|
||||
elif isinstance(dataset, str):
|
||||
x_seeds = self.dataset_seed[dataset]
|
||||
epochss = [self.all_results[(dataset, seed)].get_total_epoch() for seed in x_seeds]
|
||||
else:
|
||||
raise ValueError('invalid dataset={:}'.format(dataset))
|
||||
if len(set(epochss)) > 1: raise ValueError('Each trial mush have the same number of training epochs : {:}'.format(epochss))
|
||||
return epochss[-1]
|
||||
|
||||
def query(self, dataset, seed=None):
|
||||
"""Return the ResultsCount object (containing all information of a single trial) for 'dataset' and 'seed'"""
|
||||
if seed is None:
|
||||
x_seeds = self.dataset_seed[dataset]
|
||||
return {seed: self.all_results[(dataset, seed)] for seed in x_seeds}
|
||||
else:
|
||||
return self.all_results[(dataset, seed)]
|
||||
|
||||
def arch_idx_str(self):
|
||||
return '{:06d}'.format(self.arch_index)
|
||||
|
||||
def update(self, dataset_name, seed, result):
|
||||
if dataset_name not in self.dataset_seed:
|
||||
self.dataset_seed[dataset_name] = []
|
||||
assert seed not in self.dataset_seed[dataset_name], '{:}-th arch alreadly has this seed ({:}) on {:}'.format(self.arch_index, seed, dataset_name)
|
||||
self.dataset_seed[ dataset_name ].append( seed )
|
||||
self.dataset_seed[ dataset_name ] = sorted( self.dataset_seed[ dataset_name ] )
|
||||
assert (dataset_name, seed) not in self.all_results
|
||||
self.all_results[ (dataset_name, seed) ] = result
|
||||
self.clear_net_done = False
|
||||
|
||||
def state_dict(self):
|
||||
state_dict = dict()
|
||||
for key, value in self.__dict__.items():
|
||||
if key == 'all_results': # contain the class of ResultsCount
|
||||
xvalue = dict()
|
||||
assert isinstance(value, dict), 'invalid type of value for {:} : {:}'.format(key, type(value))
|
||||
for _k, _v in value.items():
|
||||
assert isinstance(_v, ResultsCount), 'invalid type of value for {:}/{:} : {:}'.format(key, _k, type(_v))
|
||||
xvalue[_k] = _v.state_dict()
|
||||
else:
|
||||
xvalue = value
|
||||
state_dict[key] = xvalue
|
||||
return state_dict
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
new_state_dict = dict()
|
||||
for key, value in state_dict.items():
|
||||
if key == 'all_results': # to convert to the class of ResultsCount
|
||||
xvalue = dict()
|
||||
assert isinstance(value, dict), 'invalid type of value for {:} : {:}'.format(key, type(value))
|
||||
for _k, _v in value.items():
|
||||
xvalue[_k] = ResultsCount.create_from_state_dict(_v)
|
||||
else: xvalue = value
|
||||
new_state_dict[key] = xvalue
|
||||
self.__dict__.update(new_state_dict)
|
||||
|
||||
@staticmethod
|
||||
def create_from_state_dict(state_dict_or_file):
|
||||
x = ArchResults(-1, -1)
|
||||
if isinstance(state_dict_or_file, str): # a file path
|
||||
state_dict = torch.load(state_dict_or_file, map_location='cpu')
|
||||
elif isinstance(state_dict_or_file, dict):
|
||||
state_dict = state_dict_or_file
|
||||
else:
|
||||
raise ValueError('invalid type of state_dict_or_file : {:}'.format(type(state_dict_or_file)))
|
||||
x.load_state_dict(state_dict)
|
||||
return x
|
||||
|
||||
# This function is used to clear the weights saved in each 'result'
|
||||
# This can help reduce the memory footprint.
|
||||
def clear_params(self):
|
||||
for key, result in self.all_results.items():
|
||||
del result.net_state_dict
|
||||
result.net_state_dict = None
|
||||
self.clear_net_done = True
|
||||
|
||||
def debug_test(self):
|
||||
"""This function is used for me to debug and test, which will call most methods."""
|
||||
all_dataset = ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']
|
||||
for dataset in all_dataset:
|
||||
print('---->>>> {:}'.format(dataset))
|
||||
print('The latency on {:} is {:} s'.format(dataset, self.get_latency(dataset)))
|
||||
for seed in self.dataset_seed[dataset]:
|
||||
result = self.all_results[(dataset, seed)]
|
||||
print(' ==>> result = {:}'.format(result))
|
||||
print(' ==>> cost = {:}'.format(result.get_times()))
|
||||
|
||||
def __repr__(self):
|
||||
return ('{name}(arch-index={index}, arch={arch}, {num} runs, clear={clear})'.format(name=self.__class__.__name__, index=self.arch_index, arch=self.arch_str, num=len(self.all_results), clear=self.clear_net_done))
|
||||
|
||||
|
||||
"""
|
||||
This class (ResultsCount) is used to save the information of one trial for a single architecture.
|
||||
I did not write much comment for this class, because it is the lowest-level class in NAS-Bench-201 API, which will be rarely called.
|
||||
If you have any question regarding this class, please open an issue or email me.
|
||||
"""
|
||||
class ResultsCount(object):
|
||||
|
||||
def __init__(self, name, state_dict, train_accs, train_losses, params, flop, arch_config, seed, epochs, latency):
|
||||
self.name = name
|
||||
self.net_state_dict = state_dict
|
||||
self.train_acc1es = copy.deepcopy(train_accs)
|
||||
self.train_acc5es = None
|
||||
self.train_losses = copy.deepcopy(train_losses)
|
||||
self.train_times = None
|
||||
self.arch_config = copy.deepcopy(arch_config)
|
||||
self.params = params
|
||||
self.flop = flop
|
||||
self.seed = seed
|
||||
self.epochs = epochs
|
||||
self.latency = latency
|
||||
# evaluation results
|
||||
self.reset_eval()
|
||||
|
||||
def update_train_info(self, train_acc1es, train_acc5es, train_losses, train_times) -> None:
|
||||
self.train_acc1es = train_acc1es
|
||||
self.train_acc5es = train_acc5es
|
||||
self.train_losses = train_losses
|
||||
self.train_times = train_times
|
||||
|
||||
def reset_pseudo_train_times(self, estimated_per_epoch_time: float) -> None:
|
||||
"""Assign the training times."""
|
||||
train_times = OrderedDict()
|
||||
for i in range(self.epochs):
|
||||
train_times[i] = estimated_per_epoch_time
|
||||
self.train_times = train_times
|
||||
|
||||
def reset_pseudo_eval_times(self, eval_name: Text, estimated_per_epoch_time: float) -> None:
|
||||
"""Assign the evaluation times."""
|
||||
if eval_name not in self.eval_names: raise ValueError('invalid eval name : {:}'.format(eval_name))
|
||||
for i in range(self.epochs):
|
||||
self.eval_times['{:}@{:}'.format(eval_name,i)] = estimated_per_epoch_time
|
||||
|
||||
def reset_eval(self):
|
||||
self.eval_names = []
|
||||
self.eval_acc1es = {}
|
||||
self.eval_times = {}
|
||||
self.eval_losses = {}
|
||||
|
||||
def update_latency(self, latency):
|
||||
self.latency = copy.deepcopy( latency )
|
||||
|
||||
def get_latency(self) -> float:
|
||||
"""Return the latency value in seconds. -1 represents not avaliable ; otherwise it should be a float value"""
|
||||
if self.latency is None: return -1.0
|
||||
else: return sum(self.latency) / len(self.latency)
|
||||
|
||||
def update_eval(self, accs, losses, times): # new version
|
||||
data_names = set([x.split('@')[0] for x in accs.keys()])
|
||||
for data_name in data_names:
|
||||
assert data_name not in self.eval_names, '{:} has already been added into eval-names'.format(data_name)
|
||||
self.eval_names.append( data_name )
|
||||
for iepoch in range(self.epochs):
|
||||
xkey = '{:}@{:}'.format(data_name, iepoch)
|
||||
self.eval_acc1es[ xkey ] = accs[ xkey ]
|
||||
self.eval_losses[ xkey ] = losses[ xkey ]
|
||||
self.eval_times [ xkey ] = times[ xkey ]
|
||||
|
||||
def update_OLD_eval(self, name, accs, losses): # old version
|
||||
assert name not in self.eval_names, '{:} has already added'.format(name)
|
||||
self.eval_names.append( name )
|
||||
for iepoch in range(self.epochs):
|
||||
if iepoch in accs:
|
||||
self.eval_acc1es['{:}@{:}'.format(name,iepoch)] = accs[iepoch]
|
||||
self.eval_losses['{:}@{:}'.format(name,iepoch)] = losses[iepoch]
|
||||
|
||||
def __repr__(self):
|
||||
num_eval = len(self.eval_names)
|
||||
set_name = '[' + ', '.join(self.eval_names) + ']'
|
||||
return ('{name}({xname}, arch={arch}, FLOP={flop:.2f}M, Param={param:.3f}MB, seed={seed}, {num_eval} eval-sets: {set_name})'.format(name=self.__class__.__name__, xname=self.name, arch=self.arch_config['arch_str'], flop=self.flop, param=self.params, seed=self.seed, num_eval=num_eval, set_name=set_name))
|
||||
|
||||
def get_total_epoch(self):
|
||||
return copy.deepcopy(self.epochs)
|
||||
|
||||
def get_times(self):
|
||||
"""Obtain the information regarding both training and evaluation time."""
|
||||
if self.train_times is not None and isinstance(self.train_times, dict):
|
||||
train_times = list( self.train_times.values() )
|
||||
time_info = {'T-train@epoch': np.mean(train_times), 'T-train@total': np.sum(train_times)}
|
||||
else:
|
||||
time_info = {'T-train@epoch': None, 'T-train@total': None }
|
||||
for name in self.eval_names:
|
||||
try:
|
||||
xtimes = [self.eval_times['{:}@{:}'.format(name,i)] for i in range(self.epochs)]
|
||||
time_info['T-{:}@epoch'.format(name)] = np.mean(xtimes)
|
||||
time_info['T-{:}@total'.format(name)] = np.sum(xtimes)
|
||||
except:
|
||||
time_info['T-{:}@epoch'.format(name)] = None
|
||||
time_info['T-{:}@total'.format(name)] = None
|
||||
return time_info
|
||||
|
||||
def get_eval_set(self):
|
||||
return self.eval_names
|
||||
|
||||
# get the training information
|
||||
def get_train(self, iepoch=None):
|
||||
if iepoch is None: iepoch = self.epochs-1
|
||||
assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs)
|
||||
if self.train_times is not None:
|
||||
xtime = self.train_times[iepoch]
|
||||
atime = sum([self.train_times[i] for i in range(iepoch+1)])
|
||||
else: xtime, atime = None, None
|
||||
return {'iepoch' : iepoch,
|
||||
'loss' : self.train_losses[iepoch],
|
||||
'accuracy': self.train_acc1es[iepoch],
|
||||
'cur_time': xtime,
|
||||
'all_time': atime}
|
||||
|
||||
def get_eval(self, name, iepoch=None):
|
||||
"""Get the evaluation information ; there could be multiple evaluation sets (identified by the 'name' argument)."""
|
||||
if iepoch is None: iepoch = self.epochs-1
|
||||
assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs)
|
||||
def _internal_query(xname):
|
||||
if isinstance(self.eval_times,dict) and len(self.eval_times) > 0:
|
||||
xtime = self.eval_times['{:}@{:}'.format(xname, iepoch)]
|
||||
atime = sum([self.eval_times['{:}@{:}'.format(xname, i)] for i in range(iepoch+1)])
|
||||
else:
|
||||
xtime, atime = None, None
|
||||
return {'iepoch' : iepoch,
|
||||
'loss' : self.eval_losses['{:}@{:}'.format(xname, iepoch)],
|
||||
'accuracy': self.eval_acc1es['{:}@{:}'.format(xname, iepoch)],
|
||||
'cur_time': xtime,
|
||||
'all_time': atime}
|
||||
if name == 'valid':
|
||||
return _internal_query('x-valid')
|
||||
else:
|
||||
return _internal_query(name)
|
||||
|
||||
def get_net_param(self, clone=False):
|
||||
if clone: return copy.deepcopy(self.net_state_dict)
|
||||
else: return self.net_state_dict
|
||||
|
||||
def get_config(self, str2structure):
|
||||
"""This function is used to obtain the config dict for this architecture."""
|
||||
if str2structure is None:
|
||||
# In this case, this is to handle the size search space.
|
||||
if 'name' in self.arch_config and self.arch_config['name'] == 'infer.shape.tiny':
|
||||
return {'name': 'infer.shape.tiny', 'channels': self.arch_config['channels'],
|
||||
'genotype': self.arch_config['genotype'], 'num_classes': self.arch_config['class_num']}
|
||||
# In this case, this is NAS-Bench-201
|
||||
else:
|
||||
return {'name': 'infer.tiny', 'C': self.arch_config['channel'],
|
||||
'N' : self.arch_config['num_cells'],
|
||||
'arch_str': self.arch_config['arch_str'], 'num_classes': self.arch_config['class_num']}
|
||||
else:
|
||||
# In this case, this is to handle the size search space.
|
||||
if 'name' in self.arch_config and self.arch_config['name'] == 'infer.shape.tiny':
|
||||
return {'name': 'infer.shape.tiny', 'channels': self.arch_config['channels'],
|
||||
'genotype': str2structure(self.arch_config['genotype']), 'num_classes': self.arch_config['class_num']}
|
||||
# In this case, this is NAS-Bench-201
|
||||
else:
|
||||
return {'name': 'infer.tiny', 'C': self.arch_config['channel'],
|
||||
'N' : self.arch_config['num_cells'],
|
||||
'genotype': str2structure(self.arch_config['arch_str']), 'num_classes': self.arch_config['class_num']}
|
||||
|
||||
def state_dict(self):
|
||||
_state_dict = {key: value for key, value in self.__dict__.items()}
|
||||
return _state_dict
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.__dict__.update(state_dict)
|
||||
|
||||
@staticmethod
|
||||
def create_from_state_dict(state_dict):
|
||||
x = ResultsCount(None, None, None, None, None, None, None, None, None, None)
|
||||
x.load_state_dict(state_dict)
|
||||
return x
|
||||
@@ -1,76 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .construct_utils import drop_path
|
||||
from .head_utils import CifarHEAD, AuxiliaryHeadCIFAR
|
||||
from .base_cells import InferCell
|
||||
|
||||
|
||||
class NetworkCIFAR(nn.Module):
|
||||
|
||||
def __init__(self, C, N, stem_multiplier, auxiliary, genotype, num_classes):
|
||||
super(NetworkCIFAR, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
self._stem_multiplier = stem_multiplier
|
||||
|
||||
C_curr = self._stem_multiplier * C
|
||||
self.stem = CifarHEAD(C_curr)
|
||||
|
||||
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
block_indexs = [0 ] * N + [-1 ] + [1 ] * N + [-1 ] + [2 ] * N
|
||||
block2index = {0:[], 1:[], 2:[]}
|
||||
|
||||
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
|
||||
reduction_prev, spatial, dims = False, 1, []
|
||||
self.auxiliary_index = None
|
||||
self.auxiliary_head = None
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
|
||||
cell = InferCell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
|
||||
reduction_prev = reduction
|
||||
self.cells.append( cell )
|
||||
C_prev_prev, C_prev = C_prev, cell._multiplier*C_curr
|
||||
if reduction and C_curr == C*4:
|
||||
if auxiliary:
|
||||
self.auxiliary_head = AuxiliaryHeadCIFAR(C_prev, num_classes)
|
||||
self.auxiliary_index = index
|
||||
|
||||
if reduction: spatial *= 2
|
||||
dims.append( (C_prev, spatial) )
|
||||
|
||||
self._Layer= len(self.cells)
|
||||
|
||||
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self.drop_path_prob = -1
|
||||
|
||||
def update_drop_path(self, drop_path_prob):
|
||||
self.drop_path_prob = drop_path_prob
|
||||
|
||||
def auxiliary_param(self):
|
||||
if self.auxiliary_head is None: return []
|
||||
else: return list( self.auxiliary_head.parameters() )
|
||||
|
||||
def get_message(self):
|
||||
return self.extra_repr()
|
||||
|
||||
def extra_repr(self):
|
||||
return ('{name}(C={_C}, N={_layerN}, L={_Layer}, stem={_stem_multiplier}, drop-path={drop_path_prob})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||
|
||||
def forward(self, inputs):
|
||||
stem_feature, logits_aux = self.stem(inputs), None
|
||||
cell_results = [stem_feature, stem_feature]
|
||||
for i, cell in enumerate(self.cells):
|
||||
cell_feature = cell(cell_results[-2], cell_results[-1], self.drop_path_prob)
|
||||
cell_results.append( cell_feature )
|
||||
|
||||
if self.auxiliary_index is not None and i == self.auxiliary_index and self.training:
|
||||
logits_aux = self.auxiliary_head( cell_results[-1] )
|
||||
out = self.global_pooling( cell_results[-1] )
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
if logits_aux is None: return out, logits
|
||||
else : return out, [logits, logits_aux]
|
||||
@@ -1,77 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .construct_utils import drop_path
|
||||
from .base_cells import InferCell
|
||||
from .head_utils import ImageNetHEAD, AuxiliaryHeadImageNet
|
||||
|
||||
|
||||
class NetworkImageNet(nn.Module):
|
||||
|
||||
def __init__(self, C, N, auxiliary, genotype, num_classes):
|
||||
super(NetworkImageNet, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4] * N
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
self.stem0 = nn.Sequential(
|
||||
nn.Conv2d(3, C // 2, kernel_size=3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C // 2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(C // 2, C, 3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C),
|
||||
)
|
||||
|
||||
self.stem1 = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C),
|
||||
)
|
||||
|
||||
C_prev_prev, C_prev, C_curr, reduction_prev = C, C, C, True
|
||||
|
||||
self.cells = nn.ModuleList()
|
||||
self.auxiliary_index = None
|
||||
for i, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
|
||||
cell = InferCell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
|
||||
reduction_prev = reduction
|
||||
self.cells += [cell]
|
||||
C_prev_prev, C_prev = C_prev, cell._multiplier * C_curr
|
||||
if reduction and C_curr == C*4:
|
||||
C_to_auxiliary = C_prev
|
||||
self.auxiliary_index = i
|
||||
|
||||
self._NNN = len(self.cells)
|
||||
if auxiliary:
|
||||
self.auxiliary_head = AuxiliaryHeadImageNet(C_to_auxiliary, num_classes)
|
||||
else:
|
||||
self.auxiliary_head = None
|
||||
self.global_pooling = nn.AvgPool2d(7)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self.drop_path_prob = -1
|
||||
|
||||
def update_drop_path(self, drop_path_prob):
|
||||
self.drop_path_prob = drop_path_prob
|
||||
|
||||
def extra_repr(self):
|
||||
return ('{name}(C={_C}, N=[{_layerN}, {_NNN}], aux-index={auxiliary_index}, drop-path={drop_path_prob})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||
|
||||
def get_message(self):
|
||||
return self.extra_repr()
|
||||
|
||||
def auxiliary_param(self):
|
||||
if self.auxiliary_head is None: return []
|
||||
else: return list( self.auxiliary_head.parameters() )
|
||||
|
||||
def forward(self, inputs):
|
||||
s0 = self.stem0(inputs)
|
||||
s1 = self.stem1(s0)
|
||||
logits_aux = None
|
||||
for i, cell in enumerate(self.cells):
|
||||
s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
|
||||
if i == self.auxiliary_index and self.auxiliary_head and self.training:
|
||||
logits_aux = self.auxiliary_head(s1)
|
||||
out = self.global_pooling(s1)
|
||||
logits = self.classifier(out.view(out.size(0), -1))
|
||||
|
||||
if logits_aux is None: return out, logits
|
||||
else : return out, [logits, logits_aux]
|
||||
@@ -1,5 +0,0 @@
|
||||
# Performance-Aware Template Network for One-Shot Neural Architecture Search
|
||||
from .CifarNet import NetworkCIFAR as CifarNet
|
||||
from .ImageNet import NetworkImageNet as ImageNet
|
||||
from .genotypes import Networks
|
||||
from .genotypes import build_genotype_from_dict
|
||||
@@ -1,173 +0,0 @@
|
||||
import math
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .construct_utils import drop_path
|
||||
from ..operations import OPS, Identity, FactorizedReduce, ReLUConvBN
|
||||
|
||||
|
||||
class MixedOp(nn.Module):
|
||||
|
||||
def __init__(self, C, stride, PRIMITIVES):
|
||||
super(MixedOp, self).__init__()
|
||||
self._ops = nn.ModuleList()
|
||||
self.name2idx = {}
|
||||
for idx, primitive in enumerate(PRIMITIVES):
|
||||
op = OPS[primitive](C, C, stride, False)
|
||||
self._ops.append(op)
|
||||
assert primitive not in self.name2idx, '{:} has already in'.format(primitive)
|
||||
self.name2idx[primitive] = idx
|
||||
|
||||
def forward(self, x, weights, op_name):
|
||||
if op_name is None:
|
||||
if weights is None:
|
||||
return [op(x) for op in self._ops]
|
||||
else:
|
||||
return sum(w * op(x) for w, op in zip(weights, self._ops))
|
||||
else:
|
||||
op_index = self.name2idx[op_name]
|
||||
return self._ops[op_index](x)
|
||||
|
||||
|
||||
|
||||
class SearchCell(nn.Module):
|
||||
|
||||
def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev, PRIMITIVES, use_residual):
|
||||
super(SearchCell, self).__init__()
|
||||
self.reduction = reduction
|
||||
self.PRIMITIVES = deepcopy(PRIMITIVES)
|
||||
|
||||
if reduction_prev:
|
||||
self.preprocess0 = FactorizedReduce(C_prev_prev, C, 2, affine=False)
|
||||
else:
|
||||
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False)
|
||||
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False)
|
||||
self._steps = steps
|
||||
self._multiplier = multiplier
|
||||
self._use_residual = use_residual
|
||||
|
||||
self._ops = nn.ModuleList()
|
||||
for i in range(self._steps):
|
||||
for j in range(2+i):
|
||||
stride = 2 if reduction and j < 2 else 1
|
||||
op = MixedOp(C, stride, self.PRIMITIVES)
|
||||
self._ops.append(op)
|
||||
|
||||
def extra_repr(self):
|
||||
return ('{name}(residual={_use_residual}, steps={_steps}, multiplier={_multiplier})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||
|
||||
def forward(self, S0, S1, weights, connect, adjacency, drop_prob, modes):
|
||||
if modes[0] is None:
|
||||
if modes[1] == 'normal':
|
||||
output = self.__forwardBoth(S0, S1, weights, connect, adjacency, drop_prob)
|
||||
elif modes[1] == 'only_W':
|
||||
output = self.__forwardOnlyW(S0, S1, drop_prob)
|
||||
else:
|
||||
test_genotype = modes[0]
|
||||
if self.reduction: operations, concats = test_genotype.reduce, test_genotype.reduce_concat
|
||||
else : operations, concats = test_genotype.normal, test_genotype.normal_concat
|
||||
s0, s1 = self.preprocess0(S0), self.preprocess1(S1)
|
||||
states, offset = [s0, s1], 0
|
||||
assert self._steps == len(operations), '{:} vs. {:}'.format(self._steps, len(operations))
|
||||
for i, (opA, opB) in enumerate(operations):
|
||||
A = self._ops[offset + opA[1]](states[opA[1]], None, opA[0])
|
||||
B = self._ops[offset + opB[1]](states[opB[1]], None, opB[0])
|
||||
state = A + B
|
||||
offset += len(states)
|
||||
states.append(state)
|
||||
output = torch.cat([states[i] for i in concats], dim=1)
|
||||
if self._use_residual and S1.size() == output.size():
|
||||
return S1 + output
|
||||
else: return output
|
||||
|
||||
def __forwardBoth(self, S0, S1, weights, connect, adjacency, drop_prob):
|
||||
s0, s1 = self.preprocess0(S0), self.preprocess1(S1)
|
||||
states, offset = [s0, s1], 0
|
||||
for i in range(self._steps):
|
||||
clist = []
|
||||
for j, h in enumerate(states):
|
||||
x = self._ops[offset+j](h, weights[offset+j], None)
|
||||
if self.training and drop_prob > 0.:
|
||||
x = drop_path(x, math.pow(drop_prob, 1./len(states)))
|
||||
clist.append( x )
|
||||
connection = torch.mm(connect['{:}'.format(i)], adjacency[i]).squeeze(0)
|
||||
state = sum(w * node for w, node in zip(connection, clist))
|
||||
offset += len(states)
|
||||
states.append(state)
|
||||
return torch.cat(states[-self._multiplier:], dim=1)
|
||||
|
||||
def __forwardOnlyW(self, S0, S1, drop_prob):
|
||||
s0, s1 = self.preprocess0(S0), self.preprocess1(S1)
|
||||
states, offset = [s0, s1], 0
|
||||
for i in range(self._steps):
|
||||
clist = []
|
||||
for j, h in enumerate(states):
|
||||
xs = self._ops[offset+j](h, None, None)
|
||||
clist += xs
|
||||
if self.training and drop_prob > 0.:
|
||||
xlist = [drop_path(x, math.pow(drop_prob, 1./len(states))) for x in clist]
|
||||
else: xlist = clist
|
||||
state = sum(xlist) * 2 / len(xlist)
|
||||
offset += len(states)
|
||||
states.append(state)
|
||||
return torch.cat(states[-self._multiplier:], dim=1)
|
||||
|
||||
|
||||
|
||||
class InferCell(nn.Module):
|
||||
|
||||
def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev):
|
||||
super(InferCell, self).__init__()
|
||||
print(C_prev_prev, C_prev, C)
|
||||
|
||||
if reduction_prev is None:
|
||||
self.preprocess0 = Identity()
|
||||
elif reduction_prev:
|
||||
self.preprocess0 = FactorizedReduce(C_prev_prev, C, 2)
|
||||
else:
|
||||
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
|
||||
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)
|
||||
|
||||
if reduction: step_ops, concat = genotype.reduce, genotype.reduce_concat
|
||||
else : step_ops, concat = genotype.normal, genotype.normal_concat
|
||||
self._steps = len(step_ops)
|
||||
self._concat = concat
|
||||
self._multiplier = len(concat)
|
||||
self._ops = nn.ModuleList()
|
||||
self._indices = []
|
||||
for operations in step_ops:
|
||||
for name, index in operations:
|
||||
stride = 2 if reduction and index < 2 else 1
|
||||
if reduction_prev is None and index == 0:
|
||||
op = OPS[name](C_prev_prev, C, stride, True)
|
||||
else:
|
||||
op = OPS[name](C , C, stride, True)
|
||||
self._ops.append( op )
|
||||
self._indices.append( index )
|
||||
|
||||
def extra_repr(self):
|
||||
return ('{name}(steps={_steps}, concat={_concat})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||
|
||||
def forward(self, S0, S1, drop_prob):
|
||||
s0 = self.preprocess0(S0)
|
||||
s1 = self.preprocess1(S1)
|
||||
|
||||
states = [s0, s1]
|
||||
for i in range(self._steps):
|
||||
h1 = states[self._indices[2*i]]
|
||||
h2 = states[self._indices[2*i+1]]
|
||||
op1 = self._ops[2*i]
|
||||
op2 = self._ops[2*i+1]
|
||||
h1 = op1(h1)
|
||||
h2 = op2(h2)
|
||||
if self.training and drop_prob > 0.:
|
||||
if not isinstance(op1, Identity):
|
||||
h1 = drop_path(h1, drop_prob)
|
||||
if not isinstance(op2, Identity):
|
||||
h2 = drop_path(h2, drop_prob)
|
||||
|
||||
state = h1 + h2
|
||||
states += [state]
|
||||
output = torch.cat([states[i] for i in self._concat], dim=1)
|
||||
return output
|
||||
@@ -1,60 +0,0 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def drop_path(x, drop_prob):
|
||||
if drop_prob > 0.:
|
||||
keep_prob = 1. - drop_prob
|
||||
mask = x.new_zeros(x.size(0), 1, 1, 1)
|
||||
mask = mask.bernoulli_(keep_prob)
|
||||
x = torch.div(x, keep_prob)
|
||||
x.mul_(mask)
|
||||
return x
|
||||
|
||||
|
||||
def return_alphas_str(basemodel):
|
||||
if hasattr(basemodel, 'alphas_normal'):
|
||||
string = 'normal [{:}] : \n-->>{:}'.format(basemodel.alphas_normal.size(), F.softmax(basemodel.alphas_normal, dim=-1) )
|
||||
else: string = ''
|
||||
if hasattr(basemodel, 'alphas_reduce'):
|
||||
string = string + '\nreduce : {:}'.format( F.softmax(basemodel.alphas_reduce, dim=-1) )
|
||||
|
||||
if hasattr(basemodel, 'get_adjacency'):
|
||||
adjacency = basemodel.get_adjacency()
|
||||
for i in range( len(adjacency) ):
|
||||
weight = F.softmax( basemodel.connect_normal[str(i)], dim=-1 )
|
||||
adj = torch.mm(weight, adjacency[i]).view(-1)
|
||||
adj = ['{:3.3f}'.format(x) for x in adj.cpu().tolist()]
|
||||
string = string + '\nnormal--{:}-->{:}'.format(i, ', '.join(adj))
|
||||
for i in range( len(adjacency) ):
|
||||
weight = F.softmax( basemodel.connect_reduce[str(i)], dim=-1 )
|
||||
adj = torch.mm(weight, adjacency[i]).view(-1)
|
||||
adj = ['{:3.3f}'.format(x) for x in adj.cpu().tolist()]
|
||||
string = string + '\nreduce--{:}-->{:}'.format(i, ', '.join(adj))
|
||||
|
||||
if hasattr(basemodel, 'alphas_connect'):
|
||||
weight = F.softmax(basemodel.alphas_connect, dim=-1).cpu()
|
||||
ZERO = ['{:.3f}'.format(x) for x in weight[:,0].tolist()]
|
||||
IDEN = ['{:.3f}'.format(x) for x in weight[:,1].tolist()]
|
||||
string = string + '\nconnect [{:}] : \n ->{:}\n ->{:}'.format( list(basemodel.alphas_connect.size()), ZERO, IDEN )
|
||||
else:
|
||||
string = string + '\nconnect = None'
|
||||
|
||||
if hasattr(basemodel, 'get_gcn_out'):
|
||||
outputs = basemodel.get_gcn_out(True)
|
||||
for i, output in enumerate(outputs):
|
||||
string = string + '\nnormal:[{:}] : {:}'.format(i, F.softmax(output, dim=-1) )
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def remove_duplicate_archs(all_archs):
|
||||
archs = []
|
||||
str_archs = ['{:}'.format(x) for x in all_archs]
|
||||
for i, arch_x in enumerate(str_archs):
|
||||
choose = True
|
||||
for j in range(i):
|
||||
if arch_x == str_archs[j]:
|
||||
choose = False; break
|
||||
if choose: archs.append(all_archs[i])
|
||||
return archs
|
||||
@@ -1,182 +0,0 @@
|
||||
from collections import namedtuple
|
||||
|
||||
Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat connectN connects')
|
||||
#Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
|
||||
|
||||
PRIMITIVES_small = [
|
||||
'max_pool_3x3',
|
||||
'avg_pool_3x3',
|
||||
'skip_connect',
|
||||
'sep_conv_3x3',
|
||||
'sep_conv_5x5',
|
||||
'conv_3x1_1x3',
|
||||
]
|
||||
|
||||
PRIMITIVES_large = [
|
||||
'max_pool_3x3',
|
||||
'avg_pool_3x3',
|
||||
'skip_connect',
|
||||
'sep_conv_3x3',
|
||||
'sep_conv_5x5',
|
||||
'dil_conv_3x3',
|
||||
'dil_conv_5x5',
|
||||
'conv_3x1_1x3',
|
||||
]
|
||||
|
||||
PRIMITIVES_huge = [
|
||||
'skip_connect',
|
||||
'nor_conv_1x1',
|
||||
'max_pool_3x3',
|
||||
'avg_pool_3x3',
|
||||
'nor_conv_3x3',
|
||||
'sep_conv_3x3',
|
||||
'dil_conv_3x3',
|
||||
'conv_3x1_1x3',
|
||||
'sep_conv_5x5',
|
||||
'dil_conv_5x5',
|
||||
'sep_conv_7x7',
|
||||
'conv_7x1_1x7',
|
||||
'att_squeeze',
|
||||
]
|
||||
|
||||
PRIMITIVES = {'small': PRIMITIVES_small,
|
||||
'large': PRIMITIVES_large,
|
||||
'huge' : PRIMITIVES_huge}
|
||||
|
||||
NASNet = Genotype(
|
||||
normal = [
|
||||
(('sep_conv_5x5', 1), ('sep_conv_3x3', 0)),
|
||||
(('sep_conv_5x5', 0), ('sep_conv_3x3', 0)),
|
||||
(('avg_pool_3x3', 1), ('skip_connect', 0)),
|
||||
(('avg_pool_3x3', 0), ('avg_pool_3x3', 0)),
|
||||
(('sep_conv_3x3', 1), ('skip_connect', 1)),
|
||||
],
|
||||
normal_concat = [2, 3, 4, 5, 6],
|
||||
reduce = [
|
||||
(('sep_conv_5x5', 1), ('sep_conv_7x7', 0)),
|
||||
(('max_pool_3x3', 1), ('sep_conv_7x7', 0)),
|
||||
(('avg_pool_3x3', 1), ('sep_conv_5x5', 0)),
|
||||
(('skip_connect', 3), ('avg_pool_3x3', 2)),
|
||||
(('sep_conv_3x3', 2), ('max_pool_3x3', 1)),
|
||||
],
|
||||
reduce_concat = [4, 5, 6],
|
||||
connectN=None, connects=None,
|
||||
)
|
||||
|
||||
PNASNet = Genotype(
|
||||
normal = [
|
||||
(('sep_conv_5x5', 0), ('max_pool_3x3', 0)),
|
||||
(('sep_conv_7x7', 1), ('max_pool_3x3', 1)),
|
||||
(('sep_conv_5x5', 1), ('sep_conv_3x3', 1)),
|
||||
(('sep_conv_3x3', 4), ('max_pool_3x3', 1)),
|
||||
(('sep_conv_3x3', 0), ('skip_connect', 1)),
|
||||
],
|
||||
normal_concat = [2, 3, 4, 5, 6],
|
||||
reduce = [
|
||||
(('sep_conv_5x5', 0), ('max_pool_3x3', 0)),
|
||||
(('sep_conv_7x7', 1), ('max_pool_3x3', 1)),
|
||||
(('sep_conv_5x5', 1), ('sep_conv_3x3', 1)),
|
||||
(('sep_conv_3x3', 4), ('max_pool_3x3', 1)),
|
||||
(('sep_conv_3x3', 0), ('skip_connect', 1)),
|
||||
],
|
||||
reduce_concat = [2, 3, 4, 5, 6],
|
||||
connectN=None, connects=None,
|
||||
)
|
||||
|
||||
|
||||
DARTS_V1 = Genotype(
|
||||
normal=[
|
||||
(('sep_conv_3x3', 1), ('sep_conv_3x3', 0)), # step 1
|
||||
(('skip_connect', 0), ('sep_conv_3x3', 1)), # step 2
|
||||
(('skip_connect', 0), ('sep_conv_3x3', 1)), # step 3
|
||||
(('sep_conv_3x3', 0), ('skip_connect', 2)) # step 4
|
||||
],
|
||||
normal_concat=[2, 3, 4, 5],
|
||||
reduce=[
|
||||
(('max_pool_3x3', 0), ('max_pool_3x3', 1)), # step 1
|
||||
(('skip_connect', 2), ('max_pool_3x3', 0)), # step 2
|
||||
(('max_pool_3x3', 0), ('skip_connect', 2)), # step 3
|
||||
(('skip_connect', 2), ('avg_pool_3x3', 0)) # step 4
|
||||
],
|
||||
reduce_concat=[2, 3, 4, 5],
|
||||
connectN=None, connects=None,
|
||||
)
|
||||
|
||||
# DARTS: Differentiable Architecture Search, ICLR 2019
|
||||
DARTS_V2 = Genotype(
|
||||
normal=[
|
||||
(('sep_conv_3x3', 0), ('sep_conv_3x3', 1)), # step 1
|
||||
(('sep_conv_3x3', 0), ('sep_conv_3x3', 1)), # step 2
|
||||
(('sep_conv_3x3', 1), ('skip_connect', 0)), # step 3
|
||||
(('skip_connect', 0), ('dil_conv_3x3', 2)) # step 4
|
||||
],
|
||||
normal_concat=[2, 3, 4, 5],
|
||||
reduce=[
|
||||
(('max_pool_3x3', 0), ('max_pool_3x3', 1)), # step 1
|
||||
(('skip_connect', 2), ('max_pool_3x3', 1)), # step 2
|
||||
(('max_pool_3x3', 0), ('skip_connect', 2)), # step 3
|
||||
(('skip_connect', 2), ('max_pool_3x3', 1)) # step 4
|
||||
],
|
||||
reduce_concat=[2, 3, 4, 5],
|
||||
connectN=None, connects=None,
|
||||
)
|
||||
|
||||
|
||||
# One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019
|
||||
SETN = Genotype(
|
||||
normal=[
|
||||
(('skip_connect', 0), ('sep_conv_5x5', 1)),
|
||||
(('sep_conv_5x5', 0), ('sep_conv_3x3', 1)),
|
||||
(('sep_conv_5x5', 1), ('sep_conv_5x5', 3)),
|
||||
(('max_pool_3x3', 1), ('conv_3x1_1x3', 4))],
|
||||
normal_concat=[2, 3, 4, 5],
|
||||
reduce=[
|
||||
(('sep_conv_3x3', 0), ('sep_conv_5x5', 1)),
|
||||
(('avg_pool_3x3', 0), ('sep_conv_5x5', 1)),
|
||||
(('avg_pool_3x3', 0), ('sep_conv_5x5', 1)),
|
||||
(('avg_pool_3x3', 0), ('skip_connect', 1))],
|
||||
reduce_concat=[2, 3, 4, 5],
|
||||
connectN=None, connects=None
|
||||
)
|
||||
|
||||
|
||||
# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019
|
||||
GDAS_V1 = Genotype(
|
||||
normal=[
|
||||
(('skip_connect', 0), ('skip_connect', 1)),
|
||||
(('skip_connect', 0), ('sep_conv_5x5', 2)),
|
||||
(('sep_conv_3x3', 3), ('skip_connect', 0)),
|
||||
(('sep_conv_5x5', 4), ('sep_conv_3x3', 3))],
|
||||
normal_concat=[2, 3, 4, 5],
|
||||
reduce=[
|
||||
(('sep_conv_5x5', 0), ('sep_conv_3x3', 1)),
|
||||
(('sep_conv_5x5', 2), ('sep_conv_5x5', 1)),
|
||||
(('dil_conv_5x5', 2), ('sep_conv_3x3', 1)),
|
||||
(('sep_conv_5x5', 0), ('sep_conv_5x5', 1))],
|
||||
reduce_concat=[2, 3, 4, 5],
|
||||
connectN=None, connects=None
|
||||
)
|
||||
|
||||
|
||||
|
||||
Networks = {'DARTS_V1': DARTS_V1,
|
||||
'DARTS_V2': DARTS_V2,
|
||||
'DARTS' : DARTS_V2,
|
||||
'NASNet' : NASNet,
|
||||
'GDAS_V1' : GDAS_V1,
|
||||
'PNASNet' : PNASNet,
|
||||
'SETN' : SETN,
|
||||
}
|
||||
|
||||
# This function will return a Genotype from a dict.
|
||||
def build_genotype_from_dict(xdict):
|
||||
def remove_value(nodes):
|
||||
return [tuple([(x[0], x[1]) for x in node]) for node in nodes]
|
||||
genotype = Genotype(
|
||||
normal=remove_value(xdict['normal']),
|
||||
normal_concat=xdict['normal_concat'],
|
||||
reduce=remove_value(xdict['reduce']),
|
||||
reduce_concat=xdict['reduce_concat'],
|
||||
connectN=None, connects=None
|
||||
)
|
||||
return genotype
|
||||
@@ -1,65 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class ImageNetHEAD(nn.Sequential):
|
||||
def __init__(self, C, stride=2):
|
||||
super(ImageNetHEAD, self).__init__()
|
||||
self.add_module('conv1', nn.Conv2d(3, C // 2, kernel_size=3, stride=2, padding=1, bias=False))
|
||||
self.add_module('bn1' , nn.BatchNorm2d(C // 2))
|
||||
self.add_module('relu1', nn.ReLU(inplace=True))
|
||||
self.add_module('conv2', nn.Conv2d(C // 2, C, kernel_size=3, stride=stride, padding=1, bias=False))
|
||||
self.add_module('bn2' , nn.BatchNorm2d(C))
|
||||
|
||||
|
||||
class CifarHEAD(nn.Sequential):
|
||||
def __init__(self, C):
|
||||
super(CifarHEAD, self).__init__()
|
||||
self.add_module('conv', nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False))
|
||||
self.add_module('bn', nn.BatchNorm2d(C))
|
||||
|
||||
|
||||
class AuxiliaryHeadCIFAR(nn.Module):
|
||||
|
||||
def __init__(self, C, num_classes):
|
||||
"""assuming input size 8x8"""
|
||||
super(AuxiliaryHeadCIFAR, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2
|
||||
nn.Conv2d(C, 128, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(128, 768, 2, bias=False),
|
||||
nn.BatchNorm2d(768),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.classifier = nn.Linear(768, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.classifier(x.view(x.size(0),-1))
|
||||
return x
|
||||
|
||||
|
||||
class AuxiliaryHeadImageNet(nn.Module):
|
||||
|
||||
def __init__(self, C, num_classes):
|
||||
"""assuming input size 14x14"""
|
||||
super(AuxiliaryHeadImageNet, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.AvgPool2d(5, stride=2, padding=0, count_include_pad=False),
|
||||
nn.Conv2d(C, 128, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(128, 768, 2, bias=False),
|
||||
nn.BatchNorm2d(768),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.classifier = nn.Linear(768, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.classifier(x.view(x.size(0),-1))
|
||||
return x
|
||||
@@ -1,31 +0,0 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
# I write this package to make AutoDL-Projects to be compatible with the old GDAS projects.
|
||||
# Ideally, this package will be merged into lib/models/cell_infers in future.
|
||||
# Currently, this package is used to reproduce the results in GDAS (Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019).
|
||||
##################################################
|
||||
|
||||
import os, torch
|
||||
|
||||
def obtain_nas_infer_model(config, extra_model_path=None):
|
||||
|
||||
if config.arch == 'dxys':
|
||||
from .DXYs import CifarNet, ImageNet, Networks
|
||||
from .DXYs import build_genotype_from_dict
|
||||
if config.genotype is None:
|
||||
if extra_model_path is not None and not os.path.isfile(extra_model_path):
|
||||
raise ValueError('When genotype in confiig is None, extra_model_path must be set as a path instead of {:}'.format(extra_model_path))
|
||||
xdata = torch.load(extra_model_path)
|
||||
current_epoch = xdata['epoch']
|
||||
genotype_dict = xdata['genotypes'][current_epoch-1]
|
||||
genotype = build_genotype_from_dict(genotype_dict)
|
||||
else:
|
||||
genotype = Networks[config.genotype]
|
||||
if config.dataset == 'cifar':
|
||||
return CifarNet(config.ichannel, config.layers, config.stem_multi, config.auxiliary, genotype, config.class_num)
|
||||
elif config.dataset == 'imagenet':
|
||||
return ImageNet(config.ichannel, config.layers, config.auxiliary, genotype, config.class_num)
|
||||
else: raise ValueError('invalid dataset : {:}'.format(config.dataset))
|
||||
else:
|
||||
raise ValueError('invalid nas arch type : {:}'.format(config.arch))
|
||||
@@ -1,183 +0,0 @@
|
||||
##############################################################################################
|
||||
# This code is copied and modified from Hanxiao Liu's work (https://github.com/quark0/darts) #
|
||||
##############################################################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
OPS = {
|
||||
'none' : lambda C_in, C_out, stride, affine: Zero(stride),
|
||||
'avg_pool_3x3' : lambda C_in, C_out, stride, affine: POOLING(C_in, C_out, stride, 'avg'),
|
||||
'max_pool_3x3' : lambda C_in, C_out, stride, affine: POOLING(C_in, C_out, stride, 'max'),
|
||||
'nor_conv_7x7' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (7,7), (stride,stride), (3,3), affine),
|
||||
'nor_conv_3x3' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (3,3), (stride,stride), (1,1), affine),
|
||||
'nor_conv_1x1' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (1,1), (stride,stride), (0,0), affine),
|
||||
'skip_connect' : lambda C_in, C_out, stride, affine: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine),
|
||||
'sep_conv_3x3' : lambda C_in, C_out, stride, affine: SepConv(C_in, C_out, 3, stride, 1, affine=affine),
|
||||
'sep_conv_5x5' : lambda C_in, C_out, stride, affine: SepConv(C_in, C_out, 5, stride, 2, affine=affine),
|
||||
'sep_conv_7x7' : lambda C_in, C_out, stride, affine: SepConv(C_in, C_out, 7, stride, 3, affine=affine),
|
||||
'dil_conv_3x3' : lambda C_in, C_out, stride, affine: DilConv(C_in, C_out, 3, stride, 2, 2, affine=affine),
|
||||
'dil_conv_5x5' : lambda C_in, C_out, stride, affine: DilConv(C_in, C_out, 5, stride, 4, 2, affine=affine),
|
||||
'conv_7x1_1x7' : lambda C_in, C_out, stride, affine: Conv717(C_in, C_out, stride, affine),
|
||||
'conv_3x1_1x3' : lambda C_in, C_out, stride, affine: Conv313(C_in, C_out, stride, affine)
|
||||
}
|
||||
|
||||
|
||||
class POOLING(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, stride, mode):
|
||||
super(POOLING, self).__init__()
|
||||
if C_in == C_out:
|
||||
self.preprocess = None
|
||||
else:
|
||||
self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0)
|
||||
if mode == 'avg' : self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False)
|
||||
elif mode == 'max': self.op = nn.MaxPool2d(3, stride=stride, padding=1)
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.preprocess is not None:
|
||||
x = self.preprocess(inputs)
|
||||
else: x = inputs
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class Conv313(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, stride, affine):
|
||||
super(Conv313, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in , C_out, (1,3), stride=(1, stride), padding=(0, 1), bias=False),
|
||||
nn.Conv2d(C_out, C_out, (3,1), stride=(stride, 1), padding=(1, 0), bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class Conv717(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, stride, affine):
|
||||
super(Conv717, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in , C_out, (1,7), stride=(1, stride), padding=(0, 3), bias=False),
|
||||
nn.Conv2d(C_out, C_out, (7,1), stride=(stride, 1), padding=(3, 0), bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class ReLUConvBN(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
|
||||
super(ReLUConvBN, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class DilConv(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
|
||||
super(DilConv, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False),
|
||||
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class SepConv(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
|
||||
super(SepConv, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_in, affine=affine),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride= 1, padding=padding, groups=C_in, bias=False),
|
||||
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class Zero(nn.Module):
|
||||
|
||||
def __init__(self, stride):
|
||||
super(Zero, self).__init__()
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
if self.stride == 1:
|
||||
return x.mul(0.)
|
||||
return x[:,:,::self.stride,::self.stride].mul(0.)
|
||||
|
||||
def extra_repr(self):
|
||||
return 'stride={stride}'.format(**self.__dict__)
|
||||
|
||||
|
||||
class FactorizedReduce(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, stride, affine=True):
|
||||
super(FactorizedReduce, self).__init__()
|
||||
self.stride = stride
|
||||
self.C_in = C_in
|
||||
self.C_out = C_out
|
||||
self.relu = nn.ReLU(inplace=False)
|
||||
if stride == 2:
|
||||
#assert C_out % 2 == 0, 'C_out : {:}'.format(C_out)
|
||||
C_outs = [C_out // 2, C_out - C_out // 2]
|
||||
self.convs = nn.ModuleList()
|
||||
for i in range(2):
|
||||
self.convs.append( nn.Conv2d(C_in, C_outs[i], 1, stride=stride, padding=0, bias=False) )
|
||||
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
|
||||
elif stride == 4:
|
||||
assert C_out % 4 == 0, 'C_out : {:}'.format(C_out)
|
||||
self.convs = nn.ModuleList()
|
||||
for i in range(4):
|
||||
self.convs.append( nn.Conv2d(C_in, C_out // 4, 1, stride=stride, padding=0, bias=False) )
|
||||
self.pad = nn.ConstantPad2d((0, 3, 0, 3), 0)
|
||||
else:
|
||||
raise ValueError('Invalid stride : {:}'.format(stride))
|
||||
|
||||
self.bn = nn.BatchNorm2d(C_out, affine=affine)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu(x)
|
||||
y = self.pad(x)
|
||||
if self.stride == 2:
|
||||
out = torch.cat([self.convs[0](x), self.convs[1](y[:,:,1:,1:])], dim=1)
|
||||
else:
|
||||
out = torch.cat([self.convs[0](x), self.convs[1](y[:,:,1:-2,1:-2]),
|
||||
self.convs[2](y[:,:,2:-1,2:-1]), self.convs[3](y[:,:,3:,3:])], dim=1)
|
||||
out = self.bn(out)
|
||||
return out
|
||||
|
||||
def extra_repr(self):
|
||||
return 'C_in={C_in}, C_out={C_out}, stride={stride}'.format(**self.__dict__)
|
||||
@@ -1,36 +0,0 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
from .starts import prepare_seed
|
||||
from .starts import prepare_logger
|
||||
from .starts import get_machine_info
|
||||
from .starts import save_checkpoint
|
||||
from .starts import copy_checkpoint
|
||||
from .optimizers import get_optim_scheduler
|
||||
from .funcs_nasbench import evaluate_for_seed as bench_evaluate_for_seed
|
||||
from .funcs_nasbench import pure_evaluate as bench_pure_evaluate
|
||||
from .funcs_nasbench import get_nas_bench_loaders
|
||||
|
||||
|
||||
def get_procedures(procedure):
|
||||
from .basic_main import basic_train, basic_valid
|
||||
from .search_main import search_train, search_valid
|
||||
from .search_main_v2 import search_train_v2
|
||||
from .simple_KD_main import simple_KD_train, simple_KD_valid
|
||||
|
||||
train_funcs = {
|
||||
"basic": basic_train,
|
||||
"search": search_train,
|
||||
"Simple-KD": simple_KD_train,
|
||||
"search-v2": search_train_v2,
|
||||
}
|
||||
valid_funcs = {
|
||||
"basic": basic_valid,
|
||||
"search": search_valid,
|
||||
"Simple-KD": simple_KD_valid,
|
||||
"search-v2": search_valid,
|
||||
}
|
||||
|
||||
train_func = train_funcs[procedure]
|
||||
valid_func = valid_funcs[procedure]
|
||||
return train_func, valid_func
|
||||
@@ -1,100 +0,0 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.04 #
|
||||
#####################################################
|
||||
# To be finished.
|
||||
#
|
||||
import os, sys, time, torch
|
||||
from typing import Optional, Text, Callable
|
||||
|
||||
# modules in AutoDL
|
||||
from log_utils import AverageMeter
|
||||
from log_utils import time_string
|
||||
from .eval_funcs import obtain_accuracy
|
||||
|
||||
|
||||
def get_device(tensors):
|
||||
if isinstance(tensors, (list, tuple)):
|
||||
return get_device(tensors[0])
|
||||
elif isinstance(tensors, dict):
|
||||
for key, value in tensors.items():
|
||||
return get_device(value)
|
||||
else:
|
||||
return tensors.device
|
||||
|
||||
|
||||
def basic_train_fn(
|
||||
xloader,
|
||||
network,
|
||||
criterion,
|
||||
optimizer,
|
||||
metric,
|
||||
logger,
|
||||
):
|
||||
results = procedure(
|
||||
xloader,
|
||||
network,
|
||||
criterion,
|
||||
optimizer,
|
||||
metric,
|
||||
"train",
|
||||
logger,
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def basic_eval_fn(xloader, network, metric, logger):
|
||||
with torch.no_grad():
|
||||
results = procedure(
|
||||
xloader,
|
||||
network,
|
||||
None,
|
||||
None,
|
||||
metric,
|
||||
"valid",
|
||||
logger,
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def procedure(
|
||||
xloader,
|
||||
network,
|
||||
criterion,
|
||||
optimizer,
|
||||
metric,
|
||||
mode: Text,
|
||||
logger_fn: Callable = None,
|
||||
):
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
if mode.lower() == "train":
|
||||
network.train()
|
||||
elif mode.lower() == "valid":
|
||||
network.eval()
|
||||
else:
|
||||
raise ValueError("The mode is not right : {:}".format(mode))
|
||||
|
||||
end = time.time()
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
# calculate prediction and loss
|
||||
|
||||
if mode == "train":
|
||||
optimizer.zero_grad()
|
||||
|
||||
outputs = network(inputs)
|
||||
targets = targets.to(get_device(outputs))
|
||||
|
||||
if mode == "train":
|
||||
loss = criterion(outputs, targets)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# record
|
||||
with torch.no_grad():
|
||||
results = metric(outputs, targets)
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
return metric.get_info()
|
||||
@@ -1,155 +0,0 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import os, sys, time, torch
|
||||
|
||||
# modules in AutoDL
|
||||
from log_utils import AverageMeter
|
||||
from log_utils import time_string
|
||||
from .eval_funcs import obtain_accuracy
|
||||
|
||||
|
||||
def basic_train(
|
||||
xloader,
|
||||
network,
|
||||
criterion,
|
||||
scheduler,
|
||||
optimizer,
|
||||
optim_config,
|
||||
extra_info,
|
||||
print_freq,
|
||||
logger,
|
||||
):
|
||||
loss, acc1, acc5 = procedure(
|
||||
xloader,
|
||||
network,
|
||||
criterion,
|
||||
scheduler,
|
||||
optimizer,
|
||||
"train",
|
||||
optim_config,
|
||||
extra_info,
|
||||
print_freq,
|
||||
logger,
|
||||
)
|
||||
return loss, acc1, acc5
|
||||
|
||||
|
||||
def basic_valid(
|
||||
xloader, network, criterion, optim_config, extra_info, print_freq, logger
|
||||
):
|
||||
with torch.no_grad():
|
||||
loss, acc1, acc5 = procedure(
|
||||
xloader,
|
||||
network,
|
||||
criterion,
|
||||
None,
|
||||
None,
|
||||
"valid",
|
||||
None,
|
||||
extra_info,
|
||||
print_freq,
|
||||
logger,
|
||||
)
|
||||
return loss, acc1, acc5
|
||||
|
||||
|
||||
def procedure(
|
||||
xloader,
|
||||
network,
|
||||
criterion,
|
||||
scheduler,
|
||||
optimizer,
|
||||
mode,
|
||||
config,
|
||||
extra_info,
|
||||
print_freq,
|
||||
logger,
|
||||
):
|
||||
data_time, batch_time, losses, top1, top5 = (
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
)
|
||||
if mode == "train":
|
||||
network.train()
|
||||
elif mode == "valid":
|
||||
network.eval()
|
||||
else:
|
||||
raise ValueError("The mode is not right : {:}".format(mode))
|
||||
|
||||
# logger.log('[{:5s}] config :: auxiliary={:}, message={:}'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1, network.module.get_message()))
|
||||
logger.log(
|
||||
"[{:5s}] config :: auxiliary={:}".format(
|
||||
mode, config.auxiliary if hasattr(config, "auxiliary") else -1
|
||||
)
|
||||
)
|
||||
end = time.time()
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
if mode == "train":
|
||||
scheduler.update(None, 1.0 * i / len(xloader))
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
# calculate prediction and loss
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
|
||||
if mode == "train":
|
||||
optimizer.zero_grad()
|
||||
|
||||
features, logits = network(inputs)
|
||||
if isinstance(logits, list):
|
||||
assert len(logits) == 2, "logits must has {:} items instead of {:}".format(
|
||||
2, len(logits)
|
||||
)
|
||||
logits, logits_aux = logits
|
||||
else:
|
||||
logits, logits_aux = logits, None
|
||||
loss = criterion(logits, targets)
|
||||
if config is not None and hasattr(config, "auxiliary") and config.auxiliary > 0:
|
||||
loss_aux = criterion(logits_aux, targets)
|
||||
loss += config.auxiliary * loss_aux
|
||||
|
||||
if mode == "train":
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# record
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update(prec1.item(), inputs.size(0))
|
||||
top5.update(prec5.item(), inputs.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % print_freq == 0 or (i + 1) == len(xloader):
|
||||
Sstr = (
|
||||
" {:5s} ".format(mode.upper())
|
||||
+ time_string()
|
||||
+ " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader))
|
||||
)
|
||||
if scheduler is not None:
|
||||
Sstr += " {:}".format(scheduler.get_min_info())
|
||||
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
|
||||
batch_time=batch_time, data_time=data_time
|
||||
)
|
||||
Lstr = "Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format(
|
||||
loss=losses, top1=top1, top5=top5
|
||||
)
|
||||
Istr = "Size={:}".format(list(inputs.size()))
|
||||
logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Istr)
|
||||
|
||||
logger.log(
|
||||
" **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format(
|
||||
mode=mode.upper(),
|
||||
top1=top1,
|
||||
top5=top5,
|
||||
error1=100 - top1.avg,
|
||||
error5=100 - top5.avg,
|
||||
loss=losses.avg,
|
||||
)
|
||||
)
|
||||
return losses.avg, top1.avg, top5.avg
|
||||
@@ -1,20 +0,0 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.04 #
|
||||
#####################################################
|
||||
import abc
|
||||
|
||||
|
||||
def obtain_accuracy(output, target, topk=(1,)):
|
||||
"""Computes the precision@k for the specified values of k"""
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
||||
@@ -1,438 +0,0 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
|
||||
#####################################################
|
||||
import os, time, copy, torch, pathlib
|
||||
|
||||
# modules in AutoDL
|
||||
import datasets
|
||||
from config_utils import load_config
|
||||
from procedures import prepare_seed, get_optim_scheduler
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from models import get_cell_based_tiny_net
|
||||
from utils import get_model_infos
|
||||
from .eval_funcs import obtain_accuracy
|
||||
|
||||
|
||||
__all__ = ["evaluate_for_seed", "pure_evaluate", "get_nas_bench_loaders"]
|
||||
|
||||
|
||||
def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()):
|
||||
data_time, batch_time, batch = AverageMeter(), AverageMeter(), None
|
||||
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
latencies, device = [], torch.cuda.current_device()
|
||||
network.eval()
|
||||
with torch.no_grad():
|
||||
end = time.time()
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
targets = targets.cuda(device=device, non_blocking=True)
|
||||
inputs = inputs.cuda(device=device, non_blocking=True)
|
||||
data_time.update(time.time() - end)
|
||||
# forward
|
||||
features, logits = network(inputs)
|
||||
loss = criterion(logits, targets)
|
||||
batch_time.update(time.time() - end)
|
||||
if batch is None or batch == inputs.size(0):
|
||||
batch = inputs.size(0)
|
||||
latencies.append(batch_time.val - data_time.val)
|
||||
# record loss and accuracy
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update(prec1.item(), inputs.size(0))
|
||||
top5.update(prec5.item(), inputs.size(0))
|
||||
end = time.time()
|
||||
if len(latencies) > 2:
|
||||
latencies = latencies[1:]
|
||||
return losses.avg, top1.avg, top5.avg, latencies
|
||||
|
||||
|
||||
def procedure(xloader, network, criterion, scheduler, optimizer, mode: str):
|
||||
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
if mode == "train":
|
||||
network.train()
|
||||
elif mode == "valid":
|
||||
network.eval()
|
||||
else:
|
||||
raise ValueError("The mode is not right : {:}".format(mode))
|
||||
device = torch.cuda.current_device()
|
||||
data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time()
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
if mode == "train":
|
||||
scheduler.update(None, 1.0 * i / len(xloader))
|
||||
|
||||
targets = targets.cuda(device=device, non_blocking=True)
|
||||
if mode == "train":
|
||||
optimizer.zero_grad()
|
||||
# forward
|
||||
features, logits = network(inputs)
|
||||
loss = criterion(logits, targets)
|
||||
# backward
|
||||
if mode == "train":
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
# record loss and accuracy
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update(prec1.item(), inputs.size(0))
|
||||
top5.update(prec5.item(), inputs.size(0))
|
||||
# count time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
return losses.avg, top1.avg, top5.avg, batch_time.sum
|
||||
|
||||
|
||||
def evaluate_for_seed(
|
||||
arch_config, opt_config, train_loader, valid_loaders, seed: int, logger
|
||||
):
|
||||
|
||||
prepare_seed(seed) # random seed
|
||||
net = get_cell_based_tiny_net(arch_config)
|
||||
# net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num)
|
||||
flop, param = get_model_infos(net, opt_config.xshape)
|
||||
logger.log("Network : {:}".format(net.get_message()), False)
|
||||
logger.log(
|
||||
"{:} Seed-------------------------- {:} --------------------------".format(
|
||||
time_string(), seed
|
||||
)
|
||||
)
|
||||
logger.log("FLOP = {:} MB, Param = {:} MB".format(flop, param))
|
||||
# train and valid
|
||||
optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), opt_config)
|
||||
default_device = torch.cuda.current_device()
|
||||
network = torch.nn.DataParallel(net, device_ids=[default_device]).cuda(
|
||||
device=default_device
|
||||
)
|
||||
criterion = criterion.cuda(device=default_device)
|
||||
# start training
|
||||
start_time, epoch_time, total_epoch = (
|
||||
time.time(),
|
||||
AverageMeter(),
|
||||
opt_config.epochs + opt_config.warmup,
|
||||
)
|
||||
(
|
||||
train_losses,
|
||||
train_acc1es,
|
||||
train_acc5es,
|
||||
valid_losses,
|
||||
valid_acc1es,
|
||||
valid_acc5es,
|
||||
) = ({}, {}, {}, {}, {}, {})
|
||||
train_times, valid_times, lrs = {}, {}, {}
|
||||
for epoch in range(total_epoch):
|
||||
scheduler.update(epoch, 0.0)
|
||||
lr = min(scheduler.get_lr())
|
||||
train_loss, train_acc1, train_acc5, train_tm = procedure(
|
||||
train_loader, network, criterion, scheduler, optimizer, "train"
|
||||
)
|
||||
train_losses[epoch] = train_loss
|
||||
train_acc1es[epoch] = train_acc1
|
||||
train_acc5es[epoch] = train_acc5
|
||||
train_times[epoch] = train_tm
|
||||
lrs[epoch] = lr
|
||||
with torch.no_grad():
|
||||
for key, xloder in valid_loaders.items():
|
||||
valid_loss, valid_acc1, valid_acc5, valid_tm = procedure(
|
||||
xloder, network, criterion, None, None, "valid"
|
||||
)
|
||||
valid_losses["{:}@{:}".format(key, epoch)] = valid_loss
|
||||
valid_acc1es["{:}@{:}".format(key, epoch)] = valid_acc1
|
||||
valid_acc5es["{:}@{:}".format(key, epoch)] = valid_acc5
|
||||
valid_times["{:}@{:}".format(key, epoch)] = valid_tm
|
||||
|
||||
# measure elapsed time
|
||||
epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
need_time = "Time Left: {:}".format(
|
||||
convert_secs2time(epoch_time.avg * (total_epoch - epoch - 1), True)
|
||||
)
|
||||
logger.log(
|
||||
"{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%], lr={:}".format(
|
||||
time_string(),
|
||||
need_time,
|
||||
epoch,
|
||||
total_epoch,
|
||||
train_loss,
|
||||
train_acc1,
|
||||
train_acc5,
|
||||
valid_loss,
|
||||
valid_acc1,
|
||||
valid_acc5,
|
||||
lr,
|
||||
)
|
||||
)
|
||||
info_seed = {
|
||||
"flop": flop,
|
||||
"param": param,
|
||||
"arch_config": arch_config._asdict(),
|
||||
"opt_config": opt_config._asdict(),
|
||||
"total_epoch": total_epoch,
|
||||
"train_losses": train_losses,
|
||||
"train_acc1es": train_acc1es,
|
||||
"train_acc5es": train_acc5es,
|
||||
"train_times": train_times,
|
||||
"valid_losses": valid_losses,
|
||||
"valid_acc1es": valid_acc1es,
|
||||
"valid_acc5es": valid_acc5es,
|
||||
"valid_times": valid_times,
|
||||
"learning_rates": lrs,
|
||||
"net_state_dict": net.state_dict(),
|
||||
"net_string": "{:}".format(net),
|
||||
"finish-train": True,
|
||||
}
|
||||
return info_seed
|
||||
|
||||
|
||||
def get_nas_bench_loaders(workers):
|
||||
|
||||
torch.set_num_threads(workers)
|
||||
|
||||
root_dir = (pathlib.Path(__file__).parent / ".." / "..").resolve()
|
||||
torch_dir = pathlib.Path(os.environ["TORCH_HOME"])
|
||||
# cifar
|
||||
cifar_config_path = root_dir / "configs" / "nas-benchmark" / "CIFAR.config"
|
||||
cifar_config = load_config(cifar_config_path, None, None)
|
||||
get_datasets = datasets.get_datasets # a function to return the dataset
|
||||
break_line = "-" * 150
|
||||
print("{:} Create data-loader for all datasets".format(time_string()))
|
||||
print(break_line)
|
||||
TRAIN_CIFAR10, VALID_CIFAR10, xshape, class_num = get_datasets(
|
||||
"cifar10", str(torch_dir / "cifar.python"), -1
|
||||
)
|
||||
print(
|
||||
"original CIFAR-10 : {:} training images and {:} test images : {:} input shape : {:} number of classes".format(
|
||||
len(TRAIN_CIFAR10), len(VALID_CIFAR10), xshape, class_num
|
||||
)
|
||||
)
|
||||
cifar10_splits = load_config(
|
||||
root_dir / "configs" / "nas-benchmark" / "cifar-split.txt", None, None
|
||||
)
|
||||
assert cifar10_splits.train[:10] == [
|
||||
0,
|
||||
5,
|
||||
7,
|
||||
11,
|
||||
13,
|
||||
15,
|
||||
16,
|
||||
17,
|
||||
20,
|
||||
24,
|
||||
] and cifar10_splits.valid[:10] == [
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
4,
|
||||
6,
|
||||
8,
|
||||
9,
|
||||
10,
|
||||
12,
|
||||
14,
|
||||
]
|
||||
temp_dataset = copy.deepcopy(TRAIN_CIFAR10)
|
||||
temp_dataset.transform = VALID_CIFAR10.transform
|
||||
# data loader
|
||||
trainval_cifar10_loader = torch.utils.data.DataLoader(
|
||||
TRAIN_CIFAR10,
|
||||
batch_size=cifar_config.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
train_cifar10_loader = torch.utils.data.DataLoader(
|
||||
TRAIN_CIFAR10,
|
||||
batch_size=cifar_config.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.train),
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
valid_cifar10_loader = torch.utils.data.DataLoader(
|
||||
temp_dataset,
|
||||
batch_size=cifar_config.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.valid),
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
test__cifar10_loader = torch.utils.data.DataLoader(
|
||||
VALID_CIFAR10,
|
||||
batch_size=cifar_config.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
print(
|
||||
"CIFAR-10 : trval-loader has {:3d} batch with {:} per batch".format(
|
||||
len(trainval_cifar10_loader), cifar_config.batch_size
|
||||
)
|
||||
)
|
||||
print(
|
||||
"CIFAR-10 : train-loader has {:3d} batch with {:} per batch".format(
|
||||
len(train_cifar10_loader), cifar_config.batch_size
|
||||
)
|
||||
)
|
||||
print(
|
||||
"CIFAR-10 : valid-loader has {:3d} batch with {:} per batch".format(
|
||||
len(valid_cifar10_loader), cifar_config.batch_size
|
||||
)
|
||||
)
|
||||
print(
|
||||
"CIFAR-10 : test--loader has {:3d} batch with {:} per batch".format(
|
||||
len(test__cifar10_loader), cifar_config.batch_size
|
||||
)
|
||||
)
|
||||
print(break_line)
|
||||
# CIFAR-100
|
||||
TRAIN_CIFAR100, VALID_CIFAR100, xshape, class_num = get_datasets(
|
||||
"cifar100", str(torch_dir / "cifar.python"), -1
|
||||
)
|
||||
print(
|
||||
"original CIFAR-100: {:} training images and {:} test images : {:} input shape : {:} number of classes".format(
|
||||
len(TRAIN_CIFAR100), len(VALID_CIFAR100), xshape, class_num
|
||||
)
|
||||
)
|
||||
cifar100_splits = load_config(
|
||||
root_dir / "configs" / "nas-benchmark" / "cifar100-test-split.txt", None, None
|
||||
)
|
||||
assert cifar100_splits.xvalid[:10] == [
|
||||
1,
|
||||
3,
|
||||
4,
|
||||
5,
|
||||
8,
|
||||
10,
|
||||
13,
|
||||
14,
|
||||
15,
|
||||
16,
|
||||
] and cifar100_splits.xtest[:10] == [
|
||||
0,
|
||||
2,
|
||||
6,
|
||||
7,
|
||||
9,
|
||||
11,
|
||||
12,
|
||||
17,
|
||||
20,
|
||||
24,
|
||||
]
|
||||
train_cifar100_loader = torch.utils.data.DataLoader(
|
||||
TRAIN_CIFAR100,
|
||||
batch_size=cifar_config.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
valid_cifar100_loader = torch.utils.data.DataLoader(
|
||||
VALID_CIFAR100,
|
||||
batch_size=cifar_config.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid),
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
test__cifar100_loader = torch.utils.data.DataLoader(
|
||||
VALID_CIFAR100,
|
||||
batch_size=cifar_config.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest),
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
print(
|
||||
"CIFAR-100 : train-loader has {:3d} batch".format(len(train_cifar100_loader))
|
||||
)
|
||||
print(
|
||||
"CIFAR-100 : valid-loader has {:3d} batch".format(len(valid_cifar100_loader))
|
||||
)
|
||||
print(
|
||||
"CIFAR-100 : test--loader has {:3d} batch".format(len(test__cifar100_loader))
|
||||
)
|
||||
print(break_line)
|
||||
|
||||
imagenet16_config_path = "configs/nas-benchmark/ImageNet-16.config"
|
||||
imagenet16_config = load_config(imagenet16_config_path, None, None)
|
||||
TRAIN_ImageNet16_120, VALID_ImageNet16_120, xshape, class_num = get_datasets(
|
||||
"ImageNet16-120", str(torch_dir / "cifar.python" / "ImageNet16"), -1
|
||||
)
|
||||
print(
|
||||
"original TRAIN_ImageNet16_120: {:} training images and {:} test images : {:} input shape : {:} number of classes".format(
|
||||
len(TRAIN_ImageNet16_120), len(VALID_ImageNet16_120), xshape, class_num
|
||||
)
|
||||
)
|
||||
imagenet_splits = load_config(
|
||||
root_dir / "configs" / "nas-benchmark" / "imagenet-16-120-test-split.txt",
|
||||
None,
|
||||
None,
|
||||
)
|
||||
assert imagenet_splits.xvalid[:10] == [
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
6,
|
||||
7,
|
||||
8,
|
||||
9,
|
||||
12,
|
||||
16,
|
||||
18,
|
||||
] and imagenet_splits.xtest[:10] == [
|
||||
0,
|
||||
4,
|
||||
5,
|
||||
10,
|
||||
11,
|
||||
13,
|
||||
14,
|
||||
15,
|
||||
17,
|
||||
20,
|
||||
]
|
||||
train_imagenet_loader = torch.utils.data.DataLoader(
|
||||
TRAIN_ImageNet16_120,
|
||||
batch_size=imagenet16_config.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
valid_imagenet_loader = torch.utils.data.DataLoader(
|
||||
VALID_ImageNet16_120,
|
||||
batch_size=imagenet16_config.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xvalid),
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
test__imagenet_loader = torch.utils.data.DataLoader(
|
||||
VALID_ImageNet16_120,
|
||||
batch_size=imagenet16_config.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xtest),
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
print(
|
||||
"ImageNet-16-120 : train-loader has {:3d} batch with {:} per batch".format(
|
||||
len(train_imagenet_loader), imagenet16_config.batch_size
|
||||
)
|
||||
)
|
||||
print(
|
||||
"ImageNet-16-120 : valid-loader has {:3d} batch with {:} per batch".format(
|
||||
len(valid_imagenet_loader), imagenet16_config.batch_size
|
||||
)
|
||||
)
|
||||
print(
|
||||
"ImageNet-16-120 : test--loader has {:3d} batch with {:} per batch".format(
|
||||
len(test__imagenet_loader), imagenet16_config.batch_size
|
||||
)
|
||||
)
|
||||
|
||||
# 'cifar10', 'cifar100', 'ImageNet16-120'
|
||||
loaders = {
|
||||
"cifar10@trainval": trainval_cifar10_loader,
|
||||
"cifar10@train": train_cifar10_loader,
|
||||
"cifar10@valid": valid_cifar10_loader,
|
||||
"cifar10@test": test__cifar10_loader,
|
||||
"cifar100@train": train_cifar100_loader,
|
||||
"cifar100@valid": valid_cifar100_loader,
|
||||
"cifar100@test": test__cifar100_loader,
|
||||
"ImageNet16-120@train": train_imagenet_loader,
|
||||
"ImageNet16-120@valid": valid_imagenet_loader,
|
||||
"ImageNet16-120@test": test__imagenet_loader,
|
||||
}
|
||||
return loaders
|
||||
@@ -1,134 +0,0 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.04 #
|
||||
#####################################################
|
||||
import abc
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0.0
|
||||
self.avg = 0.0
|
||||
self.sum = 0.0
|
||||
self.count = 0.0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(val={val}, avg={avg}, count={count})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
|
||||
class Metric(abc.ABC):
|
||||
"""The default meta metric class."""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, predictions, targets):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_info(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}({inner})".format(
|
||||
name=self.__class__.__name__, inner=self.inner_repr()
|
||||
)
|
||||
|
||||
def inner_repr(self):
|
||||
return ""
|
||||
|
||||
|
||||
class ComposeMetric(Metric):
|
||||
"""The composed metric class."""
|
||||
|
||||
def __init__(self, *metric_list):
|
||||
self.reset()
|
||||
for metric in metric_list:
|
||||
self.append(metric)
|
||||
|
||||
def reset(self):
|
||||
self._metric_list = []
|
||||
|
||||
def append(self, metric):
|
||||
if not isinstance(metric, Metric):
|
||||
raise ValueError(
|
||||
"The input metric is not correct: {:}".format(type(metric))
|
||||
)
|
||||
self._metric_list.append(metric)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._metric_list)
|
||||
|
||||
def __call__(self, predictions, targets):
|
||||
results = list()
|
||||
for metric in self._metric_list:
|
||||
results.append(metric(predictions, targets))
|
||||
return results
|
||||
|
||||
def get_info(self):
|
||||
results = dict()
|
||||
for metric in self._metric_list:
|
||||
for key, value in metric.get_info().items():
|
||||
results[key] = value
|
||||
return results
|
||||
|
||||
def inner_repr(self):
|
||||
xlist = []
|
||||
for metric in self._metric_list:
|
||||
xlist.append(str(metric))
|
||||
return ",".join(xlist)
|
||||
|
||||
|
||||
class MSEMetric(Metric):
|
||||
"""The metric for mse."""
|
||||
|
||||
def reset(self):
|
||||
self._mse = AverageMeter()
|
||||
|
||||
def __call__(self, predictions, targets):
|
||||
if isinstance(predictions, torch.Tensor) and isinstance(targets, torch.Tensor):
|
||||
batch = predictions.shape[0]
|
||||
loss = torch.nn.functional.mse_loss(predictions.data, targets.data)
|
||||
loss = loss.item()
|
||||
self._mse.update(loss, batch)
|
||||
return loss
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_info(self):
|
||||
return {"mse": self._mse.avg}
|
||||
|
||||
|
||||
class SaveMetric(Metric):
|
||||
"""The metric for mse."""
|
||||
|
||||
def reset(self):
|
||||
self._predicts = []
|
||||
|
||||
def __call__(self, predictions, targets=None):
|
||||
if isinstance(predictions, torch.Tensor):
|
||||
predicts = predictions.cpu().numpy()
|
||||
self._predicts.append(predicts)
|
||||
return predicts
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_info(self):
|
||||
all_predicts = np.concatenate(self._predicts)
|
||||
return {"predictions": all_predicts}
|
||||
@@ -1,263 +0,0 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import math, torch
|
||||
import torch.nn as nn
|
||||
from bisect import bisect_right
|
||||
from torch.optim import Optimizer
|
||||
|
||||
|
||||
class _LRScheduler(object):
|
||||
def __init__(self, optimizer, warmup_epochs, epochs):
|
||||
if not isinstance(optimizer, Optimizer):
|
||||
raise TypeError("{:} is not an Optimizer".format(type(optimizer).__name__))
|
||||
self.optimizer = optimizer
|
||||
for group in optimizer.param_groups:
|
||||
group.setdefault("initial_lr", group["lr"])
|
||||
self.base_lrs = list(
|
||||
map(lambda group: group["initial_lr"], optimizer.param_groups)
|
||||
)
|
||||
self.max_epochs = epochs
|
||||
self.warmup_epochs = warmup_epochs
|
||||
self.current_epoch = 0
|
||||
self.current_iter = 0
|
||||
|
||||
def extra_repr(self):
|
||||
return ""
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(warmup={warmup_epochs}, max-epoch={max_epochs}, current::epoch={current_epoch}, iter={current_iter:.2f}".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
) + ", {:})".format(
|
||||
self.extra_repr()
|
||||
)
|
||||
|
||||
def state_dict(self):
|
||||
return {
|
||||
key: value for key, value in self.__dict__.items() if key != "optimizer"
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.__dict__.update(state_dict)
|
||||
|
||||
def get_lr(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_min_info(self):
|
||||
lrs = self.get_lr()
|
||||
return "#LR=[{:.6f}~{:.6f}] epoch={:03d}, iter={:4.2f}#".format(
|
||||
min(lrs), max(lrs), self.current_epoch, self.current_iter
|
||||
)
|
||||
|
||||
def get_min_lr(self):
|
||||
return min(self.get_lr())
|
||||
|
||||
def update(self, cur_epoch, cur_iter):
|
||||
if cur_epoch is not None:
|
||||
assert (
|
||||
isinstance(cur_epoch, int) and cur_epoch >= 0
|
||||
), "invalid cur-epoch : {:}".format(cur_epoch)
|
||||
self.current_epoch = cur_epoch
|
||||
if cur_iter is not None:
|
||||
assert (
|
||||
isinstance(cur_iter, float) and cur_iter >= 0
|
||||
), "invalid cur-iter : {:}".format(cur_iter)
|
||||
self.current_iter = cur_iter
|
||||
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
|
||||
param_group["lr"] = lr
|
||||
|
||||
|
||||
class CosineAnnealingLR(_LRScheduler):
|
||||
def __init__(self, optimizer, warmup_epochs, epochs, T_max, eta_min):
|
||||
self.T_max = T_max
|
||||
self.eta_min = eta_min
|
||||
super(CosineAnnealingLR, self).__init__(optimizer, warmup_epochs, epochs)
|
||||
|
||||
def extra_repr(self):
|
||||
return "type={:}, T-max={:}, eta-min={:}".format(
|
||||
"cosine", self.T_max, self.eta_min
|
||||
)
|
||||
|
||||
def get_lr(self):
|
||||
lrs = []
|
||||
for base_lr in self.base_lrs:
|
||||
if (
|
||||
self.current_epoch >= self.warmup_epochs
|
||||
and self.current_epoch < self.max_epochs
|
||||
):
|
||||
last_epoch = self.current_epoch - self.warmup_epochs
|
||||
# if last_epoch < self.T_max:
|
||||
# if last_epoch < self.max_epochs:
|
||||
lr = (
|
||||
self.eta_min
|
||||
+ (base_lr - self.eta_min)
|
||||
* (1 + math.cos(math.pi * last_epoch / self.T_max))
|
||||
/ 2
|
||||
)
|
||||
# else:
|
||||
# lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * (self.T_max-1.0) / self.T_max)) / 2
|
||||
elif self.current_epoch >= self.max_epochs:
|
||||
lr = self.eta_min
|
||||
else:
|
||||
lr = (
|
||||
self.current_epoch / self.warmup_epochs
|
||||
+ self.current_iter / self.warmup_epochs
|
||||
) * base_lr
|
||||
lrs.append(lr)
|
||||
return lrs
|
||||
|
||||
|
||||
class MultiStepLR(_LRScheduler):
|
||||
def __init__(self, optimizer, warmup_epochs, epochs, milestones, gammas):
|
||||
assert len(milestones) == len(gammas), "invalid {:} vs {:}".format(
|
||||
len(milestones), len(gammas)
|
||||
)
|
||||
self.milestones = milestones
|
||||
self.gammas = gammas
|
||||
super(MultiStepLR, self).__init__(optimizer, warmup_epochs, epochs)
|
||||
|
||||
def extra_repr(self):
|
||||
return "type={:}, milestones={:}, gammas={:}, base-lrs={:}".format(
|
||||
"multistep", self.milestones, self.gammas, self.base_lrs
|
||||
)
|
||||
|
||||
def get_lr(self):
|
||||
lrs = []
|
||||
for base_lr in self.base_lrs:
|
||||
if self.current_epoch >= self.warmup_epochs:
|
||||
last_epoch = self.current_epoch - self.warmup_epochs
|
||||
idx = bisect_right(self.milestones, last_epoch)
|
||||
lr = base_lr
|
||||
for x in self.gammas[:idx]:
|
||||
lr *= x
|
||||
else:
|
||||
lr = (
|
||||
self.current_epoch / self.warmup_epochs
|
||||
+ self.current_iter / self.warmup_epochs
|
||||
) * base_lr
|
||||
lrs.append(lr)
|
||||
return lrs
|
||||
|
||||
|
||||
class ExponentialLR(_LRScheduler):
|
||||
def __init__(self, optimizer, warmup_epochs, epochs, gamma):
|
||||
self.gamma = gamma
|
||||
super(ExponentialLR, self).__init__(optimizer, warmup_epochs, epochs)
|
||||
|
||||
def extra_repr(self):
|
||||
return "type={:}, gamma={:}, base-lrs={:}".format(
|
||||
"exponential", self.gamma, self.base_lrs
|
||||
)
|
||||
|
||||
def get_lr(self):
|
||||
lrs = []
|
||||
for base_lr in self.base_lrs:
|
||||
if self.current_epoch >= self.warmup_epochs:
|
||||
last_epoch = self.current_epoch - self.warmup_epochs
|
||||
assert last_epoch >= 0, "invalid last_epoch : {:}".format(last_epoch)
|
||||
lr = base_lr * (self.gamma ** last_epoch)
|
||||
else:
|
||||
lr = (
|
||||
self.current_epoch / self.warmup_epochs
|
||||
+ self.current_iter / self.warmup_epochs
|
||||
) * base_lr
|
||||
lrs.append(lr)
|
||||
return lrs
|
||||
|
||||
|
||||
class LinearLR(_LRScheduler):
|
||||
def __init__(self, optimizer, warmup_epochs, epochs, max_LR, min_LR):
|
||||
self.max_LR = max_LR
|
||||
self.min_LR = min_LR
|
||||
super(LinearLR, self).__init__(optimizer, warmup_epochs, epochs)
|
||||
|
||||
def extra_repr(self):
|
||||
return "type={:}, max_LR={:}, min_LR={:}, base-lrs={:}".format(
|
||||
"LinearLR", self.max_LR, self.min_LR, self.base_lrs
|
||||
)
|
||||
|
||||
def get_lr(self):
|
||||
lrs = []
|
||||
for base_lr in self.base_lrs:
|
||||
if self.current_epoch >= self.warmup_epochs:
|
||||
last_epoch = self.current_epoch - self.warmup_epochs
|
||||
assert last_epoch >= 0, "invalid last_epoch : {:}".format(last_epoch)
|
||||
ratio = (
|
||||
(self.max_LR - self.min_LR)
|
||||
* last_epoch
|
||||
/ self.max_epochs
|
||||
/ self.max_LR
|
||||
)
|
||||
lr = base_lr * (1 - ratio)
|
||||
else:
|
||||
lr = (
|
||||
self.current_epoch / self.warmup_epochs
|
||||
+ self.current_iter / self.warmup_epochs
|
||||
) * base_lr
|
||||
lrs.append(lr)
|
||||
return lrs
|
||||
|
||||
|
||||
class CrossEntropyLabelSmooth(nn.Module):
|
||||
def __init__(self, num_classes, epsilon):
|
||||
super(CrossEntropyLabelSmooth, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.epsilon = epsilon
|
||||
self.logsoftmax = nn.LogSoftmax(dim=1)
|
||||
|
||||
def forward(self, inputs, targets):
|
||||
log_probs = self.logsoftmax(inputs)
|
||||
targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
|
||||
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
|
||||
loss = (-targets * log_probs).mean(0).sum()
|
||||
return loss
|
||||
|
||||
|
||||
def get_optim_scheduler(parameters, config):
|
||||
assert (
|
||||
hasattr(config, "optim")
|
||||
and hasattr(config, "scheduler")
|
||||
and hasattr(config, "criterion")
|
||||
), "config must have optim / scheduler / criterion keys instead of {:}".format(
|
||||
config
|
||||
)
|
||||
if config.optim == "SGD":
|
||||
optim = torch.optim.SGD(
|
||||
parameters,
|
||||
config.LR,
|
||||
momentum=config.momentum,
|
||||
weight_decay=config.decay,
|
||||
nesterov=config.nesterov,
|
||||
)
|
||||
elif config.optim == "RMSprop":
|
||||
optim = torch.optim.RMSprop(
|
||||
parameters, config.LR, momentum=config.momentum, weight_decay=config.decay
|
||||
)
|
||||
else:
|
||||
raise ValueError("invalid optim : {:}".format(config.optim))
|
||||
|
||||
if config.scheduler == "cos":
|
||||
T_max = getattr(config, "T_max", config.epochs)
|
||||
scheduler = CosineAnnealingLR(
|
||||
optim, config.warmup, config.epochs, T_max, config.eta_min
|
||||
)
|
||||
elif config.scheduler == "multistep":
|
||||
scheduler = MultiStepLR(
|
||||
optim, config.warmup, config.epochs, config.milestones, config.gammas
|
||||
)
|
||||
elif config.scheduler == "exponential":
|
||||
scheduler = ExponentialLR(optim, config.warmup, config.epochs, config.gamma)
|
||||
elif config.scheduler == "linear":
|
||||
scheduler = LinearLR(
|
||||
optim, config.warmup, config.epochs, config.LR, config.LR_min
|
||||
)
|
||||
else:
|
||||
raise ValueError("invalid scheduler : {:}".format(config.scheduler))
|
||||
|
||||
if config.criterion == "Softmax":
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
elif config.criterion == "SmoothSoftmax":
|
||||
criterion = CrossEntropyLabelSmooth(config.class_num, config.label_smooth)
|
||||
else:
|
||||
raise ValueError("invalid criterion : {:}".format(config.criterion))
|
||||
return optim, scheduler, criterion
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user