Re-organize GeMOSA
This commit is contained in:
		| @@ -1,10 +1,9 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| # Learning to Generate Model One Step Ahead         # | # Learning to Generate Model One Step Ahead         # | ||||||
| ##################################################### | ##################################################### | ||||||
| # python exps/GeMOSA/lfna.py --env_version v1 --workers 0 | # python exps/GeMOSA/main.py --env_version v1 --workers 0 | ||||||
| # python exps/GeMOSA/lfna.py --env_version v1 --device cuda --lr 0.001 | # python exps/GeMOSA/main.py --env_version v1 --device cuda --lr 0.002 --seq_length 8 --meta_batch 256 | ||||||
| # python exps/GeMOSA/main.py --env_version v1 --device cuda --lr 0.002 --seq_length 16 --meta_batch 128 | # python exps/GeMOSA/main.py --env_version v1 --device cuda --lr 0.002 --seq_length 24 --time_dim 32 --meta_batch 128 | ||||||
| # python exps/GeMOSA/lfna.py --env_version v1 --device cuda --lr 0.002 --seq_length 24 --time_dim 32 --meta_batch 128 |  | ||||||
| ##################################################### | ##################################################### | ||||||
| import sys, time, copy, torch, random, argparse | import sys, time, copy, torch, random, argparse | ||||||
| from tqdm import tqdm | from tqdm import tqdm | ||||||
| @@ -38,7 +37,9 @@ from lfna_utils import lfna_setup, train_model, TimeData | |||||||
| from meta_model import MetaModelV1 | from meta_model import MetaModelV1 | ||||||
|  |  | ||||||
|  |  | ||||||
| def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=False): | def online_evaluate( | ||||||
|  |     env, meta_model, base_model, criterion, args, logger, save=False, easy_adapt=False | ||||||
|  | ): | ||||||
|     logger.log("Online evaluate: {:}".format(env)) |     logger.log("Online evaluate: {:}".format(env)) | ||||||
|     loss_meter = AverageMeter() |     loss_meter = AverageMeter() | ||||||
|     w_containers = dict() |     w_containers = dict() | ||||||
| @@ -46,25 +47,30 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=F | |||||||
|         with torch.no_grad(): |         with torch.no_grad(): | ||||||
|             meta_model.eval() |             meta_model.eval() | ||||||
|             base_model.eval() |             base_model.eval() | ||||||
|             [future_container], time_embeds = meta_model( |             future_time_embed = meta_model.gen_time_embed( | ||||||
|                 future_time.to(args.device).view(-1), None, False |                 future_time.to(args.device).view(-1) | ||||||
|             ) |             ) | ||||||
|  |             [future_container] = meta_model.gen_model(future_time_embed) | ||||||
|             if save: |             if save: | ||||||
|                 w_containers[idx] = future_container.no_grad_clone() |                 w_containers[idx] = future_container.no_grad_clone() | ||||||
|             future_x, future_y = future_x.to(args.device), future_y.to(args.device) |             future_x, future_y = future_x.to(args.device), future_y.to(args.device) | ||||||
|             future_y_hat = base_model.forward_with_container(future_x, future_container) |             future_y_hat = base_model.forward_with_container(future_x, future_container) | ||||||
|             future_loss = criterion(future_y_hat, future_y) |             future_loss = criterion(future_y_hat, future_y) | ||||||
|             loss_meter.update(future_loss.item()) |             loss_meter.update(future_loss.item()) | ||||||
|         refine, post_refine_loss = meta_model.adapt( |         if easy_adapt: | ||||||
|             base_model, |             meta_model.easy_adapt(future_time.item(), future_time_embed) | ||||||
|             criterion, |             refine, post_refine_loss = False, -1 | ||||||
|             future_time.item(), |         else: | ||||||
|             future_x, |             refine, post_refine_loss = meta_model.adapt( | ||||||
|             future_y, |                 base_model, | ||||||
|             args.refine_lr, |                 criterion, | ||||||
|             args.refine_epochs, |                 future_time.item(), | ||||||
|             {"param": time_embeds, "loss": future_loss.item()}, |                 future_x, | ||||||
|         ) |                 future_y, | ||||||
|  |                 args.refine_lr, | ||||||
|  |                 args.refine_epochs, | ||||||
|  |                 {"param": future_time_embed, "loss": future_loss.item()}, | ||||||
|  |             ) | ||||||
|         logger.log( |         logger.log( | ||||||
|             "[ONLINE] [{:03d}/{:03d}] loss={:.4f}".format( |             "[ONLINE] [{:03d}/{:03d}] loss={:.4f}".format( | ||||||
|                 idx, len(env), future_loss.item() |                 idx, len(env), future_loss.item() | ||||||
| @@ -106,7 +112,7 @@ def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger): | |||||||
|         ) |         ) | ||||||
|         optimizer.zero_grad() |         optimizer.zero_grad() | ||||||
|  |  | ||||||
|         generated_time_embeds = gen_time_embed(meta_model.meta_timestamps) |         generated_time_embeds = meta_model.gen_time_embed(meta_model.meta_timestamps) | ||||||
|  |  | ||||||
|         batch_indexes = random.choices(total_indexes, k=args.meta_batch) |         batch_indexes = random.choices(total_indexes, k=args.meta_batch) | ||||||
|  |  | ||||||
| @@ -117,11 +123,9 @@ def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger): | |||||||
|         ) |         ) | ||||||
|         # future loss |         # future loss | ||||||
|         total_future_losses, total_present_losses = [], [] |         total_future_losses, total_present_losses = [], [] | ||||||
|         future_containers, _ = meta_model( |         future_containers = meta_model.gen_model(generated_time_embeds[batch_indexes]) | ||||||
|             None, generated_time_embeds[batch_indexes], False |         present_containers = meta_model.gen_model( | ||||||
|         ) |             meta_model.super_meta_embed[batch_indexes] | ||||||
|         present_containers, _ = meta_model( |  | ||||||
|             None, meta_model.super_meta_embed[batch_indexes], False |  | ||||||
|         ) |         ) | ||||||
|         for ibatch, time_step in enumerate(raw_time_steps.cpu().tolist()): |         for ibatch, time_step in enumerate(raw_time_steps.cpu().tolist()): | ||||||
|             _, (inputs, targets) = xenv(time_step) |             _, (inputs, targets) = xenv(time_step) | ||||||
| @@ -216,13 +220,34 @@ def main(args): | |||||||
|     # try to evaluate once |     # try to evaluate once | ||||||
|     # online_evaluate(train_env, meta_model, base_model, criterion, args, logger) |     # online_evaluate(train_env, meta_model, base_model, criterion, args, logger) | ||||||
|     # online_evaluate(valid_env, meta_model, base_model, criterion, args, logger) |     # online_evaluate(valid_env, meta_model, base_model, criterion, args, logger) | ||||||
|  |     """ | ||||||
|     w_containers, loss_meter = online_evaluate( |     w_containers, loss_meter = online_evaluate( | ||||||
|         all_env, meta_model, base_model, criterion, args, logger, True |         all_env, meta_model, base_model, criterion, args, logger, True | ||||||
|     ) |     ) | ||||||
|     logger.log("In this enviornment, the total loss-meter is {:}".format(loss_meter)) |     logger.log("In this enviornment, the total loss-meter is {:}".format(loss_meter)) | ||||||
|  |     """ | ||||||
|  |     _, test_loss_meter_adapt_v1 = online_evaluate( | ||||||
|  |         valid_env, meta_model, base_model, criterion, args, logger, False, False | ||||||
|  |     ) | ||||||
|  |     _, test_loss_meter_adapt_v2 = online_evaluate( | ||||||
|  |         valid_env, meta_model, base_model, criterion, args, logger, False, True | ||||||
|  |     ) | ||||||
|  |     logger.log( | ||||||
|  |         "In the online test enviornment, the total loss for refine-adapt is {:}".format( | ||||||
|  |             test_loss_meter_adapt_v1 | ||||||
|  |         ) | ||||||
|  |     ) | ||||||
|  |     logger.log( | ||||||
|  |         "In the online test enviornment, the total loss for easy-adapt is {:}".format( | ||||||
|  |             test_loss_meter_adapt_v2 | ||||||
|  |         ) | ||||||
|  |     ) | ||||||
|  |  | ||||||
|     save_checkpoint( |     save_checkpoint( | ||||||
|         {"all_w_containers": w_containers}, |         { | ||||||
|  |             "test_loss_adapt_v1": test_loss_meter_adapt_v1.avg, | ||||||
|  |             "test_loss_adapt_v2": test_loss_meter_adapt_v2.avg, | ||||||
|  |         }, | ||||||
|         logger.path(None) / "final-ckp-{:}.pth".format(args.rand_seed), |         logger.path(None) / "final-ckp-{:}.pth".format(args.rand_seed), | ||||||
|         logger, |         logger, | ||||||
|     ) |     ) | ||||||
|   | |||||||
| @@ -198,7 +198,7 @@ class MetaModelV1(super_core.SuperModule): | |||||||
|             batch_containers.append( |             batch_containers.append( | ||||||
|                 self._shape_container.translate(torch.split(weights.squeeze(0), 1)) |                 self._shape_container.translate(torch.split(weights.squeeze(0), 1)) | ||||||
|             ) |             ) | ||||||
|         return batch_containers, time_embeds |         return batch_containers | ||||||
|  |  | ||||||
|     def forward_raw(self, timestamps, time_embeds, tembed_only=False): |     def forward_raw(self, timestamps, time_embeds, tembed_only=False): | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
| @@ -206,6 +206,12 @@ class MetaModelV1(super_core.SuperModule): | |||||||
|     def forward_candidate(self, input): |     def forward_candidate(self, input): | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     def easy_adapt(self, timestamp, time_embed): | ||||||
|  |         with torch.no_grad(): | ||||||
|  |             timestamp = torch.Tensor([timestamp]).to(self._meta_timestamps.device) | ||||||
|  |             self.replace_append_learnt(None, None) | ||||||
|  |             self.append_fixed(timestamp, time_embed) | ||||||
|  |  | ||||||
|     def adapt(self, base_model, criterion, timestamp, x, y, lr, epochs, init_info): |     def adapt(self, base_model, criterion, timestamp, x, y, lr, epochs, init_info): | ||||||
|         distance = self.get_closest_meta_distance(timestamp) |         distance = self.get_closest_meta_distance(timestamp) | ||||||
|         if distance + self._interval * 1e-2 <= self._interval: |         if distance + self._interval * 1e-2 <= self._interval: | ||||||
| @@ -230,23 +236,20 @@ class MetaModelV1(super_core.SuperModule): | |||||||
|                 best_new_param = new_param.detach().clone() |                 best_new_param = new_param.detach().clone() | ||||||
|             for iepoch in range(epochs): |             for iepoch in range(epochs): | ||||||
|                 optimizer.zero_grad() |                 optimizer.zero_grad() | ||||||
|                 _, time_embed = self(timestamp.view(1), None) |                 time_embed = self.gen_time_embed(timestamp.view(1)) | ||||||
|                 match_loss = criterion(new_param, time_embed) |                 match_loss = criterion(new_param, time_embed) | ||||||
|  |  | ||||||
|                 [container], time_embed = self(None, new_param.view(1, -1)) |                 [container] = self.gen_model(new_param.view(1, -1)) | ||||||
|                 y_hat = base_model.forward_with_container(x, container) |                 y_hat = base_model.forward_with_container(x, container) | ||||||
|                 meta_loss = criterion(y_hat, y) |                 meta_loss = criterion(y_hat, y) | ||||||
|                 loss = meta_loss + match_loss |                 loss = meta_loss + match_loss | ||||||
|                 loss.backward() |                 loss.backward() | ||||||
|                 optimizer.step() |                 optimizer.step() | ||||||
|                 # print("{:03d}/{:03d} : loss : {:.4f} = {:.4f} + {:.4f}".format(iepoch, epochs, loss.item(), meta_loss.item(), match_loss.item())) |  | ||||||
|                 if meta_loss.item() < best_loss: |                 if meta_loss.item() < best_loss: | ||||||
|                     with torch.no_grad(): |                     with torch.no_grad(): | ||||||
|                         best_loss = meta_loss.item() |                         best_loss = meta_loss.item() | ||||||
|                         best_new_param = new_param.detach().clone() |                         best_new_param = new_param.detach().clone() | ||||||
|         with torch.no_grad(): |         self.easy_adapt(timestamp, best_new_param) | ||||||
|             self.replace_append_learnt(None, None) |  | ||||||
|             self.append_fixed(timestamp, best_new_param) |  | ||||||
|         return True, best_loss |         return True, best_loss | ||||||
|  |  | ||||||
|     def extra_repr(self) -> str: |     def extra_repr(self) -> str: | ||||||
|   | |||||||
| @@ -191,6 +191,8 @@ def visualize_env(save_dir, version): | |||||||
|         allxs.append(allx) |         allxs.append(allx) | ||||||
|         allys.append(ally) |         allys.append(ally) | ||||||
|     allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1) |     allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1) | ||||||
|  |     print("env: {:}".format(dynamic_env)) | ||||||
|  |     print("oracle_map: {:}".format(dynamic_env.oracle_map)) | ||||||
|     print("x - min={:.3f}, max={:.3f}".format(allxs.min().item(), allxs.max().item())) |     print("x - min={:.3f}, max={:.3f}".format(allxs.min().item(), allxs.max().item())) | ||||||
|     print("y - min={:.3f}, max={:.3f}".format(allys.min().item(), allys.max().item())) |     print("y - min={:.3f}, max={:.3f}".format(allys.min().item(), allys.max().item())) | ||||||
|     for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)): |     for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)): | ||||||
|   | |||||||
| @@ -1,92 +0,0 @@ | |||||||
| ##################################################### |  | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # |  | ||||||
| ##################################################### |  | ||||||
| import math |  | ||||||
| import abc |  | ||||||
| import copy |  | ||||||
| import numpy as np |  | ||||||
| from typing import Optional |  | ||||||
| import torch |  | ||||||
| import torch.utils.data as data |  | ||||||
|  |  | ||||||
| from .math_base_funcs import FitFunc |  | ||||||
| from .math_base_funcs import QuadraticFunc |  | ||||||
| from .math_base_funcs import QuarticFunc |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class ConstantFunc(FitFunc): |  | ||||||
|     """The constant function: f(x) = c.""" |  | ||||||
|  |  | ||||||
|     def __init__(self, constant=None, xstr="x"): |  | ||||||
|         param = dict() |  | ||||||
|         param[0] = constant |  | ||||||
|         super(ConstantFunc, self).__init__(0, None, param, xstr) |  | ||||||
|  |  | ||||||
|     def __call__(self, x): |  | ||||||
|         self.check_valid() |  | ||||||
|         return self._params[0] |  | ||||||
|  |  | ||||||
|     def fit(self, **kwargs): |  | ||||||
|         raise NotImplementedError |  | ||||||
|  |  | ||||||
|     def _getitem(self, x, weights): |  | ||||||
|         raise NotImplementedError |  | ||||||
|  |  | ||||||
|     def __repr__(self): |  | ||||||
|         return "{name}({a})".format(name=self.__class__.__name__, a=self._params[0]) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class ComposedSinFunc(FitFunc): |  | ||||||
|     """The composed sin function that outputs: |  | ||||||
|     f(x) = a * sin( b*x ) + c |  | ||||||
|     """ |  | ||||||
|  |  | ||||||
|     def __init__(self, params, xstr="x"): |  | ||||||
|         super(ComposedSinFunc, self).__init__(3, None, params, xstr) |  | ||||||
|  |  | ||||||
|     def __call__(self, x): |  | ||||||
|         self.check_valid() |  | ||||||
|         a = self._params[0] |  | ||||||
|         b = self._params[1] |  | ||||||
|         c = self._params[2] |  | ||||||
|         return a * math.sin(b * x) + c |  | ||||||
|  |  | ||||||
|     def _getitem(self, x, weights): |  | ||||||
|         raise NotImplementedError |  | ||||||
|  |  | ||||||
|     def __repr__(self): |  | ||||||
|         return "{name}({a} * sin({b} * {x}) + {c})".format( |  | ||||||
|             name=self.__class__.__name__, |  | ||||||
|             a=self._params[0], |  | ||||||
|             b=self._params[1], |  | ||||||
|             c=self._params[2], |  | ||||||
|             x=self.xstr, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class ComposedCosFunc(FitFunc): |  | ||||||
|     """The composed sin function that outputs: |  | ||||||
|     f(x) = a * cos( b*x ) + c |  | ||||||
|     """ |  | ||||||
|  |  | ||||||
|     def __init__(self, params, xstr="x"): |  | ||||||
|         super(ComposedCosFunc, self).__init__(3, None, params, xstr) |  | ||||||
|  |  | ||||||
|     def __call__(self, x): |  | ||||||
|         self.check_valid() |  | ||||||
|         a = self._params[0] |  | ||||||
|         b = self._params[1] |  | ||||||
|         c = self._params[2] |  | ||||||
|         return a * math.cos(b * x) + c |  | ||||||
|  |  | ||||||
|     def _getitem(self, x, weights): |  | ||||||
|         raise NotImplementedError |  | ||||||
|  |  | ||||||
|     def __repr__(self): |  | ||||||
|         return "{name}({a} * sin({b} * {x}) + {c})".format( |  | ||||||
|             name=self.__class__.__name__, |  | ||||||
|             a=self._params[0], |  | ||||||
|             b=self._params[1], |  | ||||||
|             c=self._params[2], |  | ||||||
|             x=self.xstr, |  | ||||||
|         ) |  | ||||||
| @@ -5,34 +5,33 @@ import math | |||||||
| import abc | import abc | ||||||
| import copy | import copy | ||||||
| import numpy as np | import numpy as np | ||||||
| import torch |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class FitFunc(abc.ABC): | class MathFunc(abc.ABC): | ||||||
|     """The fit function that outputs f(x) = a * x^2 + b * x + c.""" |     """The math function -- a virtual class defining some APIs.""" | ||||||
|  |  | ||||||
|     def __init__(self, freedom: int, list_of_points=None, params=None, xstr="x"): |     def __init__(self, freedom: int, params=None, xstr="x"): | ||||||
|  |         # initialize as empty | ||||||
|         self._params = dict() |         self._params = dict() | ||||||
|         for i in range(freedom): |         for i in range(freedom): | ||||||
|             self._params[i] = None |             self._params[i] = None | ||||||
|         self._freedom = freedom |         self._freedom = freedom | ||||||
|         if list_of_points is not None and params is not None: |  | ||||||
|             raise ValueError("list_of_points and params can not be set simultaneously") |  | ||||||
|         if list_of_points is not None: |  | ||||||
|             self.fit(list_of_points=list_of_points) |  | ||||||
|         if params is not None: |         if params is not None: | ||||||
|             self.set(params) |             self.set(params) | ||||||
|         self._xstr = str(xstr) |         self._xstr = str(xstr) | ||||||
|  |         self._skip_check = True | ||||||
|  |  | ||||||
|     def set(self, params): |     def set(self, params): | ||||||
|         self._params = copy.deepcopy(params) |         for key in range(self._freedom): | ||||||
|  |             param = copy.deepcopy(params[key]) | ||||||
|  |             self._params[key] = param | ||||||
|  |  | ||||||
|     def check_valid(self): |     def check_valid(self): | ||||||
|         # for key, value in self._params.items(): |         if not self._skip_check: | ||||||
|         for key in range(self._freedom): |             for key in range(self._freedom): | ||||||
|             value = self._params[key] |                 value = self._params[key] | ||||||
|             if value is None: |                 if value is None: | ||||||
|                 raise ValueError("The {:} is None".format(key)) |                     raise ValueError("The {:} is None".format(key)) | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def xstr(self): |     def xstr(self): | ||||||
| @@ -45,7 +44,8 @@ class FitFunc(abc.ABC): | |||||||
|     def __call__(self, x): |     def __call__(self, x): | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|     def noise_call(self, x, std=0.1): |     @abc.abstractmethod | ||||||
|  |     def noise_call(self, x, std): | ||||||
|         clean_y = self.__call__(x) |         clean_y = self.__call__(x) | ||||||
|         if isinstance(clean_y, np.ndarray): |         if isinstance(clean_y, np.ndarray): | ||||||
|             noise_y = clean_y + np.random.normal(scale=std, size=clean_y.shape) |             noise_y = clean_y + np.random.normal(scale=std, size=clean_y.shape) | ||||||
| @@ -53,169 +53,7 @@ class FitFunc(abc.ABC): | |||||||
|             raise ValueError("Unkonwn type: {:}".format(type(clean_y))) |             raise ValueError("Unkonwn type: {:}".format(type(clean_y))) | ||||||
|         return noise_y |         return noise_y | ||||||
|  |  | ||||||
|     @abc.abstractmethod |  | ||||||
|     def _getitem(self, x): |  | ||||||
|         raise NotImplementedError |  | ||||||
|  |  | ||||||
|     def fit(self, **kwargs): |  | ||||||
|         list_of_points = kwargs["list_of_points"] |  | ||||||
|         max_iter, lr_max, verbose = ( |  | ||||||
|             kwargs.get("max_iter", 900), |  | ||||||
|             kwargs.get("lr_max", 1.0), |  | ||||||
|             kwargs.get("verbose", False), |  | ||||||
|         ) |  | ||||||
|         with torch.no_grad(): |  | ||||||
|             data = torch.Tensor(list_of_points).type(torch.float32) |  | ||||||
|             assert data.ndim == 2 and data.size(1) == 2, "Invalid shape : {:}".format( |  | ||||||
|                 data.shape |  | ||||||
|             ) |  | ||||||
|             x, y = data[:, 0], data[:, 1] |  | ||||||
|         weights = torch.nn.Parameter(torch.Tensor(self._freedom)) |  | ||||||
|         torch.nn.init.normal_(weights, mean=0.0, std=1.0) |  | ||||||
|         optimizer = torch.optim.Adam([weights], lr=lr_max, amsgrad=True) |  | ||||||
|         lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( |  | ||||||
|             optimizer, |  | ||||||
|             milestones=[ |  | ||||||
|                 int(max_iter * 0.25), |  | ||||||
|                 int(max_iter * 0.5), |  | ||||||
|                 int(max_iter * 0.75), |  | ||||||
|             ], |  | ||||||
|             gamma=0.1, |  | ||||||
|         ) |  | ||||||
|         if verbose: |  | ||||||
|             print("The optimizer: {:}".format(optimizer)) |  | ||||||
|  |  | ||||||
|         best_loss = None |  | ||||||
|         for _iter in range(max_iter): |  | ||||||
|             y_hat = self._getitem(x, weights) |  | ||||||
|             loss = torch.mean(torch.abs(y - y_hat)) |  | ||||||
|             optimizer.zero_grad() |  | ||||||
|             loss.backward() |  | ||||||
|             optimizer.step() |  | ||||||
|             lr_scheduler.step() |  | ||||||
|             if verbose: |  | ||||||
|                 print( |  | ||||||
|                     "In the fit, loss at the {:02d}/{:02d}-th iter is {:}".format( |  | ||||||
|                         _iter, max_iter, loss.item() |  | ||||||
|                     ) |  | ||||||
|                 ) |  | ||||||
|             # Update the params |  | ||||||
|             if best_loss is None or best_loss > loss.item(): |  | ||||||
|                 best_loss = loss.item() |  | ||||||
|                 for i in range(self._freedom): |  | ||||||
|                     self._params[i] = weights[i].item() |  | ||||||
|  |  | ||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
|         return "{name}(freedom={freedom})".format( |         return "{name}(freedom={freedom})".format( | ||||||
|             name=self.__class__.__name__, freedom=freedom |             name=self.__class__.__name__, freedom=freedom | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
| class LinearFunc(FitFunc): |  | ||||||
|     """The linear function that outputs f(x) = a * x + b.""" |  | ||||||
|  |  | ||||||
|     def __init__(self, list_of_points=None, params=None, xstr="x"): |  | ||||||
|         super(LinearFunc, self).__init__(2, list_of_points, params, xstr) |  | ||||||
|  |  | ||||||
|     def __call__(self, x): |  | ||||||
|         self.check_valid() |  | ||||||
|         return self._params[0] * x + self._params[1] |  | ||||||
|  |  | ||||||
|     def _getitem(self, x, weights): |  | ||||||
|         return weights[0] * x + weights[1] |  | ||||||
|  |  | ||||||
|     def __repr__(self): |  | ||||||
|         return "{name}({a} * {x} + {b})".format( |  | ||||||
|             name=self.__class__.__name__, |  | ||||||
|             a=self._params[0], |  | ||||||
|             b=self._params[1], |  | ||||||
|             x=self.xstr, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class QuadraticFunc(FitFunc): |  | ||||||
|     """The quadratic function that outputs f(x) = a * x^2 + b * x + c.""" |  | ||||||
|  |  | ||||||
|     def __init__(self, list_of_points=None, params=None, xstr="x"): |  | ||||||
|         super(QuadraticFunc, self).__init__(3, list_of_points, params, xstr) |  | ||||||
|  |  | ||||||
|     def __call__(self, x): |  | ||||||
|         self.check_valid() |  | ||||||
|         return self._params[0] * x * x + self._params[1] * x + self._params[2] |  | ||||||
|  |  | ||||||
|     def _getitem(self, x, weights): |  | ||||||
|         return weights[0] * x * x + weights[1] * x + weights[2] |  | ||||||
|  |  | ||||||
|     def __repr__(self): |  | ||||||
|         return "{name}({a} * {x}^2 + {b} * {x} + {c})".format( |  | ||||||
|             name=self.__class__.__name__, |  | ||||||
|             a=self._params[0], |  | ||||||
|             b=self._params[1], |  | ||||||
|             c=self._params[2], |  | ||||||
|             x=self.xstr, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class CubicFunc(FitFunc): |  | ||||||
|     """The cubic function that outputs f(x) = a * x^3 + b * x^2 + c * x + d.""" |  | ||||||
|  |  | ||||||
|     def __init__(self, list_of_points=None): |  | ||||||
|         super(CubicFunc, self).__init__(4, list_of_points) |  | ||||||
|  |  | ||||||
|     def __call__(self, x): |  | ||||||
|         self.check_valid() |  | ||||||
|         return ( |  | ||||||
|             self._params[0] * x ** 3 |  | ||||||
|             + self._params[1] * x ** 2 |  | ||||||
|             + self._params[2] * x |  | ||||||
|             + self._params[3] |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def _getitem(self, x, weights): |  | ||||||
|         return weights[0] * x ** 3 + weights[1] * x ** 2 + weights[2] * x + weights[3] |  | ||||||
|  |  | ||||||
|     def __repr__(self): |  | ||||||
|         return "{name}({a} * {x}^3 + {b} * {x}^2 + {c} * {x} + {d})".format( |  | ||||||
|             name=self.__class__.__name__, |  | ||||||
|             a=self._params[0], |  | ||||||
|             b=self._params[1], |  | ||||||
|             c=self._params[2], |  | ||||||
|             d=self._params[3], |  | ||||||
|             x=self.xstr, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class QuarticFunc(FitFunc): |  | ||||||
|     """The quartic function that outputs f(x) = a * x^4 + b * x^3 + c * x^2 + d * x + e.""" |  | ||||||
|  |  | ||||||
|     def __init__(self, list_of_points=None): |  | ||||||
|         super(QuarticFunc, self).__init__(5, list_of_points) |  | ||||||
|  |  | ||||||
|     def __call__(self, x): |  | ||||||
|         self.check_valid() |  | ||||||
|         return ( |  | ||||||
|             self._params[0] * x ** 4 |  | ||||||
|             + self._params[1] * x ** 3 |  | ||||||
|             + self._params[2] * x ** 2 |  | ||||||
|             + self._params[3] * x |  | ||||||
|             + self._params[4] |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def _getitem(self, x, weights): |  | ||||||
|         return ( |  | ||||||
|             weights[0] * x ** 4 |  | ||||||
|             + weights[1] * x ** 3 |  | ||||||
|             + weights[2] * x ** 2 |  | ||||||
|             + weights[3] * x |  | ||||||
|             + weights[4] |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def __repr__(self): |  | ||||||
|         return "{name}({a} * x^4 + {b} * x^3 + {c} * x^2 + {d} * x + {e})".format( |  | ||||||
|             name=self.__class__.__name__, |  | ||||||
|             a=self._params[0], |  | ||||||
|             b=self._params[1], |  | ||||||
|             c=self._params[2], |  | ||||||
|             d=self._params[3], |  | ||||||
|             e=self._params[3], |  | ||||||
|         ) |  | ||||||
|   | |||||||
| @@ -1,10 +1,14 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 # | ||||||
| ##################################################### | ##################################################### | ||||||
| from .math_base_funcs import LinearFunc, QuadraticFunc, CubicFunc, QuarticFunc | from .math_static_funcs import ( | ||||||
| from .math_dynamic_funcs import DynamicLinearFunc |     LinearSFunc, | ||||||
| from .math_dynamic_funcs import DynamicQuadraticFunc |     QuadraticSFunc, | ||||||
| from .math_dynamic_funcs import DynamicSinQuadraticFunc |     CubicSFunc, | ||||||
| from .math_adv_funcs import ConstantFunc |     QuarticSFunc, | ||||||
| from .math_adv_funcs import ComposedSinFunc, ComposedCosFunc |     ConstantFunc, | ||||||
|  |     ComposedSinSFunc, | ||||||
|  |     ComposedCosSFunc, | ||||||
|  | ) | ||||||
|  | from .math_dynamic_funcs import LinearDFunc, QuadraticDFunc, SinQuadraticDFunc | ||||||
| from .math_dynamic_generator import GaussianDGenerator | from .math_dynamic_generator import GaussianDGenerator | ||||||
|   | |||||||
| @@ -6,23 +6,17 @@ import abc | |||||||
| import copy | import copy | ||||||
| import numpy as np | import numpy as np | ||||||
|  |  | ||||||
| from .math_base_funcs import FitFunc | from .math_base_funcs import MathFunc | ||||||
|  |  | ||||||
|  |  | ||||||
| class DynamicFunc(FitFunc): | class DynamicFunc(MathFunc): | ||||||
|     """The dynamic quadratic function, where each param is a function.""" |     """The dynamic function, where each param is a function.""" | ||||||
|  |  | ||||||
|     def __init__(self, freedom: int, params=None, xstr="x"): |     def __init__(self, freedom: int, params=None, xstr="x"): | ||||||
|         if params is not None: |         if params is not None: | ||||||
|             for param in params: |             for key, param in params.items(): | ||||||
|                 param.reset_xstr("t") if isinstance(param, FitFunc) else None |                 param.reset_xstr("t") if isinstance(param, MathFunc) else None | ||||||
|         super(DynamicFunc, self).__init__(freedom, None, params, xstr) |         super(DynamicFunc, self).__init__(freedom, params, xstr) | ||||||
|  |  | ||||||
|     def __call__(self, x, timestamp): |  | ||||||
|         raise NotImplementedError |  | ||||||
|  |  | ||||||
|     def _getitem(self, x, weights): |  | ||||||
|         raise NotImplementedError |  | ||||||
|  |  | ||||||
|     def noise_call(self, x, timestamp, std): |     def noise_call(self, x, timestamp, std): | ||||||
|         clean_y = self.__call__(x, timestamp) |         clean_y = self.__call__(x, timestamp) | ||||||
| @@ -33,13 +27,13 @@ class DynamicFunc(FitFunc): | |||||||
|         return noise_y |         return noise_y | ||||||
|  |  | ||||||
|  |  | ||||||
| class DynamicLinearFunc(DynamicFunc): | class LinearDFunc(DynamicFunc): | ||||||
|     """The dynamic linear function that outputs f(x) = a * x + b. |     """The dynamic linear function that outputs f(x) = a * x + b. | ||||||
|     The a and b is a function of timestamp. |     The a and b is a function of timestamp. | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self, params=None, xstr="x"): |     def __init__(self, params, xstr="x"): | ||||||
|         super(DynamicLinearFunc, self).__init__(3, params, xstr) |         super(LinearDFunc, self).__init__(2, params, xstr) | ||||||
|  |  | ||||||
|     def __call__(self, x, timestamp): |     def __call__(self, x, timestamp): | ||||||
|         a = self._params[0](timestamp) |         a = self._params[0](timestamp) | ||||||
| @@ -57,18 +51,15 @@ class DynamicLinearFunc(DynamicFunc): | |||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
| class DynamicQuadraticFunc(DynamicFunc): | class QuadraticDFunc(DynamicFunc): | ||||||
|     """The dynamic quadratic function that outputs f(x) = a * x^2 + b * x + c. |     """The dynamic quadratic function that outputs f(x) = a * x^2 + b * x + c. | ||||||
|     The a, b, and c is a function of timestamp. |     The a, b, and c is a function of timestamp. | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self, params=None): |     def __init__(self, params, xstr="x"): | ||||||
|         super(DynamicQuadraticFunc, self).__init__(3, params) |         super(QuadraticDFunc, self).__init__(3, params) | ||||||
|  |  | ||||||
|     def __call__( |     def __call__(self, x, timestamp): | ||||||
|         self, |  | ||||||
|         x, |  | ||||||
|     ): |  | ||||||
|         self.check_valid() |         self.check_valid() | ||||||
|         a = self._params[0](timestamp) |         a = self._params[0](timestamp) | ||||||
|         b = self._params[1](timestamp) |         b = self._params[1](timestamp) | ||||||
| @@ -78,38 +69,37 @@ class DynamicQuadraticFunc(DynamicFunc): | |||||||
|         return a * x * x + b * x + c |         return a * x * x + b * x + c | ||||||
|  |  | ||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
|         return "{name}({a} * x^2 + {b} * x + {c})".format( |         return "{name}({a} * {x}^2 + {b} * {x} + {c})".format( | ||||||
|             name=self.__class__.__name__, |             name=self.__class__.__name__, | ||||||
|             a=self._params[0], |             a=self._params[0], | ||||||
|             b=self._params[1], |             b=self._params[1], | ||||||
|             c=self._params[2], |             c=self._params[2], | ||||||
|  |             x=self.xstr, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
| class DynamicSinQuadraticFunc(DynamicFunc): | class SinQuadraticDFunc(DynamicFunc): | ||||||
|     """The dynamic quadratic function that outputs f(x) = sin(a * x^2 + b * x + c). |     """The dynamic quadratic function that outputs f(x) = sin(a * x^2 + b * x + c). | ||||||
|     The a, b, and c is a function of timestamp. |     The a, b, and c is a function of timestamp. | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self, params=None): |     def __init__(self, params=None): | ||||||
|         super(DynamicSinQuadraticFunc, self).__init__(3, params) |         super(SinQuadraticDFunc, self).__init__(3, params) | ||||||
|  |  | ||||||
|     def __call__( |     def __call__(self, x, timestamp): | ||||||
|         self, |  | ||||||
|         x, |  | ||||||
|     ): |  | ||||||
|         self.check_valid() |         self.check_valid() | ||||||
|         a = self._params[0](timestamp) |         a = self._params[0](timestamp) | ||||||
|         b = self._params[1](timestamp) |         b = self._params[1](timestamp) | ||||||
|         c = self._params[2](timestamp) |         c = self._params[2](timestamp) | ||||||
|         convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x |         convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x | ||||||
|         a, b, c = convert_fn(a), convert_fn(b), convert_fn(c) |         a, b, c = convert_fn(a), convert_fn(b), convert_fn(c) | ||||||
|         return math.sin(a * x * x + b * x + c) |         return np.sin(a * x * x + b * x + c) | ||||||
|  |  | ||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
|         return "{name}({a} * x^2 + {b} * x + {c})".format( |         return "{name}({a} * {x}^2 + {b} * {x} + {c})".format( | ||||||
|             name=self.__class__.__name__, |             name=self.__class__.__name__, | ||||||
|             a=self._params[0], |             a=self._params[0], | ||||||
|             b=self._params[1], |             b=self._params[1], | ||||||
|             c=self._params[2], |             c=self._params[2], | ||||||
|  |             x=self.xstr, | ||||||
|         ) |         ) | ||||||
|   | |||||||
							
								
								
									
										225
									
								
								xautodl/datasets/math_static_funcs.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										225
									
								
								xautodl/datasets/math_static_funcs.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,225 @@ | |||||||
|  | ##################################################### | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||||
|  | ##################################################### | ||||||
|  | import math | ||||||
|  | import abc | ||||||
|  | import copy | ||||||
|  | import numpy as np | ||||||
|  |  | ||||||
|  | from .math_base_funcs import MathFunc | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class StaticFunc(MathFunc): | ||||||
|  |     """The fit function that outputs f(x) = a * x^2 + b * x + c.""" | ||||||
|  |  | ||||||
|  |     def __init__(self, freedom: int, params=None, xstr="x"): | ||||||
|  |         super(StaticFunc, self).__init__(freedom, params, xstr) | ||||||
|  |  | ||||||
|  |     @abc.abstractmethod | ||||||
|  |     def __call__(self, x): | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     def noise_call(self, x, std): | ||||||
|  |         clean_y = self.__call__(x) | ||||||
|  |         if isinstance(clean_y, np.ndarray): | ||||||
|  |             noise_y = clean_y + np.random.normal(scale=std, size=clean_y.shape) | ||||||
|  |         else: | ||||||
|  |             raise ValueError("Unkonwn type: {:}".format(type(clean_y))) | ||||||
|  |         return noise_y | ||||||
|  |  | ||||||
|  |     def __repr__(self): | ||||||
|  |         return "{name}(freedom={freedom})".format( | ||||||
|  |             name=self.__class__.__name__, freedom=freedom | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class LinearSFunc(StaticFunc): | ||||||
|  |     """The linear function that outputs f(x) = a * x + b.""" | ||||||
|  |  | ||||||
|  |     def __init__(self, params=None, xstr="x"): | ||||||
|  |         super(LinearSFunc, self).__init__(2, params, xstr) | ||||||
|  |  | ||||||
|  |     def __call__(self, x): | ||||||
|  |         self.check_valid() | ||||||
|  |         return self._params[0] * x + self._params[1] | ||||||
|  |  | ||||||
|  |     def _getitem(self, x, weights): | ||||||
|  |         return weights[0] * x + weights[1] | ||||||
|  |  | ||||||
|  |     def __repr__(self): | ||||||
|  |         return "{name}({a} * {x} + {b})".format( | ||||||
|  |             name=self.__class__.__name__, | ||||||
|  |             a=self._params[0], | ||||||
|  |             b=self._params[1], | ||||||
|  |             x=self.xstr, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class QuadraticSFunc(StaticFunc): | ||||||
|  |     """The quadratic function that outputs f(x) = a * x^2 + b * x + c.""" | ||||||
|  |  | ||||||
|  |     def __init__(self, params=None, xstr="x"): | ||||||
|  |         super(QuadraticSFunc, self).__init__(3, params, xstr) | ||||||
|  |  | ||||||
|  |     def __call__(self, x): | ||||||
|  |         self.check_valid() | ||||||
|  |         return self._params[0] * x * x + self._params[1] * x + self._params[2] | ||||||
|  |  | ||||||
|  |     def _getitem(self, x, weights): | ||||||
|  |         return weights[0] * x * x + weights[1] * x + weights[2] | ||||||
|  |  | ||||||
|  |     def __repr__(self): | ||||||
|  |         return "{name}({a} * {x}^2 + {b} * {x} + {c})".format( | ||||||
|  |             name=self.__class__.__name__, | ||||||
|  |             a=self._params[0], | ||||||
|  |             b=self._params[1], | ||||||
|  |             c=self._params[2], | ||||||
|  |             x=self.xstr, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class CubicSFunc(StaticFunc): | ||||||
|  |     """The cubic function that outputs f(x) = a * x^3 + b * x^2 + c * x + d.""" | ||||||
|  |  | ||||||
|  |     def __init__(self, params=None, xstr="x"): | ||||||
|  |         super(CubicSFunc, self).__init__(4, params, xstr) | ||||||
|  |  | ||||||
|  |     def __call__(self, x): | ||||||
|  |         self.check_valid() | ||||||
|  |         return ( | ||||||
|  |             self._params[0] * x ** 3 | ||||||
|  |             + self._params[1] * x ** 2 | ||||||
|  |             + self._params[2] * x | ||||||
|  |             + self._params[3] | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def _getitem(self, x, weights): | ||||||
|  |         return weights[0] * x ** 3 + weights[1] * x ** 2 + weights[2] * x + weights[3] | ||||||
|  |  | ||||||
|  |     def __repr__(self): | ||||||
|  |         return "{name}({a} * {x}^3 + {b} * {x}^2 + {c} * {x} + {d})".format( | ||||||
|  |             name=self.__class__.__name__, | ||||||
|  |             a=self._params[0], | ||||||
|  |             b=self._params[1], | ||||||
|  |             c=self._params[2], | ||||||
|  |             d=self._params[3], | ||||||
|  |             x=self.xstr, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class QuarticSFunc(StaticFunc): | ||||||
|  |     """The quartic function that outputs f(x) = a * x^4 + b * x^3 + c * x^2 + d * x + e.""" | ||||||
|  |  | ||||||
|  |     def __init__(self, params=None, xstr="x"): | ||||||
|  |         super(QuarticSFunc, self).__init__(5, params, xstr) | ||||||
|  |  | ||||||
|  |     def __call__(self, x): | ||||||
|  |         self.check_valid() | ||||||
|  |         return ( | ||||||
|  |             self._params[0] * x ** 4 | ||||||
|  |             + self._params[1] * x ** 3 | ||||||
|  |             + self._params[2] * x ** 2 | ||||||
|  |             + self._params[3] * x | ||||||
|  |             + self._params[4] | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def _getitem(self, x, weights): | ||||||
|  |         return ( | ||||||
|  |             weights[0] * x ** 4 | ||||||
|  |             + weights[1] * x ** 3 | ||||||
|  |             + weights[2] * x ** 2 | ||||||
|  |             + weights[3] * x | ||||||
|  |             + weights[4] | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def __repr__(self): | ||||||
|  |         return ( | ||||||
|  |             "{name}({a} * {x}^4 + {b} * {x}^3 + {c} * {x}^2 + {d} * {x} + {e})".format( | ||||||
|  |                 name=self.__class__.__name__, | ||||||
|  |                 a=self._params[0], | ||||||
|  |                 b=self._params[1], | ||||||
|  |                 c=self._params[2], | ||||||
|  |                 d=self._params[3], | ||||||
|  |                 e=self._params[3], | ||||||
|  |                 x=self.xstr, | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | ### advanced functions | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ConstantFunc(StaticFunc): | ||||||
|  |     """The constant function: f(x) = c.""" | ||||||
|  |  | ||||||
|  |     def __init__(self, constant, xstr="x"): | ||||||
|  |         super(ConstantFunc, self).__init__(1, {0: constant}, xstr) | ||||||
|  |  | ||||||
|  |     def __call__(self, x): | ||||||
|  |         self.check_valid() | ||||||
|  |         return self._params[0] | ||||||
|  |  | ||||||
|  |     def fit(self, **kwargs): | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     def _getitem(self, x, weights): | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     def __repr__(self): | ||||||
|  |         return "{name}({a})".format(name=self.__class__.__name__, a=self._params[0]) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ComposedSinSFunc(StaticFunc): | ||||||
|  |     """The composed sin function that outputs: | ||||||
|  |     f(x) = a * sin( b*x ) + c | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def __init__(self, params, xstr="x"): | ||||||
|  |         super(ComposedSinSFunc, self).__init__(3, params, xstr) | ||||||
|  |  | ||||||
|  |     def __call__(self, x): | ||||||
|  |         self.check_valid() | ||||||
|  |         a = self._params[0] | ||||||
|  |         b = self._params[1] | ||||||
|  |         c = self._params[2] | ||||||
|  |         return a * math.sin(b * x) + c | ||||||
|  |  | ||||||
|  |     def _getitem(self, x, weights): | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     def __repr__(self): | ||||||
|  |         return "{name}({a} * sin({b} * {x}) + {c})".format( | ||||||
|  |             name=self.__class__.__name__, | ||||||
|  |             a=self._params[0], | ||||||
|  |             b=self._params[1], | ||||||
|  |             c=self._params[2], | ||||||
|  |             x=self.xstr, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ComposedCosSFunc(StaticFunc): | ||||||
|  |     """The composed sin function that outputs: | ||||||
|  |     f(x) = a * cos( b*x ) + c | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def __init__(self, params, xstr="x"): | ||||||
|  |         super(ComposedCosSFunc, self).__init__(3, params, xstr) | ||||||
|  |  | ||||||
|  |     def __call__(self, x): | ||||||
|  |         self.check_valid() | ||||||
|  |         a = self._params[0] | ||||||
|  |         b = self._params[1] | ||||||
|  |         c = self._params[2] | ||||||
|  |         return a * math.cos(b * x) + c | ||||||
|  |  | ||||||
|  |     def _getitem(self, x, weights): | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     def __repr__(self): | ||||||
|  |         return "{name}({a} * sin({b} * {x}) + {c})".format( | ||||||
|  |             name=self.__class__.__name__, | ||||||
|  |             a=self._params[0], | ||||||
|  |             b=self._params[1], | ||||||
|  |             c=self._params[2], | ||||||
|  |             x=self.xstr, | ||||||
|  |         ) | ||||||
| @@ -1,13 +1,13 @@ | |||||||
| import math | import math | ||||||
| from .synthetic_utils import TimeStamp | from .synthetic_utils import TimeStamp | ||||||
| from .synthetic_env import SyntheticDEnv | from .synthetic_env import SyntheticDEnv | ||||||
| from .math_core import LinearFunc | from .math_core import LinearSFunc | ||||||
| from .math_core import DynamicLinearFunc | from .math_core import LinearDFunc | ||||||
| from .math_core import DynamicQuadraticFunc, DynamicSinQuadraticFunc | from .math_core import QuadraticDFunc, SinQuadraticDFunc | ||||||
| from .math_core import ( | from .math_core import ( | ||||||
|     ConstantFunc, |     ConstantFunc, | ||||||
|     ComposedSinFunc as SinFunc, |     ComposedSinSFunc as SinFunc, | ||||||
|     ComposedCosFunc as CosFunc, |     ComposedCosSFunc as CosFunc, | ||||||
| ) | ) | ||||||
| from .math_core import GaussianDGenerator | from .math_core import GaussianDGenerator | ||||||
|  |  | ||||||
| @@ -17,7 +17,7 @@ __all__ = ["TimeStamp", "SyntheticDEnv", "get_synthetic_env"] | |||||||
|  |  | ||||||
| def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, version="v1"): | def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, version="v1"): | ||||||
|     max_time = math.pi * 10 |     max_time = math.pi * 10 | ||||||
|     if version == "v1": |     if version.lower() == "v1": | ||||||
|         mean_generator = ConstantFunc(0) |         mean_generator = ConstantFunc(0) | ||||||
|         std_generator = ConstantFunc(1) |         std_generator = ConstantFunc(1) | ||||||
|         data_generator = GaussianDGenerator( |         data_generator = GaussianDGenerator( | ||||||
| @@ -26,7 +26,7 @@ def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, versio | |||||||
|         time_generator = TimeStamp( |         time_generator = TimeStamp( | ||||||
|             min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode |             min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode | ||||||
|         ) |         ) | ||||||
|         oracle_map = DynamicLinearFunc( |         oracle_map = LinearDFunc( | ||||||
|             params={ |             params={ | ||||||
|                 0: SinFunc(params={0: 2.0, 1: 1.0, 2: 2.2}),  # 2 sin(t) + 2.2 |                 0: SinFunc(params={0: 2.0, 1: 1.0, 2: 2.2}),  # 2 sin(t) + 2.2 | ||||||
|                 1: SinFunc(params={0: 1.5, 1: 0.6, 2: 1.8}),  # 1.5 sin(0.6t) + 1.8 |                 1: SinFunc(params={0: 1.5, 1: 0.6, 2: 1.8}),  # 1.5 sin(0.6t) + 1.8 | ||||||
| @@ -35,7 +35,8 @@ def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, versio | |||||||
|         dynamic_env = SyntheticDEnv( |         dynamic_env = SyntheticDEnv( | ||||||
|             data_generator, oracle_map, time_generator, num_per_task |             data_generator, oracle_map, time_generator, num_per_task | ||||||
|         ) |         ) | ||||||
|     elif version == "v2": |         dynamic_env.set_regression() | ||||||
|  |     elif version.lower() == "v2": | ||||||
|         mean_generator = ConstantFunc(0) |         mean_generator = ConstantFunc(0) | ||||||
|         std_generator = ConstantFunc(1) |         std_generator = ConstantFunc(1) | ||||||
|         data_generator = GaussianDGenerator( |         data_generator = GaussianDGenerator( | ||||||
| @@ -44,16 +45,17 @@ def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, versio | |||||||
|         time_generator = TimeStamp( |         time_generator = TimeStamp( | ||||||
|             min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode |             min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode | ||||||
|         ) |         ) | ||||||
|         oracle_map = DynamicQuadraticFunc( |         oracle_map = QuadraticDFunc( | ||||||
|             params={ |             params={ | ||||||
|                 0: LinearFunc(params={0: 0.1, 1: 0}),  # 0.1 * t |                 0: LinearSFunc(params={0: 0.1, 1: 0}),  # 0.1 * t | ||||||
|                 1: SinFunc(params={0: 1, 1: 1, 2: 0}),  # sin(t) |                 1: ConstantFunc(0), | ||||||
|                 2: ConstantFunc(0), |                 2: CosFunc(params={0: 4.0, 1: 10, 2: 0}),  # 4 * cos(10 * t) | ||||||
|             } |             } | ||||||
|         ) |         ) | ||||||
|         dynamic_env = SyntheticDEnv( |         dynamic_env = SyntheticDEnv( | ||||||
|             data_generator, oracle_map, time_generator, num_per_task |             data_generator, oracle_map, time_generator, num_per_task | ||||||
|         ) |         ) | ||||||
|  |         dynamic_env.set_regression() | ||||||
|     elif version.lower() == "v3": |     elif version.lower() == "v3": | ||||||
|         mean_generator = SinFunc(params={0: 1, 1: 1, 2: 0})  # sin(t) |         mean_generator = SinFunc(params={0: 1, 1: 1, 2: 0})  # sin(t) | ||||||
|         std_generator = CosFunc(params={0: 0.5, 1: 1, 2: 1})  # 0.5 cos(t) + 1 |         std_generator = CosFunc(params={0: 0.5, 1: 1, 2: 1})  # 0.5 cos(t) + 1 | ||||||
| @@ -63,7 +65,7 @@ def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, versio | |||||||
|         time_generator = TimeStamp( |         time_generator = TimeStamp( | ||||||
|             min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode |             min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode | ||||||
|         ) |         ) | ||||||
|         oracle_map = DynamicSinQuadraticFunc( |         oracle_map = SinQuadraticDFunc( | ||||||
|             params={ |             params={ | ||||||
|                 0: CosFunc(params={0: 0.5, 1: 1, 2: 1}),  # 0.5 cos(t) + 1 |                 0: CosFunc(params={0: 0.5, 1: 1, 2: 1}),  # 0.5 cos(t) + 1 | ||||||
|                 1: SinFunc(params={0: 1, 1: 1, 2: 0}),  # sin(t) |                 1: SinFunc(params={0: 1, 1: 1, 2: 0}),  # sin(t) | ||||||
| @@ -73,6 +75,9 @@ def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, versio | |||||||
|         dynamic_env = SyntheticDEnv( |         dynamic_env = SyntheticDEnv( | ||||||
|             data_generator, oracle_map, time_generator, num_per_task |             data_generator, oracle_map, time_generator, num_per_task | ||||||
|         ) |         ) | ||||||
|  |         dynamic_env.set_regression() | ||||||
|  |     elif version.lower() == "v4": | ||||||
|  |         dynamic_env.set_classification(2) | ||||||
|     else: |     else: | ||||||
|         raise ValueError("Unknown version: {:}".format(version)) |         raise ValueError("Unknown version: {:}".format(version)) | ||||||
|     return dynamic_env |     return dynamic_env | ||||||
|   | |||||||
| @@ -49,6 +49,10 @@ class SyntheticDEnv(data.Dataset): | |||||||
|         self._meta_info["task"] = "classification" |         self._meta_info["task"] = "classification" | ||||||
|         self._meta_info["num_classes"] = int(num_classes) |         self._meta_info["num_classes"] = int(num_classes) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def oracle_map(self): | ||||||
|  |         return self._oracle_map | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def meta_info(self): |     def meta_info(self): | ||||||
|         return self._meta_info |         return self._meta_info | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user