Update synthetic environment
This commit is contained in:
		
							
								
								
									
										2
									
								
								.github/workflows/basic_test.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/basic_test.yml
									
									
									
									
										vendored
									
									
								
							| @@ -54,7 +54,7 @@ jobs: | |||||||
|         run: | |         run: | | ||||||
|           python -m pip install pytest numpy |           python -m pip install pytest numpy | ||||||
|           python -m pip install parameterized |           python -m pip install parameterized | ||||||
|           python -m pip install torch |           python -m pip install torch torchvision | ||||||
|           python --version |           python --version | ||||||
|           python -m pytest ./tests/test_synthetic.py -s |           python -m pytest ./tests/test_synthetic.py -s | ||||||
|         shell: bash |         shell: bash | ||||||
|   | |||||||
 Submodule .latent-data/NATS-Bench updated: 33bfb2eb13...f955e2ba13
									
								
							| @@ -4,5 +4,5 @@ | |||||||
| from .get_dataset_with_transform import get_datasets, get_nas_search_loaders | from .get_dataset_with_transform import get_datasets, get_nas_search_loaders | ||||||
| from .SearchDatasetWrap import SearchDataset | from .SearchDatasetWrap import SearchDataset | ||||||
|  |  | ||||||
| from .synthetic_adaptive_environment import QuadraticFunction | from .synthetic_adaptive_environment import QuadraticFunc, CubicFunc, QuarticFunc | ||||||
| from .synthetic_adaptive_environment import SynAdaptiveEnv | from .synthetic_adaptive_environment import SynAdaptiveEnv | ||||||
|   | |||||||
| @@ -2,38 +2,43 @@ | |||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||||
| ##################################################### | ##################################################### | ||||||
| import math | import math | ||||||
|  | import abc | ||||||
| import numpy as np | import numpy as np | ||||||
| from typing import Optional | from typing import Optional | ||||||
| import torch | import torch | ||||||
| import torch.utils.data as data | import torch.utils.data as data | ||||||
|  |  | ||||||
|  |  | ||||||
| class QuadraticFunction: | class FitFunc(abc.ABC): | ||||||
|     """The quadratic function that outputs f(x) = a * x^2 + b * x + c.""" |     """The fit function that outputs f(x) = a * x^2 + b * x + c.""" | ||||||
|  |  | ||||||
|     def __init__(self, list_of_points=None): |     def __init__(self, freedom: int, list_of_points=None): | ||||||
|         self._params = dict(a=None, b=None, c=None) |         self._params = dict() | ||||||
|  |         for i in range(freedom): | ||||||
|  |             self._params[i] = None | ||||||
|  |         self._freedom = freedom | ||||||
|         if list_of_points is not None: |         if list_of_points is not None: | ||||||
|             self.fit(list_of_points) |             self.fit(list_of_points) | ||||||
|  |  | ||||||
|     def set(self, a, b, c): |     def set(self, _params): | ||||||
|         self._params["a"] = a |         self._params = copy.deepcopy(_params) | ||||||
|         self._params["b"] = b |  | ||||||
|         self._params["c"] = c |  | ||||||
|  |  | ||||||
|     def check_valid(self): |     def check_valid(self): | ||||||
|         for key, value in self._params.items(): |         for key, value in self._params.items(): | ||||||
|             if value is None: |             if value is None: | ||||||
|                 raise ValueError("The {:} is None".format(key)) |                 raise ValueError("The {:} is None".format(key)) | ||||||
|  |  | ||||||
|  |     @abc.abstractmethod | ||||||
|     def __getitem__(self, x): |     def __getitem__(self, x): | ||||||
|         self.check_valid() |         raise NotImplementedError | ||||||
|         return self._params["a"] * x * x + self._params["b"] * x + self._params["c"] |  | ||||||
|  |     @abc.abstractmethod | ||||||
|  |     def _getitem(self, x): | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|     def fit( |     def fit( | ||||||
|         self, |         self, | ||||||
|         list_of_points, |         list_of_points, | ||||||
|         transf=lambda x: x, |  | ||||||
|         max_iter=900, |         max_iter=900, | ||||||
|         lr_max=1.0, |         lr_max=1.0, | ||||||
|         verbose=False, |         verbose=False, | ||||||
| @@ -44,16 +49,24 @@ class QuadraticFunction: | |||||||
|                 data.shape |                 data.shape | ||||||
|             ) |             ) | ||||||
|             x, y = data[:, 0], data[:, 1] |             x, y = data[:, 0], data[:, 1] | ||||||
|         weights = torch.nn.Parameter(torch.Tensor(3)) |         weights = torch.nn.Parameter(torch.Tensor(self._freedom)) | ||||||
|         torch.nn.init.normal_(weights, mean=0.0, std=1.0) |         torch.nn.init.normal_(weights, mean=0.0, std=1.0) | ||||||
|         optimizer = torch.optim.Adam([weights], lr=lr_max, amsgrad=True) |         optimizer = torch.optim.Adam([weights], lr=lr_max, amsgrad=True) | ||||||
|         lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[int(max_iter*0.25), int(max_iter*0.5), int(max_iter*0.75)], gamma=0.1) |         lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | ||||||
|  |             optimizer, | ||||||
|  |             milestones=[ | ||||||
|  |                 int(max_iter * 0.25), | ||||||
|  |                 int(max_iter * 0.5), | ||||||
|  |                 int(max_iter * 0.75), | ||||||
|  |             ], | ||||||
|  |             gamma=0.1, | ||||||
|  |         ) | ||||||
|         if verbose: |         if verbose: | ||||||
|             print("The optimizer: {:}".format(optimizer)) |             print("The optimizer: {:}".format(optimizer)) | ||||||
|  |  | ||||||
|         best_loss = None |         best_loss = None | ||||||
|         for _iter in range(max_iter): |         for _iter in range(max_iter): | ||||||
|             y_hat = transf(weights[0] * x * x + weights[1] * x + weights[2]) |             y_hat = self._getitem(x, weights) | ||||||
|             loss = torch.mean(torch.abs(y - y_hat)) |             loss = torch.mean(torch.abs(y - y_hat)) | ||||||
|             optimizer.zero_grad() |             optimizer.zero_grad() | ||||||
|             loss.backward() |             loss.backward() | ||||||
| @@ -61,23 +74,105 @@ class QuadraticFunction: | |||||||
|             lr_scheduler.step() |             lr_scheduler.step() | ||||||
|             if verbose: |             if verbose: | ||||||
|                 print( |                 print( | ||||||
|                     "In QuadraticFunction's fit, loss at the {:02d}/{:02d}-th iter is {:}".format( |                     "In the fit, loss at the {:02d}/{:02d}-th iter is {:}".format( | ||||||
|                         _iter, max_iter, loss.item() |                         _iter, max_iter, loss.item() | ||||||
|                     ) |                     ) | ||||||
|                 ) |                 ) | ||||||
|             # Update the params |             # Update the params | ||||||
|             if best_loss is None or best_loss > loss.item(): |             if best_loss is None or best_loss > loss.item(): | ||||||
|                 best_loss = loss.item() |                 best_loss = loss.item() | ||||||
|                 self._params["a"] = weights[0].item() |                 for i in range(self._freedom): | ||||||
|                 self._params["b"] = weights[1].item() |                     self._params[i] = weights[i].item() | ||||||
|                 self._params["c"] = weights[2].item() |  | ||||||
|  |     def __repr__(self): | ||||||
|  |         return "{name}(freedom={freedom})".format( | ||||||
|  |             name=self.__class__.__name__, freedom=freedom | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class QuadraticFunc(FitFunc): | ||||||
|  |     """The quadratic function that outputs f(x) = a * x^2 + b * x + c.""" | ||||||
|  |  | ||||||
|  |     def __init__(self, list_of_points=None): | ||||||
|  |         super(QuadraticFunc, self).__init__(3, list_of_points) | ||||||
|  |  | ||||||
|  |     def __getitem__(self, x): | ||||||
|  |         self.check_valid() | ||||||
|  |         return self._params[0] * x * x + self._params[1] * x + self._params[2] | ||||||
|  |  | ||||||
|  |     def _getitem(self, x, weights): | ||||||
|  |         return weights[0] * x * x + weights[1] * x + weights[2] | ||||||
|  |  | ||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
|         return "{name}(y = {a} * x^2 + {b} * x + {c})".format( |         return "{name}(y = {a} * x^2 + {b} * x + {c})".format( | ||||||
|             name=self.__class__.__name__, |             name=self.__class__.__name__, | ||||||
|             a=self._params["a"], |             a=self._params[0], | ||||||
|             b=self._params["b"], |             b=self._params[1], | ||||||
|             c=self._params["c"], |             c=self._params[2], | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class CubicFunc(FitFunc): | ||||||
|  |     """The cubic function that outputs f(x) = a * x^3 + b * x^2 + c * x + d.""" | ||||||
|  |  | ||||||
|  |     def __init__(self, list_of_points=None): | ||||||
|  |         super(CubicFunc, self).__init__(4, list_of_points) | ||||||
|  |  | ||||||
|  |     def __getitem__(self, x): | ||||||
|  |         self.check_valid() | ||||||
|  |         return ( | ||||||
|  |             self._params[0] * x ** 3 | ||||||
|  |             + self._params[1] * x ** 2 | ||||||
|  |             + self._params[2] * x | ||||||
|  |             + self._params[3] | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def _getitem(self, x, weights): | ||||||
|  |         return weights[0] * x ** 3 + weights[1] * x ** 2 + weights[2] * x + weights[3] | ||||||
|  |  | ||||||
|  |     def __repr__(self): | ||||||
|  |         return "{name}(y = {a} * x^3 + {b} * x^2 + {c} * x + {d})".format( | ||||||
|  |             name=self.__class__.__name__, | ||||||
|  |             a=self._params[0], | ||||||
|  |             b=self._params[1], | ||||||
|  |             c=self._params[2], | ||||||
|  |             d=self._params[3], | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class QuarticFunc(FitFunc): | ||||||
|  |     """The quartic function that outputs f(x) = a * x^4 + b * x^3 + c * x^2 + d * x + e.""" | ||||||
|  |  | ||||||
|  |     def __init__(self, list_of_points=None): | ||||||
|  |         super(QuarticFunc, self).__init__(5, list_of_points) | ||||||
|  |  | ||||||
|  |     def __getitem__(self, x): | ||||||
|  |         self.check_valid() | ||||||
|  |         return ( | ||||||
|  |             self._params[0] * x ** 4 | ||||||
|  |             + self._params[1] * x ** 3 | ||||||
|  |             + self._params[2] * x ** 2 | ||||||
|  |             + self._params[3] * x | ||||||
|  |             + self._params[4] | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def _getitem(self, x, weights): | ||||||
|  |         return ( | ||||||
|  |             weights[0] * x ** 4 | ||||||
|  |             + weights[1] * x ** 3 | ||||||
|  |             + weights[2] * x ** 2 | ||||||
|  |             + weights[3] * x | ||||||
|  |             + weights[4] | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def __repr__(self): | ||||||
|  |         return "{name}(y = {a} * x^4 + {b} * x^3 + {c} * x^2 + {d} * x + {e})".format( | ||||||
|  |             name=self.__class__.__name__, | ||||||
|  |             a=self._params[0], | ||||||
|  |             b=self._params[1], | ||||||
|  |             c=self._params[2], | ||||||
|  |             d=self._params[3], | ||||||
|  |             e=self._params[3], | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -95,28 +190,29 @@ class SynAdaptiveEnv(data.Dataset): | |||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|         num: int = 100, |         num: int = 100, | ||||||
|         num_sin_phase: int = 4, |         num_sin_phase: int = 7, | ||||||
|         min_amplitude: float = 1, |         min_amplitude: float = 1, | ||||||
|         max_amplitude: float = 4, |         max_amplitude: float = 4, | ||||||
|         phase_shift: float = 0, |         phase_shift: float = 0, | ||||||
|         mode: Optional[str] = None, |         mode: Optional[str] = None, | ||||||
|     ): |     ): | ||||||
|         self._amplitude_scale = QuadraticFunction( |         self._amplitude_scale = QuadraticFunc( | ||||||
|             [(0, min_amplitude), (0.5, max_amplitude), (0, min_amplitude)] |             [(0, min_amplitude), (0.5, max_amplitude), (1, min_amplitude)] | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         self._num_sin_phase = num_sin_phase |         self._num_sin_phase = num_sin_phase | ||||||
|         self._interval = 1.0 / (float(num) - 1) |         self._interval = 1.0 / (float(num) - 1) | ||||||
|         self._total_num = num |         self._total_num = num | ||||||
|  |  | ||||||
|         self._period_phase_shift = QuadraticFunction() |  | ||||||
|  |  | ||||||
|         fitting_data = [] |         fitting_data = [] | ||||||
|         temp_max_scalar = 2 ** num_sin_phase |         temp_max_scalar = 2 ** (num_sin_phase - 1) | ||||||
|         for i in range(num_sin_phase): |         for i in range(num_sin_phase): | ||||||
|             value = (2 ** i) / temp_max_scalar |             value = (2 ** i) / temp_max_scalar | ||||||
|             fitting_data.append((value, math.sin(value))) |             next_value = (2 ** (i + 1)) / temp_max_scalar | ||||||
|         self._period_phase_shift.fit(fitting_data, transf=lambda x: torch.sin(x)) |             for _phase in (0, 0.25, 0.5, 0.75): | ||||||
|  |                 inter_value = value + (next_value - value) * _phase | ||||||
|  |                 fitting_data.append((inter_value, math.pi * (2 * i + _phase))) | ||||||
|  |         self._period_phase_shift = QuarticFunc(fitting_data) | ||||||
|  |  | ||||||
|         # Training Set 60% |         # Training Set 60% | ||||||
|         num_of_train = int(self._total_num * 0.6) |         num_of_train = int(self._total_num * 0.6) | ||||||
| @@ -135,11 +231,6 @@ class SynAdaptiveEnv(data.Dataset): | |||||||
|             self._indexes = all_indexes[num_of_train + num_of_valid :] |             self._indexes = all_indexes[num_of_train + num_of_valid :] | ||||||
|         else: |         else: | ||||||
|             raise ValueError("Unkonwn mode of {:}".format(mode)) |             raise ValueError("Unkonwn mode of {:}".format(mode)) | ||||||
|         # transformation function |  | ||||||
|         self._transform = None |  | ||||||
|  |  | ||||||
|     def set_transform(self, fn): |  | ||||||
|         self._transform = fn |  | ||||||
|  |  | ||||||
|     def __iter__(self): |     def __iter__(self): | ||||||
|         self._iter_num = 0 |         self._iter_num = 0 | ||||||
| @@ -164,6 +255,14 @@ class SynAdaptiveEnv(data.Dataset): | |||||||
|         return len(self._indexes) |         return len(self._indexes) | ||||||
|  |  | ||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
|         return "{name}({cur_num:}/{total} elements)".format( |         return ( | ||||||
|             name=self.__class__.__name__, cur_num=self._total_num, total=len(self) |             "{name}({cur_num:}/{total} elements,\n" | ||||||
|  |             "amplitude={amplitude},\n" | ||||||
|  |             "period_phase_shift={period_phase_shift})".format( | ||||||
|  |                 name=self.__class__.__name__, | ||||||
|  |                 cur_num=self._total_num, | ||||||
|  |                 total=len(self), | ||||||
|  |                 amplitude=self._amplitude_scale, | ||||||
|  |                 period_phase_shift=self._period_phase_shift, | ||||||
|  |             ) | ||||||
|         ) |         ) | ||||||
|   | |||||||
							
								
								
									
										121
									
								
								notebooks/TOT/synthetic-adaptive-env.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										121
									
								
								notebooks/TOT/synthetic-adaptive-env.ipynb
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										263
									
								
								notebooks/TOT/synthetic-env.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										263
									
								
								notebooks/TOT/synthetic-env.ipynb
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							| @@ -13,15 +13,15 @@ print("library path: {:}".format(lib_dir)) | |||||||
| if str(lib_dir) not in sys.path: | if str(lib_dir) not in sys.path: | ||||||
|     sys.path.insert(0, str(lib_dir)) |     sys.path.insert(0, str(lib_dir)) | ||||||
|  |  | ||||||
| from datasets import QuadraticFunction | from datasets import QuadraticFunc | ||||||
| from datasets import SynAdaptiveEnv | from datasets import SynAdaptiveEnv | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestQuadraticFunction(unittest.TestCase): | class TestQuadraticFunc(unittest.TestCase): | ||||||
|     """Test the quadratic function.""" |     """Test the quadratic function.""" | ||||||
|  |  | ||||||
|     def test_simple(self): |     def test_simple(self): | ||||||
|         function = QuadraticFunction([[0, 1], [0.5, 4], [1, 1]]) |         function = QuadraticFunc([[0, 1], [0.5, 4], [1, 1]]) | ||||||
|         print(function) |         print(function) | ||||||
|         for x in (0, 0.5, 1): |         for x in (0, 0.5, 1): | ||||||
|             print("f({:})={:}".format(x, function[x])) |             print("f({:})={:}".format(x, function[x])) | ||||||
| @@ -31,7 +31,7 @@ class TestQuadraticFunction(unittest.TestCase): | |||||||
|         self.assertTrue(abs(function[1] - 1) < thresh) |         self.assertTrue(abs(function[1] - 1) < thresh) | ||||||
|  |  | ||||||
|     def test_none(self): |     def test_none(self): | ||||||
|         function = QuadraticFunction() |         function = QuadraticFunc() | ||||||
|         function.fit([[0, 1], [0.5, 4], [1, 1]], max_iter=3000, verbose=True) |         function.fit([[0, 1], [0.5, 4], [1, 1]], max_iter=3000, verbose=True) | ||||||
|         print(function) |         print(function) | ||||||
|         thresh = 0.2 |         thresh = 0.2 | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user