diff --git a/exps/LFNA/basic-his.py b/exps/LFNA/basic-his.py index 216ea66..3352687 100644 --- a/exps/LFNA/basic-his.py +++ b/exps/LFNA/basic-his.py @@ -23,6 +23,9 @@ from datasets.synthetic_core import get_synthetic_env from models.xcore import get_model +from lfna_utils import lfna_setup + + def subsample(historical_x, historical_y, maxn=10000): total = historical_x.size(0) if total <= maxn: @@ -33,24 +36,7 @@ def subsample(historical_x, historical_y, maxn=10000): def main(args): - prepare_seed(args.rand_seed) - logger = prepare_logger(args) - - cache_path = ( - logger.path(None) / ".." / "env-{:}-info.pth".format(args.env_version) - ).resolve() - if cache_path.exists(): - env_info = torch.load(cache_path) - else: - env_info = dict() - dynamic_env = get_synthetic_env(version=args.env_version) - env_info["total"] = len(dynamic_env) - for idx, (timestamp, (_allx, _ally)) in enumerate(tqdm(dynamic_env)): - env_info["{:}-timestamp".format(idx)] = timestamp - env_info["{:}-x".format(idx)] = _allx - env_info["{:}-y".format(idx)] = _ally - env_info["dynamic_env"] = dynamic_env - torch.save(env_info, cache_path) + logger, env_info = lfna_setup(args) # check indexes to be evaluated to_evaluate_indexes = split_str2indexes(args.srange, env_info["total"], None) @@ -60,6 +46,8 @@ def main(args): ) ) + w_container_per_epoch = dict() + per_timestamp_time, start_time = AverageMeter(), time.time() for i, idx in enumerate(to_evaluate_indexes): @@ -89,9 +77,6 @@ def main(args): output_dim=1, act_cls="leaky_relu", norm_cls="identity", - # norm_cls="simple_norm", - # mean=mean, - # std=std, ) model = get_model(dict(model_type="simple_mlp"), **model_kwargs) # build optimizer @@ -144,6 +129,7 @@ def main(args): save_path = logger.path(None) / "{:04d}-{:04d}.pth".format( idx, env_info["total"] ) + w_container_per_epoch[idx] = model.get_w_container().no_grad_clone() save_checkpoint( { "model_state_dict": model.state_dict(), @@ -155,10 +141,14 @@ def main(args): logger, ) logger.log("") - per_timestamp_time.update(time.time() - start_time) start_time = time.time() + save_checkpoint( + {"w_container_per_epoch": w_container_per_epoch}, + logger.path(None) / "final-ckp.pth", + logger, + ) logger.log("-" * 200 + "\n") logger.close() @@ -210,5 +200,7 @@ if __name__ == "__main__": if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) assert args.save_dir is not None, "The save dir argument can not be None" - args.save_dir = "{:}-{:}".format(args.save_dir, args.env_version) + args.save_dir = "{:}-{:}-d{:}".format( + args.save_dir, args.env_version, args.hidden_dim + ) main(args) diff --git a/exps/LFNA/basic-maml.py b/exps/LFNA/basic-maml.py new file mode 100644 index 0000000..970800c --- /dev/null +++ b/exps/LFNA/basic-maml.py @@ -0,0 +1,220 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # +##################################################### +# python exps/LFNA/basic-maml.py --env_version v1 # +# python exps/LFNA/basic-maml.py --env_version v2 # +##################################################### +import sys, time, copy, torch, random, argparse +from tqdm import tqdm +from copy import deepcopy +from pathlib import Path + +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) +from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint +from log_utils import time_string +from log_utils import AverageMeter, convert_secs2time + +from utils import split_str2indexes + +from procedures.advanced_main import basic_train_fn, basic_eval_fn +from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric +from datasets.synthetic_core import get_synthetic_env +from models.xcore import get_model +from xlayers import super_core + +from lfna_utils import lfna_setup, TimeData + + +class MAML: + """A LFNA meta-model that uses the MLP as delta-net.""" + + def __init__(self, container, criterion, meta_lr, inner_lr=0.01, inner_step=1): + self.criterion = criterion + self.container = container + self.meta_optimizer = torch.optim.Adam( + self.container.parameters(), lr=meta_lr, amsgrad=True + ) + self.inner_lr = inner_lr + self.inner_step = inner_step + + def adapt(self, model, dataset): + # create a container for the future timestamp + y_hat = model.forward_with_container(dataset.x, self.container) + loss = self.criterion(y_hat, dataset.y) + grads = torch.autograd.grad(loss, self.container.parameters()) + + fast_container = self.container.additive( + [-self.inner_lr * grad for grad in grads] + ) + import pdb + + pdb.set_trace() + w_container.requires_grad_(True) + containers = [w_container] + for idx, dataset in enumerate(seq_datasets): + x, y = dataset.x, dataset.y + y_hat = model.forward_with_container(x, containers[-1]) + loss = criterion(y_hat, y) + gradients = torch.autograd.grad(loss, containers[-1].tensors) + with torch.no_grad(): + flatten_w = containers[-1].flatten().view(-1, 1) + flatten_g = containers[-1].flatten(gradients).view(-1, 1) + input_statistics = torch.tensor([x.mean(), x.std()]).view(1, 2) + input_statistics = input_statistics.expand(flatten_w.numel(), -1) + delta_inputs = torch.cat((flatten_w, flatten_g, input_statistics), dim=-1) + delta = self.delta_net(delta_inputs).view(-1) + delta = torch.clamp(delta, -0.5, 0.5) + unflatten_delta = containers[-1].unflatten(delta) + future_container = containers[-1].no_grad_clone().additive(unflatten_delta) + # future_container = containers[-1].additive(unflatten_delta) + containers.append(future_container) + # containers = containers[1:] + meta_loss = [] + temp_containers = [] + for idx, dataset in enumerate(seq_datasets): + if idx == 0: + continue + current_container = containers[idx] + y_hat = model.forward_with_container(dataset.x, current_container) + loss = criterion(y_hat, dataset.y) + meta_loss.append(loss) + temp_containers.append((dataset.timestamp, current_container, -loss.item())) + meta_loss = sum(meta_loss) + w_container.requires_grad_(False) + # meta_loss.backward() + # self.meta_optimizer.step() + return meta_loss, temp_containers + + def step(self): + torch.nn.utils.clip_grad_norm_(self.delta_net.parameters(), 1.0) + self.meta_optimizer.step() + + def zero_grad(self): + self.meta_optimizer.zero_grad() + + +def main(args): + logger, env_info = lfna_setup(args) + + total_time = env_info["total"] + for i in range(total_time): + for xkey in ("timestamp", "x", "y"): + nkey = "{:}-{:}".format(i, xkey) + assert nkey in env_info, "{:} no in {:}".format(nkey, list(env_info.keys())) + train_time_bar = total_time // 2 + base_model = get_model( + dict(model_type="simple_mlp"), + act_cls="leaky_relu", + norm_cls="identity", + input_dim=1, + output_dim=1, + ) + + w_container = base_model.get_w_container() + criterion = torch.nn.MSELoss() + print("There are {:} weights.".format(w_container.numel())) + + maml = MAML(w_container, criterion, args.meta_lr, args.inner_lr, args.inner_step) + + # meta-training + per_epoch_time, start_time = AverageMeter(), time.time() + for iepoch in range(args.epochs): + + need_time = "Time Left: {:}".format( + convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) + ) + logger.log( + "[{:}] [{:04d}/{:04d}] ".format(time_string(), iepoch, args.epochs) + + need_time + ) + + maml.zero_grad() + + all_meta_losses = [] + for ibatch in range(args.meta_batch): + sampled_timestamp = random.randint(0, train_time_bar) + past_dataset = TimeData( + sampled_timestamp, + env_info["{:}-x".format(sampled_timestamp)], + env_info["{:}-y".format(sampled_timestamp)], + ) + future_dataset = TimeData( + sampled_timestamp + 1, + env_info["{:}-x".format(sampled_timestamp + 1)], + env_info["{:}-y".format(sampled_timestamp + 1)], + ) + maml.adapt(base_model, past_dataset) + import pdb + + pdb.set_trace() + + meta_loss = torch.stack(all_meta_losses).mean() + meta_loss.backward() + adaptor.step() + + debug_str = pool.debug_info(debug_timestamp) + logger.log("meta-loss: {:.4f}".format(meta_loss.item())) + + per_epoch_time.update(time.time() - start_time) + start_time = time.time() + + logger.log("-" * 200 + "\n") + logger.close() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("Use the data in the past.") + parser.add_argument( + "--save_dir", + type=str, + default="./outputs/lfna-synthetic/maml", + help="The checkpoint directory.", + ) + parser.add_argument( + "--env_version", + type=str, + required=True, + help="The synthetic enviornment version.", + ) + parser.add_argument( + "--meta_lr", + type=float, + default=0.01, + help="The learning rate for the MAML optimizer (default is Adam)", + ) + parser.add_argument( + "--inner_lr", + type=float, + default=0.01, + help="The learning rate for the inner optimization", + ) + parser.add_argument( + "--inner_step", type=int, default=1, help="The inner loop steps for MAML." + ) + parser.add_argument( + "--meta_batch", + type=int, + default=5, + help="The batch size for the meta-model", + ) + parser.add_argument( + "--epochs", + type=int, + default=1000, + help="The total number of epochs.", + ) + parser.add_argument( + "--workers", + type=int, + default=4, + help="The number of data loading workers (default: 4)", + ) + # Random Seed + parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") + args = parser.parse_args() + if args.rand_seed is None or args.rand_seed < 0: + args.rand_seed = random.randint(1, 100000) + assert args.save_dir is not None, "The save dir argument can not be None" + main(args) diff --git a/exps/LFNA/basic-same.py b/exps/LFNA/basic-same.py index a1bb87b..f70265d 100644 --- a/exps/LFNA/basic-same.py +++ b/exps/LFNA/basic-same.py @@ -1,7 +1,8 @@ ##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # ##################################################### -# python exps/LFNA/basic-same.py --srange 1-999 +# python exps/LFNA/basic-same.py --srange 1-999 --env_version v1 --hidden_dim 16 +# python exps/LFNA/basic-same.py --srange 1-999 --env_version v2 --hidden_dim ##################################################### import sys, time, copy, torch, random, argparse from tqdm import tqdm @@ -22,6 +23,8 @@ from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric from datasets.synthetic_core import get_synthetic_env from models.xcore import get_model +from lfna_utils import lfna_setup + def subsample(historical_x, historical_y, maxn=10000): total = historical_x.size(0) @@ -33,22 +36,7 @@ def subsample(historical_x, historical_y, maxn=10000): def main(args): - prepare_seed(args.rand_seed) - logger = prepare_logger(args) - - cache_path = (logger.path(None) / ".." / "env-info.pth").resolve() - if cache_path.exists(): - env_info = torch.load(cache_path) - else: - env_info = dict() - dynamic_env = get_synthetic_env() - env_info["total"] = len(dynamic_env) - for idx, (timestamp, (_allx, _ally)) in enumerate(tqdm(dynamic_env)): - env_info["{:}-timestamp".format(idx)] = timestamp - env_info["{:}-x".format(idx)] = _allx - env_info["{:}-y".format(idx)] = _ally - env_info["dynamic_env"] = dynamic_env - torch.save(env_info, cache_path) + logger, env_info, model_kwargs = lfna_setup(args) # check indexes to be evaluated to_evaluate_indexes = split_str2indexes(args.srange, env_info["total"], None) @@ -78,16 +66,6 @@ def main(args): historical_x = env_info["{:}-x".format(idx)] historical_y = env_info["{:}-y".format(idx)] # build model - mean, std = historical_x.mean().item(), historical_x.std().item() - model_kwargs = dict( - input_dim=1, - output_dim=1, - act_cls="leaky_relu", - norm_cls="identity", - # norm_cls="simple_norm", - # mean=mean, - # std=std, - ) model = get_model(dict(model_type="simple_mlp"), **model_kwargs) # build optimizer optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) @@ -151,9 +129,9 @@ def main(args): logger, ) logger.log("") - per_timestamp_time.update(time.time() - start_time) start_time = time.time() + save_checkpoint( {"w_container_per_epoch": w_container_per_epoch}, logger.path(None) / "final-ckp.pth", @@ -172,6 +150,18 @@ if __name__ == "__main__": default="./outputs/lfna-synthetic/use-same-timestamp", help="The checkpoint directory.", ) + parser.add_argument( + "--env_version", + type=str, + required=True, + help="The synthetic enviornment version.", + ) + parser.add_argument( + "--hidden_dim", + type=int, + required=True, + help="The hidden dimension.", + ) parser.add_argument( "--init_lr", type=float, @@ -205,4 +195,7 @@ if __name__ == "__main__": if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) assert args.save_dir is not None, "The save dir argument can not be None" + args.save_dir = "{:}-{:}-d{:}".format( + args.save_dir, args.env_version, args.hidden_dim + ) main(args) diff --git a/exps/LFNA/lfna-v0.py b/exps/LFNA/lfna-v0.py new file mode 100644 index 0000000..e3f937b --- /dev/null +++ b/exps/LFNA/lfna-v0.py @@ -0,0 +1,272 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # +##################################################### +# python exps/LFNA/lfna-v0.py +##################################################### +import sys, time, copy, torch, random, argparse +from tqdm import tqdm +from copy import deepcopy +from pathlib import Path + +lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) +from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint +from log_utils import time_string +from log_utils import AverageMeter, convert_secs2time + +from utils import split_str2indexes + +from procedures.advanced_main import basic_train_fn, basic_eval_fn +from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric +from datasets.synthetic_core import get_synthetic_env +from models.xcore import get_model +from xlayers import super_core + + +class LFNAmlp: + """A LFNA meta-model that uses the MLP as delta-net.""" + + def __init__(self, obs_dim, hidden_sizes, act_name): + self.delta_net = super_core.SuperSequential( + super_core.SuperLinear(obs_dim, hidden_sizes[0]), + super_core.super_name2activation[act_name](), + super_core.SuperLinear(hidden_sizes[0], hidden_sizes[1]), + super_core.super_name2activation[act_name](), + super_core.SuperLinear(hidden_sizes[1], 1), + ) + self.meta_optimizer = torch.optim.Adam( + self.delta_net.parameters(), lr=0.01, amsgrad=True + ) + + def adapt(self, model, criterion, w_container, seq_datasets): + w_container.requires_grad_(True) + containers = [w_container] + for idx, dataset in enumerate(seq_datasets): + x, y = dataset.x, dataset.y + y_hat = model.forward_with_container(x, containers[-1]) + loss = criterion(y_hat, y) + gradients = torch.autograd.grad(loss, containers[-1].tensors) + with torch.no_grad(): + flatten_w = containers[-1].flatten().view(-1, 1) + flatten_g = containers[-1].flatten(gradients).view(-1, 1) + input_statistics = torch.tensor([x.mean(), x.std()]).view(1, 2) + input_statistics = input_statistics.expand(flatten_w.numel(), -1) + delta_inputs = torch.cat((flatten_w, flatten_g, input_statistics), dim=-1) + delta = self.delta_net(delta_inputs).view(-1) + delta = torch.clamp(delta, -0.5, 0.5) + unflatten_delta = containers[-1].unflatten(delta) + future_container = containers[-1].no_grad_clone().additive(unflatten_delta) + # future_container = containers[-1].additive(unflatten_delta) + containers.append(future_container) + # containers = containers[1:] + meta_loss = [] + temp_containers = [] + for idx, dataset in enumerate(seq_datasets): + if idx == 0: + continue + current_container = containers[idx] + y_hat = model.forward_with_container(dataset.x, current_container) + loss = criterion(y_hat, dataset.y) + meta_loss.append(loss) + temp_containers.append((dataset.timestamp, current_container, -loss.item())) + meta_loss = sum(meta_loss) + w_container.requires_grad_(False) + # meta_loss.backward() + # self.meta_optimizer.step() + return meta_loss, temp_containers + + def step(self): + torch.nn.utils.clip_grad_norm_(self.delta_net.parameters(), 1.0) + self.meta_optimizer.step() + + def zero_grad(self): + self.meta_optimizer.zero_grad() + self.delta_net.zero_grad() + + +class TimeData: + def __init__(self, timestamp, xs, ys): + self._timestamp = timestamp + self._xs = xs + self._ys = ys + + @property + def x(self): + return self._xs + + @property + def y(self): + return self._ys + + @property + def timestamp(self): + return self._timestamp + + +class Population: + """A population used to maintain models at different timestamps.""" + + def __init__(self): + self._time2model = dict() + self._time2score = dict() # higher is better + + def append(self, timestamp, model, score): + if timestamp in self._time2model: + if self._time2score[timestamp] > score: + return + self._time2model[timestamp] = model.no_grad_clone() + self._time2score[timestamp] = score + + def query(self, timestamp): + closet_timestamp = None + for xtime, model in self._time2model.items(): + if closet_timestamp is None or ( + xtime < timestamp and timestamp - closet_timestamp >= timestamp - xtime + ): + closet_timestamp = xtime + return self._time2model[closet_timestamp], closet_timestamp + + def debug_info(self, timestamps): + xstrs = [] + for timestamp in timestamps: + if timestamp in self._time2score: + xstrs.append( + "{:04d}: {:.4f}".format(timestamp, self._time2score[timestamp]) + ) + return ", ".join(xstrs) + + +def main(args): + prepare_seed(args.rand_seed) + logger = prepare_logger(args) + + cache_path = (logger.path(None) / ".." / "env-info.pth").resolve() + if cache_path.exists(): + env_info = torch.load(cache_path) + else: + env_info = dict() + dynamic_env = get_synthetic_env() + env_info["total"] = len(dynamic_env) + for idx, (timestamp, (_allx, _ally)) in enumerate(tqdm(dynamic_env)): + env_info["{:}-timestamp".format(idx)] = timestamp + env_info["{:}-x".format(idx)] = _allx + env_info["{:}-y".format(idx)] = _ally + env_info["dynamic_env"] = dynamic_env + torch.save(env_info, cache_path) + + total_time = env_info["total"] + for i in range(total_time): + for xkey in ("timestamp", "x", "y"): + nkey = "{:}-{:}".format(i, xkey) + assert nkey in env_info, "{:} no in {:}".format(nkey, list(env_info.keys())) + train_time_bar = total_time // 2 + base_model = get_model( + dict(model_type="simple_mlp"), + act_cls="leaky_relu", + norm_cls="identity", + input_dim=1, + output_dim=1, + ) + + w_container = base_model.get_w_container() + criterion = torch.nn.MSELoss() + print("There are {:} weights.".format(w_container.numel())) + + adaptor = LFNAmlp(4, (50, 20), "leaky_relu") + + pool = Population() + pool.append(0, w_container, -100) + + # LFNA meta-training + per_epoch_time, start_time = AverageMeter(), time.time() + for iepoch in range(args.epochs): + + need_time = "Time Left: {:}".format( + convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) + ) + logger.log( + "[{:}] [{:04d}/{:04d}] ".format(time_string(), iepoch, args.epochs) + + need_time + ) + + adaptor.zero_grad() + + debug_timestamp = set() + all_meta_losses = [] + for ibatch in range(args.meta_batch): + sampled_timestamp = random.randint(0, train_time_bar) + query_w_container, query_timestamp = pool.query(sampled_timestamp) + # def adapt(self, model, w_container, xs, ys): + seq_datasets = [] + # xs, ys = [], [] + for it in range(sampled_timestamp, sampled_timestamp + args.max_seq): + xs = env_info["{:}-x".format(it)] + ys = env_info["{:}-y".format(it)] + seq_datasets.append(TimeData(it, xs, ys)) + temp_meta_loss, temp_containers = adaptor.adapt( + base_model, criterion, query_w_container, seq_datasets + ) + all_meta_losses.append(temp_meta_loss) + for temp_time, temp_container, temp_score in temp_containers: + pool.append(temp_time, temp_container, temp_score) + debug_timestamp.add(temp_time) + meta_loss = torch.stack(all_meta_losses).mean() + meta_loss.backward() + adaptor.step() + + debug_str = pool.debug_info(debug_timestamp) + logger.log("meta-loss: {:.4f}".format(meta_loss.item())) + + per_epoch_time.update(time.time() - start_time) + start_time = time.time() + + logger.log("-" * 200 + "\n") + logger.close() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("Use the data in the past.") + parser.add_argument( + "--save_dir", + type=str, + default="./outputs/lfna-synthetic/lfna-v1", + help="The checkpoint directory.", + ) + parser.add_argument( + "--init_lr", + type=float, + default=0.1, + help="The initial learning rate for the optimizer (default is Adam)", + ) + parser.add_argument( + "--meta_batch", + type=int, + default=5, + help="The batch size for the meta-model", + ) + parser.add_argument( + "--epochs", + type=int, + default=1000, + help="The total number of epochs.", + ) + parser.add_argument( + "--max_seq", + type=int, + default=5, + help="The maximum length of the sequence.", + ) + parser.add_argument( + "--workers", + type=int, + default=4, + help="The number of data loading workers (default: 4)", + ) + # Random Seed + parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") + args = parser.parse_args() + if args.rand_seed is None or args.rand_seed < 0: + args.rand_seed = random.randint(1, 100000) + assert args.save_dir is not None, "The save dir argument can not be None" + main(args) diff --git a/exps/LFNA/lfna_utils.py b/exps/LFNA/lfna_utils.py new file mode 100644 index 0000000..a46854c --- /dev/null +++ b/exps/LFNA/lfna_utils.py @@ -0,0 +1,61 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # +##################################################### +import torch +from tqdm import tqdm +from procedures import prepare_seed, prepare_logger +from datasets.synthetic_core import get_synthetic_env + + +def lfna_setup(args): + prepare_seed(args.rand_seed) + logger = prepare_logger(args) + + cache_path = ( + logger.path(None) / ".." / "env-{:}-info.pth".format(args.env_version) + ).resolve() + if cache_path.exists(): + env_info = torch.load(cache_path) + else: + env_info = dict() + dynamic_env = get_synthetic_env(version=args.env_version) + env_info["total"] = len(dynamic_env) + for idx, (timestamp, (_allx, _ally)) in enumerate(tqdm(dynamic_env)): + env_info["{:}-timestamp".format(idx)] = timestamp + env_info["{:}-x".format(idx)] = _allx + env_info["{:}-y".format(idx)] = _ally + env_info["dynamic_env"] = dynamic_env + torch.save(env_info, cache_path) + + model_kwargs = dict( + input_dim=1, + output_dim=1, + hidden_dim=args.hidden_dim, + act_cls="leaky_relu", + norm_cls="identity", + ) + return logger, env_info, model_kwargs + + +class TimeData: + def __init__(self, timestamp, xs, ys): + self._timestamp = timestamp + self._xs = xs + self._ys = ys + + @property + def x(self): + return self._xs + + @property + def y(self): + return self._ys + + @property + def timestamp(self): + return self._timestamp + + def __repr__(self): + return "{name}(timestamp={:}, with {num} samples)".format( + name=self.__class__.__name__, timestamp=self._timestamp, num=len(self._xs) + ) diff --git a/lib/models/CifarDenseNet.py b/lib/models/CifarDenseNet.py index 1d5dd5b..eaf8e98 100644 --- a/lib/models/CifarDenseNet.py +++ b/lib/models/CifarDenseNet.py @@ -8,98 +8,110 @@ from .initialization import initialize_resnet class Bottleneck(nn.Module): - def __init__(self, nChannels, growthRate): - super(Bottleneck, self).__init__() - interChannels = 4*growthRate - self.bn1 = nn.BatchNorm2d(nChannels) - self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1, bias=False) - self.bn2 = nn.BatchNorm2d(interChannels) - self.conv2 = nn.Conv2d(interChannels, growthRate, kernel_size=3, padding=1, bias=False) + def __init__(self, nChannels, growthRate): + super(Bottleneck, self).__init__() + interChannels = 4 * growthRate + self.bn1 = nn.BatchNorm2d(nChannels) + self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1, bias=False) + self.bn2 = nn.BatchNorm2d(interChannels) + self.conv2 = nn.Conv2d( + interChannels, growthRate, kernel_size=3, padding=1, bias=False + ) - def forward(self, x): - out = self.conv1(F.relu(self.bn1(x))) - out = self.conv2(F.relu(self.bn2(out))) - out = torch.cat((x, out), 1) - return out + def forward(self, x): + out = self.conv1(F.relu(self.bn1(x))) + out = self.conv2(F.relu(self.bn2(out))) + out = torch.cat((x, out), 1) + return out class SingleLayer(nn.Module): - def __init__(self, nChannels, growthRate): - super(SingleLayer, self).__init__() - self.bn1 = nn.BatchNorm2d(nChannels) - self.conv1 = nn.Conv2d(nChannels, growthRate, kernel_size=3, padding=1, bias=False) + def __init__(self, nChannels, growthRate): + super(SingleLayer, self).__init__() + self.bn1 = nn.BatchNorm2d(nChannels) + self.conv1 = nn.Conv2d( + nChannels, growthRate, kernel_size=3, padding=1, bias=False + ) - def forward(self, x): - out = self.conv1(F.relu(self.bn1(x))) - out = torch.cat((x, out), 1) - return out + def forward(self, x): + out = self.conv1(F.relu(self.bn1(x))) + out = torch.cat((x, out), 1) + return out class Transition(nn.Module): - def __init__(self, nChannels, nOutChannels): - super(Transition, self).__init__() - self.bn1 = nn.BatchNorm2d(nChannels) - self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1, bias=False) + def __init__(self, nChannels, nOutChannels): + super(Transition, self).__init__() + self.bn1 = nn.BatchNorm2d(nChannels) + self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1, bias=False) - def forward(self, x): - out = self.conv1(F.relu(self.bn1(x))) - out = F.avg_pool2d(out, 2) - return out + def forward(self, x): + out = self.conv1(F.relu(self.bn1(x))) + out = F.avg_pool2d(out, 2) + return out class DenseNet(nn.Module): - def __init__(self, growthRate, depth, reduction, nClasses, bottleneck): - super(DenseNet, self).__init__() + def __init__(self, growthRate, depth, reduction, nClasses, bottleneck): + super(DenseNet, self).__init__() - if bottleneck: nDenseBlocks = int( (depth-4) / 6 ) - else : nDenseBlocks = int( (depth-4) / 3 ) + if bottleneck: + nDenseBlocks = int((depth - 4) / 6) + else: + nDenseBlocks = int((depth - 4) / 3) - self.message = 'CifarDenseNet : block : {:}, depth : {:}, reduction : {:}, growth-rate = {:}, class = {:}'.format('bottleneck' if bottleneck else 'basic', depth, reduction, growthRate, nClasses) + self.message = "CifarDenseNet : block : {:}, depth : {:}, reduction : {:}, growth-rate = {:}, class = {:}".format( + "bottleneck" if bottleneck else "basic", + depth, + reduction, + growthRate, + nClasses, + ) - nChannels = 2*growthRate - self.conv1 = nn.Conv2d(3, nChannels, kernel_size=3, padding=1, bias=False) + nChannels = 2 * growthRate + self.conv1 = nn.Conv2d(3, nChannels, kernel_size=3, padding=1, bias=False) - self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) - nChannels += nDenseBlocks*growthRate - nOutChannels = int(math.floor(nChannels*reduction)) - self.trans1 = Transition(nChannels, nOutChannels) + self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) + nChannels += nDenseBlocks * growthRate + nOutChannels = int(math.floor(nChannels * reduction)) + self.trans1 = Transition(nChannels, nOutChannels) - nChannels = nOutChannels - self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) - nChannels += nDenseBlocks*growthRate - nOutChannels = int(math.floor(nChannels*reduction)) - self.trans2 = Transition(nChannels, nOutChannels) + nChannels = nOutChannels + self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) + nChannels += nDenseBlocks * growthRate + nOutChannels = int(math.floor(nChannels * reduction)) + self.trans2 = Transition(nChannels, nOutChannels) - nChannels = nOutChannels - self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) - nChannels += nDenseBlocks*growthRate + nChannels = nOutChannels + self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) + nChannels += nDenseBlocks * growthRate - self.act = nn.Sequential( - nn.BatchNorm2d(nChannels), nn.ReLU(inplace=True), - nn.AvgPool2d(8)) - self.fc = nn.Linear(nChannels, nClasses) + self.act = nn.Sequential( + nn.BatchNorm2d(nChannels), nn.ReLU(inplace=True), nn.AvgPool2d(8) + ) + self.fc = nn.Linear(nChannels, nClasses) - self.apply(initialize_resnet) + self.apply(initialize_resnet) - def get_message(self): - return self.message + def get_message(self): + return self.message - def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck): - layers = [] - for i in range(int(nDenseBlocks)): - if bottleneck: - layers.append(Bottleneck(nChannels, growthRate)) - else: - layers.append(SingleLayer(nChannels, growthRate)) - nChannels += growthRate - return nn.Sequential(*layers) + def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck): + layers = [] + for i in range(int(nDenseBlocks)): + if bottleneck: + layers.append(Bottleneck(nChannels, growthRate)) + else: + layers.append(SingleLayer(nChannels, growthRate)) + nChannels += growthRate + return nn.Sequential(*layers) - def forward(self, inputs): - out = self.conv1( inputs ) - out = self.trans1(self.dense1(out)) - out = self.trans2(self.dense2(out)) - out = self.dense3(out) - features = self.act(out) - features = features.view(features.size(0), -1) - out = self.fc(features) - return features, out + def forward(self, inputs): + out = self.conv1(inputs) + out = self.trans1(self.dense1(out)) + out = self.trans2(self.dense2(out)) + out = self.dense3(out) + features = self.act(out) + features = features.view(features.size(0), -1) + out = self.fc(features) + return features, out diff --git a/lib/models/CifarResNet.py b/lib/models/CifarResNet.py index 36f7f57..7ab777f 100644 --- a/lib/models/CifarResNet.py +++ b/lib/models/CifarResNet.py @@ -2,156 +2,179 @@ import torch import torch.nn as nn import torch.nn.functional as F from .initialization import initialize_resnet -from .SharedUtils import additive_func +from .SharedUtils import additive_func -class Downsample(nn.Module): +class Downsample(nn.Module): + def __init__(self, nIn, nOut, stride): + super(Downsample, self).__init__() + assert stride == 2 and nOut == 2 * nIn, "stride:{} IO:{},{}".format( + stride, nIn, nOut + ) + self.in_dim = nIn + self.out_dim = nOut + self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) + self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=1, padding=0, bias=False) - def __init__(self, nIn, nOut, stride): - super(Downsample, self).__init__() - assert stride == 2 and nOut == 2*nIn, 'stride:{} IO:{},{}'.format(stride, nIn, nOut) - self.in_dim = nIn - self.out_dim = nOut - self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) - self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=1, padding=0, bias=False) - - def forward(self, x): - x = self.avg(x) - out = self.conv(x) - return out + def forward(self, x): + x = self.avg(x) + out = self.conv(x) + return out class ConvBNReLU(nn.Module): - - def __init__(self, nIn, nOut, kernel, stride, padding, bias, relu): - super(ConvBNReLU, self).__init__() - self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, bias=bias) - self.bn = nn.BatchNorm2d(nOut) - if relu: self.relu = nn.ReLU(inplace=True) - else : self.relu = None - self.out_dim = nOut - self.num_conv = 1 + def __init__(self, nIn, nOut, kernel, stride, padding, bias, relu): + super(ConvBNReLU, self).__init__() + self.conv = nn.Conv2d( + nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, bias=bias + ) + self.bn = nn.BatchNorm2d(nOut) + if relu: + self.relu = nn.ReLU(inplace=True) + else: + self.relu = None + self.out_dim = nOut + self.num_conv = 1 - def forward(self, x): - conv = self.conv( x ) - bn = self.bn( conv ) - if self.relu: return self.relu( bn ) - else : return bn + def forward(self, x): + conv = self.conv(x) + bn = self.bn(conv) + if self.relu: + return self.relu(bn) + else: + return bn class ResNetBasicblock(nn.Module): - expansion = 1 - def __init__(self, inplanes, planes, stride): - super(ResNetBasicblock, self).__init__() - assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) - self.conv_a = ConvBNReLU(inplanes, planes, 3, stride, 1, False, True) - self.conv_b = ConvBNReLU( planes, planes, 3, 1, 1, False, False) - if stride == 2: - self.downsample = Downsample(inplanes, planes, stride) - elif inplanes != planes: - self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, False) - else: - self.downsample = None - self.out_dim = planes - self.num_conv = 2 + expansion = 1 - def forward(self, inputs): + def __init__(self, inplanes, planes, stride): + super(ResNetBasicblock, self).__init__() + assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) + self.conv_a = ConvBNReLU(inplanes, planes, 3, stride, 1, False, True) + self.conv_b = ConvBNReLU(planes, planes, 3, 1, 1, False, False) + if stride == 2: + self.downsample = Downsample(inplanes, planes, stride) + elif inplanes != planes: + self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, False) + else: + self.downsample = None + self.out_dim = planes + self.num_conv = 2 - basicblock = self.conv_a(inputs) - basicblock = self.conv_b(basicblock) + def forward(self, inputs): - if self.downsample is not None: - residual = self.downsample(inputs) - else: - residual = inputs - out = additive_func(residual, basicblock) - return F.relu(out, inplace=True) + basicblock = self.conv_a(inputs) + basicblock = self.conv_b(basicblock) + if self.downsample is not None: + residual = self.downsample(inputs) + else: + residual = inputs + out = additive_func(residual, basicblock) + return F.relu(out, inplace=True) class ResNetBottleneck(nn.Module): - expansion = 4 - def __init__(self, inplanes, planes, stride): - super(ResNetBottleneck, self).__init__() - assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) - self.conv_1x1 = ConvBNReLU(inplanes, planes, 1, 1, 0, False, True) - self.conv_3x3 = ConvBNReLU( planes, planes, 3, stride, 1, False, True) - self.conv_1x4 = ConvBNReLU(planes, planes*self.expansion, 1, 1, 0, False, False) - if stride == 2: - self.downsample = Downsample(inplanes, planes*self.expansion, stride) - elif inplanes != planes*self.expansion: - self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, False) - else: - self.downsample = None - self.out_dim = planes * self.expansion - self.num_conv = 3 + expansion = 4 - def forward(self, inputs): + def __init__(self, inplanes, planes, stride): + super(ResNetBottleneck, self).__init__() + assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) + self.conv_1x1 = ConvBNReLU(inplanes, planes, 1, 1, 0, False, True) + self.conv_3x3 = ConvBNReLU(planes, planes, 3, stride, 1, False, True) + self.conv_1x4 = ConvBNReLU( + planes, planes * self.expansion, 1, 1, 0, False, False + ) + if stride == 2: + self.downsample = Downsample(inplanes, planes * self.expansion, stride) + elif inplanes != planes * self.expansion: + self.downsample = ConvBNReLU( + inplanes, planes * self.expansion, 1, 1, 0, False, False + ) + else: + self.downsample = None + self.out_dim = planes * self.expansion + self.num_conv = 3 - bottleneck = self.conv_1x1(inputs) - bottleneck = self.conv_3x3(bottleneck) - bottleneck = self.conv_1x4(bottleneck) + def forward(self, inputs): - if self.downsample is not None: - residual = self.downsample(inputs) - else: - residual = inputs - out = additive_func(residual, bottleneck) - return F.relu(out, inplace=True) + bottleneck = self.conv_1x1(inputs) + bottleneck = self.conv_3x3(bottleneck) + bottleneck = self.conv_1x4(bottleneck) + if self.downsample is not None: + residual = self.downsample(inputs) + else: + residual = inputs + out = additive_func(residual, bottleneck) + return F.relu(out, inplace=True) class CifarResNet(nn.Module): + def __init__(self, block_name, depth, num_classes, zero_init_residual): + super(CifarResNet, self).__init__() - def __init__(self, block_name, depth, num_classes, zero_init_residual): - super(CifarResNet, self).__init__() + # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model + if block_name == "ResNetBasicblock": + block = ResNetBasicblock + assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110" + layer_blocks = (depth - 2) // 6 + elif block_name == "ResNetBottleneck": + block = ResNetBottleneck + assert (depth - 2) % 9 == 0, "depth should be one of 164" + layer_blocks = (depth - 2) // 9 + else: + raise ValueError("invalid block : {:}".format(block_name)) - #Model type specifies number of layers for CIFAR-10 and CIFAR-100 model - if block_name == 'ResNetBasicblock': - block = ResNetBasicblock - assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' - layer_blocks = (depth - 2) // 6 - elif block_name == 'ResNetBottleneck': - block = ResNetBottleneck - assert (depth - 2) % 9 == 0, 'depth should be one of 164' - layer_blocks = (depth - 2) // 9 - else: - raise ValueError('invalid block : {:}'.format(block_name)) + self.message = "CifarResNet : Block : {:}, Depth : {:}, Layers for each block : {:}".format( + block_name, depth, layer_blocks + ) + self.num_classes = num_classes + self.channels = [16] + self.layers = nn.ModuleList([ConvBNReLU(3, 16, 3, 1, 1, False, True)]) + for stage in range(3): + for iL in range(layer_blocks): + iC = self.channels[-1] + planes = 16 * (2 ** stage) + stride = 2 if stage > 0 and iL == 0 else 1 + module = block(iC, planes, stride) + self.channels.append(module.out_dim) + self.layers.append(module) + self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format( + stage, + iL, + layer_blocks, + len(self.layers) - 1, + iC, + module.out_dim, + stride, + ) - self.message = 'CifarResNet : Block : {:}, Depth : {:}, Layers for each block : {:}'.format(block_name, depth, layer_blocks) - self.num_classes = num_classes - self.channels = [16] - self.layers = nn.ModuleList( [ ConvBNReLU(3, 16, 3, 1, 1, False, True) ] ) - for stage in range(3): - for iL in range(layer_blocks): - iC = self.channels[-1] - planes = 16 * (2**stage) - stride = 2 if stage > 0 and iL == 0 else 1 - module = block(iC, planes, stride) - self.channels.append( module.out_dim ) - self.layers.append ( module ) - self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iC, module.out_dim, stride) + self.avgpool = nn.AvgPool2d(8) + self.classifier = nn.Linear(module.out_dim, num_classes) + assert ( + sum(x.num_conv for x in self.layers) + 1 == depth + ), "invalid depth check {:} vs {:}".format( + sum(x.num_conv for x in self.layers) + 1, depth + ) - self.avgpool = nn.AvgPool2d(8) - self.classifier = nn.Linear(module.out_dim, num_classes) - assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth) + self.apply(initialize_resnet) + if zero_init_residual: + for m in self.modules(): + if isinstance(m, ResNetBasicblock): + nn.init.constant_(m.conv_b.bn.weight, 0) + elif isinstance(m, ResNetBottleneck): + nn.init.constant_(m.conv_1x4.bn.weight, 0) - self.apply(initialize_resnet) - if zero_init_residual: - for m in self.modules(): - if isinstance(m, ResNetBasicblock): - nn.init.constant_(m.conv_b.bn.weight, 0) - elif isinstance(m, ResNetBottleneck): - nn.init.constant_(m.conv_1x4.bn.weight, 0) + def get_message(self): + return self.message - def get_message(self): - return self.message - - def forward(self, inputs): - x = inputs - for i, layer in enumerate(self.layers): - x = layer( x ) - features = self.avgpool(x) - features = features.view(features.size(0), -1) - logits = self.classifier(features) - return features, logits + def forward(self, inputs): + x = inputs + for i, layer in enumerate(self.layers): + x = layer(x) + features = self.avgpool(x) + features = features.view(features.size(0), -1) + logits = self.classifier(features) + return features, logits diff --git a/lib/models/CifarWideResNet.py b/lib/models/CifarWideResNet.py index e441f12..62e97c3 100644 --- a/lib/models/CifarWideResNet.py +++ b/lib/models/CifarWideResNet.py @@ -5,90 +5,111 @@ from .initialization import initialize_resnet class WideBasicblock(nn.Module): - def __init__(self, inplanes, planes, stride, dropout=False): - super(WideBasicblock, self).__init__() + def __init__(self, inplanes, planes, stride, dropout=False): + super(WideBasicblock, self).__init__() - self.bn_a = nn.BatchNorm2d(inplanes) - self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn_a = nn.BatchNorm2d(inplanes) + self.conv_a = nn.Conv2d( + inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False + ) - self.bn_b = nn.BatchNorm2d(planes) - if dropout: - self.dropout = nn.Dropout2d(p=0.5, inplace=True) - else: - self.dropout = None - self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn_b = nn.BatchNorm2d(planes) + if dropout: + self.dropout = nn.Dropout2d(p=0.5, inplace=True) + else: + self.dropout = None + self.conv_b = nn.Conv2d( + planes, planes, kernel_size=3, stride=1, padding=1, bias=False + ) - if inplanes != planes: - self.downsample = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, padding=0, bias=False) - else: - self.downsample = None + if inplanes != planes: + self.downsample = nn.Conv2d( + inplanes, planes, kernel_size=1, stride=stride, padding=0, bias=False + ) + else: + self.downsample = None - def forward(self, x): + def forward(self, x): - basicblock = self.bn_a(x) - basicblock = F.relu(basicblock) - basicblock = self.conv_a(basicblock) + basicblock = self.bn_a(x) + basicblock = F.relu(basicblock) + basicblock = self.conv_a(basicblock) - basicblock = self.bn_b(basicblock) - basicblock = F.relu(basicblock) - if self.dropout is not None: - basicblock = self.dropout(basicblock) - basicblock = self.conv_b(basicblock) + basicblock = self.bn_b(basicblock) + basicblock = F.relu(basicblock) + if self.dropout is not None: + basicblock = self.dropout(basicblock) + basicblock = self.conv_b(basicblock) - if self.downsample is not None: - x = self.downsample(x) - - return x + basicblock + if self.downsample is not None: + x = self.downsample(x) + + return x + basicblock class CifarWideResNet(nn.Module): - """ - ResNet optimized for the Cifar dataset, as specified in - https://arxiv.org/abs/1512.03385.pdf - """ - def __init__(self, depth, widen_factor, num_classes, dropout): - super(CifarWideResNet, self).__init__() + """ + ResNet optimized for the Cifar dataset, as specified in + https://arxiv.org/abs/1512.03385.pdf + """ - #Model type specifies number of layers for CIFAR-10 and CIFAR-100 model - assert (depth - 4) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' - layer_blocks = (depth - 4) // 6 - print ('CifarPreResNet : Depth : {} , Layers for each block : {}'.format(depth, layer_blocks)) + def __init__(self, depth, widen_factor, num_classes, dropout): + super(CifarWideResNet, self).__init__() - self.num_classes = num_classes - self.dropout = dropout - self.conv_3x3 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) + # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model + assert (depth - 4) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110" + layer_blocks = (depth - 4) // 6 + print( + "CifarPreResNet : Depth : {} , Layers for each block : {}".format( + depth, layer_blocks + ) + ) - self.message = 'Wide ResNet : depth={:}, widen_factor={:}, class={:}'.format(depth, widen_factor, num_classes) - self.inplanes = 16 - self.stage_1 = self._make_layer(WideBasicblock, 16*widen_factor, layer_blocks, 1) - self.stage_2 = self._make_layer(WideBasicblock, 32*widen_factor, layer_blocks, 2) - self.stage_3 = self._make_layer(WideBasicblock, 64*widen_factor, layer_blocks, 2) - self.lastact = nn.Sequential(nn.BatchNorm2d(64*widen_factor), nn.ReLU(inplace=True)) - self.avgpool = nn.AvgPool2d(8) - self.classifier = nn.Linear(64*widen_factor, num_classes) + self.num_classes = num_classes + self.dropout = dropout + self.conv_3x3 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) - self.apply(initialize_resnet) + self.message = "Wide ResNet : depth={:}, widen_factor={:}, class={:}".format( + depth, widen_factor, num_classes + ) + self.inplanes = 16 + self.stage_1 = self._make_layer( + WideBasicblock, 16 * widen_factor, layer_blocks, 1 + ) + self.stage_2 = self._make_layer( + WideBasicblock, 32 * widen_factor, layer_blocks, 2 + ) + self.stage_3 = self._make_layer( + WideBasicblock, 64 * widen_factor, layer_blocks, 2 + ) + self.lastact = nn.Sequential( + nn.BatchNorm2d(64 * widen_factor), nn.ReLU(inplace=True) + ) + self.avgpool = nn.AvgPool2d(8) + self.classifier = nn.Linear(64 * widen_factor, num_classes) - def get_message(self): - return self.message + self.apply(initialize_resnet) - def _make_layer(self, block, planes, blocks, stride): + def get_message(self): + return self.message - layers = [] - layers.append(block(self.inplanes, planes, stride, self.dropout)) - self.inplanes = planes - for i in range(1, blocks): - layers.append(block(self.inplanes, planes, 1, self.dropout)) + def _make_layer(self, block, planes, blocks, stride): - return nn.Sequential(*layers) + layers = [] + layers.append(block(self.inplanes, planes, stride, self.dropout)) + self.inplanes = planes + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, 1, self.dropout)) - def forward(self, x): - x = self.conv_3x3(x) - x = self.stage_1(x) - x = self.stage_2(x) - x = self.stage_3(x) - x = self.lastact(x) - x = self.avgpool(x) - features = x.view(x.size(0), -1) - outs = self.classifier(features) - return features, outs + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv_3x3(x) + x = self.stage_1(x) + x = self.stage_2(x) + x = self.stage_3(x) + x = self.lastact(x) + x = self.avgpool(x) + features = x.view(x.size(0), -1) + outs = self.classifier(features) + return features, outs diff --git a/lib/models/ImageNet_MobileNetV2.py b/lib/models/ImageNet_MobileNetV2.py index ec7e341..814ab39 100644 --- a/lib/models/ImageNet_MobileNetV2.py +++ b/lib/models/ImageNet_MobileNetV2.py @@ -4,98 +4,114 @@ from .initialization import initialize_resnet class ConvBNReLU(nn.Module): - def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): - super(ConvBNReLU, self).__init__() - padding = (kernel_size - 1) // 2 - self.conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False) - self.bn = nn.BatchNorm2d(out_planes) - self.relu = nn.ReLU6(inplace=True) - - def forward(self, x): - out = self.conv( x ) - out = self.bn ( out ) - out = self.relu( out ) - return out + def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): + super(ConvBNReLU, self).__init__() + padding = (kernel_size - 1) // 2 + self.conv = nn.Conv2d( + in_planes, + out_planes, + kernel_size, + stride, + padding, + groups=groups, + bias=False, + ) + self.bn = nn.BatchNorm2d(out_planes) + self.relu = nn.ReLU6(inplace=True) + + def forward(self, x): + out = self.conv(x) + out = self.bn(out) + out = self.relu(out) + return out class InvertedResidual(nn.Module): - def __init__(self, inp, oup, stride, expand_ratio): - super(InvertedResidual, self).__init__() - self.stride = stride - assert stride in [1, 2] + def __init__(self, inp, oup, stride, expand_ratio): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] - hidden_dim = int(round(inp * expand_ratio)) - self.use_res_connect = self.stride == 1 and inp == oup + hidden_dim = int(round(inp * expand_ratio)) + self.use_res_connect = self.stride == 1 and inp == oup - layers = [] - if expand_ratio != 1: - # pw - layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) - layers.extend([ - # dw - ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), - # pw-linear - nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), - nn.BatchNorm2d(oup), - ]) - self.conv = nn.Sequential(*layers) + layers = [] + if expand_ratio != 1: + # pw + layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) + layers.extend( + [ + # dw + ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + ] + ) + self.conv = nn.Sequential(*layers) - def forward(self, x): - if self.use_res_connect: - return x + self.conv(x) - else: - return self.conv(x) + def forward(self, x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) class MobileNetV2(nn.Module): - def __init__(self, num_classes, width_mult, input_channel, last_channel, block_name, dropout): - super(MobileNetV2, self).__init__() - if block_name == 'InvertedResidual': - block = InvertedResidual - else: - raise ValueError('invalid block name : {:}'.format(block_name)) - inverted_residual_setting = [ - # t, c, n, s - [1, 16 , 1, 1], - [6, 24 , 2, 2], - [6, 32 , 3, 2], - [6, 64 , 4, 2], - [6, 96 , 3, 1], - [6, 160, 3, 2], - [6, 320, 1, 1], - ] + def __init__( + self, num_classes, width_mult, input_channel, last_channel, block_name, dropout + ): + super(MobileNetV2, self).__init__() + if block_name == "InvertedResidual": + block = InvertedResidual + else: + raise ValueError("invalid block name : {:}".format(block_name)) + inverted_residual_setting = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + [6, 96, 3, 1], + [6, 160, 3, 2], + [6, 320, 1, 1], + ] - # building first layer - input_channel = int(input_channel * width_mult) - self.last_channel = int(last_channel * max(1.0, width_mult)) - features = [ConvBNReLU(3, input_channel, stride=2)] - # building inverted residual blocks - for t, c, n, s in inverted_residual_setting: - output_channel = int(c * width_mult) - for i in range(n): - stride = s if i == 0 else 1 - features.append(block(input_channel, output_channel, stride, expand_ratio=t)) - input_channel = output_channel - # building last several layers - features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) - # make it nn.Sequential - self.features = nn.Sequential(*features) + # building first layer + input_channel = int(input_channel * width_mult) + self.last_channel = int(last_channel * max(1.0, width_mult)) + features = [ConvBNReLU(3, input_channel, stride=2)] + # building inverted residual blocks + for t, c, n, s in inverted_residual_setting: + output_channel = int(c * width_mult) + for i in range(n): + stride = s if i == 0 else 1 + features.append( + block(input_channel, output_channel, stride, expand_ratio=t) + ) + input_channel = output_channel + # building last several layers + features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) + # make it nn.Sequential + self.features = nn.Sequential(*features) - # building classifier - self.classifier = nn.Sequential( - nn.Dropout(dropout), - nn.Linear(self.last_channel, num_classes), - ) - self.message = 'MobileNetV2 : width_mult={:}, in-C={:}, last-C={:}, block={:}, dropout={:}'.format(width_mult, input_channel, last_channel, block_name, dropout) + # building classifier + self.classifier = nn.Sequential( + nn.Dropout(dropout), + nn.Linear(self.last_channel, num_classes), + ) + self.message = "MobileNetV2 : width_mult={:}, in-C={:}, last-C={:}, block={:}, dropout={:}".format( + width_mult, input_channel, last_channel, block_name, dropout + ) - # weight initialization - self.apply( initialize_resnet ) + # weight initialization + self.apply(initialize_resnet) - def get_message(self): - return self.message + def get_message(self): + return self.message - def forward(self, inputs): - features = self.features(inputs) - vectors = features.mean([2, 3]) - predicts = self.classifier(vectors) - return features, predicts + def forward(self, inputs): + features = self.features(inputs) + vectors = features.mean([2, 3]) + predicts = self.classifier(vectors) + return features, predicts diff --git a/lib/models/ImageNet_ResNet.py b/lib/models/ImageNet_ResNet.py index 9042db5..66d830a 100644 --- a/lib/models/ImageNet_ResNet.py +++ b/lib/models/ImageNet_ResNet.py @@ -2,171 +2,216 @@ import torch.nn as nn from .initialization import initialize_resnet + def conv3x3(in_planes, out_planes, stride=1, groups=1): - return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, groups=groups, bias=False) + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=1, + groups=groups, + bias=False, + ) def conv1x1(in_planes, out_planes, stride=1): - return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) class BasicBlock(nn.Module): - expansion = 1 + expansion = 1 - def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64): - super(BasicBlock, self).__init__() - if groups != 1 or base_width != 64: - raise ValueError('BasicBlock only supports groups=1 and base_width=64') - # Both self.conv1 and self.downsample layers downsample the input when stride != 1 - self.conv1 = conv3x3(inplanes, planes, stride) - self.bn1 = nn.BatchNorm2d(planes) - self.relu = nn.ReLU(inplace=True) - self.conv2 = conv3x3(planes, planes) - self.bn2 = nn.BatchNorm2d(planes) - self.downsample = downsample - self.stride = stride + def __init__( + self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64 + ): + super(BasicBlock, self).__init__() + if groups != 1 or base_width != 64: + raise ValueError("BasicBlock only supports groups=1 and base_width=64") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride - def forward(self, x): - identity = x + def forward(self, x): + identity = x - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) - out = self.conv2(out) - out = self.bn2(out) + out = self.conv2(out) + out = self.bn2(out) - if self.downsample is not None: - identity = self.downsample(x) + if self.downsample is not None: + identity = self.downsample(x) - out += identity - out = self.relu(out) + out += identity + out = self.relu(out) - return out + return out class Bottleneck(nn.Module): - expansion = 4 + expansion = 4 - def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64): - super(Bottleneck, self).__init__() - width = int(planes * (base_width / 64.)) * groups - # Both self.conv2 and self.downsample layers downsample the input when stride != 1 - self.conv1 = conv1x1(inplanes, width) - self.bn1 = nn.BatchNorm2d(width) - self.conv2 = conv3x3(width, width, stride, groups) - self.bn2 = nn.BatchNorm2d(width) - self.conv3 = conv1x1(width, planes * self.expansion) - self.bn3 = nn.BatchNorm2d(planes * self.expansion) - self.relu = nn.ReLU(inplace=True) - self.downsample = downsample - self.stride = stride + def __init__( + self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64 + ): + super(Bottleneck, self).__init__() + width = int(planes * (base_width / 64.0)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = nn.BatchNorm2d(width) + self.conv2 = conv3x3(width, width, stride, groups) + self.bn2 = nn.BatchNorm2d(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride - def forward(self, x): - identity = x + def forward(self, x): + identity = x - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) - out = self.conv2(out) - out = self.bn2(out) - out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) - out = self.conv3(out) - out = self.bn3(out) + out = self.conv3(out) + out = self.bn3(out) - if self.downsample is not None: - identity = self.downsample(x) + if self.downsample is not None: + identity = self.downsample(x) - out += identity - out = self.relu(out) + out += identity + out = self.relu(out) - return out + return out class ResNet(nn.Module): + def __init__( + self, + block_name, + layers, + deep_stem, + num_classes, + zero_init_residual, + groups, + width_per_group, + ): + super(ResNet, self).__init__() - def __init__(self, block_name, layers, deep_stem, num_classes, zero_init_residual, groups, width_per_group): - super(ResNet, self).__init__() + # planes = [int(width_per_group * groups * 2 ** i) for i in range(4)] + if block_name == "BasicBlock": + block = BasicBlock + elif block_name == "Bottleneck": + block = Bottleneck + else: + raise ValueError("invalid block-name : {:}".format(block_name)) - #planes = [int(width_per_group * groups * 2 ** i) for i in range(4)] - if block_name == 'BasicBlock' : block= BasicBlock - elif block_name == 'Bottleneck': block= Bottleneck - else : raise ValueError('invalid block-name : {:}'.format(block_name)) - - if not deep_stem: - self.conv = nn.Sequential( - nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False), - nn.BatchNorm2d(64), nn.ReLU(inplace=True)) - else: - self.conv = nn.Sequential( - nn.Conv2d( 3, 32, kernel_size=3, stride=2, padding=1, bias=False), - nn.BatchNorm2d(32), nn.ReLU(inplace=True), - nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=False), - nn.BatchNorm2d(32), nn.ReLU(inplace=True), - nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False), - nn.BatchNorm2d(64), nn.ReLU(inplace=True)) - self.inplanes = 64 - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) - self.layer1 = self._make_layer(block, 64 , layers[0], stride=1, groups=groups, base_width=width_per_group) - self.layer2 = self._make_layer(block, 128, layers[1], stride=2, groups=groups, base_width=width_per_group) - self.layer3 = self._make_layer(block, 256, layers[2], stride=2, groups=groups, base_width=width_per_group) - self.layer4 = self._make_layer(block, 512, layers[3], stride=2, groups=groups, base_width=width_per_group) - self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) - self.fc = nn.Linear(512 * block.expansion, num_classes) - self.message = 'block = {:}, layers = {:}, deep_stem = {:}, num_classes = {:}'.format(block, layers, deep_stem, num_classes) - - self.apply( initialize_resnet ) - - # Zero-initialize the last BN in each residual branch, - # so that the residual branch starts with zeros, and each residual block behaves like an identity. - # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 - if zero_init_residual: - for m in self.modules(): - if isinstance(m, Bottleneck): - nn.init.constant_(m.bn3.weight, 0) - elif isinstance(m, BasicBlock): - nn.init.constant_(m.bn2.weight, 0) - - def _make_layer(self, block, planes, blocks, stride, groups, base_width): - downsample = None - if stride != 1 or self.inplanes != planes * block.expansion: - if stride == 2: - downsample = nn.Sequential( - nn.AvgPool2d(kernel_size=2, stride=2, padding=0), - conv1x1(self.inplanes, planes * block.expansion, 1), - nn.BatchNorm2d(planes * block.expansion), + if not deep_stem: + self.conv = nn.Sequential( + nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True), + ) + else: + self.conv = nn.Sequential( + nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True), + ) + self.inplanes = 64 + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer( + block, 64, layers[0], stride=1, groups=groups, base_width=width_per_group ) - elif stride == 1: - downsample = nn.Sequential( - conv1x1(self.inplanes, planes * block.expansion, stride), - nn.BatchNorm2d(planes * block.expansion), + self.layer2 = self._make_layer( + block, 128, layers[1], stride=2, groups=groups, base_width=width_per_group + ) + self.layer3 = self._make_layer( + block, 256, layers[2], stride=2, groups=groups, base_width=width_per_group + ) + self.layer4 = self._make_layer( + block, 512, layers[3], stride=2, groups=groups, base_width=width_per_group + ) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + self.message = ( + "block = {:}, layers = {:}, deep_stem = {:}, num_classes = {:}".format( + block, layers, deep_stem, num_classes + ) ) - else: raise ValueError('invalid stride [{:}] for downsample'.format(stride)) - layers = [] - layers.append(block(self.inplanes, planes, stride, downsample, groups, base_width)) - self.inplanes = planes * block.expansion - for _ in range(1, blocks): - layers.append(block(self.inplanes, planes, 1, None, groups, base_width)) + self.apply(initialize_resnet) - return nn.Sequential(*layers) + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) - def get_message(self): - return self.message + def _make_layer(self, block, planes, blocks, stride, groups, base_width): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + if stride == 2: + downsample = nn.Sequential( + nn.AvgPool2d(kernel_size=2, stride=2, padding=0), + conv1x1(self.inplanes, planes * block.expansion, 1), + nn.BatchNorm2d(planes * block.expansion), + ) + elif stride == 1: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + nn.BatchNorm2d(planes * block.expansion), + ) + else: + raise ValueError("invalid stride [{:}] for downsample".format(stride)) - def forward(self, x): - x = self.conv(x) - x = self.maxpool(x) + layers = [] + layers.append( + block(self.inplanes, planes, stride, downsample, groups, base_width) + ) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, 1, None, groups, base_width)) - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer4(x) + return nn.Sequential(*layers) - features = self.avgpool(x) - features = features.view(features.size(0), -1) - logits = self.fc(features) + def get_message(self): + return self.message - return features, logits + def forward(self, x): + x = self.conv(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + features = self.avgpool(x) + features = features.view(features.size(0), -1) + logits = self.fc(features) + + return features, logits diff --git a/lib/models/SharedUtils.py b/lib/models/SharedUtils.py index 8938752..adcdf8b 100644 --- a/lib/models/SharedUtils.py +++ b/lib/models/SharedUtils.py @@ -6,29 +6,32 @@ import torch.nn as nn def additive_func(A, B): - assert A.dim() == B.dim() and A.size(0) == B.size(0), '{:} vs {:}'.format(A.size(), B.size()) - C = min(A.size(1), B.size(1)) - if A.size(1) == B.size(1): - return A + B - elif A.size(1) < B.size(1): - out = B.clone() - out[:,:C] += A - return out - else: - out = A.clone() - out[:,:C] += B - return out + assert A.dim() == B.dim() and A.size(0) == B.size(0), "{:} vs {:}".format( + A.size(), B.size() + ) + C = min(A.size(1), B.size(1)) + if A.size(1) == B.size(1): + return A + B + elif A.size(1) < B.size(1): + out = B.clone() + out[:, :C] += A + return out + else: + out = A.clone() + out[:, :C] += B + return out def change_key(key, value): - def func(m): - if hasattr(m, key): - setattr(m, key, value) - return func + def func(m): + if hasattr(m, key): + setattr(m, key, value) + + return func def parse_channel_info(xstring): - blocks = xstring.split(' ') - blocks = [x.split('-') for x in blocks] - blocks = [[int(_) for _ in x] for x in blocks] - return blocks + blocks = xstring.split(" ") + blocks = [x.split("-") for x in blocks] + blocks = [[int(_) for _ in x] for x in blocks] + return blocks diff --git a/lib/models/__init__.py b/lib/models/__init__.py index 13f2632..b4b4aed 100644 --- a/lib/models/__init__.py +++ b/lib/models/__init__.py @@ -5,10 +5,18 @@ from os import path as osp from typing import List, Text import torch -__all__ = ['change_key', 'get_cell_based_tiny_net', 'get_search_spaces', 'get_cifar_models', 'get_imagenet_models', \ - 'obtain_model', 'obtain_search_model', 'load_net_from_checkpoint', \ - 'CellStructure', 'CellArchitectures' - ] +__all__ = [ + "change_key", + "get_cell_based_tiny_net", + "get_search_spaces", + "get_cifar_models", + "get_imagenet_models", + "obtain_model", + "obtain_search_model", + "load_net_from_checkpoint", + "CellStructure", + "CellArchitectures", +] # useful modules from config_utils import dict2config @@ -18,178 +26,301 @@ from models.cell_searchs import CellStructure, CellArchitectures # Cell-based NAS Models def get_cell_based_tiny_net(config): - if isinstance(config, dict): config = dict2config(config, None) # to support the argument being a dict - super_type = getattr(config, 'super_type', 'basic') - group_names = ['DARTS-V1', 'DARTS-V2', 'GDAS', 'SETN', 'ENAS', 'RANDOM', 'generic'] - if super_type == 'basic' and config.name in group_names: - from .cell_searchs import nas201_super_nets as nas_super_nets - try: - return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space, config.affine, config.track_running_stats) - except: - return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space) - elif super_type == 'search-shape': - from .shape_searchs import GenericNAS301Model - genotype = CellStructure.str2structure(config.genotype) - return GenericNAS301Model(config.candidate_Cs, config.max_num_Cs, genotype, config.num_classes, config.affine, config.track_running_stats) - elif super_type == 'nasnet-super': - from .cell_searchs import nasnet_super_nets as nas_super_nets - return nas_super_nets[config.name](config.C, config.N, config.steps, config.multiplier, \ - config.stem_multiplier, config.num_classes, config.space, config.affine, config.track_running_stats) - elif config.name == 'infer.tiny': - from .cell_infers import TinyNetwork - if hasattr(config, 'genotype'): - genotype = config.genotype - elif hasattr(config, 'arch_str'): - genotype = CellStructure.str2structure(config.arch_str) - else: raise ValueError('Can not find genotype from this config : {:}'.format(config)) - return TinyNetwork(config.C, config.N, genotype, config.num_classes) - elif config.name == 'infer.shape.tiny': - from .shape_infers import DynamicShapeTinyNet - if isinstance(config.channels, str): - channels = tuple([int(x) for x in config.channels.split(':')]) - else: channels = config.channels - genotype = CellStructure.str2structure(config.genotype) - return DynamicShapeTinyNet(channels, genotype, config.num_classes) - elif config.name == 'infer.nasnet-cifar': - from .cell_infers import NASNetonCIFAR - raise NotImplementedError - else: - raise ValueError('invalid network name : {:}'.format(config.name)) + if isinstance(config, dict): + config = dict2config(config, None) # to support the argument being a dict + super_type = getattr(config, "super_type", "basic") + group_names = ["DARTS-V1", "DARTS-V2", "GDAS", "SETN", "ENAS", "RANDOM", "generic"] + if super_type == "basic" and config.name in group_names: + from .cell_searchs import nas201_super_nets as nas_super_nets + + try: + return nas_super_nets[config.name]( + config.C, + config.N, + config.max_nodes, + config.num_classes, + config.space, + config.affine, + config.track_running_stats, + ) + except: + return nas_super_nets[config.name]( + config.C, config.N, config.max_nodes, config.num_classes, config.space + ) + elif super_type == "search-shape": + from .shape_searchs import GenericNAS301Model + + genotype = CellStructure.str2structure(config.genotype) + return GenericNAS301Model( + config.candidate_Cs, + config.max_num_Cs, + genotype, + config.num_classes, + config.affine, + config.track_running_stats, + ) + elif super_type == "nasnet-super": + from .cell_searchs import nasnet_super_nets as nas_super_nets + + return nas_super_nets[config.name]( + config.C, + config.N, + config.steps, + config.multiplier, + config.stem_multiplier, + config.num_classes, + config.space, + config.affine, + config.track_running_stats, + ) + elif config.name == "infer.tiny": + from .cell_infers import TinyNetwork + + if hasattr(config, "genotype"): + genotype = config.genotype + elif hasattr(config, "arch_str"): + genotype = CellStructure.str2structure(config.arch_str) + else: + raise ValueError( + "Can not find genotype from this config : {:}".format(config) + ) + return TinyNetwork(config.C, config.N, genotype, config.num_classes) + elif config.name == "infer.shape.tiny": + from .shape_infers import DynamicShapeTinyNet + + if isinstance(config.channels, str): + channels = tuple([int(x) for x in config.channels.split(":")]) + else: + channels = config.channels + genotype = CellStructure.str2structure(config.genotype) + return DynamicShapeTinyNet(channels, genotype, config.num_classes) + elif config.name == "infer.nasnet-cifar": + from .cell_infers import NASNetonCIFAR + + raise NotImplementedError + else: + raise ValueError("invalid network name : {:}".format(config.name)) # obtain the search space, i.e., a dict mapping the operation name into a python-function for this op def get_search_spaces(xtype, name) -> List[Text]: - if xtype == 'cell' or xtype == 'tss': # The topology search space. - from .cell_operations import SearchSpaceNames - assert name in SearchSpaceNames, 'invalid name [{:}] in {:}'.format(name, SearchSpaceNames.keys()) - return SearchSpaceNames[name] - elif xtype == 'sss': # The size search space. - if name in ['nats-bench', 'nats-bench-size']: - return {'candidates': [8, 16, 24, 32, 40, 48, 56, 64], - 'numbers': 5} + if xtype == "cell" or xtype == "tss": # The topology search space. + from .cell_operations import SearchSpaceNames + + assert name in SearchSpaceNames, "invalid name [{:}] in {:}".format( + name, SearchSpaceNames.keys() + ) + return SearchSpaceNames[name] + elif xtype == "sss": # The size search space. + if name in ["nats-bench", "nats-bench-size"]: + return {"candidates": [8, 16, 24, 32, 40, 48, 56, 64], "numbers": 5} + else: + raise ValueError("Invalid name : {:}".format(name)) else: - raise ValueError('Invalid name : {:}'.format(name)) - else: - raise ValueError('invalid search-space type is {:}'.format(xtype)) + raise ValueError("invalid search-space type is {:}".format(xtype)) def get_cifar_models(config, extra_path=None): - super_type = getattr(config, 'super_type', 'basic') - if super_type == 'basic': - from .CifarResNet import CifarResNet - from .CifarDenseNet import DenseNet - from .CifarWideResNet import CifarWideResNet - if config.arch == 'resnet': - return CifarResNet(config.module, config.depth, config.class_num, config.zero_init_residual) - elif config.arch == 'densenet': - return DenseNet(config.growthRate, config.depth, config.reduction, config.class_num, config.bottleneck) - elif config.arch == 'wideresnet': - return CifarWideResNet(config.depth, config.wide_factor, config.class_num, config.dropout) + super_type = getattr(config, "super_type", "basic") + if super_type == "basic": + from .CifarResNet import CifarResNet + from .CifarDenseNet import DenseNet + from .CifarWideResNet import CifarWideResNet + + if config.arch == "resnet": + return CifarResNet( + config.module, config.depth, config.class_num, config.zero_init_residual + ) + elif config.arch == "densenet": + return DenseNet( + config.growthRate, + config.depth, + config.reduction, + config.class_num, + config.bottleneck, + ) + elif config.arch == "wideresnet": + return CifarWideResNet( + config.depth, config.wide_factor, config.class_num, config.dropout + ) + else: + raise ValueError("invalid module type : {:}".format(config.arch)) + elif super_type.startswith("infer"): + from .shape_infers import InferWidthCifarResNet + from .shape_infers import InferDepthCifarResNet + from .shape_infers import InferCifarResNet + from .cell_infers import NASNetonCIFAR + + assert len(super_type.split("-")) == 2, "invalid super_type : {:}".format( + super_type + ) + infer_mode = super_type.split("-")[1] + if infer_mode == "width": + return InferWidthCifarResNet( + config.module, + config.depth, + config.xchannels, + config.class_num, + config.zero_init_residual, + ) + elif infer_mode == "depth": + return InferDepthCifarResNet( + config.module, + config.depth, + config.xblocks, + config.class_num, + config.zero_init_residual, + ) + elif infer_mode == "shape": + return InferCifarResNet( + config.module, + config.depth, + config.xblocks, + config.xchannels, + config.class_num, + config.zero_init_residual, + ) + elif infer_mode == "nasnet.cifar": + genotype = config.genotype + if extra_path is not None: # reload genotype by extra_path + if not osp.isfile(extra_path): + raise ValueError("invalid extra_path : {:}".format(extra_path)) + xdata = torch.load(extra_path) + current_epoch = xdata["epoch"] + genotype = xdata["genotypes"][current_epoch - 1] + C = config.C if hasattr(config, "C") else config.ichannel + N = config.N if hasattr(config, "N") else config.layers + return NASNetonCIFAR( + C, N, config.stem_multi, config.class_num, genotype, config.auxiliary + ) + else: + raise ValueError("invalid infer-mode : {:}".format(infer_mode)) else: - raise ValueError('invalid module type : {:}'.format(config.arch)) - elif super_type.startswith('infer'): - from .shape_infers import InferWidthCifarResNet - from .shape_infers import InferDepthCifarResNet - from .shape_infers import InferCifarResNet - from .cell_infers import NASNetonCIFAR - assert len(super_type.split('-')) == 2, 'invalid super_type : {:}'.format(super_type) - infer_mode = super_type.split('-')[1] - if infer_mode == 'width': - return InferWidthCifarResNet(config.module, config.depth, config.xchannels, config.class_num, config.zero_init_residual) - elif infer_mode == 'depth': - return InferDepthCifarResNet(config.module, config.depth, config.xblocks, config.class_num, config.zero_init_residual) - elif infer_mode == 'shape': - return InferCifarResNet(config.module, config.depth, config.xblocks, config.xchannels, config.class_num, config.zero_init_residual) - elif infer_mode == 'nasnet.cifar': - genotype = config.genotype - if extra_path is not None: # reload genotype by extra_path - if not osp.isfile(extra_path): raise ValueError('invalid extra_path : {:}'.format(extra_path)) - xdata = torch.load(extra_path) - current_epoch = xdata['epoch'] - genotype = xdata['genotypes'][current_epoch-1] - C = config.C if hasattr(config, 'C') else config.ichannel - N = config.N if hasattr(config, 'N') else config.layers - return NASNetonCIFAR(C, N, config.stem_multi, config.class_num, genotype, config.auxiliary) - else: - raise ValueError('invalid infer-mode : {:}'.format(infer_mode)) - else: - raise ValueError('invalid super-type : {:}'.format(super_type)) + raise ValueError("invalid super-type : {:}".format(super_type)) def get_imagenet_models(config): - super_type = getattr(config, 'super_type', 'basic') - if super_type == 'basic': - from .ImageNet_ResNet import ResNet - from .ImageNet_MobileNetV2 import MobileNetV2 - if config.arch == 'resnet': - return ResNet(config.block_name, config.layers, config.deep_stem, config.class_num, config.zero_init_residual, config.groups, config.width_per_group) - elif config.arch == 'mobilenet_v2': - return MobileNetV2(config.class_num, config.width_multi, config.input_channel, config.last_channel, 'InvertedResidual', config.dropout) + super_type = getattr(config, "super_type", "basic") + if super_type == "basic": + from .ImageNet_ResNet import ResNet + from .ImageNet_MobileNetV2 import MobileNetV2 + + if config.arch == "resnet": + return ResNet( + config.block_name, + config.layers, + config.deep_stem, + config.class_num, + config.zero_init_residual, + config.groups, + config.width_per_group, + ) + elif config.arch == "mobilenet_v2": + return MobileNetV2( + config.class_num, + config.width_multi, + config.input_channel, + config.last_channel, + "InvertedResidual", + config.dropout, + ) + else: + raise ValueError("invalid arch : {:}".format(config.arch)) + elif super_type.startswith("infer"): # NAS searched architecture + assert len(super_type.split("-")) == 2, "invalid super_type : {:}".format( + super_type + ) + infer_mode = super_type.split("-")[1] + if infer_mode == "shape": + from .shape_infers import InferImagenetResNet + from .shape_infers import InferMobileNetV2 + + if config.arch == "resnet": + return InferImagenetResNet( + config.block_name, + config.layers, + config.xblocks, + config.xchannels, + config.deep_stem, + config.class_num, + config.zero_init_residual, + ) + elif config.arch == "MobileNetV2": + return InferMobileNetV2( + config.class_num, config.xchannels, config.xblocks, config.dropout + ) + else: + raise ValueError("invalid arch-mode : {:}".format(config.arch)) + else: + raise ValueError("invalid infer-mode : {:}".format(infer_mode)) else: - raise ValueError('invalid arch : {:}'.format( config.arch )) - elif super_type.startswith('infer'): # NAS searched architecture - assert len(super_type.split('-')) == 2, 'invalid super_type : {:}'.format(super_type) - infer_mode = super_type.split('-')[1] - if infer_mode == 'shape': - from .shape_infers import InferImagenetResNet - from .shape_infers import InferMobileNetV2 - if config.arch == 'resnet': - return InferImagenetResNet(config.block_name, config.layers, config.xblocks, config.xchannels, config.deep_stem, config.class_num, config.zero_init_residual) - elif config.arch == "MobileNetV2": - return InferMobileNetV2(config.class_num, config.xchannels, config.xblocks, config.dropout) - else: - raise ValueError('invalid arch-mode : {:}'.format(config.arch)) - else: - raise ValueError('invalid infer-mode : {:}'.format(infer_mode)) - else: - raise ValueError('invalid super-type : {:}'.format(super_type)) + raise ValueError("invalid super-type : {:}".format(super_type)) # Try to obtain the network by config. def obtain_model(config, extra_path=None): - if config.dataset == 'cifar': - return get_cifar_models(config, extra_path) - elif config.dataset == 'imagenet': - return get_imagenet_models(config) - else: - raise ValueError('invalid dataset in the model config : {:}'.format(config)) + if config.dataset == "cifar": + return get_cifar_models(config, extra_path) + elif config.dataset == "imagenet": + return get_imagenet_models(config) + else: + raise ValueError("invalid dataset in the model config : {:}".format(config)) def obtain_search_model(config): - if config.dataset == 'cifar': - if config.arch == 'resnet': - from .shape_searchs import SearchWidthCifarResNet - from .shape_searchs import SearchDepthCifarResNet - from .shape_searchs import SearchShapeCifarResNet - if config.search_mode == 'width': - return SearchWidthCifarResNet(config.module, config.depth, config.class_num) - elif config.search_mode == 'depth': - return SearchDepthCifarResNet(config.module, config.depth, config.class_num) - elif config.search_mode == 'shape': - return SearchShapeCifarResNet(config.module, config.depth, config.class_num) - else: raise ValueError('invalid search mode : {:}'.format(config.search_mode)) - elif config.arch == 'simres': - from .shape_searchs import SearchWidthSimResNet - if config.search_mode == 'width': - return SearchWidthSimResNet(config.depth, config.class_num) - else: raise ValueError('invalid search mode : {:}'.format(config.search_mode)) + if config.dataset == "cifar": + if config.arch == "resnet": + from .shape_searchs import SearchWidthCifarResNet + from .shape_searchs import SearchDepthCifarResNet + from .shape_searchs import SearchShapeCifarResNet + + if config.search_mode == "width": + return SearchWidthCifarResNet( + config.module, config.depth, config.class_num + ) + elif config.search_mode == "depth": + return SearchDepthCifarResNet( + config.module, config.depth, config.class_num + ) + elif config.search_mode == "shape": + return SearchShapeCifarResNet( + config.module, config.depth, config.class_num + ) + else: + raise ValueError("invalid search mode : {:}".format(config.search_mode)) + elif config.arch == "simres": + from .shape_searchs import SearchWidthSimResNet + + if config.search_mode == "width": + return SearchWidthSimResNet(config.depth, config.class_num) + else: + raise ValueError("invalid search mode : {:}".format(config.search_mode)) + else: + raise ValueError( + "invalid arch : {:} for dataset [{:}]".format( + config.arch, config.dataset + ) + ) + elif config.dataset == "imagenet": + from .shape_searchs import SearchShapeImagenetResNet + + assert config.search_mode == "shape", "invalid search-mode : {:}".format( + config.search_mode + ) + if config.arch == "resnet": + return SearchShapeImagenetResNet( + config.block_name, config.layers, config.deep_stem, config.class_num + ) + else: + raise ValueError("invalid model config : {:}".format(config)) else: - raise ValueError('invalid arch : {:} for dataset [{:}]'.format(config.arch, config.dataset)) - elif config.dataset == 'imagenet': - from .shape_searchs import SearchShapeImagenetResNet - assert config.search_mode == 'shape', 'invalid search-mode : {:}'.format( config.search_mode ) - if config.arch == 'resnet': - return SearchShapeImagenetResNet(config.block_name, config.layers, config.deep_stem, config.class_num) - else: - raise ValueError('invalid model config : {:}'.format(config)) - else: - raise ValueError('invalid dataset in the model config : {:}'.format(config)) + raise ValueError("invalid dataset in the model config : {:}".format(config)) def load_net_from_checkpoint(checkpoint): - assert osp.isfile(checkpoint), 'checkpoint {:} does not exist'.format(checkpoint) - checkpoint = torch.load(checkpoint) - model_config = dict2config(checkpoint['model-config'], None) - model = obtain_model(model_config) - model.load_state_dict(checkpoint['base-model']) - return model + assert osp.isfile(checkpoint), "checkpoint {:} does not exist".format(checkpoint) + checkpoint = torch.load(checkpoint) + model_config = dict2config(checkpoint["model-config"], None) + model = obtain_model(model_config) + model.load_state_dict(checkpoint["base-model"]) + return model diff --git a/lib/models/xcore.py b/lib/models/xcore.py index a8196a0..91a0498 100644 --- a/lib/models/xcore.py +++ b/lib/models/xcore.py @@ -21,8 +21,12 @@ def get_model(config: Dict[Text, Any], **kwargs): act_cls = super_name2activation[kwargs["act_cls"]] norm_cls = super_name2norm[kwargs["norm_cls"]] mean, std = kwargs.get("mean", None), kwargs.get("std", None) - hidden_dim1 = kwargs.get("hidden_dim1", 200) - hidden_dim2 = kwargs.get("hidden_dim2", 100) + if "hidden_dim" in kwargs: + hidden_dim1 = kwargs.get("hidden_dim") + hidden_dim2 = kwargs.get("hidden_dim") + else: + hidden_dim1 = kwargs.get("hidden_dim1", 200) + hidden_dim2 = kwargs.get("hidden_dim2", 100) model = SuperSequential( norm_cls(mean=mean, std=std), SuperLinear(kwargs["input_dim"], hidden_dim1), @@ -34,4 +38,3 @@ def get_model(config: Dict[Text, Any], **kwargs): else: raise TypeError("Unkonwn model type: {:}".format(model_type)) return model - diff --git a/lib/xlayers/super_module.py b/lib/xlayers/super_module.py index aeed535..091be02 100644 --- a/lib/xlayers/super_module.py +++ b/lib/xlayers/super_module.py @@ -59,6 +59,9 @@ class TensorContainer: for tensor in self._tensors: tensor.requires_grad_(requires_grad) + def parameters(self): + return self._tensors + @property def tensors(self): return self._tensors