Finalize example vis codes
This commit is contained in:
		| @@ -1,7 +1,7 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 # | ||||||
| ############################################################################ | ############################################################################ | ||||||
| # CUDA_VISIBLE_DEVICES=0 python exps/LFNA/vis-synthetic.py                 # | # python exps/LFNA/vis-synthetic.py                                        # | ||||||
| ############################################################################ | ############################################################################ | ||||||
| import os, sys, copy, random | import os, sys, copy, random | ||||||
| import torch | import torch | ||||||
| @@ -83,7 +83,7 @@ def find_max(cur, others): | |||||||
| def compare_cl(save_dir): | def compare_cl(save_dir): | ||||||
|     save_dir = Path(str(save_dir)) |     save_dir = Path(str(save_dir)) | ||||||
|     save_dir.mkdir(parents=True, exist_ok=True) |     save_dir.mkdir(parents=True, exist_ok=True) | ||||||
|     dynamic_env, function = create_example_v1( |     dynamic_env, cl_function = create_example_v1( | ||||||
|         # timestamp_config=dict(num=200, min_timestamp=-1, max_timestamp=1.0), |         # timestamp_config=dict(num=200, min_timestamp=-1, max_timestamp=1.0), | ||||||
|         timestamp_config=dict(num=200), |         timestamp_config=dict(num=200), | ||||||
|         num_per_task=1000, |         num_per_task=1000, | ||||||
| @@ -91,7 +91,6 @@ def compare_cl(save_dir): | |||||||
|  |  | ||||||
|     models = dict() |     models = dict() | ||||||
|  |  | ||||||
|     cl_function = copy.deepcopy(function) |  | ||||||
|     cl_function.set_timestamp(0) |     cl_function.set_timestamp(0) | ||||||
|     cl_xaxis_min = None |     cl_xaxis_min = None | ||||||
|     cl_xaxis_max = None |     cl_xaxis_max = None | ||||||
| @@ -99,23 +98,15 @@ def compare_cl(save_dir): | |||||||
|     all_data = OrderedDict() |     all_data = OrderedDict() | ||||||
|  |  | ||||||
|     for idx, (timestamp, dataset) in enumerate(tqdm(dynamic_env, ncols=50)): |     for idx, (timestamp, dataset) in enumerate(tqdm(dynamic_env, ncols=50)): | ||||||
|         xaxis_all = dataset[:, 0].numpy() |         xaxis_all = dataset[0][:, 0].numpy() | ||||||
|  |         yaxis_all = dataset[1][:, 0].numpy() | ||||||
|         current_data = dict() |         current_data = dict() | ||||||
|  |  | ||||||
|         function.set_timestamp(timestamp) |  | ||||||
|         yaxis_all = function.noise_call(xaxis_all) |  | ||||||
|         current_data["lfna_xaxis_all"] = xaxis_all |         current_data["lfna_xaxis_all"] = xaxis_all | ||||||
|         current_data["lfna_yaxis_all"] = yaxis_all |         current_data["lfna_yaxis_all"] = yaxis_all | ||||||
|  |  | ||||||
|         # compute cl-min |         # compute cl-min | ||||||
|         cl_xaxis_min = find_min(cl_xaxis_min, xaxis_all.mean() - xaxis_all.std()) |         cl_xaxis_min = find_min(cl_xaxis_min, xaxis_all.mean() - xaxis_all.std()) | ||||||
|         cl_xaxis_max = find_max(cl_xaxis_max, xaxis_all.mean() + xaxis_all.std()) |         cl_xaxis_max = find_max(cl_xaxis_max, xaxis_all.mean() + xaxis_all.std()) | ||||||
|         """ |  | ||||||
|         cl_xaxis_all = np.arange(cl_xaxis_min, cl_xaxis_max, step=0.05) |  | ||||||
|         cl_yaxis_all = cl_function.noise_call(cl_xaxis_all) |  | ||||||
|         current_data["cl_xaxis_all"] = cl_xaxis_all |  | ||||||
|         current_data["cl_yaxis_all"] = cl_yaxis_all |  | ||||||
|         """ |  | ||||||
|         all_data[timestamp] = current_data |         all_data[timestamp] = current_data | ||||||
|  |  | ||||||
|     global_cl_xaxis_all = np.arange(cl_xaxis_min, cl_xaxis_max, step=0.1) |     global_cl_xaxis_all = np.arange(cl_xaxis_min, cl_xaxis_max, step=0.1) | ||||||
| @@ -170,10 +161,12 @@ def compare_cl(save_dir): | |||||||
|             xdir=save_dir |             xdir=save_dir | ||||||
|         ) |         ) | ||||||
|     ) |     ) | ||||||
|     video_cmd = "{:} -pix_fmt yuv420p {xdir}/compare-cl.mp4".format(base_cmd, xdir=save_dir) |     video_cmd = "{:} -pix_fmt yuv420p {xdir}/compare-cl.mp4".format( | ||||||
|  |         base_cmd, xdir=save_dir | ||||||
|  |     ) | ||||||
|     print(video_cmd + "\n") |     print(video_cmd + "\n") | ||||||
|     os.system(video_cmd) |     os.system(video_cmd) | ||||||
|     # os.system("{:} {xdir}/vis.webm".format(base_cmd, xdir=save_dir)) |     os.system("{:} -pix_fmt yuv420p {xdir}/vis.webm".format(base_cmd, xdir=save_dir)) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|   | |||||||
| @@ -5,7 +5,8 @@ from .get_dataset_with_transform import get_datasets, get_nas_search_loaders | |||||||
| from .SearchDatasetWrap import SearchDataset | from .SearchDatasetWrap import SearchDataset | ||||||
|  |  | ||||||
| from .math_base_funcs import QuadraticFunc, CubicFunc, QuarticFunc | from .math_base_funcs import QuadraticFunc, CubicFunc, QuarticFunc | ||||||
| from .math_adv_funcs import DynamicQuadraticFunc, ConstantFunc | from .math_dynamic_funcs import DynamicQuadraticFunc | ||||||
|  | from .math_adv_funcs import ConstantFunc | ||||||
| from .math_adv_funcs import ComposedSinFunc | from .math_adv_funcs import ComposedSinFunc | ||||||
|  |  | ||||||
| from .synthetic_utils import TimeStamp | from .synthetic_utils import TimeStamp | ||||||
|   | |||||||
| @@ -14,41 +14,6 @@ from .math_base_funcs import QuadraticFunc | |||||||
| from .math_base_funcs import QuarticFunc | from .math_base_funcs import QuarticFunc | ||||||
|  |  | ||||||
|  |  | ||||||
| class DynamicQuadraticFunc(FitFunc): |  | ||||||
|     """The dynamic quadratic function that outputs f(x) = a * x^2 + b * x + c. |  | ||||||
|     The a, b, and c is a function of timestamp. |  | ||||||
|     """ |  | ||||||
|  |  | ||||||
|     def __init__(self, list_of_points=None): |  | ||||||
|         super(DynamicQuadraticFunc, self).__init__(3, list_of_points) |  | ||||||
|         self._timestamp = None |  | ||||||
|  |  | ||||||
|     def __call__(self, x, timestamp=None): |  | ||||||
|         self.check_valid() |  | ||||||
|         if timestamp is None: |  | ||||||
|             timestamp = self._timestamp |  | ||||||
|         a = self._params[0](timestamp) |  | ||||||
|         b = self._params[1](timestamp) |  | ||||||
|         c = self._params[2](timestamp) |  | ||||||
|         convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x |  | ||||||
|         a, b, c = convert_fn(a), convert_fn(b), convert_fn(c) |  | ||||||
|         return a * x * x + b * x + c |  | ||||||
|  |  | ||||||
|     def _getitem(self, x, weights): |  | ||||||
|         raise NotImplementedError |  | ||||||
|  |  | ||||||
|     def set_timestamp(self, timestamp): |  | ||||||
|         self._timestamp = timestamp |  | ||||||
|  |  | ||||||
|     def __repr__(self): |  | ||||||
|         return "{name}({a} * x^2 + {b} * x + {c})".format( |  | ||||||
|             name=self.__class__.__name__, |  | ||||||
|             a=self._params[0], |  | ||||||
|             b=self._params[1], |  | ||||||
|             c=self._params[2], |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class ConstantFunc(FitFunc): | class ConstantFunc(FitFunc): | ||||||
|     """The constant function: f(x) = c.""" |     """The constant function: f(x) = c.""" | ||||||
|  |  | ||||||
|   | |||||||
| @@ -13,20 +13,20 @@ import torch.utils.data as data | |||||||
| class FitFunc(abc.ABC): | class FitFunc(abc.ABC): | ||||||
|     """The fit function that outputs f(x) = a * x^2 + b * x + c.""" |     """The fit function that outputs f(x) = a * x^2 + b * x + c.""" | ||||||
|  |  | ||||||
|     def __init__(self, freedom: int, list_of_points=None, _params=None): |     def __init__(self, freedom: int, list_of_points=None, params=None): | ||||||
|         self._params = dict() |         self._params = dict() | ||||||
|         for i in range(freedom): |         for i in range(freedom): | ||||||
|             self._params[i] = None |             self._params[i] = None | ||||||
|         self._freedom = freedom |         self._freedom = freedom | ||||||
|         if list_of_points is not None and _params is not None: |         if list_of_points is not None and params is not None: | ||||||
|             raise ValueError("list_of_points and _params can not be set simultaneously") |             raise ValueError("list_of_points and params can not be set simultaneously") | ||||||
|         if list_of_points is not None: |         if list_of_points is not None: | ||||||
|             self.fit(list_of_points=list_of_points) |             self.fit(list_of_points=list_of_points) | ||||||
|         if _params is not None: |         if params is not None: | ||||||
|             self.set(_params) |             self.set(params) | ||||||
|  |  | ||||||
|     def set(self, _params): |     def set(self, params): | ||||||
|         self._params = copy.deepcopy(_params) |         self._params = copy.deepcopy(params) | ||||||
|  |  | ||||||
|     def check_valid(self): |     def check_valid(self): | ||||||
|         for key, value in self._params.items(): |         for key, value in self._params.items(): | ||||||
|   | |||||||
							
								
								
									
										66
									
								
								lib/datasets/math_dynamic_funcs.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										66
									
								
								lib/datasets/math_dynamic_funcs.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,66 @@ | |||||||
|  | ##################################################### | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||||
|  | ##################################################### | ||||||
|  | import math | ||||||
|  | import abc | ||||||
|  | import copy | ||||||
|  | import numpy as np | ||||||
|  | from typing import Optional | ||||||
|  | import torch | ||||||
|  | import torch.utils.data as data | ||||||
|  |  | ||||||
|  | from .math_base_funcs import FitFunc | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class DynamicFunc(FitFunc): | ||||||
|  |     """The dynamic quadratic function, where each param is a function.""" | ||||||
|  |  | ||||||
|  |     def __init__(self, freedom: int, params=None): | ||||||
|  |         super(DynamicFunc, self).__init__(freedom, None, params) | ||||||
|  |         self._timestamp = None | ||||||
|  |  | ||||||
|  |     def __call__(self, x, timestamp=None): | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     def _getitem(self, x, weights): | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     def set_timestamp(self, timestamp): | ||||||
|  |         self._timestamp = timestamp | ||||||
|  |  | ||||||
|  |     def noise_call(self, x, timestamp=None, std=0.1): | ||||||
|  |         clean_y = self.__call__(x, timestamp) | ||||||
|  |         if isinstance(clean_y, np.ndarray): | ||||||
|  |             noise_y = clean_y + np.random.normal(scale=std, size=clean_y.shape) | ||||||
|  |         else: | ||||||
|  |             raise ValueError("Unkonwn type: {:}".format(type(clean_y))) | ||||||
|  |         return noise_y | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class DynamicQuadraticFunc(DynamicFunc): | ||||||
|  |     """The dynamic quadratic function that outputs f(x) = a * x^2 + b * x + c. | ||||||
|  |     The a, b, and c is a function of timestamp. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def __init__(self, params=None): | ||||||
|  |         super(DynamicQuadraticFunc, self).__init__(3, params) | ||||||
|  |  | ||||||
|  |     def __call__(self, x, timestamp=None): | ||||||
|  |         self.check_valid() | ||||||
|  |         if timestamp is None: | ||||||
|  |             timestamp = self._timestamp | ||||||
|  |         a = self._params[0](timestamp) | ||||||
|  |         b = self._params[1](timestamp) | ||||||
|  |         c = self._params[2](timestamp) | ||||||
|  |         convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x | ||||||
|  |         a, b, c = convert_fn(a), convert_fn(b), convert_fn(c) | ||||||
|  |         return a * x * x + b * x + c | ||||||
|  |  | ||||||
|  |     def __repr__(self): | ||||||
|  |         return "{name}({a} * x^2 + {b} * x + {c}, timestamp={timestamp})".format( | ||||||
|  |             name=self.__class__.__name__, | ||||||
|  |             a=self._params[0], | ||||||
|  |             b=self._params[1], | ||||||
|  |             c=self._params[2], | ||||||
|  |             timestamp=self._timestamp, | ||||||
|  |         ) | ||||||
| @@ -41,6 +41,11 @@ class SyntheticDEnv(data.Dataset): | |||||||
|         self._mean_functors = mean_functors |         self._mean_functors = mean_functors | ||||||
|         self._cov_functors = cov_functors |         self._cov_functors = cov_functors | ||||||
|  |  | ||||||
|  |         self._oracle_map = None | ||||||
|  |  | ||||||
|  |     def set_oracle_map(self, functor): | ||||||
|  |         self._oracle_map = functor | ||||||
|  |  | ||||||
|     def __iter__(self): |     def __iter__(self): | ||||||
|         self._iter_num = 0 |         self._iter_num = 0 | ||||||
|         return self |         return self | ||||||
| @@ -63,7 +68,11 @@ class SyntheticDEnv(data.Dataset): | |||||||
|         dataset = np.random.multivariate_normal( |         dataset = np.random.multivariate_normal( | ||||||
|             mean_list, cov_matrix, size=self._num_per_task |             mean_list, cov_matrix, size=self._num_per_task | ||||||
|         ) |         ) | ||||||
|  |         if self._oracle_map is None: | ||||||
|             return timestamp, torch.Tensor(dataset) |             return timestamp, torch.Tensor(dataset) | ||||||
|  |         else: | ||||||
|  |             targets = self._oracle_map.noise_call(dataset, timestamp) | ||||||
|  |             return timestamp, (torch.Tensor(dataset), torch.Tensor(targets)) | ||||||
|  |  | ||||||
|     def __len__(self): |     def __len__(self): | ||||||
|         return len(self._timestamp_generator) |         return len(self._timestamp_generator) | ||||||
|   | |||||||
| @@ -1,8 +1,9 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||||
| ##################################################### | ##################################################### | ||||||
|  | import copy | ||||||
|  |  | ||||||
| from .math_adv_funcs import DynamicQuadraticFunc | from .math_dynamic_funcs import DynamicQuadraticFunc | ||||||
| from .math_adv_funcs import ConstantFunc, ComposedSinFunc | from .math_adv_funcs import ConstantFunc, ComposedSinFunc | ||||||
| from .synthetic_env import SyntheticDEnv | from .synthetic_env import SyntheticDEnv | ||||||
|  |  | ||||||
| @@ -11,7 +12,6 @@ def create_example_v1( | |||||||
|     timestamp_config=None, |     timestamp_config=None, | ||||||
|     num_per_task=5000, |     num_per_task=5000, | ||||||
| ): | ): | ||||||
|     # timestamp_config=dict(num=100, min_timestamp=0.0, max_timestamp=1.0), |  | ||||||
|     mean_generator = ComposedSinFunc() |     mean_generator = ComposedSinFunc() | ||||||
|     std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=0.5) |     std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=0.5) | ||||||
|  |  | ||||||
| @@ -32,4 +32,6 @@ def create_example_v1( | |||||||
|         num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9 |         num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9 | ||||||
|     ) |     ) | ||||||
|     function.set(function_param) |     function.set(function_param) | ||||||
|  |  | ||||||
|  |     dynamic_env.set_oracle_map(copy.deepcopy(function)) | ||||||
|     return dynamic_env, function |     return dynamic_env, function | ||||||
|   | |||||||
| @@ -6,3 +6,4 @@ black ./lib/datasets | |||||||
| black ./lib/xlayers | black ./lib/xlayers | ||||||
| black ./exps/LFNA | black ./exps/LFNA | ||||||
| black ./exps/trading | black ./exps/trading | ||||||
|  | black ./lib/procedures | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user