Reformulate Math Functions

This commit is contained in:
D-X-Y 2021-04-26 05:16:38 -07:00
parent 1980779053
commit e1818694a4
14 changed files with 308 additions and 254 deletions

View File

@ -56,5 +56,6 @@ jobs:
python -m pip install parameterized python -m pip install parameterized
python -m pip install torch torchvision python -m pip install torch torchvision
python --version python --version
python -m pytest ./tests/test_math*.py -s
python -m pytest ./tests/test_synthetic*.py -s python -m pytest ./tests/test_synthetic*.py -s
shell: bash shell: bash

@ -1 +1 @@
Subproject commit 47de7e1508536512ece82e0add082e0547cc7996 Subproject commit 8756c33d85b8c9d4031ded28dcbb50750bc886be

View File

@ -31,10 +31,10 @@ from datasets.synthetic_example import create_example_v1
from utils.temp_sync import optimize_fn, evaluate_fn from utils.temp_sync import optimize_fn, evaluate_fn
def draw_multi_fig(save_dir, timestamp, scatter_list, fig_title=None): def draw_multi_fig(save_dir, timestamp, scatter_list, wh, fig_title=None):
save_path = save_dir / "{:04d}".format(timestamp) save_path = save_dir / "{:04d}".format(timestamp)
# print('Plot the figure at timestamp-{:} into {:}'.format(timestamp, save_path)) # print('Plot the figure at timestamp-{:} into {:}'.format(timestamp, save_path))
dpi, width, height = 40, 2000, 1300 dpi, width, height = 40, wh[0], wh[1]
figsize = width / float(dpi), height / float(dpi) figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize, font_gap = 80, 80, 5 LabelSize, LegendFontsize, font_gap = 80, 80, 5
@ -61,8 +61,7 @@ def draw_multi_fig(save_dir, timestamp, scatter_list, fig_title=None):
tick.label.set_rotation(10) tick.label.set_rotation(10)
for tick in cur_ax.yaxis.get_major_ticks(): for tick in cur_ax.yaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap) tick.label.set_fontsize(LabelSize - font_gap)
plt.legend(loc=1, fontsize=LegendFontsize)
plt.legend(loc=1, fontsize=LegendFontsize)
fig.savefig(str(save_path) + ".pdf", dpi=dpi, bbox_inches="tight", format="pdf") fig.savefig(str(save_path) + ".pdf", dpi=dpi, bbox_inches="tight", format="pdf")
fig.savefig(str(save_path) + ".png", dpi=dpi, bbox_inches="tight", format="png") fig.savefig(str(save_path) + ".png", dpi=dpi, bbox_inches="tight", format="png")
plt.close("all") plt.close("all")
@ -115,18 +114,19 @@ def compare_cl(save_dir):
"color": "r", "color": "r",
"s": 10, "s": 10,
"xlim": (-6, 6 + timestamp * 0.2), "xlim": (-6, 6 + timestamp * 0.2),
"ylim": (-200, 40), "ylim": (-40, 40),
"alpha": 0.99, "alpha": 0.99,
"label": "Continual Learning", "label": "Continual Learning",
} }
) )
draw_multi_fig( draw_multi_fig(
save_dir, timestamp, scatter_list, "Timestamp={:03d}".format(timestamp) save_dir, timestamp, scatter_list,
wh=(2000, 1300), fig_title="Timestamp={:03d}".format(timestamp)
) )
print("Save all figures into {:}".format(save_dir)) print("Save all figures into {:}".format(save_dir))
save_dir = save_dir.resolve() save_dir = save_dir.resolve()
cmd = "ffmpeg -y -i {xdir}/%04d.png -pix_fmt yuv420p -vf fps=2 -vf scale=1500:1000 -vb 5000k {xdir}/vis.mp4".format( cmd = "ffmpeg -y -i {xdir}/%04d.png -pix_fmt yuv420p -vf fps=2 -vf scale=2000:1300 -vb 5000k {xdir}/vis.mp4".format(
xdir=save_dir xdir=save_dir
) )
os.system(cmd) os.system(cmd)

View File

@ -5,6 +5,8 @@ from .get_dataset_with_transform import get_datasets, get_nas_search_loaders
from .SearchDatasetWrap import SearchDataset from .SearchDatasetWrap import SearchDataset
from .math_base_funcs import QuadraticFunc, CubicFunc, QuarticFunc from .math_base_funcs import QuadraticFunc, CubicFunc, QuarticFunc
from .math_base_funcs import DynamicQuadraticFunc from .math_adv_funcs import DynamicQuadraticFunc, ConstantFunc
from .synthetic_utils import SinGenerator, ConstantGenerator from .math_adv_funcs import ComposedSinFunc
from .synthetic_utils import TimeStamp
from .synthetic_env import SyntheticDEnv from .synthetic_env import SyntheticDEnv

View File

@ -0,0 +1,121 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
import math
import abc
import copy
import numpy as np
from typing import Optional
import torch
import torch.utils.data as data
from .math_base_funcs import FitFunc
from .math_base_funcs import QuadraticFunc
from .math_base_funcs import QuarticFunc
class DynamicQuadraticFunc(FitFunc):
"""The dynamic quadratic function that outputs f(x) = a * x^2 + b * x + c.
The a, b, and c is a function of timestamp.
"""
def __init__(self, list_of_points=None):
super(DynamicQuadraticFunc, self).__init__(3, list_of_points)
self._timestamp = None
def __call__(self, x, timestamp=None):
self.check_valid()
if timestamp is None:
timestamp = self._timestamp
a = self._params[0](timestamp)
b = self._params[1](timestamp)
c = self._params[2](timestamp)
convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x
a, b, c = convert_fn(a), convert_fn(b), convert_fn(c)
return a * x * x + b * x + c
def _getitem(self, x, weights):
raise NotImplementedError
def set_timestamp(self, timestamp):
self._timestamp = timestamp
def __repr__(self):
return "{name}({a} * x^2 + {b} * x + {c})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
c=self._params[2],
)
class ConstantFunc(FitFunc):
"""The constant function: f(x) = c."""
def __init__(self, constant=None):
param = dict()
param[0] = constant
super(ConstantFunc, self).__init__(0, None, param)
def __call__(self, x):
self.check_valid()
return self._params[0]
def fit(self, **kwargs):
raise NotImplementedError
def _getitem(self, x, weights):
raise NotImplementedError
def __repr__(self):
return "{name}({a})".format(name=self.__class__.__name__, a=self._params[0])
class ComposedSinFunc(FitFunc):
"""The composed sin function that outputs:
f(x) = amplitude-scale-of(x) * sin( period-phase-shift-of(x) )
- the amplitude scale is a quadratic function of x
- the period-phase-shift is another quadratic function of x
"""
def __init__(self, **kwargs):
super(ComposedSinFunc, self).__init__(0, None)
self.fit(**kwargs)
def __call__(self, x):
self.check_valid()
scale = self._params["amplitude_scale"](x)
period_phase = self._params["period_phase_shift"](x)
return scale * math.sin(period_phase)
def fit(self, **kwargs):
num_sin_phase = kwargs.get("num_sin_phase", 7)
min_amplitude = kwargs.get("min_amplitude", 1)
max_amplitude = kwargs.get("max_amplitude", 4)
phase_shift = kwargs.get("phase_shift", 0.0)
# create parameters
amplitude_scale = QuadraticFunc(
[(0, min_amplitude), (0.5, max_amplitude), (1, min_amplitude)]
)
fitting_data = []
temp_max_scalar = 2 ** (num_sin_phase - 1)
for i in range(num_sin_phase):
value = (2 ** i) / temp_max_scalar
next_value = (2 ** (i + 1)) / temp_max_scalar
for _phase in (0, 0.25, 0.5, 0.75):
inter_value = value + (next_value - value) * _phase
fitting_data.append((inter_value, math.pi * (2 * i + _phase)))
period_phase_shift = QuarticFunc(fitting_data)
self.set(
dict(amplitude_scale=amplitude_scale, period_phase_shift=period_phase_shift)
)
def _getitem(self, x, weights):
raise NotImplementedError
def __repr__(self):
return "{name}({amplitude_scale} * sin({period_phase_shift}))".format(
name=self.__class__.__name__,
amplitude_scale=self._params["amplitude_scale"],
period_phase_shift=self._params["period_phase_shift"],
)

View File

@ -13,13 +13,17 @@ import torch.utils.data as data
class FitFunc(abc.ABC): class FitFunc(abc.ABC):
"""The fit function that outputs f(x) = a * x^2 + b * x + c.""" """The fit function that outputs f(x) = a * x^2 + b * x + c."""
def __init__(self, freedom: int, list_of_points=None): def __init__(self, freedom: int, list_of_points=None, _params=None):
self._params = dict() self._params = dict()
for i in range(freedom): for i in range(freedom):
self._params[i] = None self._params[i] = None
self._freedom = freedom self._freedom = freedom
if list_of_points is not None and _params is not None:
raise ValueError("list_of_points and _params can not be set simultaneously")
if list_of_points is not None: if list_of_points is not None:
self.fit(list_of_points) self.fit(list_of_points=list_of_points)
if _params is not None:
self.set(_params)
def set(self, _params): def set(self, _params):
self._params = copy.deepcopy(_params) self._params = copy.deepcopy(_params)
@ -45,13 +49,13 @@ class FitFunc(abc.ABC):
def _getitem(self, x): def _getitem(self, x):
raise NotImplementedError raise NotImplementedError
def fit( def fit(self, **kwargs):
self, list_of_points = kwargs["list_of_points"]
list_of_points, max_iter, lr_max, verbose = (
max_iter=900, kwargs.get("max_iter", 900),
lr_max=1.0, kwargs.get("lr_max", 1.0),
verbose=False, kwargs.get("verbose", False),
): )
with torch.no_grad(): with torch.no_grad():
data = torch.Tensor(list_of_points).type(torch.float32) data = torch.Tensor(list_of_points).type(torch.float32)
assert data.ndim == 2 and data.size(1) == 2, "Invalid shape : {:}".format( assert data.ndim == 2 and data.size(1) == 2, "Invalid shape : {:}".format(
@ -113,7 +117,7 @@ class QuadraticFunc(FitFunc):
return weights[0] * x * x + weights[1] * x + weights[2] return weights[0] * x * x + weights[1] * x + weights[2]
def __repr__(self): def __repr__(self):
return "{name}(y = {a} * x^2 + {b} * x + {c})".format( return "{name}({a} * x^2 + {b} * x + {c})".format(
name=self.__class__.__name__, name=self.__class__.__name__,
a=self._params[0], a=self._params[0],
b=self._params[1], b=self._params[1],
@ -140,7 +144,7 @@ class CubicFunc(FitFunc):
return weights[0] * x ** 3 + weights[1] * x ** 2 + weights[2] * x + weights[3] return weights[0] * x ** 3 + weights[1] * x ** 2 + weights[2] * x + weights[3]
def __repr__(self): def __repr__(self):
return "{name}(y = {a} * x^3 + {b} * x^2 + {c} * x + {d})".format( return "{name}({a} * x^3 + {b} * x^2 + {c} * x + {d})".format(
name=self.__class__.__name__, name=self.__class__.__name__,
a=self._params[0], a=self._params[0],
b=self._params[1], b=self._params[1],
@ -175,7 +179,7 @@ class QuarticFunc(FitFunc):
) )
def __repr__(self): def __repr__(self):
return "{name}(y = {a} * x^4 + {b} * x^3 + {c} * x^2 + {d} * x + {e})".format( return "{name}({a} * x^4 + {b} * x^3 + {c} * x^2 + {d} * x + {e})".format(
name=self.__class__.__name__, name=self.__class__.__name__,
a=self._params[0], a=self._params[0],
b=self._params[1], b=self._params[1],
@ -183,34 +187,3 @@ class QuarticFunc(FitFunc):
d=self._params[3], d=self._params[3],
e=self._params[3], e=self._params[3],
) )
class DynamicQuadraticFunc(FitFunc):
"""The dynamic quadratic function that outputs f(x) = a * x^2 + b * x + c."""
def __init__(self, list_of_points=None):
super(DynamicQuadraticFunc, self).__init__(3, list_of_points)
self._timestamp = None
def __call__(self, x):
self.check_valid()
a = self._params[0][self._timestamp]
b = self._params[1][self._timestamp]
c = self._params[2][self._timestamp]
convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x
a, b, c = convert_fn(a), convert_fn(b), convert_fn(c)
return a * x * x + b * x + c
def _getitem(self, x, weights):
raise NotImplementedError
def set_timestamp(self, timestamp):
self._timestamp = timestamp
def __repr__(self):
return "{name}(y = {a} * x^2 + {b} * x + {c})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
c=self._params[2],
)

View File

@ -4,45 +4,42 @@
import math import math
import abc import abc
import numpy as np import numpy as np
from typing import List, Optional from typing import List, Optional, Dict
import torch import torch
import torch.utils.data as data import torch.utils.data as data
from .synthetic_utils import UnifiedSplit from .synthetic_utils import TimeStamp
class SyntheticDEnv(UnifiedSplit, data.Dataset): class SyntheticDEnv(data.Dataset):
"""The synethtic dynamic environment.""" """The synethtic dynamic environment."""
def __init__( def __init__(
self, self,
mean_generators: List[data.Dataset], mean_functors: List[data.Dataset],
cov_generators: List[List[data.Dataset]], cov_functors: List[List[data.Dataset]],
num_per_task: int = 5000, num_per_task: int = 5000,
time_stamp_config: Optional[Dict] = None,
mode: Optional[str] = None, mode: Optional[str] = None,
): ):
self._ndim = len(mean_generators) self._ndim = len(mean_functors)
assert self._ndim == len( assert self._ndim == len(
cov_generators cov_functors
), "length does not match {:} vs. {:}".format(self._ndim, len(cov_generators)) ), "length does not match {:} vs. {:}".format(self._ndim, len(cov_functors))
for cov_generator in cov_generators: for cov_functor in cov_functors:
assert self._ndim == len( assert self._ndim == len(
cov_generator cov_functor
), "length does not match {:} vs. {:}".format( ), "length does not match {:} vs. {:}".format(self._ndim, len(cov_functor))
self._ndim, len(cov_generator)
)
self._num_per_task = num_per_task self._num_per_task = num_per_task
self._total_num = len(mean_generators[0]) if time_stamp_config is None:
for mean_generator in mean_generators: time_stamp_config = dict(mode=mode)
assert self._total_num == len(mean_generator) else:
for cov_generator in cov_generators: time_stamp_config["mode"] = mode
for cov_g in cov_generator:
assert self._total_num == len(cov_g)
self._mean_generators = mean_generators self._timestamp_generator = TimeStamp(**time_stamp_config)
self._cov_generators = cov_generators
UnifiedSplit.__init__(self, self._total_num, mode) self._mean_functors = mean_functors
self._cov_functors = cov_functors
def __iter__(self): def __iter__(self):
self._iter_num = 0 self._iter_num = 0
@ -56,11 +53,11 @@ class SyntheticDEnv(UnifiedSplit, data.Dataset):
def __getitem__(self, index): def __getitem__(self, index):
assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self)) assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self))
index = self._indexes[index] index, timestamp = self._timestamp_generator[index]
mean_list = [generator[index][-1] for generator in self._mean_generators] mean_list = [functor(timestamp) for functor in self._mean_functors]
cov_matrix = [ cov_matrix = [
[cov_gen[index][-1] for cov_gen in cov_generator] [cov_gen(timestamp) for cov_gen in cov_functor]
for cov_generator in self._cov_generators for cov_functor in self._cov_functors
] ]
dataset = np.random.multivariate_normal( dataset = np.random.multivariate_normal(
@ -69,13 +66,13 @@ class SyntheticDEnv(UnifiedSplit, data.Dataset):
return index, torch.Tensor(dataset) return index, torch.Tensor(dataset)
def __len__(self): def __len__(self):
return len(self._indexes) return len(self._timestamp_generator)
def __repr__(self): def __repr__(self):
return "{name}({cur_num:}/{total} elements, ndim={ndim}, num_per_task={num_per_task})".format( return "{name}({cur_num:}/{total} elements, ndim={ndim}, num_per_task={num_per_task})".format(
name=self.__class__.__name__, name=self.__class__.__name__,
cur_num=len(self), cur_num=len(self),
total=self._total_num, total=len(self._timestamp_generator),
ndim=self._ndim, ndim=self._ndim,
num_per_task=self._num_per_task, num_per_task=self._num_per_task,
) )

View File

@ -3,25 +3,30 @@
##################################################### #####################################################
from .math_base_funcs import DynamicQuadraticFunc from .math_base_funcs import DynamicQuadraticFunc
from .synthetic_utils import ConstantGenerator, SinGenerator from .math_adv_funcs import ConstantFunc, ComposedSinFunc
from .synthetic_env import SyntheticDEnv from .synthetic_env import SyntheticDEnv
def create_example_v1(timestamps=50, num_per_task=5000): def create_example_v1(timestamps=50, num_per_task=5000):
mean_generator = SinGenerator(num=timestamps) mean_generator = ComposedSinFunc()
std_generator = SinGenerator(num=timestamps, min_amplitude=0.5, max_amplitude=0.5) std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=0.5)
std_generator.set_transform(lambda x: x + 1) std_generator.set_transform(lambda x: x + 1)
dynamic_env = SyntheticDEnv( dynamic_env = SyntheticDEnv(
[mean_generator], [[std_generator]], num_per_task=num_per_task [mean_generator],
[[std_generator]],
num_per_task=num_per_task,
time_stamp_config=dict(num=timestamps),
) )
function = DynamicQuadraticFunc() function = DynamicQuadraticFunc()
function_param = dict() function_param = dict()
function_param[0] = SinGenerator( function_param[0] = ComposedSinFunc(
num=timestamps, num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0 num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0
) )
function_param[1] = ConstantGenerator(constant=0.9) function_param[1] = ConstantFunc(constant=0.9)
function_param[2] = SinGenerator( function_param[2] = ComposedSinFunc(
num=timestamps, num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9 num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9
) )
function.set(function_param) function.set(function_param)
return dynamic_env, function return dynamic_env, function

View File

@ -8,8 +8,6 @@ from typing import Optional
import torch import torch
import torch.utils.data as data import torch.utils.data as data
from .math_base_funcs import QuadraticFunc, QuarticFunc
class UnifiedSplit: class UnifiedSplit:
"""A class to unify the split strategy.""" """A class to unify the split strategy."""
@ -39,102 +37,20 @@ class UnifiedSplit:
return self._mode return self._mode
class SinGenerator(UnifiedSplit, data.Dataset): class TimeStamp(UnifiedSplit, data.Dataset):
"""The synethtic generator for the dynamically changing environment. """The timestamp dataset."""
- x in [0, 1]
- y = amplitude-scale-of(x) * sin( period-phase-shift-of(x) )
- where
- the amplitude scale is a quadratic function of x
- the period-phase-shift is another quadratic function of x
"""
def __init__( def __init__(
self, self,
min_timestamp: float = 0.0,
max_timestamp: float = 1.0,
num: int = 100, num: int = 100,
num_sin_phase: int = 7,
min_amplitude: float = 1,
max_amplitude: float = 4,
phase_shift: float = 0,
mode: Optional[str] = None, mode: Optional[str] = None,
): ):
self._amplitude_scale = QuadraticFunc( self._min_timestamp = min_timestamp
[(0, min_amplitude), (0.5, max_amplitude), (1, min_amplitude)] self._max_timestamp = max_timestamp
) self._interval = (max_timestamp - min_timestamp) / (float(num) - 1)
self._num_sin_phase = num_sin_phase
self._interval = 1.0 / (float(num) - 1)
self._total_num = num self._total_num = num
fitting_data = []
temp_max_scalar = 2 ** (num_sin_phase - 1)
for i in range(num_sin_phase):
value = (2 ** i) / temp_max_scalar
next_value = (2 ** (i + 1)) / temp_max_scalar
for _phase in (0, 0.25, 0.5, 0.75):
inter_value = value + (next_value - value) * _phase
fitting_data.append((inter_value, math.pi * (2 * i + _phase)))
self._period_phase_shift = QuarticFunc(fitting_data)
UnifiedSplit.__init__(self, self._total_num, mode)
self._transform = None
def __iter__(self):
self._iter_num = 0
return self
def __next__(self):
if self._iter_num >= len(self):
raise StopIteration
self._iter_num += 1
return self.__getitem__(self._iter_num - 1)
def set_transform(self, transform):
self._transform = transform
def transform(self, x):
if self._transform is None:
return x
else:
return self._transform(x)
def __getitem__(self, index):
assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self))
index = self._indexes[index]
position = self._interval * index
value = self._amplitude_scale(position) * math.sin(
self._period_phase_shift(position)
)
return index, position, self.transform(value)
def __len__(self):
return len(self._indexes)
def __repr__(self):
return (
"{name}({cur_num:}/{total} elements,\n"
"amplitude={amplitude},\n"
"period_phase_shift={period_phase_shift})".format(
name=self.__class__.__name__,
cur_num=len(self),
total=self._total_num,
amplitude=self._amplitude_scale,
period_phase_shift=self._period_phase_shift,
)
)
class ConstantGenerator(UnifiedSplit, data.Dataset):
"""The constant generator."""
def __init__(
self,
num: int = 100,
constant: float = 0.1,
mode: Optional[str] = None,
):
self._total_num = num
self._constant = constant
UnifiedSplit.__init__(self, self._total_num, mode) UnifiedSplit.__init__(self, self._total_num, mode)
def __iter__(self): def __iter__(self):
@ -150,7 +66,8 @@ class ConstantGenerator(UnifiedSplit, data.Dataset):
def __getitem__(self, index): def __getitem__(self, index):
assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self)) assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self))
index = self._indexes[index] index = self._indexes[index]
return index, index, self._constant timestamp = self._min_timestamp + self._interval * index
return index, timestamp
def __len__(self): def __len__(self):
return len(self._indexes) return len(self._indexes)

52
tests/test_math_adv.py Normal file
View File

@ -0,0 +1,52 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
# pytest tests/test_math_adv.py -s #
#####################################################
import sys, random
import unittest
import pytest
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / "lib").resolve()
print("library path: {:}".format(lib_dir))
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from datasets import QuadraticFunc
from datasets import ConstantFunc
from datasets import DynamicQuadraticFunc
from datasets import ComposedSinFunc
class TestConstantFunc(unittest.TestCase):
"""Test the constant function."""
def test_simple(self):
function = ConstantFunc(0.1)
for i in range(100):
assert function(i) == 0.1
class TestDynamicFunc(unittest.TestCase):
"""Test DynamicQuadraticFunc."""
def test_simple(self):
timestamps = 30
function = DynamicQuadraticFunc()
function_param = dict()
function_param[0] = ComposedSinFunc(
num=timestamps, num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0
)
function_param[1] = ConstantFunc(constant=0.9)
function_param[2] = ComposedSinFunc(
num=timestamps, num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9
)
function.set(function_param)
print(function)
with self.assertRaises(TypeError) as context:
function(0)
function.set_timestamp(1)
print(function(2))

41
tests/test_math_base.py Normal file
View File

@ -0,0 +1,41 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
# pytest tests/test_math_base.py -s #
#####################################################
import sys, random
import unittest
import pytest
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / "lib").resolve()
print("library path: {:}".format(lib_dir))
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from datasets import QuadraticFunc
class TestQuadraticFunc(unittest.TestCase):
"""Test the quadratic function."""
def test_simple(self):
function = QuadraticFunc([[0, 1], [0.5, 4], [1, 1]])
print(function)
for x in (0, 0.5, 1):
print("f({:})={:}".format(x, function(x)))
thresh = 0.2
self.assertTrue(abs(function(0) - 1) < thresh)
self.assertTrue(abs(function(0.5) - 4) < thresh)
self.assertTrue(abs(function(1) - 1) < thresh)
def test_none(self):
function = QuadraticFunc()
function.fit(
list_of_points=[[0, 1], [0.5, 4], [1, 1]], max_iter=3000, verbose=False
)
print(function)
thresh = 0.15
self.assertTrue(abs(function(0) - 1) < thresh)
self.assertTrue(abs(function(0.5) - 4) < thresh)
self.assertTrue(abs(function(1) - 1) < thresh)

View File

@ -79,7 +79,7 @@ def test_super_sequential_v1():
super_core.SuperSimpleNorm(1, 1), super_core.SuperSimpleNorm(1, 1),
torch.nn.ReLU(), torch.nn.ReLU(),
super_core.SuperLinear(10, 10), super_core.SuperLinear(10, 10),
super_core.SuperReLU() super_core.SuperReLU(),
) )
inputs = torch.rand(10, 10) inputs = torch.rand(10, 10)
print(model) print(model)

View File

@ -13,7 +13,7 @@ print("library path: {:}".format(lib_dir))
if str(lib_dir) not in sys.path: if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir)) sys.path.insert(0, str(lib_dir))
from datasets import ConstantGenerator, SinGenerator from datasets import ConstantFunc, ComposedSinFunc
from datasets import SyntheticDEnv from datasets import SyntheticDEnv
@ -21,10 +21,10 @@ class TestSynethicEnv(unittest.TestCase):
"""Test the synethtic environment.""" """Test the synethtic environment."""
def test_simple(self): def test_simple(self):
mean_generator = SinGenerator() mean_generator = ComposedSinFunc(constant=0.1)
std_generator = ConstantGenerator(constant=0.5) std_generator = ConstantFunc(constant=0.5)
dataset = SyntheticDEnv([mean_generator], [[std_generator]]) dataset = SyntheticDEnv([mean_generator], [[std_generator]], num_per_task=5000)
print(dataset) print(dataset)
for timestamp, tau in dataset: for timestamp, tau in dataset:
assert tau.shape == (5000, 1) assert tau.shape == (5000, 1)

View File

@ -13,74 +13,19 @@ print("library path: {:}".format(lib_dir))
if str(lib_dir) not in sys.path: if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir)) sys.path.insert(0, str(lib_dir))
from datasets import QuadraticFunc from datasets import TimeStamp
from datasets import ConstantGenerator, SinGenerator
from datasets import DynamicQuadraticFunc
class TestQuadraticFunc(unittest.TestCase): class TestTimeStamp(unittest.TestCase):
"""Test the quadratic function.""" """Test the timestamp generator."""
def test_simple(self): def test_simple(self):
function = QuadraticFunc([[0, 1], [0.5, 4], [1, 1]]) for mode in (None, "train", "valid", "test"):
print(function) generator = TimeStamp(0, 1)
for x in (0, 0.5, 1): print(generator)
print("f({:})={:}".format(x, function(x))) for idx, (i, xtime) in enumerate(generator):
thresh = 0.2 self.assertTrue(i == idx)
self.assertTrue(abs(function(0) - 1) < thresh) if idx == 0:
self.assertTrue(abs(function(0.5) - 4) < thresh) self.assertTrue(xtime == 0)
self.assertTrue(abs(function(1) - 1) < thresh) if idx + 1 == len(generator):
self.assertTrue(abs(xtime - 1) < 1e-8)
def test_none(self):
function = QuadraticFunc()
function.fit([[0, 1], [0.5, 4], [1, 1]], max_iter=3000, verbose=False)
print(function)
thresh = 0.15
self.assertTrue(abs(function(0) - 1) < thresh)
self.assertTrue(abs(function(0.5) - 4) < thresh)
self.assertTrue(abs(function(1) - 1) < thresh)
class TestConstantGenerator(unittest.TestCase):
"""Test the constant data generator."""
def test_simple(self):
dataset = ConstantGenerator()
for i, (idx, t, x) in enumerate(dataset):
assert i == idx, "First loop: {:} vs {:}".format(i, idx)
assert x == 0.1
class TestSinGenerator(unittest.TestCase):
"""Test the synethtic data generator."""
def test_simple(self):
dataset = SinGenerator()
for i, (idx, t, x) in enumerate(dataset):
assert i == idx, "First loop: {:} vs {:}".format(i, idx)
for i, (idx, t, x) in enumerate(dataset):
assert i == idx, "Second loop: {:} vs {:}".format(i, idx)
class TestDynamicFunc(unittest.TestCase):
"""Test DynamicQuadraticFunc."""
def test_simple(self):
timestamps = 30
function = DynamicQuadraticFunc()
function_param = dict()
function_param[0] = SinGenerator(
num=timestamps, num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0
)
function_param[1] = ConstantGenerator(constant=0.9)
function_param[2] = SinGenerator(
num=timestamps, num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9
)
function.set(function_param)
print(function)
with self.assertRaises(TypeError) as context:
function(0)
function.set_timestamp(1)
print(function(2))