Add one more synthetic env
This commit is contained in:
		| @@ -5,7 +5,7 @@ from .get_dataset_with_transform import get_datasets, get_nas_search_loaders | |||||||
| from .SearchDatasetWrap import SearchDataset | from .SearchDatasetWrap import SearchDataset | ||||||
|  |  | ||||||
| from .math_base_funcs import QuadraticFunc, CubicFunc, QuarticFunc | from .math_base_funcs import QuadraticFunc, CubicFunc, QuarticFunc | ||||||
| from .math_dynamic_funcs import DynamicQuadraticFunc | from .math_dynamic_funcs import DynamicQuadraticFunc, DynamicLinearFunc | ||||||
| from .math_adv_funcs import ConstantFunc | from .math_adv_funcs import ConstantFunc | ||||||
| from .math_adv_funcs import ComposedSinFunc | from .math_adv_funcs import ComposedSinFunc | ||||||
|  |  | ||||||
|   | |||||||
| @@ -59,18 +59,24 @@ class ComposedSinFunc(FitFunc): | |||||||
|         max_amplitude = kwargs.get("max_amplitude", 4) |         max_amplitude = kwargs.get("max_amplitude", 4) | ||||||
|         phase_shift = kwargs.get("phase_shift", 0.0) |         phase_shift = kwargs.get("phase_shift", 0.0) | ||||||
|         # create parameters |         # create parameters | ||||||
|         amplitude_scale = QuadraticFunc( |         if kwargs.get("amplitude_scale", None) is None: | ||||||
|             [(0, min_amplitude), (0.5, max_amplitude), (1, min_amplitude)] |             amplitude_scale = QuadraticFunc( | ||||||
|         ) |                 [(0, min_amplitude), (0.5, max_amplitude), (1, min_amplitude)] | ||||||
|         fitting_data = [] |             ) | ||||||
|         temp_max_scalar = 2 ** (num_sin_phase - 1) |         else: | ||||||
|         for i in range(num_sin_phase): |             amplitude_scale = kwargs.get("amplitude_scale") | ||||||
|             value = (2 ** i) / temp_max_scalar |         if kwargs.get("period_phase_shift", None) is None: | ||||||
|             next_value = (2 ** (i + 1)) / temp_max_scalar |             fitting_data = [] | ||||||
|             for _phase in (0, 0.25, 0.5, 0.75): |             temp_max_scalar = 2 ** (num_sin_phase - 1) | ||||||
|                 inter_value = value + (next_value - value) * _phase |             for i in range(num_sin_phase): | ||||||
|                 fitting_data.append((inter_value, math.pi * (2 * i + _phase))) |                 value = (2 ** i) / temp_max_scalar | ||||||
|         period_phase_shift = QuarticFunc(fitting_data) |                 next_value = (2 ** (i + 1)) / temp_max_scalar | ||||||
|  |                 for _phase in (0, 0.25, 0.5, 0.75): | ||||||
|  |                     inter_value = value + (next_value - value) * _phase | ||||||
|  |                     fitting_data.append((inter_value, math.pi * (2 * i + _phase))) | ||||||
|  |             period_phase_shift = QuarticFunc(fitting_data) | ||||||
|  |         else: | ||||||
|  |             period_phase_shift = kwargs.get("period_phase_shift") | ||||||
|         self.set( |         self.set( | ||||||
|             dict(amplitude_scale=amplitude_scale, period_phase_shift=period_phase_shift) |             dict(amplitude_scale=amplitude_scale, period_phase_shift=period_phase_shift) | ||||||
|         ) |         ) | ||||||
|   | |||||||
| @@ -37,6 +37,33 @@ class DynamicFunc(FitFunc): | |||||||
|         return noise_y |         return noise_y | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class DynamicLinearFunc(DynamicFunc): | ||||||
|  |     """The dynamic linear function that outputs f(x) = a * x + b. | ||||||
|  |     The a and b is a function of timestamp. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def __init__(self, params=None): | ||||||
|  |         super(DynamicLinearFunc, self).__init__(3, params) | ||||||
|  |  | ||||||
|  |     def __call__(self, x, timestamp=None): | ||||||
|  |         self.check_valid() | ||||||
|  |         if timestamp is None: | ||||||
|  |             timestamp = self._timestamp | ||||||
|  |         a = self._params[0](timestamp) | ||||||
|  |         b = self._params[1](timestamp) | ||||||
|  |         convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x | ||||||
|  |         a, b = convert_fn(a), convert_fn(b) | ||||||
|  |         return a * x + b | ||||||
|  |  | ||||||
|  |     def __repr__(self): | ||||||
|  |         return "{name}({a} * x + {b}, timestamp={timestamp})".format( | ||||||
|  |             name=self.__class__.__name__, | ||||||
|  |             a=self._params[0], | ||||||
|  |             b=self._params[1], | ||||||
|  |             timestamp=self._timestamp, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
| class DynamicQuadraticFunc(DynamicFunc): | class DynamicQuadraticFunc(DynamicFunc): | ||||||
|     """The dynamic quadratic function that outputs f(x) = a * x^2 + b * x + c. |     """The dynamic quadratic function that outputs f(x) = a * x^2 + b * x + c. | ||||||
|     The a, b, and c is a function of timestamp. |     The a, b, and c is a function of timestamp. | ||||||
|   | |||||||
| @@ -3,11 +3,20 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| import copy | import copy | ||||||
|  |  | ||||||
| from .math_dynamic_funcs import DynamicQuadraticFunc | from .math_dynamic_funcs import DynamicLinearFunc, DynamicQuadraticFunc | ||||||
| from .math_adv_funcs import ConstantFunc, ComposedSinFunc | from .math_adv_funcs import ConstantFunc, ComposedSinFunc | ||||||
| from .synthetic_env import SyntheticDEnv | from .synthetic_env import SyntheticDEnv | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def create_example(timestamp_config=None, num_per_task=5000, indicator="v1"): | ||||||
|  |     if indicator == "v1": | ||||||
|  |         return create_example_v1(timestamp_config, num_per_task) | ||||||
|  |     elif indicator == "v2": | ||||||
|  |         return create_example_v2(timestamp_config, num_per_task) | ||||||
|  |     else: | ||||||
|  |         raise ValueError("Unkonwn indicator: {:}".format(indicator)) | ||||||
|  |  | ||||||
|  |  | ||||||
| def create_example_v1( | def create_example_v1( | ||||||
|     timestamp_config=None, |     timestamp_config=None, | ||||||
|     num_per_task=5000, |     num_per_task=5000, | ||||||
| @@ -35,3 +44,29 @@ def create_example_v1( | |||||||
|  |  | ||||||
|     dynamic_env.set_oracle_map(copy.deepcopy(function)) |     dynamic_env.set_oracle_map(copy.deepcopy(function)) | ||||||
|     return dynamic_env, function |     return dynamic_env, function | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def create_example_v2( | ||||||
|  |     timestamp_config=None, | ||||||
|  |     num_per_task=5000, | ||||||
|  | ): | ||||||
|  |     mean_generator = ConstantFunc(0) | ||||||
|  |     std_generator = ConstantFunc(1) | ||||||
|  |  | ||||||
|  |     dynamic_env = SyntheticDEnv( | ||||||
|  |         [mean_generator], | ||||||
|  |         [[std_generator]], | ||||||
|  |         num_per_task=num_per_task, | ||||||
|  |         timestamp_config=timestamp_config, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     function = DynamicLinearFunc() | ||||||
|  |     function_param = dict() | ||||||
|  |     function_param[0] = ComposedSinFunc( | ||||||
|  |         amplitude_scale=ConstantFunc(1.0), period_phase_shift=ConstantFunc(1.0) | ||||||
|  |     ) | ||||||
|  |     function_param[1] = ConstantFunc(constant=0.9) | ||||||
|  |     function.set(function_param) | ||||||
|  |  | ||||||
|  |     dynamic_env.set_oracle_map(copy.deepcopy(function)) | ||||||
|  |     return dynamic_env, function | ||||||
|   | |||||||
| @@ -15,6 +15,7 @@ if str(lib_dir) not in sys.path: | |||||||
|  |  | ||||||
| from datasets import QuadraticFunc | from datasets import QuadraticFunc | ||||||
| from datasets import ConstantFunc | from datasets import ConstantFunc | ||||||
|  | from datasets import DynamicLinearFunc | ||||||
| from datasets import DynamicQuadraticFunc | from datasets import DynamicQuadraticFunc | ||||||
| from datasets import ComposedSinFunc | from datasets import ComposedSinFunc | ||||||
|  |  | ||||||
| @@ -50,3 +51,20 @@ class TestDynamicFunc(unittest.TestCase): | |||||||
|  |  | ||||||
|         function.set_timestamp(1) |         function.set_timestamp(1) | ||||||
|         print(function(2)) |         print(function(2)) | ||||||
|  |  | ||||||
|  |     def test_simple_linear(self): | ||||||
|  |         timestamps = 30 | ||||||
|  |         function = DynamicLinearFunc() | ||||||
|  |         function_param = dict() | ||||||
|  |         function_param[0] = ComposedSinFunc( | ||||||
|  |             num=timestamps, num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0 | ||||||
|  |         ) | ||||||
|  |         function_param[1] = ConstantFunc(constant=0.9) | ||||||
|  |         function.set(function_param) | ||||||
|  |         print(function) | ||||||
|  |  | ||||||
|  |         with self.assertRaises(TypeError) as context: | ||||||
|  |             function(0) | ||||||
|  |  | ||||||
|  |         function.set_timestamp(1) | ||||||
|  |         print(function(2)) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user