Update synthetic codes
This commit is contained in:
		| @@ -103,11 +103,32 @@ class FitFunc(abc.ABC): | |||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class LinearFunc(FitFunc): | ||||||
|  |     """The linear function that outputs f(x) = a * x + b.""" | ||||||
|  |  | ||||||
|  |     def __init__(self, list_of_points=None, params=None): | ||||||
|  |         super(LinearFunc, self).__init__(2, list_of_points, params) | ||||||
|  |  | ||||||
|  |     def __call__(self, x): | ||||||
|  |         self.check_valid() | ||||||
|  |         return self._params[0] * x + self._params[1] | ||||||
|  |  | ||||||
|  |     def _getitem(self, x, weights): | ||||||
|  |         return weights[0] * x + weights[1] | ||||||
|  |  | ||||||
|  |     def __repr__(self): | ||||||
|  |         return "{name}({a} * x + {b})".format( | ||||||
|  |             name=self.__class__.__name__, | ||||||
|  |             a=self._params[0], | ||||||
|  |             b=self._params[1], | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
| class QuadraticFunc(FitFunc): | class QuadraticFunc(FitFunc): | ||||||
|     """The quadratic function that outputs f(x) = a * x^2 + b * x + c.""" |     """The quadratic function that outputs f(x) = a * x^2 + b * x + c.""" | ||||||
|  |  | ||||||
|     def __init__(self, list_of_points=None): |     def __init__(self, list_of_points=None, params=None): | ||||||
|         super(QuadraticFunc, self).__init__(3, list_of_points) |         super(QuadraticFunc, self).__init__(3, list_of_points, params) | ||||||
|  |  | ||||||
|     def __call__(self, x): |     def __call__(self, x): | ||||||
|         self.check_valid() |         self.check_valid() | ||||||
|   | |||||||
| @@ -1,7 +1,7 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 # | ||||||
| ##################################################### | ##################################################### | ||||||
| from .math_base_funcs import QuadraticFunc, CubicFunc, QuarticFunc | from .math_base_funcs import LinearFunc, QuadraticFunc, CubicFunc, QuarticFunc | ||||||
| from .math_dynamic_funcs import DynamicLinearFunc | from .math_dynamic_funcs import DynamicLinearFunc | ||||||
| from .math_dynamic_funcs import DynamicQuadraticFunc | from .math_dynamic_funcs import DynamicQuadraticFunc | ||||||
| from .math_adv_funcs import ConstantFunc | from .math_adv_funcs import ConstantFunc | ||||||
|   | |||||||
| @@ -3,9 +3,10 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| from .synthetic_utils import TimeStamp | from .synthetic_utils import TimeStamp | ||||||
| from .synthetic_env import SyntheticDEnv | from .synthetic_env import SyntheticDEnv | ||||||
| from .math_dynamic_funcs import DynamicLinearFunc | from .math_core import LinearFunc | ||||||
| from .math_dynamic_funcs import DynamicQuadraticFunc | from .math_core import DynamicLinearFunc | ||||||
| from .math_adv_funcs import ConstantFunc, ComposedSinFunc | from .math_core import DynamicQuadraticFunc | ||||||
|  | from .math_core import ConstantFunc, ComposedSinFunc | ||||||
|  |  | ||||||
|  |  | ||||||
| __all__ = ["TimeStamp", "SyntheticDEnv", "get_synthetic_env"] | __all__ = ["TimeStamp", "SyntheticDEnv", "get_synthetic_env"] | ||||||
| @@ -32,7 +33,8 @@ def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None, versio | |||||||
|         function = DynamicLinearFunc() |         function = DynamicLinearFunc() | ||||||
|         function_param = dict() |         function_param = dict() | ||||||
|         function_param[0] = ComposedSinFunc( |         function_param[0] = ComposedSinFunc( | ||||||
|             amplitude_scale=ConstantFunc(1.0), period_phase_shift=ConstantFunc(10) |             amplitude_scale=ConstantFunc(1.0), | ||||||
|  |             period_phase_shift=LinearFunc(params={0: 10, 1: 0}), | ||||||
|         ) |         ) | ||||||
|         function_param[1] = ConstantFunc(constant=0.9) |         function_param[1] = ConstantFunc(constant=0.9) | ||||||
|     elif version == "v2": |     elif version == "v2": | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user