Update synthetic codes
This commit is contained in:
		| @@ -103,11 +103,32 @@ class FitFunc(abc.ABC): | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class LinearFunc(FitFunc): | ||||
|     """The linear function that outputs f(x) = a * x + b.""" | ||||
|  | ||||
|     def __init__(self, list_of_points=None, params=None): | ||||
|         super(LinearFunc, self).__init__(2, list_of_points, params) | ||||
|  | ||||
|     def __call__(self, x): | ||||
|         self.check_valid() | ||||
|         return self._params[0] * x + self._params[1] | ||||
|  | ||||
|     def _getitem(self, x, weights): | ||||
|         return weights[0] * x + weights[1] | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({a} * x + {b})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             a=self._params[0], | ||||
|             b=self._params[1], | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class QuadraticFunc(FitFunc): | ||||
|     """The quadratic function that outputs f(x) = a * x^2 + b * x + c.""" | ||||
|  | ||||
|     def __init__(self, list_of_points=None): | ||||
|         super(QuadraticFunc, self).__init__(3, list_of_points) | ||||
|     def __init__(self, list_of_points=None, params=None): | ||||
|         super(QuadraticFunc, self).__init__(3, list_of_points, params) | ||||
|  | ||||
|     def __call__(self, x): | ||||
|         self.check_valid() | ||||
|   | ||||
| @@ -1,7 +1,7 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 # | ||||
| ##################################################### | ||||
| from .math_base_funcs import QuadraticFunc, CubicFunc, QuarticFunc | ||||
| from .math_base_funcs import LinearFunc, QuadraticFunc, CubicFunc, QuarticFunc | ||||
| from .math_dynamic_funcs import DynamicLinearFunc | ||||
| from .math_dynamic_funcs import DynamicQuadraticFunc | ||||
| from .math_adv_funcs import ConstantFunc | ||||
|   | ||||
| @@ -3,9 +3,10 @@ | ||||
| ##################################################### | ||||
| from .synthetic_utils import TimeStamp | ||||
| from .synthetic_env import SyntheticDEnv | ||||
| from .math_dynamic_funcs import DynamicLinearFunc | ||||
| from .math_dynamic_funcs import DynamicQuadraticFunc | ||||
| from .math_adv_funcs import ConstantFunc, ComposedSinFunc | ||||
| from .math_core import LinearFunc | ||||
| from .math_core import DynamicLinearFunc | ||||
| from .math_core import DynamicQuadraticFunc | ||||
| from .math_core import ConstantFunc, ComposedSinFunc | ||||
|  | ||||
|  | ||||
| __all__ = ["TimeStamp", "SyntheticDEnv", "get_synthetic_env"] | ||||
| @@ -32,7 +33,8 @@ def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None, versio | ||||
|         function = DynamicLinearFunc() | ||||
|         function_param = dict() | ||||
|         function_param[0] = ComposedSinFunc( | ||||
|             amplitude_scale=ConstantFunc(1.0), period_phase_shift=ConstantFunc(10) | ||||
|             amplitude_scale=ConstantFunc(1.0), | ||||
|             period_phase_shift=LinearFunc(params={0: 10, 1: 0}), | ||||
|         ) | ||||
|         function_param[1] = ConstantFunc(constant=0.9) | ||||
|     elif version == "v2": | ||||
|   | ||||
		Reference in New Issue
	
	Block a user