Add one more synthetic env
This commit is contained in:
		| @@ -5,7 +5,7 @@ from .get_dataset_with_transform import get_datasets, get_nas_search_loaders | ||||
| from .SearchDatasetWrap import SearchDataset | ||||
|  | ||||
| from .math_base_funcs import QuadraticFunc, CubicFunc, QuarticFunc | ||||
| from .math_dynamic_funcs import DynamicQuadraticFunc | ||||
| from .math_dynamic_funcs import DynamicQuadraticFunc, DynamicLinearFunc | ||||
| from .math_adv_funcs import ConstantFunc | ||||
| from .math_adv_funcs import ComposedSinFunc | ||||
|  | ||||
|   | ||||
| @@ -59,9 +59,13 @@ class ComposedSinFunc(FitFunc): | ||||
|         max_amplitude = kwargs.get("max_amplitude", 4) | ||||
|         phase_shift = kwargs.get("phase_shift", 0.0) | ||||
|         # create parameters | ||||
|         if kwargs.get("amplitude_scale", None) is None: | ||||
|             amplitude_scale = QuadraticFunc( | ||||
|                 [(0, min_amplitude), (0.5, max_amplitude), (1, min_amplitude)] | ||||
|             ) | ||||
|         else: | ||||
|             amplitude_scale = kwargs.get("amplitude_scale") | ||||
|         if kwargs.get("period_phase_shift", None) is None: | ||||
|             fitting_data = [] | ||||
|             temp_max_scalar = 2 ** (num_sin_phase - 1) | ||||
|             for i in range(num_sin_phase): | ||||
| @@ -71,6 +75,8 @@ class ComposedSinFunc(FitFunc): | ||||
|                     inter_value = value + (next_value - value) * _phase | ||||
|                     fitting_data.append((inter_value, math.pi * (2 * i + _phase))) | ||||
|             period_phase_shift = QuarticFunc(fitting_data) | ||||
|         else: | ||||
|             period_phase_shift = kwargs.get("period_phase_shift") | ||||
|         self.set( | ||||
|             dict(amplitude_scale=amplitude_scale, period_phase_shift=period_phase_shift) | ||||
|         ) | ||||
|   | ||||
| @@ -37,6 +37,33 @@ class DynamicFunc(FitFunc): | ||||
|         return noise_y | ||||
|  | ||||
|  | ||||
| class DynamicLinearFunc(DynamicFunc): | ||||
|     """The dynamic linear function that outputs f(x) = a * x + b. | ||||
|     The a and b is a function of timestamp. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, params=None): | ||||
|         super(DynamicLinearFunc, self).__init__(3, params) | ||||
|  | ||||
|     def __call__(self, x, timestamp=None): | ||||
|         self.check_valid() | ||||
|         if timestamp is None: | ||||
|             timestamp = self._timestamp | ||||
|         a = self._params[0](timestamp) | ||||
|         b = self._params[1](timestamp) | ||||
|         convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x | ||||
|         a, b = convert_fn(a), convert_fn(b) | ||||
|         return a * x + b | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({a} * x + {b}, timestamp={timestamp})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             a=self._params[0], | ||||
|             b=self._params[1], | ||||
|             timestamp=self._timestamp, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class DynamicQuadraticFunc(DynamicFunc): | ||||
|     """The dynamic quadratic function that outputs f(x) = a * x^2 + b * x + c. | ||||
|     The a, b, and c is a function of timestamp. | ||||
|   | ||||
| @@ -3,11 +3,20 @@ | ||||
| ##################################################### | ||||
| import copy | ||||
|  | ||||
| from .math_dynamic_funcs import DynamicQuadraticFunc | ||||
| from .math_dynamic_funcs import DynamicLinearFunc, DynamicQuadraticFunc | ||||
| from .math_adv_funcs import ConstantFunc, ComposedSinFunc | ||||
| from .synthetic_env import SyntheticDEnv | ||||
|  | ||||
|  | ||||
| def create_example(timestamp_config=None, num_per_task=5000, indicator="v1"): | ||||
|     if indicator == "v1": | ||||
|         return create_example_v1(timestamp_config, num_per_task) | ||||
|     elif indicator == "v2": | ||||
|         return create_example_v2(timestamp_config, num_per_task) | ||||
|     else: | ||||
|         raise ValueError("Unkonwn indicator: {:}".format(indicator)) | ||||
|  | ||||
|  | ||||
| def create_example_v1( | ||||
|     timestamp_config=None, | ||||
|     num_per_task=5000, | ||||
| @@ -35,3 +44,29 @@ def create_example_v1( | ||||
|  | ||||
|     dynamic_env.set_oracle_map(copy.deepcopy(function)) | ||||
|     return dynamic_env, function | ||||
|  | ||||
|  | ||||
| def create_example_v2( | ||||
|     timestamp_config=None, | ||||
|     num_per_task=5000, | ||||
| ): | ||||
|     mean_generator = ConstantFunc(0) | ||||
|     std_generator = ConstantFunc(1) | ||||
|  | ||||
|     dynamic_env = SyntheticDEnv( | ||||
|         [mean_generator], | ||||
|         [[std_generator]], | ||||
|         num_per_task=num_per_task, | ||||
|         timestamp_config=timestamp_config, | ||||
|     ) | ||||
|  | ||||
|     function = DynamicLinearFunc() | ||||
|     function_param = dict() | ||||
|     function_param[0] = ComposedSinFunc( | ||||
|         amplitude_scale=ConstantFunc(1.0), period_phase_shift=ConstantFunc(1.0) | ||||
|     ) | ||||
|     function_param[1] = ConstantFunc(constant=0.9) | ||||
|     function.set(function_param) | ||||
|  | ||||
|     dynamic_env.set_oracle_map(copy.deepcopy(function)) | ||||
|     return dynamic_env, function | ||||
|   | ||||
| @@ -15,6 +15,7 @@ if str(lib_dir) not in sys.path: | ||||
|  | ||||
| from datasets import QuadraticFunc | ||||
| from datasets import ConstantFunc | ||||
| from datasets import DynamicLinearFunc | ||||
| from datasets import DynamicQuadraticFunc | ||||
| from datasets import ComposedSinFunc | ||||
|  | ||||
| @@ -50,3 +51,20 @@ class TestDynamicFunc(unittest.TestCase): | ||||
|  | ||||
|         function.set_timestamp(1) | ||||
|         print(function(2)) | ||||
|  | ||||
|     def test_simple_linear(self): | ||||
|         timestamps = 30 | ||||
|         function = DynamicLinearFunc() | ||||
|         function_param = dict() | ||||
|         function_param[0] = ComposedSinFunc( | ||||
|             num=timestamps, num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0 | ||||
|         ) | ||||
|         function_param[1] = ConstantFunc(constant=0.9) | ||||
|         function.set(function_param) | ||||
|         print(function) | ||||
|  | ||||
|         with self.assertRaises(TypeError) as context: | ||||
|             function(0) | ||||
|  | ||||
|         function.set_timestamp(1) | ||||
|         print(function(2)) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user