##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # ##################################################### import copy import torch from tqdm import tqdm from procedures import prepare_seed, prepare_logger from datasets.synthetic_core import get_synthetic_env def lfna_setup(args): prepare_seed(args.rand_seed) logger = prepare_logger(args) cache_path = ( logger.path(None) / ".." / "env-{:}-info.pth".format(args.env_version) ).resolve() if cache_path.exists(): env_info = torch.load(cache_path) else: env_info = dict() dynamic_env = get_synthetic_env(version=args.env_version) env_info["total"] = len(dynamic_env) for idx, (timestamp, (_allx, _ally)) in enumerate(tqdm(dynamic_env)): env_info["{:}-timestamp".format(idx)] = timestamp env_info["{:}-x".format(idx)] = _allx env_info["{:}-y".format(idx)] = _ally env_info["dynamic_env"] = dynamic_env torch.save(env_info, cache_path) model_kwargs = dict( input_dim=1, output_dim=1, hidden_dim=args.hidden_dim, act_cls="leaky_relu", norm_cls="identity", ) return logger, env_info, model_kwargs def train_model(model, dataset, lr, epochs): criterion = torch.nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=lr, amsgrad=True) best_loss, best_param = None, None for _iepoch in range(epochs): preds = model(dataset.x) optimizer.zero_grad() loss = criterion(preds, dataset.y) loss.backward() optimizer.step() # save best if best_loss is None or best_loss > loss.item(): best_loss = loss.item() best_param = copy.deepcopy(model.state_dict()) model.load_state_dict(best_param) return best_loss class TimeData: def __init__(self, timestamp, xs, ys): self._timestamp = timestamp self._xs = xs self._ys = ys @property def x(self): return self._xs @property def y(self): return self._ys @property def timestamp(self): return self._timestamp def __repr__(self): return "{name}(timestamp={timestamp}, with {num} samples)".format( name=self.__class__.__name__, timestamp=self._timestamp, num=len(self._xs) )