| 
									
										
										
										
											2021-04-14 01:04:46 +08:00
										 |  |  | ##################################################### | 
					
						
							|  |  |  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | 
					
						
							|  |  |  | ##################################################### | 
					
						
							| 
									
										
										
										
											2021-04-22 19:12:21 +08:00
										 |  |  | import math | 
					
						
							| 
									
										
										
										
											2021-04-22 20:31:20 +08:00
										 |  |  | import abc | 
					
						
							| 
									
										
										
										
											2021-04-22 23:32:26 +08:00
										 |  |  | import copy | 
					
						
							| 
									
										
										
										
											2021-04-14 01:04:46 +08:00
										 |  |  | import numpy as np | 
					
						
							|  |  |  | from typing import Optional | 
					
						
							| 
									
										
										
										
											2021-04-22 19:12:21 +08:00
										 |  |  | import torch | 
					
						
							| 
									
										
										
										
											2021-04-14 01:04:46 +08:00
										 |  |  | import torch.utils.data as data | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-22 20:31:20 +08:00
										 |  |  | class FitFunc(abc.ABC): | 
					
						
							|  |  |  |     """The fit function that outputs f(x) = a * x^2 + b * x + c.""" | 
					
						
							| 
									
										
										
										
											2021-04-22 19:12:21 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-27 20:09:37 +08:00
										 |  |  |     def __init__(self, freedom: int, list_of_points=None, params=None): | 
					
						
							| 
									
										
										
										
											2021-04-22 20:31:20 +08:00
										 |  |  |         self._params = dict() | 
					
						
							|  |  |  |         for i in range(freedom): | 
					
						
							|  |  |  |             self._params[i] = None | 
					
						
							|  |  |  |         self._freedom = freedom | 
					
						
							| 
									
										
										
										
											2021-04-27 20:09:37 +08:00
										 |  |  |         if list_of_points is not None and params is not None: | 
					
						
							|  |  |  |             raise ValueError("list_of_points and params can not be set simultaneously") | 
					
						
							| 
									
										
										
										
											2021-04-22 19:12:21 +08:00
										 |  |  |         if list_of_points is not None: | 
					
						
							| 
									
										
										
										
											2021-04-26 05:16:38 -07:00
										 |  |  |             self.fit(list_of_points=list_of_points) | 
					
						
							| 
									
										
										
										
											2021-04-27 20:09:37 +08:00
										 |  |  |         if params is not None: | 
					
						
							|  |  |  |             self.set(params) | 
					
						
							| 
									
										
										
										
											2021-04-22 19:12:21 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-27 20:09:37 +08:00
										 |  |  |     def set(self, params): | 
					
						
							|  |  |  |         self._params = copy.deepcopy(params) | 
					
						
							| 
									
										
										
										
											2021-04-22 19:12:21 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def check_valid(self): | 
					
						
							|  |  |  |         for key, value in self._params.items(): | 
					
						
							|  |  |  |             if value is None: | 
					
						
							|  |  |  |                 raise ValueError("The {:} is None".format(key)) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-22 20:31:20 +08:00
										 |  |  |     @abc.abstractmethod | 
					
						
							| 
									
										
										
										
											2021-04-22 23:32:26 +08:00
										 |  |  |     def __call__(self, x): | 
					
						
							| 
									
										
										
										
											2021-04-22 20:31:20 +08:00
										 |  |  |         raise NotImplementedError | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-23 02:12:11 -07:00
										 |  |  |     def noise_call(self, x, std=0.1): | 
					
						
							|  |  |  |         clean_y = self.__call__(x) | 
					
						
							|  |  |  |         if isinstance(clean_y, np.ndarray): | 
					
						
							|  |  |  |             noise_y = clean_y + np.random.normal(scale=std, size=clean_y.shape) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             raise ValueError("Unkonwn type: {:}".format(type(clean_y))) | 
					
						
							|  |  |  |         return noise_y | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-22 20:31:20 +08:00
										 |  |  |     @abc.abstractmethod | 
					
						
							|  |  |  |     def _getitem(self, x): | 
					
						
							|  |  |  |         raise NotImplementedError | 
					
						
							| 
									
										
										
										
											2021-04-22 19:12:21 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-26 05:16:38 -07:00
										 |  |  |     def fit(self, **kwargs): | 
					
						
							|  |  |  |         list_of_points = kwargs["list_of_points"] | 
					
						
							|  |  |  |         max_iter, lr_max, verbose = ( | 
					
						
							|  |  |  |             kwargs.get("max_iter", 900), | 
					
						
							|  |  |  |             kwargs.get("lr_max", 1.0), | 
					
						
							|  |  |  |             kwargs.get("verbose", False), | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-04-22 19:12:21 +08:00
										 |  |  |         with torch.no_grad(): | 
					
						
							|  |  |  |             data = torch.Tensor(list_of_points).type(torch.float32) | 
					
						
							|  |  |  |             assert data.ndim == 2 and data.size(1) == 2, "Invalid shape : {:}".format( | 
					
						
							|  |  |  |                 data.shape | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             x, y = data[:, 0], data[:, 1] | 
					
						
							| 
									
										
										
										
											2021-04-22 20:31:20 +08:00
										 |  |  |         weights = torch.nn.Parameter(torch.Tensor(self._freedom)) | 
					
						
							| 
									
										
										
										
											2021-04-22 19:12:21 +08:00
										 |  |  |         torch.nn.init.normal_(weights, mean=0.0, std=1.0) | 
					
						
							|  |  |  |         optimizer = torch.optim.Adam([weights], lr=lr_max, amsgrad=True) | 
					
						
							| 
									
										
										
										
											2021-04-22 20:31:20 +08:00
										 |  |  |         lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | 
					
						
							|  |  |  |             optimizer, | 
					
						
							|  |  |  |             milestones=[ | 
					
						
							|  |  |  |                 int(max_iter * 0.25), | 
					
						
							|  |  |  |                 int(max_iter * 0.5), | 
					
						
							|  |  |  |                 int(max_iter * 0.75), | 
					
						
							|  |  |  |             ], | 
					
						
							|  |  |  |             gamma=0.1, | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-04-22 19:12:21 +08:00
										 |  |  |         if verbose: | 
					
						
							|  |  |  |             print("The optimizer: {:}".format(optimizer)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         best_loss = None | 
					
						
							|  |  |  |         for _iter in range(max_iter): | 
					
						
							| 
									
										
										
										
											2021-04-22 20:31:20 +08:00
										 |  |  |             y_hat = self._getitem(x, weights) | 
					
						
							| 
									
										
										
										
											2021-04-22 19:12:21 +08:00
										 |  |  |             loss = torch.mean(torch.abs(y - y_hat)) | 
					
						
							|  |  |  |             optimizer.zero_grad() | 
					
						
							|  |  |  |             loss.backward() | 
					
						
							|  |  |  |             optimizer.step() | 
					
						
							|  |  |  |             lr_scheduler.step() | 
					
						
							|  |  |  |             if verbose: | 
					
						
							|  |  |  |                 print( | 
					
						
							| 
									
										
										
										
											2021-04-22 20:31:20 +08:00
										 |  |  |                     "In the fit, loss at the {:02d}/{:02d}-th iter is {:}".format( | 
					
						
							| 
									
										
										
										
											2021-04-22 19:12:21 +08:00
										 |  |  |                         _iter, max_iter, loss.item() | 
					
						
							|  |  |  |                     ) | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |             # Update the params | 
					
						
							|  |  |  |             if best_loss is None or best_loss > loss.item(): | 
					
						
							|  |  |  |                 best_loss = loss.item() | 
					
						
							| 
									
										
										
										
											2021-04-22 20:31:20 +08:00
										 |  |  |                 for i in range(self._freedom): | 
					
						
							|  |  |  |                     self._params[i] = weights[i].item() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __repr__(self): | 
					
						
							|  |  |  |         return "{name}(freedom={freedom})".format( | 
					
						
							|  |  |  |             name=self.__class__.__name__, freedom=freedom | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class QuadraticFunc(FitFunc): | 
					
						
							|  |  |  |     """The quadratic function that outputs f(x) = a * x^2 + b * x + c.""" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__(self, list_of_points=None): | 
					
						
							|  |  |  |         super(QuadraticFunc, self).__init__(3, list_of_points) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-22 23:32:26 +08:00
										 |  |  |     def __call__(self, x): | 
					
						
							| 
									
										
										
										
											2021-04-22 20:31:20 +08:00
										 |  |  |         self.check_valid() | 
					
						
							|  |  |  |         return self._params[0] * x * x + self._params[1] * x + self._params[2] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _getitem(self, x, weights): | 
					
						
							|  |  |  |         return weights[0] * x * x + weights[1] * x + weights[2] | 
					
						
							| 
									
										
										
										
											2021-04-22 19:12:21 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def __repr__(self): | 
					
						
							| 
									
										
										
										
											2021-04-26 05:16:38 -07:00
										 |  |  |         return "{name}({a} * x^2 + {b} * x + {c})".format( | 
					
						
							| 
									
										
										
										
											2021-04-22 19:12:21 +08:00
										 |  |  |             name=self.__class__.__name__, | 
					
						
							| 
									
										
										
										
											2021-04-22 20:31:20 +08:00
										 |  |  |             a=self._params[0], | 
					
						
							|  |  |  |             b=self._params[1], | 
					
						
							|  |  |  |             c=self._params[2], | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class CubicFunc(FitFunc): | 
					
						
							|  |  |  |     """The cubic function that outputs f(x) = a * x^3 + b * x^2 + c * x + d.""" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__(self, list_of_points=None): | 
					
						
							|  |  |  |         super(CubicFunc, self).__init__(4, list_of_points) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-22 23:32:26 +08:00
										 |  |  |     def __call__(self, x): | 
					
						
							| 
									
										
										
										
											2021-04-22 20:31:20 +08:00
										 |  |  |         self.check_valid() | 
					
						
							|  |  |  |         return ( | 
					
						
							|  |  |  |             self._params[0] * x ** 3 | 
					
						
							|  |  |  |             + self._params[1] * x ** 2 | 
					
						
							|  |  |  |             + self._params[2] * x | 
					
						
							|  |  |  |             + self._params[3] | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _getitem(self, x, weights): | 
					
						
							|  |  |  |         return weights[0] * x ** 3 + weights[1] * x ** 2 + weights[2] * x + weights[3] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __repr__(self): | 
					
						
							| 
									
										
										
										
											2021-04-26 05:16:38 -07:00
										 |  |  |         return "{name}({a} * x^3 + {b} * x^2 + {c} * x + {d})".format( | 
					
						
							| 
									
										
										
										
											2021-04-22 20:31:20 +08:00
										 |  |  |             name=self.__class__.__name__, | 
					
						
							|  |  |  |             a=self._params[0], | 
					
						
							|  |  |  |             b=self._params[1], | 
					
						
							|  |  |  |             c=self._params[2], | 
					
						
							|  |  |  |             d=self._params[3], | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class QuarticFunc(FitFunc): | 
					
						
							|  |  |  |     """The quartic function that outputs f(x) = a * x^4 + b * x^3 + c * x^2 + d * x + e.""" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__(self, list_of_points=None): | 
					
						
							|  |  |  |         super(QuarticFunc, self).__init__(5, list_of_points) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-22 23:32:26 +08:00
										 |  |  |     def __call__(self, x): | 
					
						
							| 
									
										
										
										
											2021-04-22 20:31:20 +08:00
										 |  |  |         self.check_valid() | 
					
						
							|  |  |  |         return ( | 
					
						
							|  |  |  |             self._params[0] * x ** 4 | 
					
						
							|  |  |  |             + self._params[1] * x ** 3 | 
					
						
							|  |  |  |             + self._params[2] * x ** 2 | 
					
						
							|  |  |  |             + self._params[3] * x | 
					
						
							|  |  |  |             + self._params[4] | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _getitem(self, x, weights): | 
					
						
							|  |  |  |         return ( | 
					
						
							|  |  |  |             weights[0] * x ** 4 | 
					
						
							|  |  |  |             + weights[1] * x ** 3 | 
					
						
							|  |  |  |             + weights[2] * x ** 2 | 
					
						
							|  |  |  |             + weights[3] * x | 
					
						
							|  |  |  |             + weights[4] | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __repr__(self): | 
					
						
							| 
									
										
										
										
											2021-04-26 05:16:38 -07:00
										 |  |  |         return "{name}({a} * x^4 + {b} * x^3 + {c} * x^2 + {d} * x + {e})".format( | 
					
						
							| 
									
										
										
										
											2021-04-22 20:31:20 +08:00
										 |  |  |             name=self.__class__.__name__, | 
					
						
							|  |  |  |             a=self._params[0], | 
					
						
							|  |  |  |             b=self._params[1], | 
					
						
							|  |  |  |             c=self._params[2], | 
					
						
							|  |  |  |             d=self._params[3], | 
					
						
							|  |  |  |             e=self._params[3], | 
					
						
							| 
									
										
										
										
											2021-04-22 19:12:21 +08:00
										 |  |  |         ) |