autodl-projects/xautodl/datasets/math_base_funcs.py

63 lines
1.8 KiB
Python
Raw Normal View History

2021-04-13 19:04:46 +02:00
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
2021-04-22 13:12:21 +02:00
import math
2021-04-22 14:31:20 +02:00
import abc
2021-04-22 17:32:26 +02:00
import copy
2021-04-13 19:04:46 +02:00
import numpy as np
2021-05-27 09:44:01 +02:00
class MathFunc(abc.ABC):
"""The math function -- a virtual class defining some APIs."""
2021-04-22 13:12:21 +02:00
2021-05-27 09:44:01 +02:00
def __init__(self, freedom: int, params=None, xstr="x"):
# initialize as empty
2021-04-22 14:31:20 +02:00
self._params = dict()
for i in range(freedom):
self._params[i] = None
self._freedom = freedom
2021-04-27 14:09:37 +02:00
if params is not None:
self.set(params)
2021-05-24 07:06:10 +02:00
self._xstr = str(xstr)
2021-05-27 09:44:01 +02:00
self._skip_check = True
2021-04-22 13:12:21 +02:00
2021-04-27 14:09:37 +02:00
def set(self, params):
2021-05-27 09:44:01 +02:00
for key in range(self._freedom):
param = copy.deepcopy(params[key])
self._params[key] = param
2021-04-22 13:12:21 +02:00
def check_valid(self):
2021-05-27 09:44:01 +02:00
if not self._skip_check:
for key in range(self._freedom):
value = self._params[key]
if value is None:
raise ValueError("The {:} is None".format(key))
2021-04-22 13:12:21 +02:00
2021-05-24 07:06:10 +02:00
@property
def xstr(self):
return self._xstr
def reset_xstr(self, xstr):
self._xstr = str(xstr)
2021-05-27 09:56:08 +02:00
def output_shape(self, input_shape):
return input_shape
2021-04-22 14:31:20 +02:00
@abc.abstractmethod
2021-04-22 17:32:26 +02:00
def __call__(self, x):
2021-04-22 14:31:20 +02:00
raise NotImplementedError
2021-05-27 09:44:01 +02:00
@abc.abstractmethod
def noise_call(self, x, std):
clean_y = self.__call__(x)
if isinstance(clean_y, np.ndarray):
noise_y = clean_y + np.random.normal(scale=std, size=clean_y.shape)
else:
raise ValueError("Unkonwn type: {:}".format(type(clean_y)))
return noise_y
2021-04-22 14:31:20 +02:00
def __repr__(self):
return "{name}(freedom={freedom})".format(
name=self.__class__.__name__, freedom=freedom
)