diff --git a/exps/LFNA/backup/lfna-fix-init.py b/exps/LFNA/backup/lfna-fix-init.py deleted file mode 100644 index c3e8e7b..0000000 --- a/exps/LFNA/backup/lfna-fix-init.py +++ /dev/null @@ -1,239 +0,0 @@ -##################################################### -# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # -##################################################### -# python exps/LFNA/lfna-fix-init.py --env_version v1 --hidden_dim 16 -##################################################### -import sys, time, copy, torch, random, argparse -from tqdm import tqdm -from copy import deepcopy -from pathlib import Path - -lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() -if str(lib_dir) not in sys.path: - sys.path.insert(0, str(lib_dir)) -from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint -from log_utils import time_string -from log_utils import AverageMeter, convert_secs2time - -from utils import split_str2indexes - -from procedures.advanced_main import basic_train_fn, basic_eval_fn -from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric -from datasets.synthetic_core import get_synthetic_env -from models.xcore import get_model -from xlayers import super_core - - -from lfna_utils import lfna_setup, train_model, TimeData - - -class LFNAmlp: - """A LFNA meta-model that uses the MLP as delta-net.""" - - def __init__(self, obs_dim, hidden_sizes, act_name, criterion): - self.delta_net = super_core.SuperSequential( - super_core.SuperLinear(obs_dim, hidden_sizes[0]), - super_core.super_name2activation[act_name](), - super_core.SuperLinear(hidden_sizes[0], hidden_sizes[1]), - super_core.super_name2activation[act_name](), - super_core.SuperLinear(hidden_sizes[1], 1), - ) - self.meta_optimizer = torch.optim.Adam( - self.delta_net.parameters(), lr=0.001, amsgrad=True - ) - self.criterion = criterion - - def adapt(self, model, seq_datasets): - delta_inputs = [] - container = model.get_w_container() - for iseq, dataset in enumerate(seq_datasets): - y_hat = model.forward_with_container(dataset.x, container) - loss = self.criterion(y_hat, dataset.y) - gradients = torch.autograd.grad(loss, container.parameters()) - with torch.no_grad(): - flatten_g = container.flatten(gradients) - delta_inputs.append(flatten_g) - flatten_w = container.no_grad_clone().flatten() - delta_inputs.append(flatten_w) - delta_inputs = torch.stack(delta_inputs, dim=-1) - delta = self.delta_net(delta_inputs) - - delta = torch.clamp(delta, -0.8, 0.8) - unflatten_delta = container.unflatten(delta) - future_container = container.no_grad_clone().additive(unflatten_delta) - return future_container - - def step(self): - torch.nn.utils.clip_grad_norm_(self.delta_net.parameters(), 1.0) - self.meta_optimizer.step() - - def zero_grad(self): - self.meta_optimizer.zero_grad() - self.delta_net.zero_grad() - - def state_dict(self): - return dict( - delta_net=self.delta_net.state_dict(), - meta_optimizer=self.meta_optimizer.state_dict(), - ) - - -def main(args): - logger, env_info, model_kwargs = lfna_setup(args) - dynamic_env = env_info["dynamic_env"] - model = get_model(dict(model_type="simple_mlp"), **model_kwargs) - - total_time = env_info["total"] - for i in range(total_time): - for xkey in ("timestamp", "x", "y"): - nkey = "{:}-{:}".format(i, xkey) - assert nkey in env_info, "{:} no in {:}".format(nkey, list(env_info.keys())) - train_time_bar = total_time // 2 - network = get_model(dict(model_type="simple_mlp"), **model_kwargs) - - criterion = torch.nn.MSELoss() - logger.log("There are {:} weights.".format(network.get_w_container().numel())) - - adaptor = LFNAmlp(1 + args.meta_seq, (20, 20), "leaky_relu", criterion) - - # pre-train the model - init_dataset = TimeData(0, env_info["0-x"], env_info["0-y"]) - init_loss = train_model(network, init_dataset, args.init_lr, args.epochs) - logger.log("The pre-training loss is {:.4f}".format(init_loss)) - - # LFNA meta-training - meta_loss_meter = AverageMeter() - per_epoch_time, start_time = AverageMeter(), time.time() - for iepoch in range(args.epochs): - - need_time = "Time Left: {:}".format( - convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) - ) - logger.log( - "[{:}] [{:04d}/{:04d}] ".format(time_string(), iepoch, args.epochs) - + need_time - ) - - adaptor.zero_grad() - - batch_indexes, meta_losses = [], [] - for ibatch in range(args.meta_batch): - sampled_timestamp = random.random() * train_time_bar - batch_indexes.append("{:.3f}".format(sampled_timestamp)) - seq_datasets = [] - for iseq in range(args.meta_seq + 1): - cur_time = sampled_timestamp + iseq * dynamic_env.timestamp_interval - cur_time, (x, y) = dynamic_env(cur_time) - seq_datasets.append(TimeData(cur_time, x, y)) - history_datasets, future_dataset = seq_datasets[:-1], seq_datasets[-1] - future_container = adaptor.adapt(network, history_datasets) - future_y_hat = network.forward_with_container( - future_dataset.x, future_container - ) - future_loss = adaptor.criterion(future_y_hat, future_dataset.y) - meta_losses.append(future_loss) - meta_loss = torch.stack(meta_losses).mean() - meta_loss.backward() - adaptor.step() - - meta_loss_meter.update(meta_loss.item()) - - logger.log( - "meta-loss: {:.4f} ({:.4f}) batch: {:}".format( - meta_loss_meter.avg, meta_loss_meter.val, ",".join(batch_indexes[:5]) - ) - ) - if iepoch % 200 == 0: - save_checkpoint( - {"adaptor": adaptor.state_dict(), "iepoch": iepoch}, - logger.path("model"), - logger, - ) - per_epoch_time.update(time.time() - start_time) - start_time = time.time() - - w_container_per_epoch = dict() - for idx in range(1, env_info["total"]): - future_time = env_info["{:}-timestamp".format(idx)] - future_x = env_info["{:}-x".format(idx)] - future_y = env_info["{:}-y".format(idx)] - seq_datasets = [] - for iseq in range(1, args.meta_seq + 1): - cur_time = future_time - iseq * dynamic_env.timestamp_interval - cur_time, (x, y) = dynamic_env(cur_time) - seq_datasets.append(TimeData(cur_time, x, y)) - seq_datasets.reverse() - future_container = adaptor.adapt(network, seq_datasets) - w_container_per_epoch[idx] = future_container.no_grad_clone() - with torch.no_grad(): - future_y_hat = network.forward_with_container( - future_x, w_container_per_epoch[idx] - ) - future_loss = adaptor.criterion(future_y_hat, future_y) - logger.log("meta-test: [{:03d}] -> loss={:.4f}".format(idx, future_loss.item())) - - save_checkpoint( - {"w_container_per_epoch": w_container_per_epoch}, - logger.path(None) / "final-ckp.pth", - logger, - ) - - logger.log("-" * 200 + "\n") - logger.close() - - -if __name__ == "__main__": - parser = argparse.ArgumentParser("Use the data in the past.") - parser.add_argument( - "--save_dir", - type=str, - default="./outputs/lfna-synthetic/lfna-fix-init", - help="The checkpoint directory.", - ) - parser.add_argument( - "--env_version", - type=str, - required=True, - help="The synthetic enviornment version.", - ) - parser.add_argument( - "--hidden_dim", - type=int, - required=True, - help="The hidden dimension.", - ) - ##### - parser.add_argument( - "--init_lr", - type=float, - default=0.1, - help="The initial learning rate for the optimizer (default is Adam)", - ) - parser.add_argument( - "--meta_batch", - type=int, - default=32, - help="The batch size for the meta-model", - ) - parser.add_argument( - "--meta_seq", - type=int, - default=10, - help="The length of the sequence for meta-model.", - ) - parser.add_argument( - "--epochs", - type=int, - default=1000, - help="The total number of epochs.", - ) - # Random Seed - parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") - args = parser.parse_args() - if args.rand_seed is None or args.rand_seed < 0: - args.rand_seed = random.randint(1, 100000) - assert args.save_dir is not None, "The save dir argument can not be None" - args.save_dir = "{:}-{:}-d{:}".format( - args.save_dir, args.env_version, args.hidden_dim - ) - main(args) diff --git a/exps/LFNA/backup/lfna-test-hpnet.py b/exps/LFNA/backup/lfna-test-hpnet.py deleted file mode 100644 index 334c8ac..0000000 --- a/exps/LFNA/backup/lfna-test-hpnet.py +++ /dev/null @@ -1,239 +0,0 @@ -##################################################### -# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # -##################################################### -# python exps/LFNA/lfna-test-hpnet.py --env_version v1 --hidden_dim 16 --layer_dim 16 --epochs 10000 --init_lr 0.01 -# python exps/LFNA/lfna-test-hpnet.py --env_version v1 --hidden_dim 16 --layer_dim 16 --epochs 10000 --init_lr 0.01 --device cuda -##################################################### -import sys, time, copy, torch, random, argparse -from tqdm import tqdm -from copy import deepcopy -from pathlib import Path - -lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() -if str(lib_dir) not in sys.path: - sys.path.insert(0, str(lib_dir)) -from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint -from log_utils import time_string -from log_utils import AverageMeter, convert_secs2time - -from utils import split_str2indexes - -from procedures.advanced_main import basic_train_fn, basic_eval_fn -from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric -from datasets.synthetic_core import get_synthetic_env -from models.xcore import get_model -from xlayers import super_core, trunc_normal_ - - -from lfna_utils import lfna_setup, train_model, TimeData - -# from lfna_models import HyperNet_VX as HyperNet -from lfna_models import HyperNet - - -def main(args): - logger, env_info, model_kwargs = lfna_setup(args) - dynamic_env = env_info["dynamic_env"] - model = get_model(**model_kwargs) - model = model.to(args.device) - criterion = torch.nn.MSELoss() - - shape_container = model.get_w_container().to_shape_container() - total_bar = 100 - hypernet = HyperNet(shape_container, args.layer_dim, args.task_dim, total_bar) - hypernet = hypernet.to(args.device) - - logger.log( - "{:} There are {:} weights in the base-model.".format( - time_string(), model.numel() - ) - ) - logger.log( - "{:} There are {:} weights in the meta-model.".format( - time_string(), hypernet.numel() - ) - ) - for i in range(total_bar): - env_info["{:}-x".format(i)] = env_info["{:}-x".format(i)].to(args.device) - env_info["{:}-y".format(i)] = env_info["{:}-y".format(i)].to(args.device) - - model.train() - hypernet.train() - - optimizer = torch.optim.Adam( - hypernet.parameters(), lr=args.init_lr, weight_decay=1e-5, amsgrad=True - ) - lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( - optimizer, - milestones=[ - int(args.epochs * 0.8), - int(args.epochs * 0.9), - ], - gamma=0.1, - ) - - # total_bar = env_info["total"] - 1 - # LFNA meta-training - loss_meter = AverageMeter() - per_epoch_time, start_time = AverageMeter(), time.time() - last_success = 0 - for iepoch in range(args.epochs): - - need_time = "Time Left: {:}".format( - convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) - ) - head_str = ( - "[{:}] [{:04d}/{:04d}] ".format(time_string(), iepoch, args.epochs) - + need_time - ) - - losses = [] - # for ibatch in range(args.meta_batch): - for cur_time in range(total_bar): - # cur_time = random.randint(0, total_bar) - # cur_task_embed = task_embeds[cur_time] - cur_container = hypernet(cur_time) - cur_x = env_info["{:}-x".format(cur_time)] - cur_y = env_info["{:}-y".format(cur_time)] - cur_dataset = TimeData(cur_time, cur_x, cur_y) - - preds = model.forward_with_container(cur_dataset.x, cur_container) - optimizer.zero_grad() - loss = criterion(preds, cur_dataset.y) - - losses.append(loss) - - final_loss = torch.stack(losses).mean() - final_loss.backward() - optimizer.step() - lr_scheduler.step() - - loss_meter.update(final_loss.item()) - success, best_score = hypernet.save_best(-loss_meter.val) - if success: - logger.log("Achieve the best with best_score = {:.3f}".format(best_score)) - last_success = iepoch - if iepoch - last_success >= args.early_stop_thresh: - logger.log("Early stop at {:}".format(iepoch)) - break - if iepoch % 20 == 0: - logger.log( - head_str - + " meta-loss: {:.4f} ({:.4f}) :: lr={:.5f}, batch={:}".format( - loss_meter.avg, - loss_meter.val, - min(lr_scheduler.get_last_lr()), - len(losses), - ) - ) - - save_checkpoint( - { - "hypernet": hypernet.state_dict(), - "lr_scheduler": lr_scheduler.state_dict(), - "iepoch": iepoch, - }, - logger.path("model"), - logger, - ) - loss_meter.reset() - per_epoch_time.update(time.time() - start_time) - start_time = time.time() - - print(model) - print(hypernet) - hypernet.load_best() - - w_container_per_epoch = dict() - for idx in range(0, total_bar): - future_time = env_info["{:}-timestamp".format(idx)] - future_x = env_info["{:}-x".format(idx)] - future_y = env_info["{:}-y".format(idx)] - # future_container = hypernet(task_embeds[idx]) - future_container = hypernet(idx) - w_container_per_epoch[idx] = future_container.no_grad_clone() - with torch.no_grad(): - future_y_hat = model.forward_with_container( - future_x, w_container_per_epoch[idx] - ) - future_loss = criterion(future_y_hat, future_y) - logger.log("meta-test: [{:03d}] -> loss={:.4f}".format(idx, future_loss.item())) - - save_checkpoint( - {"w_container_per_epoch": w_container_per_epoch}, - logger.path(None) / "final-ckp.pth", - logger, - ) - - logger.log("-" * 200 + "\n") - logger.close() - - -if __name__ == "__main__": - parser = argparse.ArgumentParser("Use the data in the past.") - parser.add_argument( - "--save_dir", - type=str, - default="./outputs/lfna-synthetic/lfna-test-hpnet", - help="The checkpoint directory.", - ) - parser.add_argument( - "--env_version", - type=str, - required=True, - help="The synthetic enviornment version.", - ) - parser.add_argument( - "--hidden_dim", - type=int, - required=True, - help="The hidden dimension.", - ) - parser.add_argument( - "--layer_dim", - type=int, - required=True, - help="The hidden dimension.", - ) - parser.add_argument( - "--early_stop_thresh", - type=int, - default=100, - help="The maximum epochs for early stop.", - ) - ##### - parser.add_argument( - "--init_lr", - type=float, - default=0.1, - help="The initial learning rate for the optimizer (default is Adam)", - ) - parser.add_argument( - "--meta_batch", - type=int, - default=64, - help="The batch size for the meta-model", - ) - parser.add_argument( - "--epochs", - type=int, - default=2000, - help="The total number of epochs.", - ) - parser.add_argument( - "--device", - type=str, - default="cpu", - help="", - ) - # Random Seed - parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") - args = parser.parse_args() - if args.rand_seed is None or args.rand_seed < 0: - args.rand_seed = random.randint(1, 100000) - assert args.save_dir is not None, "The save dir argument can not be None" - args.task_dim = args.layer_dim - args.save_dir = "{:}-{:}-d{:}".format( - args.save_dir, args.env_version, args.hidden_dim - ) - main(args) diff --git a/exps/LFNA/backup/lfna-ttss-hpnet.py b/exps/LFNA/backup/lfna-ttss-hpnet.py deleted file mode 100644 index a3e85a7..0000000 --- a/exps/LFNA/backup/lfna-ttss-hpnet.py +++ /dev/null @@ -1,134 +0,0 @@ -##################################################### -# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # -##################################################### -# python exps/LFNA/lfna-ttss-hpnet.py --env_version v1 --hidden_dim 16 -##################################################### -import sys, time, copy, torch, random, argparse -from tqdm import tqdm -from copy import deepcopy -from pathlib import Path - -lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() -if str(lib_dir) not in sys.path: - sys.path.insert(0, str(lib_dir)) -from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint -from log_utils import time_string -from log_utils import AverageMeter, convert_secs2time - -from utils import split_str2indexes - -from procedures.advanced_main import basic_train_fn, basic_eval_fn -from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric -from datasets.synthetic_core import get_synthetic_env -from models.xcore import get_model -from xlayers import super_core - - -from lfna_utils import lfna_setup, train_model, TimeData -from lfna_models import HyperNet_VX as HyperNet - - -def main(args): - logger, env_info, model_kwargs = lfna_setup(args) - dynamic_env = env_info["dynamic_env"] - model = get_model(**model_kwargs) - - total_time = env_info["total"] - for i in range(total_time): - for xkey in ("timestamp", "x", "y"): - nkey = "{:}-{:}".format(i, xkey) - assert nkey in env_info, "{:} no in {:}".format(nkey, list(env_info.keys())) - train_time_bar = total_time // 2 - - criterion = torch.nn.MSELoss() - logger.log("There are {:} weights.".format(model.get_w_container().numel())) - - # pre-train the model - dataset = init_dataset = TimeData(0, env_info["0-x"], env_info["0-y"]) - - shape_container = model.get_w_container().to_shape_container() - hypernet = HyperNet(shape_container, 16) - print(hypernet) - - optimizer = torch.optim.Adam(hypernet.parameters(), lr=args.init_lr, amsgrad=True) - - best_loss, best_param = None, None - for _iepoch in range(args.epochs): - container = hypernet(None) - - preds = model.forward_with_container(dataset.x, container) - optimizer.zero_grad() - loss = criterion(preds, dataset.y) - loss.backward() - optimizer.step() - # save best - if best_loss is None or best_loss > loss.item(): - best_loss = loss.item() - best_param = copy.deepcopy(model.state_dict()) - print("hyper-net : best={:.4f}".format(best_loss)) - - init_loss = train_model(model, init_dataset, args.init_lr, args.epochs) - logger.log("The pre-training loss is {:.4f}".format(init_loss)) - - print(model) - print(hypernet) - - logger.log("-" * 200 + "\n") - logger.close() - - -if __name__ == "__main__": - parser = argparse.ArgumentParser("Use the data in the past.") - parser.add_argument( - "--save_dir", - type=str, - default="./outputs/lfna-synthetic/lfna-debug", - help="The checkpoint directory.", - ) - parser.add_argument( - "--env_version", - type=str, - required=True, - help="The synthetic enviornment version.", - ) - parser.add_argument( - "--hidden_dim", - type=int, - required=True, - help="The hidden dimension.", - ) - ##### - parser.add_argument( - "--init_lr", - type=float, - default=0.1, - help="The initial learning rate for the optimizer (default is Adam)", - ) - parser.add_argument( - "--meta_batch", - type=int, - default=32, - help="The batch size for the meta-model", - ) - parser.add_argument( - "--meta_seq", - type=int, - default=10, - help="The length of the sequence for meta-model.", - ) - parser.add_argument( - "--epochs", - type=int, - default=2000, - help="The total number of epochs.", - ) - # Random Seed - parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") - args = parser.parse_args() - if args.rand_seed is None or args.rand_seed < 0: - args.rand_seed = random.randint(1, 100000) - assert args.save_dir is not None, "The save dir argument can not be None" - args.save_dir = "{:}-{:}-d{:}".format( - args.save_dir, args.env_version, args.hidden_dim - ) - main(args) diff --git a/exps/LFNA/backup/lfna-v1.py b/exps/LFNA/backup/lfna-v1.py deleted file mode 100644 index 60dd62a..0000000 --- a/exps/LFNA/backup/lfna-v1.py +++ /dev/null @@ -1,272 +0,0 @@ -##################################################### -# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # -##################################################### -# python exps/LFNA/lfna-v1.py -##################################################### -import sys, time, copy, torch, random, argparse -from tqdm import tqdm -from copy import deepcopy -from pathlib import Path - -lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() -if str(lib_dir) not in sys.path: - sys.path.insert(0, str(lib_dir)) -from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint -from log_utils import time_string -from log_utils import AverageMeter, convert_secs2time - -from utils import split_str2indexes - -from procedures.advanced_main import basic_train_fn, basic_eval_fn -from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric -from datasets.synthetic_core import get_synthetic_env -from models.xcore import get_model -from xlayers import super_core - - -class LFNAmlp: - """A LFNA meta-model that uses the MLP as delta-net.""" - - def __init__(self, obs_dim, hidden_sizes, act_name): - self.delta_net = super_core.SuperSequential( - super_core.SuperLinear(obs_dim, hidden_sizes[0]), - super_core.super_name2activation[act_name](), - super_core.SuperLinear(hidden_sizes[0], hidden_sizes[1]), - super_core.super_name2activation[act_name](), - super_core.SuperLinear(hidden_sizes[1], 1), - ) - self.meta_optimizer = torch.optim.Adam( - self.delta_net.parameters(), lr=0.01, amsgrad=True - ) - - def adapt(self, model, criterion, w_container, seq_datasets): - w_container.requires_grad_(True) - containers = [w_container] - for idx, dataset in enumerate(seq_datasets): - x, y = dataset.x, dataset.y - y_hat = model.forward_with_container(x, containers[-1]) - loss = criterion(y_hat, y) - gradients = torch.autograd.grad(loss, containers[-1].tensors) - with torch.no_grad(): - flatten_w = containers[-1].flatten().view(-1, 1) - flatten_g = containers[-1].flatten(gradients).view(-1, 1) - input_statistics = torch.tensor([x.mean(), x.std()]).view(1, 2) - input_statistics = input_statistics.expand(flatten_w.numel(), -1) - delta_inputs = torch.cat((flatten_w, flatten_g, input_statistics), dim=-1) - delta = self.delta_net(delta_inputs).view(-1) - delta = torch.clamp(delta, -0.5, 0.5) - unflatten_delta = containers[-1].unflatten(delta) - future_container = containers[-1].no_grad_clone().additive(unflatten_delta) - # future_container = containers[-1].additive(unflatten_delta) - containers.append(future_container) - # containers = containers[1:] - meta_loss = [] - temp_containers = [] - for idx, dataset in enumerate(seq_datasets): - if idx == 0: - continue - current_container = containers[idx] - y_hat = model.forward_with_container(dataset.x, current_container) - loss = criterion(y_hat, dataset.y) - meta_loss.append(loss) - temp_containers.append((dataset.timestamp, current_container, -loss.item())) - meta_loss = sum(meta_loss) - w_container.requires_grad_(False) - # meta_loss.backward() - # self.meta_optimizer.step() - return meta_loss, temp_containers - - def step(self): - torch.nn.utils.clip_grad_norm_(self.delta_net.parameters(), 1.0) - self.meta_optimizer.step() - - def zero_grad(self): - self.meta_optimizer.zero_grad() - self.delta_net.zero_grad() - - -class TimeData: - def __init__(self, timestamp, xs, ys): - self._timestamp = timestamp - self._xs = xs - self._ys = ys - - @property - def x(self): - return self._xs - - @property - def y(self): - return self._ys - - @property - def timestamp(self): - return self._timestamp - - -class Population: - """A population used to maintain models at different timestamps.""" - - def __init__(self): - self._time2model = dict() - self._time2score = dict() # higher is better - - def append(self, timestamp, model, score): - if timestamp in self._time2model: - if self._time2score[timestamp] > score: - return - self._time2model[timestamp] = model.no_grad_clone() - self._time2score[timestamp] = score - - def query(self, timestamp): - closet_timestamp = None - for xtime, model in self._time2model.items(): - if closet_timestamp is None or ( - xtime < timestamp and timestamp - closet_timestamp >= timestamp - xtime - ): - closet_timestamp = xtime - return self._time2model[closet_timestamp], closet_timestamp - - def debug_info(self, timestamps): - xstrs = [] - for timestamp in timestamps: - if timestamp in self._time2score: - xstrs.append( - "{:04d}: {:.4f}".format(timestamp, self._time2score[timestamp]) - ) - return ", ".join(xstrs) - - -def main(args): - prepare_seed(args.rand_seed) - logger = prepare_logger(args) - - cache_path = (logger.path(None) / ".." / "env-info.pth").resolve() - if cache_path.exists(): - env_info = torch.load(cache_path) - else: - env_info = dict() - dynamic_env = get_synthetic_env() - env_info["total"] = len(dynamic_env) - for idx, (timestamp, (_allx, _ally)) in enumerate(tqdm(dynamic_env)): - env_info["{:}-timestamp".format(idx)] = timestamp - env_info["{:}-x".format(idx)] = _allx - env_info["{:}-y".format(idx)] = _ally - env_info["dynamic_env"] = dynamic_env - torch.save(env_info, cache_path) - - total_time = env_info["total"] - for i in range(total_time): - for xkey in ("timestamp", "x", "y"): - nkey = "{:}-{:}".format(i, xkey) - assert nkey in env_info, "{:} no in {:}".format(nkey, list(env_info.keys())) - train_time_bar = total_time // 2 - base_model = get_model( - dict(model_type="simple_mlp"), - act_cls="leaky_relu", - norm_cls="identity", - input_dim=1, - output_dim=1, - ) - - w_container = base_model.get_w_container() - criterion = torch.nn.MSELoss() - print("There are {:} weights.".format(w_container.numel())) - - adaptor = LFNAmlp(4, (50, 20), "leaky_relu") - - pool = Population() - pool.append(0, w_container, -100) - - # LFNA meta-training - per_epoch_time, start_time = AverageMeter(), time.time() - for iepoch in range(args.epochs): - - need_time = "Time Left: {:}".format( - convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) - ) - logger.log( - "[{:}] [{:04d}/{:04d}] ".format(time_string(), iepoch, args.epochs) - + need_time - ) - - adaptor.zero_grad() - - debug_timestamp = set() - all_meta_losses = [] - for ibatch in range(args.meta_batch): - sampled_timestamp = random.randint(0, train_time_bar) - query_w_container, query_timestamp = pool.query(sampled_timestamp) - # def adapt(self, model, w_container, xs, ys): - seq_datasets = [] - # xs, ys = [], [] - for it in range(sampled_timestamp, sampled_timestamp + args.max_seq): - xs = env_info["{:}-x".format(it)] - ys = env_info["{:}-y".format(it)] - seq_datasets.append(TimeData(it, xs, ys)) - temp_meta_loss, temp_containers = adaptor.adapt( - base_model, criterion, query_w_container, seq_datasets - ) - all_meta_losses.append(temp_meta_loss) - for temp_time, temp_container, temp_score in temp_containers: - pool.append(temp_time, temp_container, temp_score) - debug_timestamp.add(temp_time) - meta_loss = torch.stack(all_meta_losses).mean() - meta_loss.backward() - adaptor.step() - - debug_str = pool.debug_info(debug_timestamp) - logger.log("meta-loss: {:.4f}".format(meta_loss.item())) - - per_epoch_time.update(time.time() - start_time) - start_time = time.time() - - logger.log("-" * 200 + "\n") - logger.close() - - -if __name__ == "__main__": - parser = argparse.ArgumentParser("Use the data in the past.") - parser.add_argument( - "--save_dir", - type=str, - default="./outputs/lfna-synthetic/lfna-v1", - help="The checkpoint directory.", - ) - parser.add_argument( - "--init_lr", - type=float, - default=0.1, - help="The initial learning rate for the optimizer (default is Adam)", - ) - parser.add_argument( - "--meta_batch", - type=int, - default=5, - help="The batch size for the meta-model", - ) - parser.add_argument( - "--epochs", - type=int, - default=1000, - help="The total number of epochs.", - ) - parser.add_argument( - "--max_seq", - type=int, - default=5, - help="The maximum length of the sequence.", - ) - parser.add_argument( - "--workers", - type=int, - default=4, - help="The number of data loading workers (default: 4)", - ) - # Random Seed - parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") - args = parser.parse_args() - if args.rand_seed is None or args.rand_seed < 0: - args.rand_seed = random.randint(1, 100000) - assert args.save_dir is not None, "The save dir argument can not be None" - main(args) diff --git a/exps/LFNA/backup/lfna_models.py b/exps/LFNA/backup/lfna_models.py deleted file mode 100644 index 2e163c7..0000000 --- a/exps/LFNA/backup/lfna_models.py +++ /dev/null @@ -1,50 +0,0 @@ -##################################################### -# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # -##################################################### -import copy -import torch - -from xlayers import super_core -from xlayers import trunc_normal_ -from models.xcore import get_model - - -class HyperNet(super_core.SuperModule): - def __init__(self, shape_container, input_embeding, return_container=True): - super(HyperNet, self).__init__() - self._shape_container = shape_container - self._num_layers = len(shape_container) - self._numel_per_layer = [] - for ilayer in range(self._num_layers): - self._numel_per_layer.append(shape_container[ilayer].numel()) - - self.register_parameter( - "_super_layer_embed", - torch.nn.Parameter(torch.Tensor(self._num_layers, input_embeding)), - ) - trunc_normal_(self._super_layer_embed, std=0.02) - - model_kwargs = dict( - input_dim=input_embeding, - output_dim=max(self._numel_per_layer), - hidden_dim=input_embeding * 4, - act_cls="sigmoid", - norm_cls="identity", - ) - self._generator = get_model(dict(model_type="simple_mlp"), **model_kwargs) - self._return_container = return_container - print("generator: {:}".format(self._generator)) - - def forward_raw(self, input): - weights = self._generator(self._super_layer_embed) - if self._return_container: - weights = torch.split(weights, 1) - return self._shape_container.translate(weights) - else: - return weights - - def forward_candidate(self, input): - raise NotImplementedError - - def extra_repr(self) -> str: - return "(_super_layer_embed): {:}".format(list(self._super_layer_embed.shape)) diff --git a/exps/LFNA/basic-maml.py b/exps/LFNA/basic-maml.py index b86adff..b3fcce3 100644 --- a/exps/LFNA/basic-maml.py +++ b/exps/LFNA/basic-maml.py @@ -1,7 +1,7 @@ ##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # ##################################################### -# python exps/LFNA/basic-maml.py --env_version v1 --hidden_dim 16 --inner_step 5 +# python exps/LFNA/basic-maml.py --env_version v1 --inner_step 5 # python exps/LFNA/basic-maml.py --env_version v2 ##################################################### import sys, time, copy, torch, random, argparse @@ -20,7 +20,7 @@ from utils import split_str2indexes from procedures.advanced_main import basic_train_fn, basic_eval_fn from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric -from datasets.synthetic_core import get_synthetic_env +from datasets.synthetic_core import get_synthetic_env, EnvSampler from models.xcore import get_model from xlayers import super_core @@ -42,11 +42,10 @@ class MAML: self.meta_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( self.meta_optimizer, milestones=[ - int(epochs * 0.25), - int(epochs * 0.5), - int(epochs * 0.75), + int(epochs * 0.8), + int(epochs * 0.9), ], - gamma=0.3, + gamma=0.1, ) self.inner_lr = inner_lr self.inner_step = inner_step @@ -85,33 +84,27 @@ class MAML: self.meta_optimizer.load_state_dict(state_dict["meta_optimizer"]) self.meta_lr_scheduler.load_state_dict(state_dict["meta_lr_scheduler"]) - def save_best(self, iepoch, score): - if self._best_info["score"] is None or self._best_info["score"] < score: - state_dict = dict( - criterion=self.criterion.state_dict(), - network=self.network.state_dict(), - meta_optimizer=self.meta_optimizer.state_dict(), - meta_lr_scheduler=self.meta_lr_scheduler.state_dict(), - ) - self._best_info["state_dict"] = state_dict - self._best_info["score"] = score - self._best_info["iepoch"] = iepoch - is_best = True - else: - is_best = False - return self._best_info, is_best + def state_dict(self): + state_dict = dict() + state_dict["criterion"] = self.criterion.state_dict() + state_dict["network"] = self.network.state_dict() + state_dict["meta_optimizer"] = self.meta_optimizer.state_dict() + state_dict["meta_lr_scheduler"] = self.meta_lr_scheduler.state_dict() + return state_dict + + def save_best(self, score): + success, best_score = self.network.save_best(score) + return success, best_score + + def load_best(self): + self.network.load_best() def main(args): logger, env_info, model_kwargs = lfna_setup(args) - model = get_model(dict(model_type="simple_mlp"), **model_kwargs) + model = get_model(**model_kwargs) - total_time = env_info["total"] - for i in range(total_time): - for xkey in ("timestamp", "x", "y"): - nkey = "{:}-{:}".format(i, xkey) - assert nkey in env_info, "{:} no in {:}".format(nkey, list(env_info.keys())) - train_time_bar = total_time // 2 + dynamic_env = get_synthetic_env(mode="train", version=args.env_version) criterion = torch.nn.MSELoss() @@ -120,83 +113,65 @@ def main(args): ) # meta-training + last_success_epoch = 0 per_epoch_time, start_time = AverageMeter(), time.time() - # for iepoch in range(args.epochs): - iepoch = 0 - while iepoch < args.epochs: + for iepoch in range(args.epochs): need_time = "Time Left: {:}".format( convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) ) - logger.log( + head_str = ( "[{:}] [{:04d}/{:04d}] ".format(time_string(), iepoch, args.epochs) + need_time ) maml.zero_grad() - batch_indexes, meta_losses = [], [] + meta_losses = [] for ibatch in range(args.meta_batch): - sampled_timestamp = random.randint(0, train_time_bar) - batch_indexes.append("{:5d}".format(sampled_timestamp)) - past_dataset = TimeData( - sampled_timestamp, - env_info["{:}-x".format(sampled_timestamp)], - env_info["{:}-y".format(sampled_timestamp)], + future_timestamp = dynamic_env.random_timestamp() + _, (future_x, future_y) = dynamic_env(future_timestamp) + past_timestamp = ( + future_timestamp - args.prev_time * dynamic_env.timestamp_interval ) - future_dataset = TimeData( - sampled_timestamp + 1, - env_info["{:}-x".format(sampled_timestamp + 1)], - env_info["{:}-y".format(sampled_timestamp + 1)], - ) - future_container = maml.adapt(past_dataset) - future_y_hat = maml.predict(future_dataset.x, future_container) - future_loss = maml.criterion(future_y_hat, future_dataset.y) + _, (past_x, past_y) = dynamic_env(past_timestamp) + + future_container = maml.adapt(TimeData(past_timestamp, past_x, past_y)) + future_y_hat = maml.predict(future_x, future_container) + future_loss = maml.criterion(future_y_hat, future_y) meta_losses.append(future_loss) meta_loss = torch.stack(meta_losses).mean() meta_loss.backward() maml.step() - logger.log( - "meta-loss: {:.4f} batch: {:}".format( - meta_loss.item(), ",".join(batch_indexes) - ) - ) - best_info, is_best = maml.save_best(iepoch, -meta_loss.item()) - if is_best: - save_checkpoint(best_info, logger.path("best"), logger) - logger.log("Save the best into {:}".format(logger.path("best"))) - if iepoch >= 10 and ( - torch.isnan(meta_loss).item() or meta_loss.item() >= args.fail_thresh - ): - xdata = torch.load(logger.path("best")) - maml.load_state_dict(xdata["state_dict"]) - iepoch = xdata["iepoch"] - logger.log( - "The training failed, re-use the previous best epoch [{:}]".format( - iepoch - ) - ) - else: - iepoch = iepoch + 1 + logger.log(head_str + " meta-loss: {:.4f}".format(meta_loss.item())) + success, best_score = maml.save_best(-meta_loss.item()) + if success: + logger.log("Achieve the best with best_score = {:.3f}".format(best_score)) + save_checkpoint(maml.state_dict(), logger.path("model"), logger) + last_success_epoch = iepoch + if iepoch - last_success_epoch >= args.early_stop_thresh: + logger.log("Early stop at {:}".format(iepoch)) + break + per_epoch_time.update(time.time() - start_time) start_time = time.time() + # meta-test + maml.load_best() + eval_env = env_info["dynamic_env"] + assert eval_env.timestamp_interval == dynamic_env.timestamp_interval w_container_per_epoch = dict() - for idx in range(1, env_info["total"]): - past_dataset = TimeData( - idx - 1, - env_info["{:}-x".format(idx - 1)], - env_info["{:}-y".format(idx - 1)], + for idx in range(args.prev_time, len(eval_env)): + future_timestamp, (future_x, future_y) = eval_env[idx] + past_timestamp = ( + future_timestamp.item() - args.prev_time * eval_env.timestamp_interval ) - current_container = maml.adapt(past_dataset) - w_container_per_epoch[idx] = current_container.no_grad_clone() + _, (past_x, past_y) = eval_env(past_timestamp) + future_container = maml.adapt(TimeData(past_timestamp, past_x, past_y)) + w_container_per_epoch[idx] = future_container.no_grad_clone() with torch.no_grad(): - current_x = env_info["{:}-x".format(idx)] - current_y = env_info["{:}-y".format(idx)] - current_y_hat = maml.predict(current_x, w_container_per_epoch[idx]) - current_loss = maml.criterion(current_y_hat, current_y) - logger.log( - "meta-test: [{:03d}] -> loss={:.4f}".format(idx, current_loss.item()) - ) + future_y_hat = maml.predict(future_x, w_container_per_epoch[idx]) + future_loss = maml.criterion(future_y_hat, future_y) + logger.log("meta-test: [{:03d}] -> loss={:.4f}".format(idx, future_loss.item())) save_checkpoint( {"w_container_per_epoch": w_container_per_epoch}, logger.path(None) / "final-ckp.pth", @@ -224,13 +199,13 @@ if __name__ == "__main__": parser.add_argument( "--hidden_dim", type=int, - required=True, + default=16, help="The hidden dimension.", ) parser.add_argument( "--meta_lr", type=float, - default=0.05, + default=0.01, help="The learning rate for the MAML optimizer (default is Adam)", ) parser.add_argument( @@ -242,24 +217,36 @@ if __name__ == "__main__": parser.add_argument( "--inner_lr", type=float, - default=0.01, + default=0.005, help="The learning rate for the inner optimization", ) parser.add_argument( "--inner_step", type=int, default=1, help="The inner loop steps for MAML." ) + parser.add_argument( + "--prev_time", + type=int, + default=5, + help="The gap between prev_time and current_timestamp", + ) parser.add_argument( "--meta_batch", type=int, - default=10, + default=64, help="The batch size for the meta-model", ) parser.add_argument( "--epochs", type=int, - default=1000, + default=2000, help="The total number of epochs.", ) + parser.add_argument( + "--early_stop_thresh", + type=int, + default=50, + help="The maximum epochs for early stop.", + ) parser.add_argument( "--workers", type=int, @@ -272,7 +259,13 @@ if __name__ == "__main__": if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) assert args.save_dir is not None, "The save dir argument can not be None" - args.save_dir = "{:}-s{:}-{:}-d{:}".format( - args.save_dir, args.inner_step, args.env_version, args.hidden_dim + args.save_dir = "{:}-s{:}-mlr{:}-d{:}-prev{:}-e{:}-env{:}".format( + args.save_dir, + args.inner_step, + args.meta_lr, + args.hidden_dim, + args.prev_time, + args.epochs, + args.env_version, ) main(args) diff --git a/exps/LFNA/basic-prev.py b/exps/LFNA/basic-prev.py index a1dc1c5..96756c0 100644 --- a/exps/LFNA/basic-prev.py +++ b/exps/LFNA/basic-prev.py @@ -1,7 +1,7 @@ ##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # ##################################################### -# python exps/LFNA/basic-prev.py --env_version v1 --hidden_dim 16 --epochs 500 --init_lr 0.1 +# python exps/LFNA/basic-prev.py --env_version v1 --prev_time 5 --hidden_dim 16 --epochs 500 --init_lr 0.1 # python exps/LFNA/basic-prev.py --env_version v2 --hidden_dim 16 --epochs 1000 --init_lr 0.05 ##################################################### import sys, time, copy, torch, random, argparse @@ -41,7 +41,7 @@ def main(args): w_container_per_epoch = dict() per_timestamp_time, start_time = AverageMeter(), time.time() - for idx in range(1, env_info["total"]): + for idx in range(args.prev_time, env_info["total"]): need_time = "Time Left: {:}".format( convert_secs2time(per_timestamp_time.avg * (env_info["total"] - idx), True) @@ -53,8 +53,8 @@ def main(args): + need_time ) # train the same data - historical_x = env_info["{:}-x".format(idx - 1)] - historical_y = env_info["{:}-y".format(idx - 1)] + historical_x = env_info["{:}-x".format(idx - args.prev_time)] + historical_y = env_info["{:}-y".format(idx - args.prev_time)] # build model model = get_model(**model_kwargs) print(model) @@ -160,6 +160,12 @@ if __name__ == "__main__": default=0.1, help="The initial learning rate for the optimizer (default is Adam)", ) + parser.add_argument( + "--prev_time", + type=int, + default=5, + help="The gap between prev_time and current_timestamp", + ) parser.add_argument( "--batch_size", type=int, @@ -184,7 +190,12 @@ if __name__ == "__main__": if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) assert args.save_dir is not None, "The save dir argument can not be None" - args.save_dir = "{:}-{:}-d{:}".format( - args.save_dir, args.env_version, args.hidden_dim + args.save_dir = "{:}-d{:}_e{:}_lr{:}-prev{:}-env{:}".format( + args.save_dir, + args.hidden_dim, + args.epochs, + args.init_lr, + args.prev_time, + args.env_version, ) main(args) diff --git a/exps/LFNA/basic-same.py b/exps/LFNA/basic-same.py index d7dc9b2..3f53528 100644 --- a/exps/LFNA/basic-same.py +++ b/exps/LFNA/basic-same.py @@ -41,7 +41,7 @@ def main(args): w_container_per_epoch = dict() per_timestamp_time, start_time = AverageMeter(), time.time() - for idx in range(env_info["total"]): + for idx in range(1, env_info["total"]): need_time = "Time Left: {:}".format( convert_secs2time(per_timestamp_time.avg * (env_info["total"] - idx), True) @@ -184,7 +184,7 @@ if __name__ == "__main__": if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) assert args.save_dir is not None, "The save dir argument can not be None" - args.save_dir = "{:}-{:}-d{:}".format( - args.save_dir, args.env_version, args.hidden_dim + args.save_dir = "{:}-d{:}_e{:}_lr{:}-env{:}".format( + args.save_dir, args.hidden_dim, args.epochs, args.init_lr, args.env_version ) main(args) diff --git a/exps/LFNA/lfna.py b/exps/LFNA/lfna.py index 422af1a..6d08ed3 100644 --- a/exps/LFNA/lfna.py +++ b/exps/LFNA/lfna.py @@ -157,11 +157,11 @@ def main(args): per_epoch_time.update(time.time() - start_time) start_time = time.time() - # meta-training + # meta-test meta_model.load_best() eval_env = env_info["dynamic_env"] w_container_per_epoch = dict() - for idx in range(args.seq_length, env_info["total"]): + for idx in range(args.seq_length, len(eval_env)): # build-timestamp future_time = env_info["{:}-timestamp".format(idx)] time_seqs = [] @@ -176,8 +176,8 @@ def main(args): future_container = seq_containers[-1] w_container_per_epoch[idx] = future_container.no_grad_clone() # evaluation - future_x = env_info["{:}-x".format(idx)] - future_y = env_info["{:}-y".format(idx)] + future_x = env_info["{:}-x".format(idx)].to(args.device) + future_y = env_info["{:}-y".format(idx)].to(args.device) future_y_hat = base_model.forward_with_container( future_x, w_container_per_epoch[idx] ) @@ -299,12 +299,12 @@ if __name__ == "__main__": if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) assert args.save_dir is not None, "The save dir argument can not be None" - args.save_dir = "{:}-{:}-d{:}_{:}_{:}-e{:}".format( + args.save_dir = "{:}-d{:}_{:}_{:}-e{:}-env{:}".format( args.save_dir, - args.env_version, args.hidden_dim, args.layer_dim, args.time_dim, args.epochs, + args.env_version, ) main(args) diff --git a/exps/LFNA/vis-synthetic.py b/exps/LFNA/vis-synthetic.py index 96c1ba2..027776e 100644 --- a/exps/LFNA/vis-synthetic.py +++ b/exps/LFNA/vis-synthetic.py @@ -237,18 +237,20 @@ def compare_algs(save_dir, version, alg_dir="./outputs/lfna-synthetic"): env_info = torch.load(cache_path) alg_name2dir = OrderedDict() - alg_name2dir["Optimal"] = "use-same-timestamp" # alg_name2dir["Supervised Learning (History Data)"] = "use-all-past-data" # alg_name2dir["MAML"] = "use-maml-s1" # alg_name2dir["LFNA (fix init)"] = "lfna-fix-init" - alg_name2dir["LFNA (debug)"] = "lfna-tall-hpnet" - alg_name2all_containers = OrderedDict() if version == "v1": - poststr = "v1-d16" + # alg_name2dir["Optimal"] = "use-same-timestamp" + alg_name2dir["LFNA"] = "lfna-battle-v1-d16_16_16-e200" + alg_name2dir[ + "Previous Timestamp" + ] = "use-prev-timestamp-d16_e500_lr0.1-prev5-envv1" else: raise ValueError("Invalid version: {:}".format(version)) + alg_name2all_containers = OrderedDict() for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()): - ckp_path = Path(alg_dir) / "{:}-{:}".format(xdir, poststr) / "final-ckp.pth" + ckp_path = Path(alg_dir) / str(xdir) / "final-ckp.pth" xdata = torch.load(ckp_path, map_location="cpu") alg_name2all_containers[alg] = xdata["w_container_per_epoch"] # load the basic model @@ -267,11 +269,11 @@ def compare_algs(save_dir, version, alg_dir="./outputs/lfna-synthetic"): dynamic_env = env_info["dynamic_env"] min_t, max_t = dynamic_env.min_timestamp, dynamic_env.max_timestamp - linewidths = 10 + linewidths, skip = 10, 5 for idx, (timestamp, (ori_allx, ori_ally)) in enumerate( tqdm(dynamic_env, ncols=50) ): - if idx == 0: + if idx <= skip: continue fig = plt.figure(figsize=figsize) cur_ax = fig.add_subplot(2, 1, 1) @@ -335,9 +337,9 @@ def compare_algs(save_dir, version, alg_dir="./outputs/lfna-synthetic"): cur_ax.set_ylim(0, 10) cur_ax.legend(loc=1, fontsize=LegendFontsize) - pdf_save_path = save_dir / "pdf" / "v{:}-{:05d}.pdf".format(version, idx) + pdf_save_path = save_dir / "pdf" / "v{:}-{:05d}.pdf".format(version, idx - skip) fig.savefig(str(pdf_save_path), dpi=dpi, bbox_inches="tight", format="pdf") - png_save_path = save_dir / "png" / "v{:}-{:05d}.png".format(version, idx) + png_save_path = save_dir / "png" / "v{:}-{:05d}.png".format(version, idx - skip) fig.savefig(str(png_save_path), dpi=dpi, bbox_inches="tight", format="png") plt.close("all") save_dir = save_dir.resolve() diff --git a/lib/datasets/synthetic_env.py b/lib/datasets/synthetic_env.py index d4fcf24..e8e519c 100644 --- a/lib/datasets/synthetic_env.py +++ b/lib/datasets/synthetic_env.py @@ -80,6 +80,12 @@ class SyntheticDEnv(data.Dataset): def timestamp_interval(self): return self._timestamp_generator.interval + def random_timestamp(self): + return ( + random.random() * (self.max_timestamp - self.min_timestamp) + + self.min_timestamp + ) + def reset_max_seq_length(self, seq_length): self._seq_length = seq_length diff --git a/lib/datasets/synthetic_utils.py b/lib/datasets/synthetic_utils.py index a738fca..14d32a0 100644 --- a/lib/datasets/synthetic_utils.py +++ b/lib/datasets/synthetic_utils.py @@ -56,11 +56,11 @@ class TimeStamp(UnifiedSplit, data.Dataset): @property def min_timestamp(self): - return self._min_timestamp + return self._min_timestamp + self._interval * min(self._indexes) @property def max_timestamp(self): - return self._max_timestamp + return self._min_timestamp + self._interval * max(self._indexes) @property def interval(self):