Update synthetic environment
This commit is contained in:
		
							
								
								
									
										2
									
								
								.github/workflows/basic_test.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/basic_test.yml
									
									
									
									
										vendored
									
									
								
							| @@ -54,7 +54,7 @@ jobs: | ||||
|         run: | | ||||
|           python -m pip install pytest numpy | ||||
|           python -m pip install parameterized | ||||
|           python -m pip install torch | ||||
|           python -m pip install torch torchvision | ||||
|           python --version | ||||
|           python -m pytest ./tests/test_synthetic.py -s | ||||
|         shell: bash | ||||
|   | ||||
 Submodule .latent-data/NATS-Bench updated: 33bfb2eb13...f955e2ba13
									
								
							| @@ -4,5 +4,5 @@ | ||||
| from .get_dataset_with_transform import get_datasets, get_nas_search_loaders | ||||
| from .SearchDatasetWrap import SearchDataset | ||||
|  | ||||
| from .synthetic_adaptive_environment import QuadraticFunction | ||||
| from .synthetic_adaptive_environment import QuadraticFunc, CubicFunc, QuarticFunc | ||||
| from .synthetic_adaptive_environment import SynAdaptiveEnv | ||||
|   | ||||
| @@ -2,38 +2,43 @@ | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||
| ##################################################### | ||||
| import math | ||||
| import abc | ||||
| import numpy as np | ||||
| from typing import Optional | ||||
| import torch | ||||
| import torch.utils.data as data | ||||
|  | ||||
|  | ||||
| class QuadraticFunction: | ||||
|     """The quadratic function that outputs f(x) = a * x^2 + b * x + c.""" | ||||
| class FitFunc(abc.ABC): | ||||
|     """The fit function that outputs f(x) = a * x^2 + b * x + c.""" | ||||
|  | ||||
|     def __init__(self, list_of_points=None): | ||||
|         self._params = dict(a=None, b=None, c=None) | ||||
|     def __init__(self, freedom: int, list_of_points=None): | ||||
|         self._params = dict() | ||||
|         for i in range(freedom): | ||||
|             self._params[i] = None | ||||
|         self._freedom = freedom | ||||
|         if list_of_points is not None: | ||||
|             self.fit(list_of_points) | ||||
|  | ||||
|     def set(self, a, b, c): | ||||
|         self._params["a"] = a | ||||
|         self._params["b"] = b | ||||
|         self._params["c"] = c | ||||
|     def set(self, _params): | ||||
|         self._params = copy.deepcopy(_params) | ||||
|  | ||||
|     def check_valid(self): | ||||
|         for key, value in self._params.items(): | ||||
|             if value is None: | ||||
|                 raise ValueError("The {:} is None".format(key)) | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def __getitem__(self, x): | ||||
|         self.check_valid() | ||||
|         return self._params["a"] * x * x + self._params["b"] * x + self._params["c"] | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def _getitem(self, x): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def fit( | ||||
|         self, | ||||
|         list_of_points, | ||||
|         transf=lambda x: x, | ||||
|         max_iter=900, | ||||
|         lr_max=1.0, | ||||
|         verbose=False, | ||||
| @@ -44,16 +49,24 @@ class QuadraticFunction: | ||||
|                 data.shape | ||||
|             ) | ||||
|             x, y = data[:, 0], data[:, 1] | ||||
|         weights = torch.nn.Parameter(torch.Tensor(3)) | ||||
|         weights = torch.nn.Parameter(torch.Tensor(self._freedom)) | ||||
|         torch.nn.init.normal_(weights, mean=0.0, std=1.0) | ||||
|         optimizer = torch.optim.Adam([weights], lr=lr_max, amsgrad=True) | ||||
|         lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[int(max_iter*0.25), int(max_iter*0.5), int(max_iter*0.75)], gamma=0.1) | ||||
|         lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | ||||
|             optimizer, | ||||
|             milestones=[ | ||||
|                 int(max_iter * 0.25), | ||||
|                 int(max_iter * 0.5), | ||||
|                 int(max_iter * 0.75), | ||||
|             ], | ||||
|             gamma=0.1, | ||||
|         ) | ||||
|         if verbose: | ||||
|             print("The optimizer: {:}".format(optimizer)) | ||||
|  | ||||
|         best_loss = None | ||||
|         for _iter in range(max_iter): | ||||
|             y_hat = transf(weights[0] * x * x + weights[1] * x + weights[2]) | ||||
|             y_hat = self._getitem(x, weights) | ||||
|             loss = torch.mean(torch.abs(y - y_hat)) | ||||
|             optimizer.zero_grad() | ||||
|             loss.backward() | ||||
| @@ -61,23 +74,105 @@ class QuadraticFunction: | ||||
|             lr_scheduler.step() | ||||
|             if verbose: | ||||
|                 print( | ||||
|                     "In QuadraticFunction's fit, loss at the {:02d}/{:02d}-th iter is {:}".format( | ||||
|                     "In the fit, loss at the {:02d}/{:02d}-th iter is {:}".format( | ||||
|                         _iter, max_iter, loss.item() | ||||
|                     ) | ||||
|                 ) | ||||
|             # Update the params | ||||
|             if best_loss is None or best_loss > loss.item(): | ||||
|                 best_loss = loss.item() | ||||
|                 self._params["a"] = weights[0].item() | ||||
|                 self._params["b"] = weights[1].item() | ||||
|                 self._params["c"] = weights[2].item() | ||||
|                 for i in range(self._freedom): | ||||
|                     self._params[i] = weights[i].item() | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}(freedom={freedom})".format( | ||||
|             name=self.__class__.__name__, freedom=freedom | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class QuadraticFunc(FitFunc): | ||||
|     """The quadratic function that outputs f(x) = a * x^2 + b * x + c.""" | ||||
|  | ||||
|     def __init__(self, list_of_points=None): | ||||
|         super(QuadraticFunc, self).__init__(3, list_of_points) | ||||
|  | ||||
|     def __getitem__(self, x): | ||||
|         self.check_valid() | ||||
|         return self._params[0] * x * x + self._params[1] * x + self._params[2] | ||||
|  | ||||
|     def _getitem(self, x, weights): | ||||
|         return weights[0] * x * x + weights[1] * x + weights[2] | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}(y = {a} * x^2 + {b} * x + {c})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             a=self._params["a"], | ||||
|             b=self._params["b"], | ||||
|             c=self._params["c"], | ||||
|             a=self._params[0], | ||||
|             b=self._params[1], | ||||
|             c=self._params[2], | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class CubicFunc(FitFunc): | ||||
|     """The cubic function that outputs f(x) = a * x^3 + b * x^2 + c * x + d.""" | ||||
|  | ||||
|     def __init__(self, list_of_points=None): | ||||
|         super(CubicFunc, self).__init__(4, list_of_points) | ||||
|  | ||||
|     def __getitem__(self, x): | ||||
|         self.check_valid() | ||||
|         return ( | ||||
|             self._params[0] * x ** 3 | ||||
|             + self._params[1] * x ** 2 | ||||
|             + self._params[2] * x | ||||
|             + self._params[3] | ||||
|         ) | ||||
|  | ||||
|     def _getitem(self, x, weights): | ||||
|         return weights[0] * x ** 3 + weights[1] * x ** 2 + weights[2] * x + weights[3] | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}(y = {a} * x^3 + {b} * x^2 + {c} * x + {d})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             a=self._params[0], | ||||
|             b=self._params[1], | ||||
|             c=self._params[2], | ||||
|             d=self._params[3], | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class QuarticFunc(FitFunc): | ||||
|     """The quartic function that outputs f(x) = a * x^4 + b * x^3 + c * x^2 + d * x + e.""" | ||||
|  | ||||
|     def __init__(self, list_of_points=None): | ||||
|         super(QuarticFunc, self).__init__(5, list_of_points) | ||||
|  | ||||
|     def __getitem__(self, x): | ||||
|         self.check_valid() | ||||
|         return ( | ||||
|             self._params[0] * x ** 4 | ||||
|             + self._params[1] * x ** 3 | ||||
|             + self._params[2] * x ** 2 | ||||
|             + self._params[3] * x | ||||
|             + self._params[4] | ||||
|         ) | ||||
|  | ||||
|     def _getitem(self, x, weights): | ||||
|         return ( | ||||
|             weights[0] * x ** 4 | ||||
|             + weights[1] * x ** 3 | ||||
|             + weights[2] * x ** 2 | ||||
|             + weights[3] * x | ||||
|             + weights[4] | ||||
|         ) | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}(y = {a} * x^4 + {b} * x^3 + {c} * x^2 + {d} * x + {e})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             a=self._params[0], | ||||
|             b=self._params[1], | ||||
|             c=self._params[2], | ||||
|             d=self._params[3], | ||||
|             e=self._params[3], | ||||
|         ) | ||||
|  | ||||
|  | ||||
| @@ -95,28 +190,29 @@ class SynAdaptiveEnv(data.Dataset): | ||||
|     def __init__( | ||||
|         self, | ||||
|         num: int = 100, | ||||
|         num_sin_phase: int = 4, | ||||
|         num_sin_phase: int = 7, | ||||
|         min_amplitude: float = 1, | ||||
|         max_amplitude: float = 4, | ||||
|         phase_shift: float = 0, | ||||
|         mode: Optional[str] = None, | ||||
|     ): | ||||
|         self._amplitude_scale = QuadraticFunction( | ||||
|             [(0, min_amplitude), (0.5, max_amplitude), (0, min_amplitude)] | ||||
|         self._amplitude_scale = QuadraticFunc( | ||||
|             [(0, min_amplitude), (0.5, max_amplitude), (1, min_amplitude)] | ||||
|         ) | ||||
|  | ||||
|         self._num_sin_phase = num_sin_phase | ||||
|         self._interval = 1.0 / (float(num) - 1) | ||||
|         self._total_num = num | ||||
|  | ||||
|         self._period_phase_shift = QuadraticFunction() | ||||
|  | ||||
|         fitting_data = [] | ||||
|         temp_max_scalar = 2 ** num_sin_phase | ||||
|         temp_max_scalar = 2 ** (num_sin_phase - 1) | ||||
|         for i in range(num_sin_phase): | ||||
|             value = (2 ** i) / temp_max_scalar | ||||
|             fitting_data.append((value, math.sin(value))) | ||||
|         self._period_phase_shift.fit(fitting_data, transf=lambda x: torch.sin(x)) | ||||
|             next_value = (2 ** (i + 1)) / temp_max_scalar | ||||
|             for _phase in (0, 0.25, 0.5, 0.75): | ||||
|                 inter_value = value + (next_value - value) * _phase | ||||
|                 fitting_data.append((inter_value, math.pi * (2 * i + _phase))) | ||||
|         self._period_phase_shift = QuarticFunc(fitting_data) | ||||
|  | ||||
|         # Training Set 60% | ||||
|         num_of_train = int(self._total_num * 0.6) | ||||
| @@ -135,11 +231,6 @@ class SynAdaptiveEnv(data.Dataset): | ||||
|             self._indexes = all_indexes[num_of_train + num_of_valid :] | ||||
|         else: | ||||
|             raise ValueError("Unkonwn mode of {:}".format(mode)) | ||||
|         # transformation function | ||||
|         self._transform = None | ||||
|  | ||||
|     def set_transform(self, fn): | ||||
|         self._transform = fn | ||||
|  | ||||
|     def __iter__(self): | ||||
|         self._iter_num = 0 | ||||
| @@ -164,6 +255,14 @@ class SynAdaptiveEnv(data.Dataset): | ||||
|         return len(self._indexes) | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({cur_num:}/{total} elements)".format( | ||||
|             name=self.__class__.__name__, cur_num=self._total_num, total=len(self) | ||||
|         return ( | ||||
|             "{name}({cur_num:}/{total} elements,\n" | ||||
|             "amplitude={amplitude},\n" | ||||
|             "period_phase_shift={period_phase_shift})".format( | ||||
|                 name=self.__class__.__name__, | ||||
|                 cur_num=self._total_num, | ||||
|                 total=len(self), | ||||
|                 amplitude=self._amplitude_scale, | ||||
|                 period_phase_shift=self._period_phase_shift, | ||||
|             ) | ||||
|         ) | ||||
|   | ||||
							
								
								
									
										121
									
								
								notebooks/TOT/synthetic-adaptive-env.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										121
									
								
								notebooks/TOT/synthetic-adaptive-env.ipynb
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										263
									
								
								notebooks/TOT/synthetic-env.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										263
									
								
								notebooks/TOT/synthetic-env.ipynb
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							| @@ -13,15 +13,15 @@ print("library path: {:}".format(lib_dir)) | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| from datasets import QuadraticFunction | ||||
| from datasets import QuadraticFunc | ||||
| from datasets import SynAdaptiveEnv | ||||
|  | ||||
|  | ||||
| class TestQuadraticFunction(unittest.TestCase): | ||||
| class TestQuadraticFunc(unittest.TestCase): | ||||
|     """Test the quadratic function.""" | ||||
|  | ||||
|     def test_simple(self): | ||||
|         function = QuadraticFunction([[0, 1], [0.5, 4], [1, 1]]) | ||||
|         function = QuadraticFunc([[0, 1], [0.5, 4], [1, 1]]) | ||||
|         print(function) | ||||
|         for x in (0, 0.5, 1): | ||||
|             print("f({:})={:}".format(x, function[x])) | ||||
| @@ -31,7 +31,7 @@ class TestQuadraticFunction(unittest.TestCase): | ||||
|         self.assertTrue(abs(function[1] - 1) < thresh) | ||||
|  | ||||
|     def test_none(self): | ||||
|         function = QuadraticFunction() | ||||
|         function = QuadraticFunc() | ||||
|         function.fit([[0, 1], [0.5, 4], [1, 1]], max_iter=3000, verbose=True) | ||||
|         print(function) | ||||
|         thresh = 0.2 | ||||
|   | ||||
		Reference in New Issue
	
	Block a user