87 lines
		
	
	
		
			2.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			87 lines
		
	
	
		
			2.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #####################################################
 | |
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
 | |
| #####################################################
 | |
| import math
 | |
| import abc
 | |
| import copy
 | |
| import numpy as np
 | |
| from typing import Optional
 | |
| import torch
 | |
| import torch.utils.data as data
 | |
| 
 | |
| from .math_base_funcs import FitFunc
 | |
| from .math_base_funcs import QuadraticFunc
 | |
| from .math_base_funcs import QuarticFunc
 | |
| 
 | |
| 
 | |
| class ConstantFunc(FitFunc):
 | |
|     """The constant function: f(x) = c."""
 | |
| 
 | |
|     def __init__(self, constant=None):
 | |
|         param = dict()
 | |
|         param[0] = constant
 | |
|         super(ConstantFunc, self).__init__(0, None, param)
 | |
| 
 | |
|     def __call__(self, x):
 | |
|         self.check_valid()
 | |
|         return self._params[0]
 | |
| 
 | |
|     def fit(self, **kwargs):
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     def _getitem(self, x, weights):
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     def __repr__(self):
 | |
|         return "{name}({a})".format(name=self.__class__.__name__, a=self._params[0])
 | |
| 
 | |
| 
 | |
| class ComposedSinFunc(FitFunc):
 | |
|     """The composed sin function that outputs:
 | |
|       f(x) = amplitude-scale-of(x) * sin( period-phase-shift-of(x) )
 | |
|     - the amplitude scale is a quadratic function of x
 | |
|     - the period-phase-shift is another quadratic function of x
 | |
|     """
 | |
| 
 | |
|     def __init__(self, **kwargs):
 | |
|         super(ComposedSinFunc, self).__init__(0, None)
 | |
|         self.fit(**kwargs)
 | |
| 
 | |
|     def __call__(self, x):
 | |
|         self.check_valid()
 | |
|         scale = self._params["amplitude_scale"](x)
 | |
|         period_phase = self._params["period_phase_shift"](x)
 | |
|         return scale * math.sin(period_phase)
 | |
| 
 | |
|     def fit(self, **kwargs):
 | |
|         num_sin_phase = kwargs.get("num_sin_phase", 7)
 | |
|         min_amplitude = kwargs.get("min_amplitude", 1)
 | |
|         max_amplitude = kwargs.get("max_amplitude", 4)
 | |
|         phase_shift = kwargs.get("phase_shift", 0.0)
 | |
|         # create parameters
 | |
|         amplitude_scale = QuadraticFunc(
 | |
|             [(0, min_amplitude), (0.5, max_amplitude), (1, min_amplitude)]
 | |
|         )
 | |
|         fitting_data = []
 | |
|         temp_max_scalar = 2 ** (num_sin_phase - 1)
 | |
|         for i in range(num_sin_phase):
 | |
|             value = (2 ** i) / temp_max_scalar
 | |
|             next_value = (2 ** (i + 1)) / temp_max_scalar
 | |
|             for _phase in (0, 0.25, 0.5, 0.75):
 | |
|                 inter_value = value + (next_value - value) * _phase
 | |
|                 fitting_data.append((inter_value, math.pi * (2 * i + _phase)))
 | |
|         period_phase_shift = QuarticFunc(fitting_data)
 | |
|         self.set(
 | |
|             dict(amplitude_scale=amplitude_scale, period_phase_shift=period_phase_shift)
 | |
|         )
 | |
| 
 | |
|     def _getitem(self, x, weights):
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     def __repr__(self):
 | |
|         return "{name}({amplitude_scale} * sin({period_phase_shift}))".format(
 | |
|             name=self.__class__.__name__,
 | |
|             amplitude_scale=self._params["amplitude_scale"],
 | |
|             period_phase_shift=self._params["period_phase_shift"],
 | |
|         )
 |