From e1818694a4ee82f584d332319b78caf9f6ec822a Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Mon, 26 Apr 2021 05:16:38 -0700 Subject: [PATCH] Reformulate Math Functions --- .github/workflows/basic_test.yml | 1 + .latent-data/NATS-Bench | 2 +- exps/LFNA/vis-synthetic.py | 14 ++-- lib/datasets/__init__.py | 6 +- lib/datasets/math_adv_funcs.py | 121 ++++++++++++++++++++++++++++++ lib/datasets/math_base_funcs.py | 59 ++++----------- lib/datasets/synthetic_env.py | 53 ++++++------- lib/datasets/synthetic_example.py | 23 +++--- lib/datasets/synthetic_utils.py | 101 +++---------------------- tests/test_math_adv.py | 52 +++++++++++++ tests/test_math_base.py | 41 ++++++++++ tests/test_super_container.py | 2 +- tests/test_synthetic_env.py | 8 +- tests/test_synthetic_utils.py | 79 +++---------------- 14 files changed, 308 insertions(+), 254 deletions(-) create mode 100644 lib/datasets/math_adv_funcs.py create mode 100644 tests/test_math_adv.py create mode 100644 tests/test_math_base.py diff --git a/.github/workflows/basic_test.yml b/.github/workflows/basic_test.yml index 9180687..0919eab 100644 --- a/.github/workflows/basic_test.yml +++ b/.github/workflows/basic_test.yml @@ -56,5 +56,6 @@ jobs: python -m pip install parameterized python -m pip install torch torchvision python --version + python -m pytest ./tests/test_math*.py -s python -m pytest ./tests/test_synthetic*.py -s shell: bash diff --git a/.latent-data/NATS-Bench b/.latent-data/NATS-Bench index 47de7e1..8756c33 160000 --- a/.latent-data/NATS-Bench +++ b/.latent-data/NATS-Bench @@ -1 +1 @@ -Subproject commit 47de7e1508536512ece82e0add082e0547cc7996 +Subproject commit 8756c33d85b8c9d4031ded28dcbb50750bc886be diff --git a/exps/LFNA/vis-synthetic.py b/exps/LFNA/vis-synthetic.py index b9a5f18..7eefc0a 100644 --- a/exps/LFNA/vis-synthetic.py +++ b/exps/LFNA/vis-synthetic.py @@ -31,10 +31,10 @@ from datasets.synthetic_example import create_example_v1 from utils.temp_sync import optimize_fn, evaluate_fn -def draw_multi_fig(save_dir, timestamp, scatter_list, fig_title=None): +def draw_multi_fig(save_dir, timestamp, scatter_list, wh, fig_title=None): save_path = save_dir / "{:04d}".format(timestamp) # print('Plot the figure at timestamp-{:} into {:}'.format(timestamp, save_path)) - dpi, width, height = 40, 2000, 1300 + dpi, width, height = 40, wh[0], wh[1] figsize = width / float(dpi), height / float(dpi) LabelSize, LegendFontsize, font_gap = 80, 80, 5 @@ -61,8 +61,7 @@ def draw_multi_fig(save_dir, timestamp, scatter_list, fig_title=None): tick.label.set_rotation(10) for tick in cur_ax.yaxis.get_major_ticks(): tick.label.set_fontsize(LabelSize - font_gap) - - plt.legend(loc=1, fontsize=LegendFontsize) + plt.legend(loc=1, fontsize=LegendFontsize) fig.savefig(str(save_path) + ".pdf", dpi=dpi, bbox_inches="tight", format="pdf") fig.savefig(str(save_path) + ".png", dpi=dpi, bbox_inches="tight", format="png") plt.close("all") @@ -115,18 +114,19 @@ def compare_cl(save_dir): "color": "r", "s": 10, "xlim": (-6, 6 + timestamp * 0.2), - "ylim": (-200, 40), + "ylim": (-40, 40), "alpha": 0.99, "label": "Continual Learning", } ) draw_multi_fig( - save_dir, timestamp, scatter_list, "Timestamp={:03d}".format(timestamp) + save_dir, timestamp, scatter_list, + wh=(2000, 1300), fig_title="Timestamp={:03d}".format(timestamp) ) print("Save all figures into {:}".format(save_dir)) save_dir = save_dir.resolve() - cmd = "ffmpeg -y -i {xdir}/%04d.png -pix_fmt yuv420p -vf fps=2 -vf scale=1500:1000 -vb 5000k {xdir}/vis.mp4".format( + cmd = "ffmpeg -y -i {xdir}/%04d.png -pix_fmt yuv420p -vf fps=2 -vf scale=2000:1300 -vb 5000k {xdir}/vis.mp4".format( xdir=save_dir ) os.system(cmd) diff --git a/lib/datasets/__init__.py b/lib/datasets/__init__.py index 3b5800b..62e79cc 100644 --- a/lib/datasets/__init__.py +++ b/lib/datasets/__init__.py @@ -5,6 +5,8 @@ from .get_dataset_with_transform import get_datasets, get_nas_search_loaders from .SearchDatasetWrap import SearchDataset from .math_base_funcs import QuadraticFunc, CubicFunc, QuarticFunc -from .math_base_funcs import DynamicQuadraticFunc -from .synthetic_utils import SinGenerator, ConstantGenerator +from .math_adv_funcs import DynamicQuadraticFunc, ConstantFunc +from .math_adv_funcs import ComposedSinFunc + +from .synthetic_utils import TimeStamp from .synthetic_env import SyntheticDEnv diff --git a/lib/datasets/math_adv_funcs.py b/lib/datasets/math_adv_funcs.py new file mode 100644 index 0000000..4315258 --- /dev/null +++ b/lib/datasets/math_adv_funcs.py @@ -0,0 +1,121 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # +##################################################### +import math +import abc +import copy +import numpy as np +from typing import Optional +import torch +import torch.utils.data as data + +from .math_base_funcs import FitFunc +from .math_base_funcs import QuadraticFunc +from .math_base_funcs import QuarticFunc + + +class DynamicQuadraticFunc(FitFunc): + """The dynamic quadratic function that outputs f(x) = a * x^2 + b * x + c. + The a, b, and c is a function of timestamp. + """ + + def __init__(self, list_of_points=None): + super(DynamicQuadraticFunc, self).__init__(3, list_of_points) + self._timestamp = None + + def __call__(self, x, timestamp=None): + self.check_valid() + if timestamp is None: + timestamp = self._timestamp + a = self._params[0](timestamp) + b = self._params[1](timestamp) + c = self._params[2](timestamp) + convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x + a, b, c = convert_fn(a), convert_fn(b), convert_fn(c) + return a * x * x + b * x + c + + def _getitem(self, x, weights): + raise NotImplementedError + + def set_timestamp(self, timestamp): + self._timestamp = timestamp + + def __repr__(self): + return "{name}({a} * x^2 + {b} * x + {c})".format( + name=self.__class__.__name__, + a=self._params[0], + b=self._params[1], + c=self._params[2], + ) + + +class ConstantFunc(FitFunc): + """The constant function: f(x) = c.""" + + def __init__(self, constant=None): + param = dict() + param[0] = constant + super(ConstantFunc, self).__init__(0, None, param) + + def __call__(self, x): + self.check_valid() + return self._params[0] + + def fit(self, **kwargs): + raise NotImplementedError + + def _getitem(self, x, weights): + raise NotImplementedError + + def __repr__(self): + return "{name}({a})".format(name=self.__class__.__name__, a=self._params[0]) + + +class ComposedSinFunc(FitFunc): + """The composed sin function that outputs: + f(x) = amplitude-scale-of(x) * sin( period-phase-shift-of(x) ) + - the amplitude scale is a quadratic function of x + - the period-phase-shift is another quadratic function of x + """ + + def __init__(self, **kwargs): + super(ComposedSinFunc, self).__init__(0, None) + self.fit(**kwargs) + + def __call__(self, x): + self.check_valid() + scale = self._params["amplitude_scale"](x) + period_phase = self._params["period_phase_shift"](x) + return scale * math.sin(period_phase) + + def fit(self, **kwargs): + num_sin_phase = kwargs.get("num_sin_phase", 7) + min_amplitude = kwargs.get("min_amplitude", 1) + max_amplitude = kwargs.get("max_amplitude", 4) + phase_shift = kwargs.get("phase_shift", 0.0) + # create parameters + amplitude_scale = QuadraticFunc( + [(0, min_amplitude), (0.5, max_amplitude), (1, min_amplitude)] + ) + fitting_data = [] + temp_max_scalar = 2 ** (num_sin_phase - 1) + for i in range(num_sin_phase): + value = (2 ** i) / temp_max_scalar + next_value = (2 ** (i + 1)) / temp_max_scalar + for _phase in (0, 0.25, 0.5, 0.75): + inter_value = value + (next_value - value) * _phase + fitting_data.append((inter_value, math.pi * (2 * i + _phase))) + period_phase_shift = QuarticFunc(fitting_data) + self.set( + dict(amplitude_scale=amplitude_scale, period_phase_shift=period_phase_shift) + ) + + def _getitem(self, x, weights): + raise NotImplementedError + + def __repr__(self): + return "{name}({amplitude_scale} * sin({period_phase_shift}))".format( + name=self.__class__.__name__, + amplitude_scale=self._params["amplitude_scale"], + period_phase_shift=self._params["period_phase_shift"], + ) diff --git a/lib/datasets/math_base_funcs.py b/lib/datasets/math_base_funcs.py index 177fe3f..cab66a2 100644 --- a/lib/datasets/math_base_funcs.py +++ b/lib/datasets/math_base_funcs.py @@ -13,13 +13,17 @@ import torch.utils.data as data class FitFunc(abc.ABC): """The fit function that outputs f(x) = a * x^2 + b * x + c.""" - def __init__(self, freedom: int, list_of_points=None): + def __init__(self, freedom: int, list_of_points=None, _params=None): self._params = dict() for i in range(freedom): self._params[i] = None self._freedom = freedom + if list_of_points is not None and _params is not None: + raise ValueError("list_of_points and _params can not be set simultaneously") if list_of_points is not None: - self.fit(list_of_points) + self.fit(list_of_points=list_of_points) + if _params is not None: + self.set(_params) def set(self, _params): self._params = copy.deepcopy(_params) @@ -45,13 +49,13 @@ class FitFunc(abc.ABC): def _getitem(self, x): raise NotImplementedError - def fit( - self, - list_of_points, - max_iter=900, - lr_max=1.0, - verbose=False, - ): + def fit(self, **kwargs): + list_of_points = kwargs["list_of_points"] + max_iter, lr_max, verbose = ( + kwargs.get("max_iter", 900), + kwargs.get("lr_max", 1.0), + kwargs.get("verbose", False), + ) with torch.no_grad(): data = torch.Tensor(list_of_points).type(torch.float32) assert data.ndim == 2 and data.size(1) == 2, "Invalid shape : {:}".format( @@ -113,7 +117,7 @@ class QuadraticFunc(FitFunc): return weights[0] * x * x + weights[1] * x + weights[2] def __repr__(self): - return "{name}(y = {a} * x^2 + {b} * x + {c})".format( + return "{name}({a} * x^2 + {b} * x + {c})".format( name=self.__class__.__name__, a=self._params[0], b=self._params[1], @@ -140,7 +144,7 @@ class CubicFunc(FitFunc): return weights[0] * x ** 3 + weights[1] * x ** 2 + weights[2] * x + weights[3] def __repr__(self): - return "{name}(y = {a} * x^3 + {b} * x^2 + {c} * x + {d})".format( + return "{name}({a} * x^3 + {b} * x^2 + {c} * x + {d})".format( name=self.__class__.__name__, a=self._params[0], b=self._params[1], @@ -175,7 +179,7 @@ class QuarticFunc(FitFunc): ) def __repr__(self): - return "{name}(y = {a} * x^4 + {b} * x^3 + {c} * x^2 + {d} * x + {e})".format( + return "{name}({a} * x^4 + {b} * x^3 + {c} * x^2 + {d} * x + {e})".format( name=self.__class__.__name__, a=self._params[0], b=self._params[1], @@ -183,34 +187,3 @@ class QuarticFunc(FitFunc): d=self._params[3], e=self._params[3], ) - - -class DynamicQuadraticFunc(FitFunc): - """The dynamic quadratic function that outputs f(x) = a * x^2 + b * x + c.""" - - def __init__(self, list_of_points=None): - super(DynamicQuadraticFunc, self).__init__(3, list_of_points) - self._timestamp = None - - def __call__(self, x): - self.check_valid() - a = self._params[0][self._timestamp] - b = self._params[1][self._timestamp] - c = self._params[2][self._timestamp] - convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x - a, b, c = convert_fn(a), convert_fn(b), convert_fn(c) - return a * x * x + b * x + c - - def _getitem(self, x, weights): - raise NotImplementedError - - def set_timestamp(self, timestamp): - self._timestamp = timestamp - - def __repr__(self): - return "{name}(y = {a} * x^2 + {b} * x + {c})".format( - name=self.__class__.__name__, - a=self._params[0], - b=self._params[1], - c=self._params[2], - ) diff --git a/lib/datasets/synthetic_env.py b/lib/datasets/synthetic_env.py index db396e4..9d64f3b 100644 --- a/lib/datasets/synthetic_env.py +++ b/lib/datasets/synthetic_env.py @@ -4,45 +4,42 @@ import math import abc import numpy as np -from typing import List, Optional +from typing import List, Optional, Dict import torch import torch.utils.data as data -from .synthetic_utils import UnifiedSplit +from .synthetic_utils import TimeStamp -class SyntheticDEnv(UnifiedSplit, data.Dataset): +class SyntheticDEnv(data.Dataset): """The synethtic dynamic environment.""" def __init__( self, - mean_generators: List[data.Dataset], - cov_generators: List[List[data.Dataset]], + mean_functors: List[data.Dataset], + cov_functors: List[List[data.Dataset]], num_per_task: int = 5000, + time_stamp_config: Optional[Dict] = None, mode: Optional[str] = None, ): - self._ndim = len(mean_generators) + self._ndim = len(mean_functors) assert self._ndim == len( - cov_generators - ), "length does not match {:} vs. {:}".format(self._ndim, len(cov_generators)) - for cov_generator in cov_generators: + cov_functors + ), "length does not match {:} vs. {:}".format(self._ndim, len(cov_functors)) + for cov_functor in cov_functors: assert self._ndim == len( - cov_generator - ), "length does not match {:} vs. {:}".format( - self._ndim, len(cov_generator) - ) + cov_functor + ), "length does not match {:} vs. {:}".format(self._ndim, len(cov_functor)) self._num_per_task = num_per_task - self._total_num = len(mean_generators[0]) - for mean_generator in mean_generators: - assert self._total_num == len(mean_generator) - for cov_generator in cov_generators: - for cov_g in cov_generator: - assert self._total_num == len(cov_g) + if time_stamp_config is None: + time_stamp_config = dict(mode=mode) + else: + time_stamp_config["mode"] = mode - self._mean_generators = mean_generators - self._cov_generators = cov_generators + self._timestamp_generator = TimeStamp(**time_stamp_config) - UnifiedSplit.__init__(self, self._total_num, mode) + self._mean_functors = mean_functors + self._cov_functors = cov_functors def __iter__(self): self._iter_num = 0 @@ -56,11 +53,11 @@ class SyntheticDEnv(UnifiedSplit, data.Dataset): def __getitem__(self, index): assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self)) - index = self._indexes[index] - mean_list = [generator[index][-1] for generator in self._mean_generators] + index, timestamp = self._timestamp_generator[index] + mean_list = [functor(timestamp) for functor in self._mean_functors] cov_matrix = [ - [cov_gen[index][-1] for cov_gen in cov_generator] - for cov_generator in self._cov_generators + [cov_gen(timestamp) for cov_gen in cov_functor] + for cov_functor in self._cov_functors ] dataset = np.random.multivariate_normal( @@ -69,13 +66,13 @@ class SyntheticDEnv(UnifiedSplit, data.Dataset): return index, torch.Tensor(dataset) def __len__(self): - return len(self._indexes) + return len(self._timestamp_generator) def __repr__(self): return "{name}({cur_num:}/{total} elements, ndim={ndim}, num_per_task={num_per_task})".format( name=self.__class__.__name__, cur_num=len(self), - total=self._total_num, + total=len(self._timestamp_generator), ndim=self._ndim, num_per_task=self._num_per_task, ) diff --git a/lib/datasets/synthetic_example.py b/lib/datasets/synthetic_example.py index 55e0fa4..0fd780c 100644 --- a/lib/datasets/synthetic_example.py +++ b/lib/datasets/synthetic_example.py @@ -3,25 +3,30 @@ ##################################################### from .math_base_funcs import DynamicQuadraticFunc -from .synthetic_utils import ConstantGenerator, SinGenerator +from .math_adv_funcs import ConstantFunc, ComposedSinFunc from .synthetic_env import SyntheticDEnv def create_example_v1(timestamps=50, num_per_task=5000): - mean_generator = SinGenerator(num=timestamps) - std_generator = SinGenerator(num=timestamps, min_amplitude=0.5, max_amplitude=0.5) + mean_generator = ComposedSinFunc() + std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=0.5) std_generator.set_transform(lambda x: x + 1) + dynamic_env = SyntheticDEnv( - [mean_generator], [[std_generator]], num_per_task=num_per_task + [mean_generator], + [[std_generator]], + num_per_task=num_per_task, + time_stamp_config=dict(num=timestamps), ) + function = DynamicQuadraticFunc() function_param = dict() - function_param[0] = SinGenerator( - num=timestamps, num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0 + function_param[0] = ComposedSinFunc( + num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0 ) - function_param[1] = ConstantGenerator(constant=0.9) - function_param[2] = SinGenerator( - num=timestamps, num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9 + function_param[1] = ConstantFunc(constant=0.9) + function_param[2] = ComposedSinFunc( + num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9 ) function.set(function_param) return dynamic_env, function diff --git a/lib/datasets/synthetic_utils.py b/lib/datasets/synthetic_utils.py index ce286bc..ced9119 100644 --- a/lib/datasets/synthetic_utils.py +++ b/lib/datasets/synthetic_utils.py @@ -8,8 +8,6 @@ from typing import Optional import torch import torch.utils.data as data -from .math_base_funcs import QuadraticFunc, QuarticFunc - class UnifiedSplit: """A class to unify the split strategy.""" @@ -39,102 +37,20 @@ class UnifiedSplit: return self._mode -class SinGenerator(UnifiedSplit, data.Dataset): - """The synethtic generator for the dynamically changing environment. - - - x in [0, 1] - - y = amplitude-scale-of(x) * sin( period-phase-shift-of(x) ) - - where - - the amplitude scale is a quadratic function of x - - the period-phase-shift is another quadratic function of x - - """ +class TimeStamp(UnifiedSplit, data.Dataset): + """The timestamp dataset.""" def __init__( self, + min_timestamp: float = 0.0, + max_timestamp: float = 1.0, num: int = 100, - num_sin_phase: int = 7, - min_amplitude: float = 1, - max_amplitude: float = 4, - phase_shift: float = 0, mode: Optional[str] = None, ): - self._amplitude_scale = QuadraticFunc( - [(0, min_amplitude), (0.5, max_amplitude), (1, min_amplitude)] - ) - - self._num_sin_phase = num_sin_phase - self._interval = 1.0 / (float(num) - 1) + self._min_timestamp = min_timestamp + self._max_timestamp = max_timestamp + self._interval = (max_timestamp - min_timestamp) / (float(num) - 1) self._total_num = num - - fitting_data = [] - temp_max_scalar = 2 ** (num_sin_phase - 1) - for i in range(num_sin_phase): - value = (2 ** i) / temp_max_scalar - next_value = (2 ** (i + 1)) / temp_max_scalar - for _phase in (0, 0.25, 0.5, 0.75): - inter_value = value + (next_value - value) * _phase - fitting_data.append((inter_value, math.pi * (2 * i + _phase))) - self._period_phase_shift = QuarticFunc(fitting_data) - UnifiedSplit.__init__(self, self._total_num, mode) - self._transform = None - - def __iter__(self): - self._iter_num = 0 - return self - - def __next__(self): - if self._iter_num >= len(self): - raise StopIteration - self._iter_num += 1 - return self.__getitem__(self._iter_num - 1) - - def set_transform(self, transform): - self._transform = transform - - def transform(self, x): - if self._transform is None: - return x - else: - return self._transform(x) - - def __getitem__(self, index): - assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self)) - index = self._indexes[index] - position = self._interval * index - value = self._amplitude_scale(position) * math.sin( - self._period_phase_shift(position) - ) - return index, position, self.transform(value) - - def __len__(self): - return len(self._indexes) - - def __repr__(self): - return ( - "{name}({cur_num:}/{total} elements,\n" - "amplitude={amplitude},\n" - "period_phase_shift={period_phase_shift})".format( - name=self.__class__.__name__, - cur_num=len(self), - total=self._total_num, - amplitude=self._amplitude_scale, - period_phase_shift=self._period_phase_shift, - ) - ) - - -class ConstantGenerator(UnifiedSplit, data.Dataset): - """The constant generator.""" - - def __init__( - self, - num: int = 100, - constant: float = 0.1, - mode: Optional[str] = None, - ): - self._total_num = num - self._constant = constant UnifiedSplit.__init__(self, self._total_num, mode) def __iter__(self): @@ -150,7 +66,8 @@ class ConstantGenerator(UnifiedSplit, data.Dataset): def __getitem__(self, index): assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self)) index = self._indexes[index] - return index, index, self._constant + timestamp = self._min_timestamp + self._interval * index + return index, timestamp def __len__(self): return len(self._indexes) diff --git a/tests/test_math_adv.py b/tests/test_math_adv.py new file mode 100644 index 0000000..d9ac1d0 --- /dev/null +++ b/tests/test_math_adv.py @@ -0,0 +1,52 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # +##################################################### +# pytest tests/test_math_adv.py -s # +##################################################### +import sys, random +import unittest +import pytest +from pathlib import Path + +lib_dir = (Path(__file__).parent / ".." / "lib").resolve() +print("library path: {:}".format(lib_dir)) +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) + +from datasets import QuadraticFunc +from datasets import ConstantFunc +from datasets import DynamicQuadraticFunc +from datasets import ComposedSinFunc + + +class TestConstantFunc(unittest.TestCase): + """Test the constant function.""" + + def test_simple(self): + function = ConstantFunc(0.1) + for i in range(100): + assert function(i) == 0.1 + + +class TestDynamicFunc(unittest.TestCase): + """Test DynamicQuadraticFunc.""" + + def test_simple(self): + timestamps = 30 + function = DynamicQuadraticFunc() + function_param = dict() + function_param[0] = ComposedSinFunc( + num=timestamps, num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0 + ) + function_param[1] = ConstantFunc(constant=0.9) + function_param[2] = ComposedSinFunc( + num=timestamps, num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9 + ) + function.set(function_param) + print(function) + + with self.assertRaises(TypeError) as context: + function(0) + + function.set_timestamp(1) + print(function(2)) diff --git a/tests/test_math_base.py b/tests/test_math_base.py new file mode 100644 index 0000000..5512fd5 --- /dev/null +++ b/tests/test_math_base.py @@ -0,0 +1,41 @@ +##################################################### +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # +##################################################### +# pytest tests/test_math_base.py -s # +##################################################### +import sys, random +import unittest +import pytest +from pathlib import Path + +lib_dir = (Path(__file__).parent / ".." / "lib").resolve() +print("library path: {:}".format(lib_dir)) +if str(lib_dir) not in sys.path: + sys.path.insert(0, str(lib_dir)) + +from datasets import QuadraticFunc + + +class TestQuadraticFunc(unittest.TestCase): + """Test the quadratic function.""" + + def test_simple(self): + function = QuadraticFunc([[0, 1], [0.5, 4], [1, 1]]) + print(function) + for x in (0, 0.5, 1): + print("f({:})={:}".format(x, function(x))) + thresh = 0.2 + self.assertTrue(abs(function(0) - 1) < thresh) + self.assertTrue(abs(function(0.5) - 4) < thresh) + self.assertTrue(abs(function(1) - 1) < thresh) + + def test_none(self): + function = QuadraticFunc() + function.fit( + list_of_points=[[0, 1], [0.5, 4], [1, 1]], max_iter=3000, verbose=False + ) + print(function) + thresh = 0.15 + self.assertTrue(abs(function(0) - 1) < thresh) + self.assertTrue(abs(function(0.5) - 4) < thresh) + self.assertTrue(abs(function(1) - 1) < thresh) diff --git a/tests/test_super_container.py b/tests/test_super_container.py index affa107..8b8aa70 100644 --- a/tests/test_super_container.py +++ b/tests/test_super_container.py @@ -79,7 +79,7 @@ def test_super_sequential_v1(): super_core.SuperSimpleNorm(1, 1), torch.nn.ReLU(), super_core.SuperLinear(10, 10), - super_core.SuperReLU() + super_core.SuperReLU(), ) inputs = torch.rand(10, 10) print(model) diff --git a/tests/test_synthetic_env.py b/tests/test_synthetic_env.py index 3db40b0..ac1fe0b 100644 --- a/tests/test_synthetic_env.py +++ b/tests/test_synthetic_env.py @@ -13,7 +13,7 @@ print("library path: {:}".format(lib_dir)) if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) -from datasets import ConstantGenerator, SinGenerator +from datasets import ConstantFunc, ComposedSinFunc from datasets import SyntheticDEnv @@ -21,10 +21,10 @@ class TestSynethicEnv(unittest.TestCase): """Test the synethtic environment.""" def test_simple(self): - mean_generator = SinGenerator() - std_generator = ConstantGenerator(constant=0.5) + mean_generator = ComposedSinFunc(constant=0.1) + std_generator = ConstantFunc(constant=0.5) - dataset = SyntheticDEnv([mean_generator], [[std_generator]]) + dataset = SyntheticDEnv([mean_generator], [[std_generator]], num_per_task=5000) print(dataset) for timestamp, tau in dataset: assert tau.shape == (5000, 1) diff --git a/tests/test_synthetic_utils.py b/tests/test_synthetic_utils.py index 17474da..2f95884 100644 --- a/tests/test_synthetic_utils.py +++ b/tests/test_synthetic_utils.py @@ -13,74 +13,19 @@ print("library path: {:}".format(lib_dir)) if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) -from datasets import QuadraticFunc -from datasets import ConstantGenerator, SinGenerator -from datasets import DynamicQuadraticFunc +from datasets import TimeStamp -class TestQuadraticFunc(unittest.TestCase): - """Test the quadratic function.""" +class TestTimeStamp(unittest.TestCase): + """Test the timestamp generator.""" def test_simple(self): - function = QuadraticFunc([[0, 1], [0.5, 4], [1, 1]]) - print(function) - for x in (0, 0.5, 1): - print("f({:})={:}".format(x, function(x))) - thresh = 0.2 - self.assertTrue(abs(function(0) - 1) < thresh) - self.assertTrue(abs(function(0.5) - 4) < thresh) - self.assertTrue(abs(function(1) - 1) < thresh) - - def test_none(self): - function = QuadraticFunc() - function.fit([[0, 1], [0.5, 4], [1, 1]], max_iter=3000, verbose=False) - print(function) - thresh = 0.15 - self.assertTrue(abs(function(0) - 1) < thresh) - self.assertTrue(abs(function(0.5) - 4) < thresh) - self.assertTrue(abs(function(1) - 1) < thresh) - - -class TestConstantGenerator(unittest.TestCase): - """Test the constant data generator.""" - - def test_simple(self): - dataset = ConstantGenerator() - for i, (idx, t, x) in enumerate(dataset): - assert i == idx, "First loop: {:} vs {:}".format(i, idx) - assert x == 0.1 - - -class TestSinGenerator(unittest.TestCase): - """Test the synethtic data generator.""" - - def test_simple(self): - dataset = SinGenerator() - for i, (idx, t, x) in enumerate(dataset): - assert i == idx, "First loop: {:} vs {:}".format(i, idx) - for i, (idx, t, x) in enumerate(dataset): - assert i == idx, "Second loop: {:} vs {:}".format(i, idx) - - -class TestDynamicFunc(unittest.TestCase): - """Test DynamicQuadraticFunc.""" - - def test_simple(self): - timestamps = 30 - function = DynamicQuadraticFunc() - function_param = dict() - function_param[0] = SinGenerator( - num=timestamps, num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0 - ) - function_param[1] = ConstantGenerator(constant=0.9) - function_param[2] = SinGenerator( - num=timestamps, num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9 - ) - function.set(function_param) - print(function) - - with self.assertRaises(TypeError) as context: - function(0) - - function.set_timestamp(1) - print(function(2)) + for mode in (None, "train", "valid", "test"): + generator = TimeStamp(0, 1) + print(generator) + for idx, (i, xtime) in enumerate(generator): + self.assertTrue(i == idx) + if idx == 0: + self.assertTrue(xtime == 0) + if idx + 1 == len(generator): + self.assertTrue(abs(xtime - 1) < 1e-8)