From cfabd05de8b3d006d3ebef36d789f5082ff398eb Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Thu, 13 May 2021 21:33:34 +0800 Subject: [PATCH] Update LFNA version 1.0 --- exps/LFNA/lfna.py | 214 ++++++++++++++++++------------- exps/LFNA/lfna_meta_model.py | 128 ++++++++++++++++++ exps/LFNA/lfna_models_v2.py | 72 ----------- exps/LFNA/vis-synthetic.py | 4 +- lib/datasets/synthetic_core.py | 1 + lib/datasets/synthetic_env.py | 73 ++++++++++- lib/datasets/synthetic_utils.py | 1 + lib/xlayers/super_rl_actor.py | 120 ----------------- lib/xlayers/super_transformer.py | 9 +- lib/xlayers/weight_init.py | 5 +- tests/test_synthetic_env.py | 12 +- 11 files changed, 340 insertions(+), 299 deletions(-) create mode 100644 exps/LFNA/lfna_meta_model.py delete mode 100644 exps/LFNA/lfna_models_v2.py delete mode 100644 lib/xlayers/super_rl_actor.py diff --git a/exps/LFNA/lfna.py b/exps/LFNA/lfna.py index 8e9cfae..4d69a5c 100644 --- a/exps/LFNA/lfna.py +++ b/exps/LFNA/lfna.py @@ -1,7 +1,7 @@ ##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # ##################################################### -# python exps/LFNA/lfna.py --env_version v1 --hidden_dim 16 --layer_dim 32 --epochs 50000 +# python exps/LFNA/lfna.py --env_version v1 ##################################################### import sys, time, copy, torch, random, argparse from tqdm import tqdm @@ -19,56 +19,82 @@ from utils import split_str2indexes from procedures.advanced_main import basic_train_fn, basic_eval_fn from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric -from datasets.synthetic_core import get_synthetic_env +from datasets.synthetic_core import get_synthetic_env, EnvSampler from models.xcore import get_model from xlayers import super_core, trunc_normal_ - from lfna_utils import lfna_setup, train_model, TimeData +from lfna_meta_model import LFNA_Meta -from lfna_models_v2 import HyperNet + +def epoch_train(loader, meta_model, base_model, optimizer, criterion, device, logger): + base_model.train() + meta_model.train() + loss_meter = AverageMeter() + for ibatch, batch_data in enumerate(loader): + timestamps, (batch_seq_inputs, batch_seq_targets) = batch_data + timestamps = timestamps.squeeze(dim=-1).to(device) + batch_seq_inputs = batch_seq_inputs.to(device) + batch_seq_targets = batch_seq_targets.to(device) + + optimizer.zero_grad() + + batch_seq_containers = meta_model(timestamps) + losses = [] + for seq_containers, seq_inputs, seq_targets in zip( + batch_seq_containers, batch_seq_inputs, batch_seq_targets + ): + for container, inputs, targets in zip( + seq_containers, seq_inputs, seq_targets + ): + predictions = base_model.forward_with_container(inputs, container) + loss = criterion(predictions, targets) + losses.append(loss) + final_loss = torch.stack(losses).mean() + final_loss.backward() + optimizer.step() + loss_meter.update(final_loss.item()) + return loss_meter def main(args): logger, env_info, model_kwargs = lfna_setup(args) - dynamic_env = env_info["dynamic_env"] - model = get_model(**model_kwargs) - model = model.to(args.device) + dynamic_env = get_synthetic_env(mode="train", version=args.env_version) + base_model = get_model(**model_kwargs) + base_model = base_model.to(args.device) criterion = torch.nn.MSELoss() - logger.log("There are {:} weights.".format(model.get_w_container().numel())) - # meta_train_range = (dynamic_env.min_timestamp, (dynamic_env.min_timestamp + dynamic_env.max_timestamp) / 2) - # meta_train_interval = dynamic_env.timestamp_interval - - shape_container = model.get_w_container().to_shape_container() + shape_container = base_model.get_w_container().to_shape_container() # pre-train the hypernetwork - timestamps = list( - dynamic_env.get_timestamp(index) for index in range(len(dynamic_env) // 2) + timestamps = dynamic_env.get_timestamp(None) + meta_model = LFNA_Meta(shape_container, args.layer_dim, args.time_dim, timestamps) + meta_model = meta_model.to(args.device) + + logger.log("The base-model has {:} weights.".format(base_model.numel())) + logger.log("The meta-model has {:} weights.".format(meta_model.numel())) + + batch_sampler = EnvSampler(dynamic_env, args.meta_batch, args.sampler_enlarge) + dynamic_env.reset_max_seq_length(args.seq_length) + """ + env_loader = torch.utils.data.DataLoader( + dynamic_env, + batch_size=args.meta_batch, + shuffle=True, + num_workers=args.workers, + pin_memory=True, + ) + """ + env_loader = torch.utils.data.DataLoader( + dynamic_env, + batch_sampler=batch_sampler, + num_workers=args.workers, + pin_memory=True, ) - hypernet = HyperNet(shape_container, args.layer_dim, args.task_dim, timestamps) - hypernet = hypernet.to(args.device) - - import pdb - - pdb.set_trace() - - # task_embed = torch.nn.Parameter(torch.Tensor(env_info["total"], args.task_dim)) - total_bar = 16 - task_embeds = [] - for i in range(total_bar): - tensor = torch.Tensor(1, args.task_dim).to(args.device) - task_embeds.append(torch.nn.Parameter(tensor)) - for task_embed in task_embeds: - trunc_normal_(task_embed, std=0.02) - - model.train() - hypernet.train() - - parameters = list(hypernet.parameters()) + task_embeds - # optimizer = torch.optim.Adam(parameters, lr=args.init_lr, amsgrad=True) - optimizer = torch.optim.Adam(parameters, lr=args.init_lr, weight_decay=1e-5) + optimizer = torch.optim.Adam( + meta_model.parameters(), lr=args.init_lr, weight_decay=1e-5, amsgrad=True + ) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=[ @@ -77,71 +103,59 @@ def main(args): ], gamma=0.1, ) + logger.log("The base-model is\n{:}".format(base_model)) + logger.log("The meta-model is\n{:}".format(meta_model)) + logger.log("The optimizer is\n{:}".format(optimizer)) + logger.log("Per epoch iterations = {:}".format(len(env_loader))) - # total_bar = env_info["total"] - 1 # LFNA meta-training - loss_meter = AverageMeter() per_epoch_time, start_time = AverageMeter(), time.time() + last_success_epoch = 0 for iepoch in range(args.epochs): - need_time = "Time Left: {:}".format( + head_str = "[{:}] [{:04d}/{:04d}] ".format( + time_string(), iepoch, args.epochs + ) + "Time Left: {:}".format( convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) ) - head_str = ( - "[{:}] [{:04d}/{:04d}] ".format(time_string(), iepoch, args.epochs) - + need_time + + loss_meter = epoch_train( + env_loader, + meta_model, + base_model, + optimizer, + criterion, + args.device, + logger, ) - - losses = [] - # for ibatch in range(args.meta_batch): - for cur_time in range(total_bar): - # cur_time = random.randint(0, total_bar) - cur_task_embed = task_embeds[cur_time] - cur_container = hypernet(cur_task_embed) - cur_x = env_info["{:}-x".format(cur_time)].to(args.device) - cur_y = env_info["{:}-y".format(cur_time)].to(args.device) - cur_dataset = TimeData(cur_time, cur_x, cur_y) - - preds = model.forward_with_container(cur_dataset.x, cur_container) - optimizer.zero_grad() - loss = criterion(preds, cur_dataset.y) - - losses.append(loss) - - final_loss = torch.stack(losses).mean() - final_loss.backward() - optimizer.step() lr_scheduler.step() - - loss_meter.update(final_loss.item()) - if iepoch % 100 == 0: - logger.log( - head_str - + " meta-loss: {:.4f} ({:.4f}) :: lr={:.5f}, batch={:}".format( - loss_meter.avg, - loss_meter.val, - min(lr_scheduler.get_last_lr()), - len(losses), - ) - ) - + logger.log( + head_str + + " meta-loss: {meter.avg:.4f} ({meter.count:.0f})".format(meter=loss_meter) + + " :: lr={:.5f}".format(min(lr_scheduler.get_last_lr())) + ) + success, best_score = meta_model.save_best(-loss_meter.avg) + if success: + logger.log("Achieve the best with best_score = {:.3f}".format(best_score)) + last_success_epoch = iepoch save_checkpoint( { - "hypernet": hypernet.state_dict(), - "task_embed": task_embed, + "meta_model": meta_model.state_dict(), + "optimizer": optimizer.state_dict(), "lr_scheduler": lr_scheduler.state_dict(), "iepoch": iepoch, + "args": args, }, logger.path("model"), logger, ) - loss_meter.reset() + if iepoch - last_success_epoch >= args.early_stop_thresh: + logger.log("Early stop at {:}".format(iepoch)) + break + per_epoch_time.update(time.time() - start_time) start_time = time.time() - print(model) - print(hypernet) - w_container_per_epoch = dict() for idx in range(0, total_bar): future_time = env_info["{:}-timestamp".format(idx)] @@ -183,20 +197,26 @@ if __name__ == "__main__": parser.add_argument( "--hidden_dim", type=int, - required=True, + default=16, help="The hidden dimension.", ) parser.add_argument( "--layer_dim", type=int, - required=True, - help="The hidden dimension.", + default=16, + help="The layer chunk dimension.", + ) + parser.add_argument( + "--time_dim", + type=int, + default=16, + help="The timestamp dimension.", ) ##### parser.add_argument( "--init_lr", type=float, - default=0.1, + default=0.01, help="The initial learning rate for the optimizer (default is Adam)", ) parser.add_argument( @@ -206,10 +226,23 @@ if __name__ == "__main__": help="The batch size for the meta-model", ) parser.add_argument( - "--epochs", + "--sampler_enlarge", type=int, - default=2000, - help="The total number of epochs.", + default=5, + help="Enlarge the #iterations for an epoch", + ) + parser.add_argument("--epochs", type=int, default=1000, help="The total #epochs.") + parser.add_argument( + "--early_stop_thresh", + type=int, + default=50, + help="The maximum epochs for early stop.", + ) + parser.add_argument( + "--seq_length", type=int, default=5, help="The sequence length." + ) + parser.add_argument( + "--workers", type=int, default=4, help="The number of workers in parallel." ) parser.add_argument( "--device", @@ -223,8 +256,7 @@ if __name__ == "__main__": if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) assert args.save_dir is not None, "The save dir argument can not be None" - args.task_dim = args.layer_dim - args.save_dir = "{:}-{:}-d{:}".format( - args.save_dir, args.env_version, args.hidden_dim + args.save_dir = "{:}-{:}-d{:}_{:}_{:}".format( + args.save_dir, args.env_version, args.hidden_dim, args.layer_dim, args.time_dim ) main(args) diff --git a/exps/LFNA/lfna_meta_model.py b/exps/LFNA/lfna_meta_model.py new file mode 100644 index 0000000..588c375 --- /dev/null +++ b/exps/LFNA/lfna_meta_model.py @@ -0,0 +1,128 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # +##################################################### +import copy +import torch + +import torch.nn.functional as F + +from xlayers import super_core +from xlayers import trunc_normal_ +from models.xcore import get_model + + +class LFNA_Meta(super_core.SuperModule): + """Learning to Forecast Neural Adaptation (Meta Model Design).""" + + def __init__( + self, + shape_container, + layer_embeding, + time_embedding, + meta_timestamps, + mha_depth: int = 2, + dropout: float = 0.1, + ): + super(LFNA_Meta, self).__init__() + self._shape_container = shape_container + self._num_layers = len(shape_container) + self._numel_per_layer = [] + for ilayer in range(self._num_layers): + self._numel_per_layer.append(shape_container[ilayer].numel()) + self._raw_meta_timestamps = meta_timestamps + + self.register_parameter( + "_super_layer_embed", + torch.nn.Parameter(torch.Tensor(self._num_layers, layer_embeding)), + ) + self.register_parameter( + "_super_meta_embed", + torch.nn.Parameter(torch.Tensor(len(meta_timestamps), time_embedding)), + ) + self.register_buffer("_meta_timestamps", torch.Tensor(meta_timestamps)) + + # build transformer + layers = [] + for ilayer in range(mha_depth): + layers.append( + super_core.SuperTransformerEncoderLayer( + time_embedding, + 4, + True, + 4, + dropout, + norm_affine=False, + order=super_core.LayerOrder.PostNorm, + ) + ) + self.meta_corrector = super_core.SuperSequential(*layers) + + model_kwargs = dict( + config=dict(model_type="dual_norm_mlp"), + input_dim=layer_embeding + time_embedding, + output_dim=max(self._numel_per_layer), + hidden_dims=[(layer_embeding + time_embedding) * 2] * 3, + act_cls="gelu", + norm_cls="layer_norm_1d", + dropout=dropout, + ) + self._generator = get_model(**model_kwargs) + # print("generator: {:}".format(self._generator)) + + # unknown token + self.register_parameter( + "_unknown_token", + torch.nn.Parameter(torch.Tensor(1, time_embedding)), + ) + + # initialization + trunc_normal_( + [self._super_layer_embed, self._super_meta_embed, self._unknown_token], + std=0.02, + ) + + def forward_raw(self, timestamps): + # timestamps is a batch of sequence of timestamps + batch, seq = timestamps.shape + timestamps = timestamps.unsqueeze(dim=-1) + meta_timestamps = self._meta_timestamps.view(1, 1, -1) + time_diffs = timestamps - meta_timestamps + time_match_v, time_match_i = torch.min(torch.abs(time_diffs), dim=-1) + # select corresponding meta-knowledge + meta_match = torch.index_select( + self._super_meta_embed, dim=0, index=time_match_i.view(-1) + ) + meta_match = meta_match.view(batch, seq, -1) + # create the probability + time_probs = (1 / torch.exp(time_match_v * 10)).view(batch, seq, 1) + time_probs[:, -1, :] = 0 + unknown_token = self._unknown_token.view(1, 1, -1) + raw_meta_embed = time_probs * meta_match + (1 - time_probs) * unknown_token + + meta_embed = self.meta_corrector(raw_meta_embed) + # create joint embed + num_layer, _ = self._super_layer_embed.shape + meta_embed = meta_embed.view(batch, seq, 1, -1).expand(-1, -1, num_layer, -1) + layer_embed = self._super_layer_embed.view(1, 1, num_layer, -1).expand( + batch, seq, -1, -1 + ) + joint_embed = torch.cat((meta_embed, layer_embed), dim=-1) + batch_weights = self._generator(joint_embed) + batch_containers = [] + for seq_weights in torch.split(batch_weights, 1): + seq_containers = [] + for weights in torch.split(seq_weights.squeeze(0), 1): + weights = torch.split(weights.squeeze(0), 1) + seq_containers.append(self._shape_container.translate(weights)) + batch_containers.append(seq_containers) + return batch_containers + + def forward_candidate(self, input): + raise NotImplementedError + + def extra_repr(self) -> str: + return "(_super_layer_embed): {:}, (_super_meta_embed): {:}, (_meta_timestamps): {:}".format( + list(self._super_layer_embed.shape), + list(self._super_meta_embed.shape), + list(self._meta_timestamps.shape), + ) diff --git a/exps/LFNA/lfna_models_v2.py b/exps/LFNA/lfna_models_v2.py deleted file mode 100644 index ad1f91f..0000000 --- a/exps/LFNA/lfna_models_v2.py +++ /dev/null @@ -1,72 +0,0 @@ -##################################################### -# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # -##################################################### -import copy -import torch - -import torch.nn.functional as F - -from xlayers import super_core -from xlayers import trunc_normal_ -from models.xcore import get_model - - -class HyperNet(super_core.SuperModule): - """The hyper-network.""" - - def __init__( - self, - shape_container, - layer_embeding, - task_embedding, - meta_timestamps, - return_container: bool = True, - ): - super(HyperNet, self).__init__() - self._shape_container = shape_container - self._num_layers = len(shape_container) - self._numel_per_layer = [] - for ilayer in range(self._num_layers): - self._numel_per_layer.append(shape_container[ilayer].numel()) - - self.register_parameter( - "_super_layer_embed", - torch.nn.Parameter(torch.Tensor(self._num_layers, layer_embeding)), - ) - trunc_normal_(self._super_layer_embed, std=0.02) - - model_kwargs = dict( - config=dict(model_type="dual_norm_mlp"), - input_dim=layer_embeding + task_embedding, - output_dim=max(self._numel_per_layer), - hidden_dims=[(layer_embeding + task_embedding) * 2] * 3, - act_cls="gelu", - norm_cls="layer_norm_1d", - dropout=0.2, - ) - import pdb - - pdb.set_trace() - self._generator = get_model(**model_kwargs) - self._return_container = return_container - print("generator: {:}".format(self._generator)) - - def forward_raw(self, task_embed): - # task_embed = F.normalize(task_embed, dim=-1, p=2) - # layer_embed = F.normalize(self._super_layer_embed, dim=-1, p=2) - layer_embed = self._super_layer_embed - task_embed = task_embed.view(1, -1).expand(self._num_layers, -1) - - joint_embed = torch.cat((task_embed, layer_embed), dim=-1) - weights = self._generator(joint_embed) - if self._return_container: - weights = torch.split(weights, 1) - return self._shape_container.translate(weights) - else: - return weights - - def forward_candidate(self, input): - raise NotImplementedError - - def extra_repr(self) -> str: - return "(_super_layer_embed): {:}".format(list(self._super_layer_embed.shape)) diff --git a/exps/LFNA/vis-synthetic.py b/exps/LFNA/vis-synthetic.py index 649b890..96c1ba2 100644 --- a/exps/LFNA/vis-synthetic.py +++ b/exps/LFNA/vis-synthetic.py @@ -225,8 +225,8 @@ def visualize_env(save_dir, version): def compare_algs(save_dir, version, alg_dir="./outputs/lfna-synthetic"): save_dir = Path(str(save_dir)) for substr in ("pdf", "png"): - sub_save_dir = save_dir / substr - sub_save_dir.mkdir(parents=True, exist_ok=True) + sub_save_dir = save_dir / substr + sub_save_dir.mkdir(parents=True, exist_ok=True) dpi, width, height = 30, 3200, 2000 figsize = width / float(dpi), height / float(dpi) diff --git a/lib/datasets/synthetic_core.py b/lib/datasets/synthetic_core.py index 0fb7238..5f2bfee 100644 --- a/lib/datasets/synthetic_core.py +++ b/lib/datasets/synthetic_core.py @@ -2,6 +2,7 @@ # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 # ##################################################### from .synthetic_utils import TimeStamp +from .synthetic_env import EnvSampler from .synthetic_env import SyntheticDEnv from .math_core import LinearFunc from .math_core import DynamicLinearFunc diff --git a/lib/datasets/synthetic_env.py b/lib/datasets/synthetic_env.py index 506e1f2..9cef5be 100644 --- a/lib/datasets/synthetic_env.py +++ b/lib/datasets/synthetic_env.py @@ -2,7 +2,7 @@ # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # ##################################################### import math -import abc +import random import numpy as np from typing import List, Optional, Dict import torch @@ -11,6 +11,28 @@ import torch.utils.data as data from .synthetic_utils import TimeStamp +def is_list_tuple(x): + return isinstance(x, (tuple, list)) + + +def zip_sequence(sequence): + def _combine(*alist): + if is_list_tuple(alist[0]): + return [_combine(*xlist) for xlist in zip(*alist)] + else: + return torch.cat(alist, dim=0) + + def unsqueeze(a): + if is_list_tuple(a): + return [unsqueeze(x) for x in a] + else: + return a.unsqueeze(dim=0) + + with torch.no_grad(): + sequence = [unsqueeze(a) for a in sequence] + return _combine(*sequence) + + class SyntheticDEnv(data.Dataset): """The synethtic dynamic environment.""" @@ -33,7 +55,7 @@ class SyntheticDEnv(data.Dataset): self._num_per_task = num_per_task if timestamp_config is None: timestamp_config = dict(mode=mode) - else: + elif "mode" not in timestamp_config: timestamp_config["mode"] = mode self._timestamp_generator = TimeStamp(**timestamp_config) @@ -42,6 +64,7 @@ class SyntheticDEnv(data.Dataset): self._cov_functors = cov_functors self._oracle_map = None + self._seq_length = None @property def min_timestamp(self): @@ -55,9 +78,18 @@ class SyntheticDEnv(data.Dataset): def timestamp_interval(self): return self._timestamp_generator.interval + def reset_max_seq_length(self, seq_length): + self._seq_length = seq_length + def get_timestamp(self, index): - index, timestamp = self._timestamp_generator[index] - return timestamp + if index is None: + timestamps = [] + for index in range(len(self._timestamp_generator)): + timestamps.append(self._timestamp_generator[index][1]) + return tuple(timestamps) + else: + index, timestamp = self._timestamp_generator[index] + return timestamp def set_oracle_map(self, functor): self._oracle_map = functor @@ -75,7 +107,14 @@ class SyntheticDEnv(data.Dataset): def __getitem__(self, index): assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self)) index, timestamp = self._timestamp_generator[index] - return self.__call__(timestamp) + if self._seq_length is None: + return self.__call__(timestamp) + else: + timestamps = [ + timestamp + i * self.timestamp_interval for i in range(self._seq_length) + ] + xdata = [self.__call__(timestamp) for timestamp in timestamps] + return zip_sequence(xdata) def __call__(self, timestamp): mean_list = [functor(timestamp) for functor in self._mean_functors] @@ -88,10 +127,13 @@ class SyntheticDEnv(data.Dataset): mean_list, cov_matrix, size=self._num_per_task ) if self._oracle_map is None: - return timestamp, torch.Tensor(dataset) + return torch.Tensor([timestamp]), torch.Tensor(dataset) else: targets = self._oracle_map.noise_call(dataset, timestamp) - return timestamp, (torch.Tensor(dataset), torch.Tensor(targets)) + return torch.Tensor([timestamp]), ( + torch.Tensor(dataset), + torch.Tensor(targets), + ) def __len__(self): return len(self._timestamp_generator) @@ -104,3 +146,20 @@ class SyntheticDEnv(data.Dataset): ndim=self._ndim, num_per_task=self._num_per_task, ) + + +class EnvSampler: + def __init__(self, env, batch, enlarge): + indexes = list(range(len(env))) + self._indexes = indexes * enlarge + self._batch = batch + self._iterations = len(self._indexes) // self._batch + + def __iter__(self): + random.shuffle(self._indexes) + for it in range(self._iterations): + indexes = self._indexes[it * self._batch : (it + 1) * self._batch] + yield indexes + + def __len__(self): + return self._iterations diff --git a/lib/datasets/synthetic_utils.py b/lib/datasets/synthetic_utils.py index 93e7b2b..a738fca 100644 --- a/lib/datasets/synthetic_utils.py +++ b/lib/datasets/synthetic_utils.py @@ -30,6 +30,7 @@ class UnifiedSplit: self._indexes = all_indexes[num_of_train + num_of_valid :] else: raise ValueError("Unkonwn mode of {:}".format(mode)) + self._all_indexes = all_indexes self._mode = mode @property diff --git a/lib/xlayers/super_rl_actor.py b/lib/xlayers/super_rl_actor.py deleted file mode 100644 index 5725fed..0000000 --- a/lib/xlayers/super_rl_actor.py +++ /dev/null @@ -1,120 +0,0 @@ -##################################################### -# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # -##################################################### -# DISABLED / NOT-FINISHED -##################################################### -import torch -import torch.nn as nn -import torch.nn.functional as F - -import math -from typing import Optional, Callable - -import spaces -from .super_container import SuperSequential -from .super_linear import SuperLinear - - -class SuperActor(SuperModule): - """A Actor in RL.""" - - def _distribution(self, obs): - raise NotImplementedError - - def _log_prob_from_distribution(self, pi, act): - raise NotImplementedError - - def forward_candidate(self, **kwargs): - return self.forward_raw(**kwargs) - - def forward_raw(self, obs, act=None): - # Produce action distributions for given observations, and - # optionally compute the log likelihood of given actions under - # those distributions. - pi = self._distribution(obs) - logp_a = None - if act is not None: - logp_a = self._log_prob_from_distribution(pi, act) - return pi, logp_a - - -class SuperLfnaMetaMLP(SuperModule): - def __init__(self, obs_dim, hidden_sizes, act_cls): - super(SuperLfnaMetaMLP).__init__() - self.delta_net = SuperSequential( - SuperLinear(obs_dim, hidden_sizes[0]), - act_cls(), - SuperLinear(hidden_sizes[0], hidden_sizes[1]), - act_cls(), - SuperLinear(hidden_sizes[1], 1), - ) - - -class SuperLfnaMetaMLP(SuperModule): - def __init__(self, obs_dim, act_dim, hidden_sizes, act_cls): - super(SuperLfnaMetaMLP).__init__() - log_std = -0.5 * np.ones(act_dim, dtype=np.float32) - self.log_std = torch.nn.Parameter(torch.as_tensor(log_std)) - self.mu_net = SuperSequential( - SuperLinear(obs_dim, hidden_sizes[0]), - act_cls(), - SuperLinear(hidden_sizes[0], hidden_sizes[1]), - act_cls(), - SuperLinear(hidden_sizes[1], act_dim), - ) - - def _distribution(self, obs): - mu = self.mu_net(obs) - std = torch.exp(self.log_std) - return Normal(mu, std) - - def _log_prob_from_distribution(self, pi, act): - return pi.log_prob(act).sum(axis=-1) - - def forward_candidate(self, **kwargs): - return self.forward_raw(**kwargs) - - def forward_raw(self, obs, act=None): - # Produce action distributions for given observations, and - # optionally compute the log likelihood of given actions under - # those distributions. - pi = self._distribution(obs) - logp_a = None - if act is not None: - logp_a = self._log_prob_from_distribution(pi, act) - return pi, logp_a - - -class SuperMLPGaussianActor(SuperModule): - def __init__(self, obs_dim, act_dim, hidden_sizes, act_cls): - super(SuperMLPGaussianActor).__init__() - log_std = -0.5 * np.ones(act_dim, dtype=np.float32) - self.log_std = torch.nn.Parameter(torch.as_tensor(log_std)) - self.mu_net = SuperSequential( - SuperLinear(obs_dim, hidden_sizes[0]), - act_cls(), - SuperLinear(hidden_sizes[0], hidden_sizes[1]), - act_cls(), - SuperLinear(hidden_sizes[1], act_dim), - ) - - def _distribution(self, obs): - mu = self.mu_net(obs) - std = torch.exp(self.log_std) - return Normal(mu, std) - - def _log_prob_from_distribution(self, pi, act): - return pi.log_prob(act).sum(axis=-1) - - def forward_candidate(self, **kwargs): - return self.forward_raw(**kwargs) - - def forward_raw(self, obs, act=None): - # Produce action distributions for given observations, and - # optionally compute the log likelihood of given actions under - # those distributions. - pi = self._distribution(obs) - logp_a = None - if act is not None: - logp_a = self._log_prob_from_distribution(pi, act) - return pi, logp_a diff --git a/lib/xlayers/super_transformer.py b/lib/xlayers/super_transformer.py index f21ac54..dcec793 100644 --- a/lib/xlayers/super_transformer.py +++ b/lib/xlayers/super_transformer.py @@ -42,6 +42,7 @@ class SuperTransformerEncoderLayer(SuperModule): qkv_bias: BoolSpaceType = False, mlp_hidden_multiplier: IntSpaceType = 4, drop: Optional[float] = None, + norm_affine: bool = True, act_layer: Callable[[], nn.Module] = nn.GELU, order: LayerOrder = LayerOrder.PreNorm, ): @@ -62,19 +63,19 @@ class SuperTransformerEncoderLayer(SuperModule): drop=drop, ) if order is LayerOrder.PreNorm: - self.norm1 = SuperLayerNorm1D(d_model) + self.norm1 = SuperLayerNorm1D(d_model, elementwise_affine=norm_affine) self.mha = mha self.drop1 = nn.Dropout(drop or 0.0) - self.norm2 = SuperLayerNorm1D(d_model) + self.norm2 = SuperLayerNorm1D(d_model, elementwise_affine=norm_affine) self.mlp = mlp self.drop2 = nn.Dropout(drop or 0.0) elif order is LayerOrder.PostNorm: self.mha = mha self.drop1 = nn.Dropout(drop or 0.0) - self.norm1 = SuperLayerNorm1D(d_model) + self.norm1 = SuperLayerNorm1D(d_model, elementwise_affine=norm_affine) self.mlp = mlp self.drop2 = nn.Dropout(drop or 0.0) - self.norm2 = SuperLayerNorm1D(d_model) + self.norm2 = SuperLayerNorm1D(d_model, elementwise_affine=norm_affine) else: raise ValueError("Unknown order: {:}".format(order)) self._order = order diff --git a/lib/xlayers/weight_init.py b/lib/xlayers/weight_init.py index 478a462..b9bb504 100644 --- a/lib/xlayers/weight_init.py +++ b/lib/xlayers/weight_init.py @@ -60,4 +60,7 @@ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): >>> w = torch.empty(3, 5) >>> nn.init.trunc_normal_(w) """ - return _no_grad_trunc_normal_(tensor, mean, std, a, b) + if isinstance(tensor, list): + return [_no_grad_trunc_normal_(x, mean, std, a, b) for x in tensor] + else: + return _no_grad_trunc_normal_(tensor, mean, std, a, b) diff --git a/tests/test_synthetic_env.py b/tests/test_synthetic_env.py index 8cac2fb..4abaa65 100644 --- a/tests/test_synthetic_env.py +++ b/tests/test_synthetic_env.py @@ -23,8 +23,16 @@ class TestSynethicEnv(unittest.TestCase): def test_simple(self): mean_generator = ComposedSinFunc(constant=0.1) std_generator = ConstantFunc(constant=0.5) - dataset = SyntheticDEnv([mean_generator], [[std_generator]], num_per_task=5000) print(dataset) for timestamp, tau in dataset: - assert tau.shape == (5000, 1) + self.assertEqual(tau.shape, (5000, 1)) + + def test_length(self): + mean_generator = ComposedSinFunc(constant=0.1) + std_generator = ConstantFunc(constant=0.5) + dataset = SyntheticDEnv([mean_generator], [[std_generator]], num_per_task=5000) + self.assertEqual(len(dataset), 100) + + dataset = SyntheticDEnv([mean_generator], [[std_generator]], mode="train") + self.assertEqual(len(dataset), 60)