Update LFNA version 1.0
This commit is contained in:
		| @@ -2,6 +2,7 @@ | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 # | ||||
| ##################################################### | ||||
| from .synthetic_utils import TimeStamp | ||||
| from .synthetic_env import EnvSampler | ||||
| from .synthetic_env import SyntheticDEnv | ||||
| from .math_core import LinearFunc | ||||
| from .math_core import DynamicLinearFunc | ||||
|   | ||||
| @@ -2,7 +2,7 @@ | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||
| ##################################################### | ||||
| import math | ||||
| import abc | ||||
| import random | ||||
| import numpy as np | ||||
| from typing import List, Optional, Dict | ||||
| import torch | ||||
| @@ -11,6 +11,28 @@ import torch.utils.data as data | ||||
| from .synthetic_utils import TimeStamp | ||||
|  | ||||
|  | ||||
| def is_list_tuple(x): | ||||
|     return isinstance(x, (tuple, list)) | ||||
|  | ||||
|  | ||||
| def zip_sequence(sequence): | ||||
|     def _combine(*alist): | ||||
|         if is_list_tuple(alist[0]): | ||||
|             return [_combine(*xlist) for xlist in zip(*alist)] | ||||
|         else: | ||||
|             return torch.cat(alist, dim=0) | ||||
|  | ||||
|     def unsqueeze(a): | ||||
|         if is_list_tuple(a): | ||||
|             return [unsqueeze(x) for x in a] | ||||
|         else: | ||||
|             return a.unsqueeze(dim=0) | ||||
|  | ||||
|     with torch.no_grad(): | ||||
|         sequence = [unsqueeze(a) for a in sequence] | ||||
|         return _combine(*sequence) | ||||
|  | ||||
|  | ||||
| class SyntheticDEnv(data.Dataset): | ||||
|     """The synethtic dynamic environment.""" | ||||
|  | ||||
| @@ -33,7 +55,7 @@ class SyntheticDEnv(data.Dataset): | ||||
|         self._num_per_task = num_per_task | ||||
|         if timestamp_config is None: | ||||
|             timestamp_config = dict(mode=mode) | ||||
|         else: | ||||
|         elif "mode" not in timestamp_config: | ||||
|             timestamp_config["mode"] = mode | ||||
|  | ||||
|         self._timestamp_generator = TimeStamp(**timestamp_config) | ||||
| @@ -42,6 +64,7 @@ class SyntheticDEnv(data.Dataset): | ||||
|         self._cov_functors = cov_functors | ||||
|  | ||||
|         self._oracle_map = None | ||||
|         self._seq_length = None | ||||
|  | ||||
|     @property | ||||
|     def min_timestamp(self): | ||||
| @@ -55,9 +78,18 @@ class SyntheticDEnv(data.Dataset): | ||||
|     def timestamp_interval(self): | ||||
|         return self._timestamp_generator.interval | ||||
|  | ||||
|     def reset_max_seq_length(self, seq_length): | ||||
|         self._seq_length = seq_length | ||||
|  | ||||
|     def get_timestamp(self, index): | ||||
|         index, timestamp = self._timestamp_generator[index] | ||||
|         return timestamp | ||||
|         if index is None: | ||||
|             timestamps = [] | ||||
|             for index in range(len(self._timestamp_generator)): | ||||
|                 timestamps.append(self._timestamp_generator[index][1]) | ||||
|             return tuple(timestamps) | ||||
|         else: | ||||
|             index, timestamp = self._timestamp_generator[index] | ||||
|             return timestamp | ||||
|  | ||||
|     def set_oracle_map(self, functor): | ||||
|         self._oracle_map = functor | ||||
| @@ -75,7 +107,14 @@ class SyntheticDEnv(data.Dataset): | ||||
|     def __getitem__(self, index): | ||||
|         assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self)) | ||||
|         index, timestamp = self._timestamp_generator[index] | ||||
|         return self.__call__(timestamp) | ||||
|         if self._seq_length is None: | ||||
|             return self.__call__(timestamp) | ||||
|         else: | ||||
|             timestamps = [ | ||||
|                 timestamp + i * self.timestamp_interval for i in range(self._seq_length) | ||||
|             ] | ||||
|             xdata = [self.__call__(timestamp) for timestamp in timestamps] | ||||
|             return zip_sequence(xdata) | ||||
|  | ||||
|     def __call__(self, timestamp): | ||||
|         mean_list = [functor(timestamp) for functor in self._mean_functors] | ||||
| @@ -88,10 +127,13 @@ class SyntheticDEnv(data.Dataset): | ||||
|             mean_list, cov_matrix, size=self._num_per_task | ||||
|         ) | ||||
|         if self._oracle_map is None: | ||||
|             return timestamp, torch.Tensor(dataset) | ||||
|             return torch.Tensor([timestamp]), torch.Tensor(dataset) | ||||
|         else: | ||||
|             targets = self._oracle_map.noise_call(dataset, timestamp) | ||||
|             return timestamp, (torch.Tensor(dataset), torch.Tensor(targets)) | ||||
|             return torch.Tensor([timestamp]), ( | ||||
|                 torch.Tensor(dataset), | ||||
|                 torch.Tensor(targets), | ||||
|             ) | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self._timestamp_generator) | ||||
| @@ -104,3 +146,20 @@ class SyntheticDEnv(data.Dataset): | ||||
|             ndim=self._ndim, | ||||
|             num_per_task=self._num_per_task, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class EnvSampler: | ||||
|     def __init__(self, env, batch, enlarge): | ||||
|         indexes = list(range(len(env))) | ||||
|         self._indexes = indexes * enlarge | ||||
|         self._batch = batch | ||||
|         self._iterations = len(self._indexes) // self._batch | ||||
|  | ||||
|     def __iter__(self): | ||||
|         random.shuffle(self._indexes) | ||||
|         for it in range(self._iterations): | ||||
|             indexes = self._indexes[it * self._batch : (it + 1) * self._batch] | ||||
|             yield indexes | ||||
|  | ||||
|     def __len__(self): | ||||
|         return self._iterations | ||||
|   | ||||
| @@ -30,6 +30,7 @@ class UnifiedSplit: | ||||
|             self._indexes = all_indexes[num_of_train + num_of_valid :] | ||||
|         else: | ||||
|             raise ValueError("Unkonwn mode of {:}".format(mode)) | ||||
|         self._all_indexes = all_indexes | ||||
|         self._mode = mode | ||||
|  | ||||
|     @property | ||||
|   | ||||
		Reference in New Issue
	
	Block a user