Re-org GeMOSA codes
This commit is contained in:
		| @@ -1,117 +0,0 @@ | |||||||
| ##################################################### |  | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # |  | ||||||
| ##################################################### |  | ||||||
| import copy |  | ||||||
| import torch |  | ||||||
|  |  | ||||||
| import torch.nn.functional as F |  | ||||||
|  |  | ||||||
| from xlayers import super_core |  | ||||||
| from xlayers import trunc_normal_ |  | ||||||
| from models.xcore import get_model |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class HyperNet(super_core.SuperModule): |  | ||||||
|     """The hyper-network.""" |  | ||||||
|  |  | ||||||
|     def __init__( |  | ||||||
|         self, |  | ||||||
|         shape_container, |  | ||||||
|         layer_embeding, |  | ||||||
|         task_embedding, |  | ||||||
|         num_tasks, |  | ||||||
|         return_container=True, |  | ||||||
|     ): |  | ||||||
|         super(HyperNet, self).__init__() |  | ||||||
|         self._shape_container = shape_container |  | ||||||
|         self._num_layers = len(shape_container) |  | ||||||
|         self._numel_per_layer = [] |  | ||||||
|         for ilayer in range(self._num_layers): |  | ||||||
|             self._numel_per_layer.append(shape_container[ilayer].numel()) |  | ||||||
|  |  | ||||||
|         self.register_parameter( |  | ||||||
|             "_super_layer_embed", |  | ||||||
|             torch.nn.Parameter(torch.Tensor(self._num_layers, layer_embeding)), |  | ||||||
|         ) |  | ||||||
|         self.register_parameter( |  | ||||||
|             "_super_task_embed", |  | ||||||
|             torch.nn.Parameter(torch.Tensor(num_tasks, task_embedding)), |  | ||||||
|         ) |  | ||||||
|         trunc_normal_(self._super_layer_embed, std=0.02) |  | ||||||
|         trunc_normal_(self._super_task_embed, std=0.02) |  | ||||||
|  |  | ||||||
|         model_kwargs = dict( |  | ||||||
|             config=dict(model_type="dual_norm_mlp"), |  | ||||||
|             input_dim=layer_embeding + task_embedding, |  | ||||||
|             output_dim=max(self._numel_per_layer), |  | ||||||
|             hidden_dims=[(layer_embeding + task_embedding) * 2] * 3, |  | ||||||
|             act_cls="gelu", |  | ||||||
|             norm_cls="layer_norm_1d", |  | ||||||
|             dropout=0.2, |  | ||||||
|         ) |  | ||||||
|         self._generator = get_model(**model_kwargs) |  | ||||||
|         self._return_container = return_container |  | ||||||
|         print("generator: {:}".format(self._generator)) |  | ||||||
|  |  | ||||||
|     def forward_raw(self, task_embed_id): |  | ||||||
|         layer_embed = self._super_layer_embed |  | ||||||
|         task_embed = ( |  | ||||||
|             self._super_task_embed[task_embed_id] |  | ||||||
|             .view(1, -1) |  | ||||||
|             .expand(self._num_layers, -1) |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         joint_embed = torch.cat((task_embed, layer_embed), dim=-1) |  | ||||||
|         weights = self._generator(joint_embed) |  | ||||||
|         if self._return_container: |  | ||||||
|             weights = torch.split(weights, 1) |  | ||||||
|             return self._shape_container.translate(weights) |  | ||||||
|         else: |  | ||||||
|             return weights |  | ||||||
|  |  | ||||||
|     def forward_candidate(self, input): |  | ||||||
|         raise NotImplementedError |  | ||||||
|  |  | ||||||
|     def extra_repr(self) -> str: |  | ||||||
|         return "(_super_layer_embed): {:}".format(list(self._super_layer_embed.shape)) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class HyperNet_VX(super_core.SuperModule): |  | ||||||
|     def __init__(self, shape_container, input_embeding, return_container=True): |  | ||||||
|         super(HyperNet_VX, self).__init__() |  | ||||||
|         self._shape_container = shape_container |  | ||||||
|         self._num_layers = len(shape_container) |  | ||||||
|         self._numel_per_layer = [] |  | ||||||
|         for ilayer in range(self._num_layers): |  | ||||||
|             self._numel_per_layer.append(shape_container[ilayer].numel()) |  | ||||||
|  |  | ||||||
|         self.register_parameter( |  | ||||||
|             "_super_layer_embed", |  | ||||||
|             torch.nn.Parameter(torch.Tensor(self._num_layers, input_embeding)), |  | ||||||
|         ) |  | ||||||
|         trunc_normal_(self._super_layer_embed, std=0.02) |  | ||||||
|  |  | ||||||
|         model_kwargs = dict( |  | ||||||
|             input_dim=input_embeding, |  | ||||||
|             output_dim=max(self._numel_per_layer), |  | ||||||
|             hidden_dim=input_embeding * 4, |  | ||||||
|             act_cls="sigmoid", |  | ||||||
|             norm_cls="identity", |  | ||||||
|         ) |  | ||||||
|         self._generator = get_model(dict(model_type="simple_mlp"), **model_kwargs) |  | ||||||
|         self._return_container = return_container |  | ||||||
|         print("generator: {:}".format(self._generator)) |  | ||||||
|  |  | ||||||
|     def forward_raw(self, input): |  | ||||||
|         weights = self._generator(self._super_layer_embed) |  | ||||||
|         if self._return_container: |  | ||||||
|             weights = torch.split(weights, 1) |  | ||||||
|             return self._shape_container.translate(weights) |  | ||||||
|         else: |  | ||||||
|             return weights |  | ||||||
|  |  | ||||||
|     def forward_candidate(self, input): |  | ||||||
|         raise NotImplementedError |  | ||||||
|  |  | ||||||
|     def extra_repr(self) -> str: |  | ||||||
|         return "(_super_layer_embed): {:}".format(list(self._super_layer_embed.shape)) |  | ||||||
| @@ -35,7 +35,7 @@ from xautodl.models.xcore import get_model | |||||||
| from xautodl.xlayers import super_core, trunc_normal_ | from xautodl.xlayers import super_core, trunc_normal_ | ||||||
|  |  | ||||||
| from lfna_utils import lfna_setup, train_model, TimeData | from lfna_utils import lfna_setup, train_model, TimeData | ||||||
| from lfna_meta_model import MetaModelV1 | from meta_model import MetaModelV1 | ||||||
|  |  | ||||||
|  |  | ||||||
| def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=False): | def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=False): | ||||||
| @@ -106,7 +106,7 @@ def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger): | |||||||
|         ) |         ) | ||||||
|         optimizer.zero_grad() |         optimizer.zero_grad() | ||||||
|  |  | ||||||
|         generated_time_embeds = meta_model(meta_model.meta_timestamps, None, True) |         generated_time_embeds = gen_time_embed(meta_model.meta_timestamps) | ||||||
|  |  | ||||||
|         batch_indexes = random.choices(total_indexes, k=args.meta_batch) |         batch_indexes = random.choices(total_indexes, k=args.meta_batch) | ||||||
|  |  | ||||||
| @@ -219,11 +219,11 @@ def main(args): | |||||||
|     w_containers, loss_meter = online_evaluate( |     w_containers, loss_meter = online_evaluate( | ||||||
|         all_env, meta_model, base_model, criterion, args, logger, True |         all_env, meta_model, base_model, criterion, args, logger, True | ||||||
|     ) |     ) | ||||||
|     logger.log("In this enviornment, the loss-meter is {:}".format(loss_meter)) |     logger.log("In this enviornment, the total loss-meter is {:}".format(loss_meter)) | ||||||
|  |  | ||||||
|     save_checkpoint( |     save_checkpoint( | ||||||
|         {"w_containers": w_containers}, |         {"all_w_containers": w_containers}, | ||||||
|         logger.path(None) / "final-ckp.pth", |         logger.path(None) / "final-ckp-{:}.pth".format(args.rand_seed), | ||||||
|         logger, |         logger, | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -154,8 +154,9 @@ class MetaModelV1(super_core.SuperModule): | |||||||
|                     (self._append_meta_embed["fixed"], meta_embed), dim=0 |                     (self._append_meta_embed["fixed"], meta_embed), dim=0 | ||||||
|                 ) |                 ) | ||||||
| 
 | 
 | ||||||
|     def _obtain_time_embed(self, timestamps): |     def gen_time_embed(self, timestamps): | ||||||
|         # timestamps is a batch of sequence of timestamps |         # timestamps is a batch of timestamps | ||||||
|  |         [B] = timestamps.shape | ||||||
|         # batch, seq = timestamps.shape |         # batch, seq = timestamps.shape | ||||||
|         timestamps = timestamps.view(-1, 1) |         timestamps = timestamps.view(-1, 1) | ||||||
|         meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed |         meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed | ||||||
| @@ -179,15 +180,8 @@ class MetaModelV1(super_core.SuperModule): | |||||||
|         ) |         ) | ||||||
|         return timestamp_embeds[:, -1, :] |         return timestamp_embeds[:, -1, :] | ||||||
| 
 | 
 | ||||||
|     def forward_raw(self, timestamps, time_embeds, tembed_only=False): |     def gen_model(self, time_embeds): | ||||||
|         if time_embeds is None: |  | ||||||
|             [B] = timestamps.shape |  | ||||||
|             time_embeds = self._obtain_time_embed(timestamps) |  | ||||||
|         else:  # use the hyper-net only |  | ||||||
|             time_seq = None |  | ||||||
|         B, _ = time_embeds.shape |         B, _ = time_embeds.shape | ||||||
|         if tembed_only: |  | ||||||
|             return time_embeds |  | ||||||
|         # create joint embed |         # create joint embed | ||||||
|         num_layer, _ = self._super_layer_embed.shape |         num_layer, _ = self._super_layer_embed.shape | ||||||
|         # The shape of `joint_embed` is batch * num-layers * input-dim |         # The shape of `joint_embed` is batch * num-layers * input-dim | ||||||
| @@ -206,6 +200,9 @@ class MetaModelV1(super_core.SuperModule): | |||||||
|             ) |             ) | ||||||
|         return batch_containers, time_embeds |         return batch_containers, time_embeds | ||||||
| 
 | 
 | ||||||
|  |     def forward_raw(self, timestamps, time_embeds, tembed_only=False): | ||||||
|  |         raise NotImplementedError | ||||||
|  | 
 | ||||||
|     def forward_candidate(self, input): |     def forward_candidate(self, input): | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
| 
 | 
 | ||||||
| @@ -1,8 +1,9 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 # | ||||||
| ############################################################################ | ############################################################################ | ||||||
| # python exps/GMOA/vis-synthetic.py --env_version v1                       # | # python exps/GeMOSA/vis-synthetic.py --env_version v1                     # | ||||||
| # python exps/GMOA/vis-synthetic.py --env_version v2                       # | # python exps/GeMOSA/vis-synthetic.py --env_version v2                     # | ||||||
|  | # python exps/GeMOSA/vis-synthetic.py --env_version v2                     # | ||||||
| ############################################################################ | ############################################################################ | ||||||
| import os, sys, copy, random | import os, sys, copy, random | ||||||
| import torch | import torch | ||||||
| @@ -181,7 +182,7 @@ def compare_cl(save_dir): | |||||||
| def visualize_env(save_dir, version): | def visualize_env(save_dir, version): | ||||||
|     save_dir = Path(str(save_dir)) |     save_dir = Path(str(save_dir)) | ||||||
|     for substr in ("pdf", "png"): |     for substr in ("pdf", "png"): | ||||||
|         sub_save_dir = save_dir / substr |         sub_save_dir = save_dir / "{:}-{:}".format(substr, version) | ||||||
|         sub_save_dir.mkdir(parents=True, exist_ok=True) |         sub_save_dir.mkdir(parents=True, exist_ok=True) | ||||||
|  |  | ||||||
|     dynamic_env = get_synthetic_env(version=version) |     dynamic_env = get_synthetic_env(version=version) | ||||||
| @@ -190,6 +191,8 @@ def visualize_env(save_dir, version): | |||||||
|         allxs.append(allx) |         allxs.append(allx) | ||||||
|         allys.append(ally) |         allys.append(ally) | ||||||
|     allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1) |     allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1) | ||||||
|  |     print("x - min={:.3f}, max={:.3f}".format(allxs.min().item(), allxs.max().item())) | ||||||
|  |     print("y - min={:.3f}, max={:.3f}".format(allys.min().item(), allys.max().item())) | ||||||
|     for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)): |     for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)): | ||||||
|         dpi, width, height = 30, 1800, 1400 |         dpi, width, height = 30, 1800, 1400 | ||||||
|         figsize = width / float(dpi), height / float(dpi) |         figsize = width / float(dpi), height / float(dpi) | ||||||
| @@ -210,14 +213,22 @@ def visualize_env(save_dir, version): | |||||||
|         cur_ax.set_ylim(round(allys.min().item(), 1), round(allys.max().item(), 1)) |         cur_ax.set_ylim(round(allys.min().item(), 1), round(allys.max().item(), 1)) | ||||||
|         cur_ax.legend(loc=1, fontsize=LegendFontsize) |         cur_ax.legend(loc=1, fontsize=LegendFontsize) | ||||||
|  |  | ||||||
|         pdf_save_path = save_dir / "pdf" / "v{:}-{:05d}.pdf".format(version, idx) |         pdf_save_path = ( | ||||||
|  |             save_dir | ||||||
|  |             / "pdf-{:}".format(version) | ||||||
|  |             / "v{:}-{:05d}.pdf".format(version, idx) | ||||||
|  |         ) | ||||||
|         fig.savefig(str(pdf_save_path), dpi=dpi, bbox_inches="tight", format="pdf") |         fig.savefig(str(pdf_save_path), dpi=dpi, bbox_inches="tight", format="pdf") | ||||||
|         png_save_path = save_dir / "png" / "v{:}-{:05d}.png".format(version, idx) |         png_save_path = ( | ||||||
|  |             save_dir | ||||||
|  |             / "png-{:}".format(version) | ||||||
|  |             / "v{:}-{:05d}.png".format(version, idx) | ||||||
|  |         ) | ||||||
|         fig.savefig(str(png_save_path), dpi=dpi, bbox_inches="tight", format="png") |         fig.savefig(str(png_save_path), dpi=dpi, bbox_inches="tight", format="png") | ||||||
|         plt.close("all") |         plt.close("all") | ||||||
|     save_dir = save_dir.resolve() |     save_dir = save_dir.resolve() | ||||||
|     base_cmd = "ffmpeg -y -i {xdir}/v{version}-%05d.png -vf scale=1800:1400 -pix_fmt yuv420p -vb 5000k".format( |     base_cmd = "ffmpeg -y -i {xdir}/v{version}-%05d.png -vf scale=1800:1400 -pix_fmt yuv420p -vb 5000k".format( | ||||||
|         xdir=save_dir / "png", version=version |         xdir=save_dir / "png-{:}".format(version), version=version | ||||||
|     ) |     ) | ||||||
|     print(base_cmd) |     print(base_cmd) | ||||||
|     os.system("{:} {xdir}/env-{ver}.mp4".format(base_cmd, xdir=save_dir, ver=version)) |     os.system("{:} {xdir}/env-{ver}.mp4".format(base_cmd, xdir=save_dir, ver=version)) | ||||||
| @@ -367,7 +378,7 @@ if __name__ == "__main__": | |||||||
|     ) |     ) | ||||||
|     args = parser.parse_args() |     args = parser.parse_args() | ||||||
|  |  | ||||||
|     # visualize_env(os.path.join(args.save_dir, "vis-env"), "v1") |     visualize_env(os.path.join(args.save_dir, "vis-env"), args.env_version) | ||||||
|     # visualize_env(os.path.join(args.save_dir, "vis-env"), "v2") |     # visualize_env(os.path.join(args.save_dir, "vis-env"), "v2") | ||||||
|     compare_algs(os.path.join(args.save_dir, "compare-alg"), args.env_version) |     # compare_algs(os.path.join(args.save_dir, "compare-alg"), args.env_version) | ||||||
|     # compare_cl(os.path.join(args.save_dir, "compare-cl")) |     # compare_cl(os.path.join(args.save_dir, "compare-cl")) | ||||||
|   | |||||||
| @@ -4,6 +4,7 @@ | |||||||
| from .math_base_funcs import LinearFunc, QuadraticFunc, CubicFunc, QuarticFunc | from .math_base_funcs import LinearFunc, QuadraticFunc, CubicFunc, QuarticFunc | ||||||
| from .math_dynamic_funcs import DynamicLinearFunc | from .math_dynamic_funcs import DynamicLinearFunc | ||||||
| from .math_dynamic_funcs import DynamicQuadraticFunc | from .math_dynamic_funcs import DynamicQuadraticFunc | ||||||
|  | from .math_dynamic_funcs import DynamicSinQuadraticFunc | ||||||
| from .math_adv_funcs import ConstantFunc | from .math_adv_funcs import ConstantFunc | ||||||
| from .math_adv_funcs import ComposedSinFunc, ComposedCosFunc | from .math_adv_funcs import ComposedSinFunc, ComposedCosFunc | ||||||
| from .math_dynamic_generator import GaussianDGenerator | from .math_dynamic_generator import GaussianDGenerator | ||||||
|   | |||||||
| @@ -5,9 +5,6 @@ import math | |||||||
| import abc | import abc | ||||||
| import copy | import copy | ||||||
| import numpy as np | import numpy as np | ||||||
| from typing import Optional |  | ||||||
| import torch |  | ||||||
| import torch.utils.data as data |  | ||||||
|  |  | ||||||
| from .math_base_funcs import FitFunc | from .math_base_funcs import FitFunc | ||||||
|  |  | ||||||
| @@ -68,10 +65,11 @@ class DynamicQuadraticFunc(DynamicFunc): | |||||||
|     def __init__(self, params=None): |     def __init__(self, params=None): | ||||||
|         super(DynamicQuadraticFunc, self).__init__(3, params) |         super(DynamicQuadraticFunc, self).__init__(3, params) | ||||||
|  |  | ||||||
|     def __call__(self, x, timestamp=None): |     def __call__( | ||||||
|  |         self, | ||||||
|  |         x, | ||||||
|  |     ): | ||||||
|         self.check_valid() |         self.check_valid() | ||||||
|         if timestamp is None: |  | ||||||
|             timestamp = self._timestamp |  | ||||||
|         a = self._params[0](timestamp) |         a = self._params[0](timestamp) | ||||||
|         b = self._params[1](timestamp) |         b = self._params[1](timestamp) | ||||||
|         c = self._params[2](timestamp) |         c = self._params[2](timestamp) | ||||||
| @@ -80,10 +78,38 @@ class DynamicQuadraticFunc(DynamicFunc): | |||||||
|         return a * x * x + b * x + c |         return a * x * x + b * x + c | ||||||
|  |  | ||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
|         return "{name}({a} * x^2 + {b} * x + {c}, timestamp={timestamp})".format( |         return "{name}({a} * x^2 + {b} * x + {c})".format( | ||||||
|  |             name=self.__class__.__name__, | ||||||
|  |             a=self._params[0], | ||||||
|  |             b=self._params[1], | ||||||
|  |             c=self._params[2], | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class DynamicSinQuadraticFunc(DynamicFunc): | ||||||
|  |     """The dynamic quadratic function that outputs f(x) = sin(a * x^2 + b * x + c). | ||||||
|  |     The a, b, and c is a function of timestamp. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def __init__(self, params=None): | ||||||
|  |         super(DynamicSinQuadraticFunc, self).__init__(3, params) | ||||||
|  |  | ||||||
|  |     def __call__( | ||||||
|  |         self, | ||||||
|  |         x, | ||||||
|  |     ): | ||||||
|  |         self.check_valid() | ||||||
|  |         a = self._params[0](timestamp) | ||||||
|  |         b = self._params[1](timestamp) | ||||||
|  |         c = self._params[2](timestamp) | ||||||
|  |         convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x | ||||||
|  |         a, b, c = convert_fn(a), convert_fn(b), convert_fn(c) | ||||||
|  |         return math.sin(a * x * x + b * x + c) | ||||||
|  |  | ||||||
|  |     def __repr__(self): | ||||||
|  |         return "{name}({a} * x^2 + {b} * x + {c})".format( | ||||||
|             name=self.__class__.__name__, |             name=self.__class__.__name__, | ||||||
|             a=self._params[0], |             a=self._params[0], | ||||||
|             b=self._params[1], |             b=self._params[1], | ||||||
|             c=self._params[2], |             c=self._params[2], | ||||||
|             timestamp=self._timestamp, |  | ||||||
|         ) |         ) | ||||||
|   | |||||||
| @@ -3,7 +3,7 @@ from .synthetic_utils import TimeStamp | |||||||
| from .synthetic_env import SyntheticDEnv | from .synthetic_env import SyntheticDEnv | ||||||
| from .math_core import LinearFunc | from .math_core import LinearFunc | ||||||
| from .math_core import DynamicLinearFunc | from .math_core import DynamicLinearFunc | ||||||
| from .math_core import DynamicQuadraticFunc | from .math_core import DynamicQuadraticFunc, DynamicSinQuadraticFunc | ||||||
| from .math_core import ( | from .math_core import ( | ||||||
|     ConstantFunc, |     ConstantFunc, | ||||||
|     ComposedSinFunc as SinFunc, |     ComposedSinFunc as SinFunc, | ||||||
| @@ -63,9 +63,9 @@ def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, versio | |||||||
|         time_generator = TimeStamp( |         time_generator = TimeStamp( | ||||||
|             min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode |             min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode | ||||||
|         ) |         ) | ||||||
|         oracle_map = DynamicQuadraticFunc( |         oracle_map = DynamicSinQuadraticFunc( | ||||||
|             params={ |             params={ | ||||||
|                 0: LinearFunc(params={0: 0.1, 1: 0}),  # 0.1 * t |                 0: CosFunc(params={0: 0.5, 1: 1, 2: 1}),  # 0.5 cos(t) + 1 | ||||||
|                 1: SinFunc(params={0: 1, 1: 1, 2: 0}),  # sin(t) |                 1: SinFunc(params={0: 1, 1: 1, 2: 0}),  # sin(t) | ||||||
|                 2: ConstantFunc(0), |                 2: ConstantFunc(0), | ||||||
|             } |             } | ||||||
|   | |||||||
| @@ -1,6 +1,3 @@ | |||||||
| import math |  | ||||||
| import random |  | ||||||
| from typing import List, Optional, Dict |  | ||||||
| import torch | import torch | ||||||
| import torch.utils.data as data | import torch.utils.data as data | ||||||
|  |  | ||||||
| @@ -43,6 +40,18 @@ class SyntheticDEnv(data.Dataset): | |||||||
|         self._oracle_map = oracle_map |         self._oracle_map = oracle_map | ||||||
|         self._num_per_task = num_per_task |         self._num_per_task = num_per_task | ||||||
|         self._noise = noise |         self._noise = noise | ||||||
|  |         self._meta_info = dict() | ||||||
|  |  | ||||||
|  |     def set_regression(self): | ||||||
|  |         self._meta_info["task"] = "regression" | ||||||
|  |  | ||||||
|  |     def set_classification(self, num_classes): | ||||||
|  |         self._meta_info["task"] = "classification" | ||||||
|  |         self._meta_info["num_classes"] = int(num_classes) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def meta_info(self): | ||||||
|  |         return self._meta_info | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def min_timestamp(self): |     def min_timestamp(self): | ||||||
| @@ -60,13 +69,6 @@ class SyntheticDEnv(data.Dataset): | |||||||
|     def mode(self): |     def mode(self): | ||||||
|         return self._time_generator.mode |         return self._time_generator.mode | ||||||
|  |  | ||||||
|     def random_timestamp(self, min_timestamp=None, max_timestamp=None): |  | ||||||
|         if min_timestamp is None: |  | ||||||
|             min_timestamp = self.min_timestamp |  | ||||||
|         if max_timestamp is None: |  | ||||||
|             max_timestamp = self.max_timestamp |  | ||||||
|         return random.random() * (max_timestamp - min_timestamp) + min_timestamp |  | ||||||
|  |  | ||||||
|     def get_timestamp(self, index): |     def get_timestamp(self, index): | ||||||
|         if index is None: |         if index is None: | ||||||
|             timestamps = [] |             timestamps = [] | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user