Update GeMOSA v4
This commit is contained in:
		| @@ -2,9 +2,9 @@ | ||||
| # Learning to Generate Model One Step Ahead         # | ||||
| ##################################################### | ||||
| # python exps/GeMOSA/main.py --env_version v1 --workers 0 | ||||
| # python exps/GeMOSA/main.py --env_version v1 --device cuda --lr 0.002 --hidden_dim 16 --meta_batch 256 | ||||
| # python exps/GeMOSA/main.py --env_version v2 --device cuda --lr 0.002 --hidden_dim 16 --meta_batch 256 | ||||
| # python exps/GeMOSA/main.py --env_version v3 --device cuda --lr 0.002 --hidden_dim 32 --meta_batch 256 | ||||
| # python exps/GeMOSA/main.py --env_version v1 --lr 0.002 --hidden_dim 16 --meta_batch 256 --device cuda | ||||
| # python exps/GeMOSA/main.py --env_version v2 --lr 0.002 --hidden_dim 16 --meta_batch 256 --device cuda | ||||
| # python exps/GeMOSA/main.py --env_version v3 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --device cuda | ||||
| ##################################################### | ||||
| import sys, time, copy, torch, random, argparse | ||||
| from tqdm import tqdm | ||||
|   | ||||
| @@ -3,7 +3,8 @@ | ||||
| ############################################################################ | ||||
| # python exps/GeMOSA/vis-synthetic.py --env_version v1                     # | ||||
| # python exps/GeMOSA/vis-synthetic.py --env_version v2                     # | ||||
| # python exps/GeMOSA/vis-synthetic.py --env_version v2                     # | ||||
| # python exps/GeMOSA/vis-synthetic.py --env_version v3                     # | ||||
| # python exps/GeMOSA/vis-synthetic.py --env_version v4                     # | ||||
| ############################################################################ | ||||
| import os, sys, copy, random | ||||
| import torch | ||||
| @@ -31,8 +32,8 @@ from xautodl.procedures.metric_utils import MSEMetric | ||||
|  | ||||
|  | ||||
| def plot_scatter(cur_ax, xs, ys, color, alpha, linewidths, label=None): | ||||
|     cur_ax.scatter([-100], [-100], color=color, linewidths=linewidths, label=label) | ||||
|     cur_ax.scatter(xs, ys, color=color, alpha=alpha, linewidths=1.5, label=None) | ||||
|     cur_ax.scatter([-100], [-100], color=color, linewidths=linewidths[0], label=label) | ||||
|     cur_ax.scatter(xs, ys, color=color, alpha=alpha, linewidths=linewidths[1], label=None) | ||||
|  | ||||
|  | ||||
| def draw_multi_fig(save_dir, timestamp, scatter_list, wh, fig_title=None): | ||||
| @@ -186,15 +187,23 @@ def visualize_env(save_dir, version): | ||||
|         sub_save_dir.mkdir(parents=True, exist_ok=True) | ||||
|  | ||||
|     dynamic_env = get_synthetic_env(version=version) | ||||
|     print("env: {:}".format(dynamic_env)) | ||||
|     print("oracle_map: {:}".format(dynamic_env.oracle_map)) | ||||
|     allxs, allys = [], [] | ||||
|     for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)): | ||||
|         allxs.append(allx) | ||||
|         allys.append(ally) | ||||
|     if dynamic_env.meta_info['task'] == 'regression': | ||||
|         allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1) | ||||
|     print("env: {:}".format(dynamic_env)) | ||||
|     print("oracle_map: {:}".format(dynamic_env.oracle_map)) | ||||
|         print("x - min={:.3f}, max={:.3f}".format(allxs.min().item(), allxs.max().item())) | ||||
|         print("y - min={:.3f}, max={:.3f}".format(allys.min().item(), allys.max().item())) | ||||
|     elif dynamic_env.meta_info['task'] == 'classification': | ||||
|         allxs = torch.cat(allxs) | ||||
|         print("x[0] - min={:.3f}, max={:.3f}".format(allxs[:,0].min().item(), allxs[:,0].max().item())) | ||||
|         print("x[1] - min={:.3f}, max={:.3f}".format(allxs[:,1].min().item(), allxs[:,1].max().item())) | ||||
|     else: | ||||
|         raise ValueError("Unknown task".format(dynamic_env.meta_info['task'])) | ||||
|  | ||||
|     for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)): | ||||
|         dpi, width, height = 30, 1800, 1400 | ||||
|         figsize = width / float(dpi), height / float(dpi) | ||||
| @@ -202,8 +211,21 @@ def visualize_env(save_dir, version): | ||||
|         fig = plt.figure(figsize=figsize) | ||||
|  | ||||
|         cur_ax = fig.add_subplot(1, 1, 1) | ||||
|         if dynamic_env.meta_info['task'] == 'regression': | ||||
|             allx, ally = allx[:, 0].numpy(), ally[:, 0].numpy() | ||||
|         plot_scatter(cur_ax, allx, ally, "k", 0.99, 15, "timestamp={:05d}".format(idx)) | ||||
|             plot_scatter(cur_ax, allx, ally, "k", 0.99, (15, 1.5), "timestamp={:05d}".format(idx)) | ||||
|             cur_ax.set_xlim(round(allxs.min().item(), 1), round(allxs.max().item(), 1)) | ||||
|             cur_ax.set_ylim(round(allys.min().item(), 1), round(allys.max().item(), 1)) | ||||
|         elif dynamic_env.meta_info['task'] == 'classification': | ||||
|             positive, negative = ally == 1, ally == 0 | ||||
|             # plot_scatter(cur_ax, [1], [1], "k", 0.1, 1, "timestamp={:05d}".format(idx)) | ||||
|             plot_scatter(cur_ax, allx[positive,0], allx[positive,1], "r", 0.99, (20, 10), "positive") | ||||
|             plot_scatter(cur_ax, allx[negative,0], allx[negative,1], "g", 0.99, (20, 10), "negative") | ||||
|             cur_ax.set_xlim(round(allxs[:,0].min().item(), 1), round(allxs[:,0].max().item(), 1)) | ||||
|             cur_ax.set_ylim(round(allxs[:,1].min().item(), 1), round(allxs[:,1].max().item(), 1)) | ||||
|         else: | ||||
|             raise ValueError("Unknown task".format(dynamic_env.meta_info['task'])) | ||||
|  | ||||
|         cur_ax.set_xlabel("X", fontsize=LabelSize) | ||||
|         cur_ax.set_ylabel("Y", rotation=0, fontsize=LabelSize) | ||||
|         for tick in cur_ax.xaxis.get_major_ticks(): | ||||
| @@ -211,10 +233,7 @@ def visualize_env(save_dir, version): | ||||
|                 tick.label.set_rotation(10) | ||||
|         for tick in cur_ax.yaxis.get_major_ticks(): | ||||
|                 tick.label.set_fontsize(LabelSize - font_gap) | ||||
|         cur_ax.set_xlim(round(allxs.min().item(), 1), round(allxs.max().item(), 1)) | ||||
|         cur_ax.set_ylim(round(allys.min().item(), 1), round(allys.max().item(), 1)) | ||||
|         cur_ax.legend(loc=1, fontsize=LegendFontsize)    | ||||
|  | ||||
|         pdf_save_path = ( | ||||
|             save_dir | ||||
|             / "pdf-{:}".format(version) | ||||
| @@ -237,7 +256,7 @@ def visualize_env(save_dir, version): | ||||
|     os.system("{:} {xdir}/env-{ver}.webm".format(base_cmd, xdir=save_dir, ver=version)) | ||||
|  | ||||
|  | ||||
| def compare_algs(save_dir, version, alg_dir="./outputs/lfna-synthetic"): | ||||
| def compare_algs(save_dir, version, alg_dir="./outputs/GeMOSA-synthetic"): | ||||
|     save_dir = Path(str(save_dir)) | ||||
|     for substr in ("pdf", "png"): | ||||
|         sub_save_dir = save_dir / substr | ||||
|   | ||||
| @@ -10,5 +10,10 @@ from .math_static_funcs import ( | ||||
|     ComposedSinSFunc, | ||||
|     ComposedCosSFunc, | ||||
| ) | ||||
| from .math_dynamic_funcs import LinearDFunc, QuadraticDFunc, SinQuadraticDFunc | ||||
| from .math_dynamic_generator import GaussianDGenerator | ||||
| from .math_dynamic_funcs import ( | ||||
|     LinearDFunc, | ||||
|     QuadraticDFunc, | ||||
|     SinQuadraticDFunc, | ||||
|     BinaryQuadraticDFunc, | ||||
| ) | ||||
| from .math_dynamic_generator import UniformDGenerator, GaussianDGenerator | ||||
|   | ||||
| @@ -20,7 +20,9 @@ class DynamicFunc(MathFunc): | ||||
|  | ||||
|     def noise_call(self, x, timestamp, std): | ||||
|         clean_y = self.__call__(x, timestamp) | ||||
|         if isinstance(clean_y, np.ndarray): | ||||
|         if std is None: | ||||
|             noise_y = clean_y | ||||
|         elif isinstance(clean_y, np.ndarray): | ||||
|             noise_y = clean_y + np.random.normal(scale=std, size=clean_y.shape) | ||||
|         else: | ||||
|             raise ValueError("Unkonwn type: {:}".format(type(clean_y))) | ||||
| @@ -43,7 +45,7 @@ class LinearDFunc(DynamicFunc): | ||||
|         return a * x + b | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({a} * {x} + {b})".format( | ||||
|         return "({a} * {x} + {b})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             a=self._params[0], | ||||
|             b=self._params[1], | ||||
| @@ -69,7 +71,7 @@ class QuadraticDFunc(DynamicFunc): | ||||
|         return a * x * x + b * x + c | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({a} * {x}^2 + {b} * {x} + {c})".format( | ||||
|         return "({a} * {x}^2 + {b} * {x} + {c})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             a=self._params[0], | ||||
|             b=self._params[1], | ||||
| @@ -97,6 +99,39 @@ class SinQuadraticDFunc(DynamicFunc): | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({a} * {x}^2 + {b} * {x} + {c})".format( | ||||
|             name="Sin", | ||||
|             a=self._params[0], | ||||
|             b=self._params[1], | ||||
|             c=self._params[2], | ||||
|             x=self.xstr, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class BinaryQuadraticDFunc(DynamicFunc): | ||||
|     """The dynamic quadratic function that outputs f(x) = a * x[0]^2 + b * x[1] + c >= 0. | ||||
|     The a, b, and c is a function of timestamp. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, params=None): | ||||
|         super(BinaryQuadraticDFunc, self).__init__(3, params) | ||||
|  | ||||
|     def __call__(self, x, timestamp): | ||||
|         self.check_valid() | ||||
|         a = self._params[0](timestamp) | ||||
|         b = self._params[1](timestamp) | ||||
|         c = self._params[2](timestamp) | ||||
|         convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x | ||||
|         a, b, c = convert_fn(a), convert_fn(b), convert_fn(c) | ||||
|         if isinstance(x, np.ndarray) and x.shape[-1] == 2: | ||||
|             results = a * x[..., 0] * x[..., 0] + b * x[..., 1] + c | ||||
|             return (results >= 0).astype(np.int) | ||||
|         else: | ||||
|             raise ValueError( | ||||
|                 "Either the type {:} or the shape is incorrect.".format(type(x)) | ||||
|             ) | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "({a} * {x}[0]^2 + {b} * {x}[1] + {c} >= 0)".format( | ||||
|             name=self.__class__.__name__, | ||||
|             a=self._params[0], | ||||
|             b=self._params[1], | ||||
|   | ||||
| @@ -20,6 +20,37 @@ class DynamicGenerator(abc.ABC): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|  | ||||
| class UniformDGenerator(DynamicGenerator): | ||||
|     """Generate data from the uniform distribution.""" | ||||
|  | ||||
|     def __init__(self, l_functors, r_functors): | ||||
|         super(UniformDGenerator, self).__init__() | ||||
|         self._ndim = assert_list_tuple(l_functors) | ||||
|         assert self._ndim == assert_list_tuple(r_functors) | ||||
|         self._l_functors = l_functors | ||||
|         self._r_functors = r_functors | ||||
|  | ||||
|     @property | ||||
|     def ndim(self): | ||||
|         return self._ndim | ||||
|  | ||||
|     def output_shape(self): | ||||
|         return (self._ndim,) | ||||
|  | ||||
|     def __call__(self, time, num): | ||||
|         l_list = [functor(time) for functor in self._l_functors] | ||||
|         r_list = [functor(time) for functor in self._r_functors] | ||||
|         values = [] | ||||
|         for l, r in zip(l_list, r_list): | ||||
|             values.append(np.random.uniform(low=l, high=r, size=num)) | ||||
|         return np.stack(values, axis=-1) | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({ndim} dims)".format( | ||||
|             name=self.__class__.__name__, ndim=self._ndim | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class GaussianDGenerator(DynamicGenerator): | ||||
|     """Generate data from Gaussian distribution.""" | ||||
|  | ||||
|   | ||||
| @@ -47,7 +47,7 @@ class LinearSFunc(StaticFunc): | ||||
|         return weights[0] * x + weights[1] | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({a} * {x} + {b})".format( | ||||
|         return "({a} * {x} + {b})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             a=self._params[0], | ||||
|             b=self._params[1], | ||||
| @@ -69,7 +69,7 @@ class QuadraticSFunc(StaticFunc): | ||||
|         return weights[0] * x * x + weights[1] * x + weights[2] | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({a} * {x}^2 + {b} * {x} + {c})".format( | ||||
|         return "({a} * {x}^2 + {b} * {x} + {c})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             a=self._params[0], | ||||
|             b=self._params[1], | ||||
| @@ -97,7 +97,7 @@ class CubicSFunc(StaticFunc): | ||||
|         return weights[0] * x ** 3 + weights[1] * x ** 2 + weights[2] * x + weights[3] | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({a} * {x}^3 + {b} * {x}^2 + {c} * {x} + {d})".format( | ||||
|         return "({a} * {x}^3 + {b} * {x}^2 + {c} * {x} + {d})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             a=self._params[0], | ||||
|             b=self._params[1], | ||||
| @@ -166,7 +166,7 @@ class ConstantFunc(StaticFunc): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({a})".format(name=self.__class__.__name__, a=self._params[0]) | ||||
|         return "{a}".format(name=self.__class__.__name__, a=self._params[0]) | ||||
|  | ||||
|  | ||||
| class ComposedSinSFunc(StaticFunc): | ||||
| @@ -188,7 +188,7 @@ class ComposedSinSFunc(StaticFunc): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({a} * sin({b} * {x}) + {c})".format( | ||||
|         return "({a} * sin({b} * {x}) + {c})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             a=self._params[0], | ||||
|             b=self._params[1], | ||||
| @@ -216,7 +216,7 @@ class ComposedCosSFunc(StaticFunc): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({a} * sin({b} * {x}) + {c})".format( | ||||
|         return "({a} * sin({b} * {x}) + {c})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             a=self._params[0], | ||||
|             b=self._params[1], | ||||
|   | ||||
| @@ -3,13 +3,13 @@ from .synthetic_utils import TimeStamp | ||||
| from .synthetic_env import SyntheticDEnv | ||||
| from .math_core import LinearSFunc | ||||
| from .math_core import LinearDFunc | ||||
| from .math_core import QuadraticDFunc, SinQuadraticDFunc | ||||
| from .math_core import QuadraticDFunc, SinQuadraticDFunc, BinaryQuadraticDFunc | ||||
| from .math_core import ( | ||||
|     ConstantFunc, | ||||
|     ComposedSinSFunc as SinFunc, | ||||
|     ComposedCosSFunc as CosFunc, | ||||
| ) | ||||
| from .math_core import GaussianDGenerator | ||||
| from .math_core import UniformDGenerator, GaussianDGenerator | ||||
|  | ||||
|  | ||||
| __all__ = ["TimeStamp", "SyntheticDEnv", "get_synthetic_env"] | ||||
| @@ -77,8 +77,21 @@ def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, versio | ||||
|         ) | ||||
|         dynamic_env.set_regression() | ||||
|     elif version.lower() == "v4": | ||||
|         l_generator = ConstantFunc(-2) | ||||
|         r_generator = ConstantFunc(2) | ||||
|         data_generator = UniformDGenerator([l_generator] * 2, [r_generator] * 2) | ||||
|         time_generator = TimeStamp( | ||||
|             min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode | ||||
|         ) | ||||
|         oracle_map = BinaryQuadraticDFunc( | ||||
|             params={ | ||||
|                 0: SinFunc(params={0: 1, 1: 3, 2: 0}),  # sin(3 * t) | ||||
|                 1: CosFunc(params={0: 1, 1: 6, 2: 0}),  # cos(6 * t) | ||||
|                 2: ConstantFunc(0), | ||||
|             } | ||||
|         ) | ||||
|         dynamic_env = SyntheticDEnv( | ||||
|             data_generator, oracle_map, time_generator, num_per_task, noise=0.05 | ||||
|             data_generator, oracle_map, time_generator, num_per_task, noise=None | ||||
|         ) | ||||
|         dynamic_env.set_classification(2) | ||||
|     else: | ||||
|   | ||||
| @@ -119,10 +119,15 @@ class SyntheticDEnv(data.Dataset): | ||||
|     def __call__(self, timestamp): | ||||
|         dataset = self._data_generator(timestamp, self._num_per_task) | ||||
|         targets = self._oracle_map.noise_call(dataset, timestamp, self._noise) | ||||
|         return torch.Tensor([timestamp]), ( | ||||
|             torch.Tensor(dataset), | ||||
|             torch.Tensor(targets), | ||||
|         ) | ||||
|         if isinstance(dataset, np.ndarray): | ||||
|             dataset = torch.from_numpy(dataset) | ||||
|         else: | ||||
|             dataset = torch.Tensor(dataset) | ||||
|         if isinstance(targets, np.ndarray): | ||||
|             targets = torch.from_numpy(targets) | ||||
|         else: | ||||
|             targets = torch.Tensor(targets) | ||||
|         return torch.Tensor([timestamp]), (dataset, targets) | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self._time_generator) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user