Add SuperSimpleNorm and update synthetic env
This commit is contained in:
		| @@ -33,6 +33,14 @@ class FitFunc(abc.ABC): | ||||
|     def __call__(self, x): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def noise_call(self, x, std=0.1): | ||||
|         clean_y = self.__call__(x) | ||||
|         if isinstance(clean_y, np.ndarray): | ||||
|             noise_y = clean_y + np.random.normal(scale=std, size=clean_y.shape) | ||||
|         else: | ||||
|             raise ValueError("Unkonwn type: {:}".format(type(clean_y))) | ||||
|         return noise_y | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def _getitem(self, x): | ||||
|         raise NotImplementedError | ||||
|   | ||||
							
								
								
									
										63
									
								
								lib/utils/temp_sync.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										63
									
								
								lib/utils/temp_sync.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,63 @@ | ||||
| # To be deleted. | ||||
| import copy | ||||
| import torch | ||||
|  | ||||
| from xlayers.super_core import SuperSequential, SuperMLPv1 | ||||
| from xlayers.super_core import SuperSimpleNorm | ||||
| from xlayers.super_core import SuperLinear | ||||
|  | ||||
|  | ||||
| def optimize_fn(xs, ys, device="cpu", max_iter=2000, max_lr=0.1): | ||||
|     xs = torch.FloatTensor(xs).view(-1, 1).to(device) | ||||
|     ys = torch.FloatTensor(ys).view(-1, 1).to(device) | ||||
|  | ||||
|     model = SuperSequential( | ||||
|         SuperSimpleNorm(xs.mean().item(), xs.std().item()), | ||||
|         SuperLinear(1, 200), | ||||
|         torch.nn.LeakyReLU(), | ||||
|         SuperLinear(200, 100), | ||||
|         torch.nn.LeakyReLU(), | ||||
|         SuperLinear(100, 1), | ||||
|     ).to(device) | ||||
|     model.train() | ||||
|     optimizer = torch.optim.Adam( | ||||
|         model.parameters(), lr=max_lr, amsgrad=True | ||||
|     ) | ||||
|     loss_func = torch.nn.MSELoss() | ||||
|     lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | ||||
|         optimizer, | ||||
|         milestones=[ | ||||
|             int(max_iter * 0.25), | ||||
|             int(max_iter * 0.5), | ||||
|             int(max_iter * 0.75), | ||||
|         ], | ||||
|         gamma=0.3, | ||||
|     ) | ||||
|  | ||||
|     best_loss, best_param = None, None | ||||
|     for _iter in range(max_iter): | ||||
|         preds = model(xs) | ||||
|  | ||||
|         optimizer.zero_grad() | ||||
|         loss = loss_func(preds, ys) | ||||
|         loss.backward() | ||||
|         optimizer.step() | ||||
|         lr_scheduler.step() | ||||
|  | ||||
|         if best_loss is None or best_loss > loss.item(): | ||||
|             best_loss = loss.item() | ||||
|             best_param = copy.deepcopy(model.state_dict()) | ||||
|          | ||||
|         # print('loss={:}, best-loss={:}'.format(loss.item(), best_loss)) | ||||
|     model.load_state_dict(best_param) | ||||
|     return model, loss_func, best_loss | ||||
|  | ||||
|  | ||||
| def evaluate_fn(model, xs, ys, loss_fn, device="cpu"): | ||||
|     with torch.no_grad(): | ||||
|         inputs = torch.FloatTensor(xs).view(-1, 1).to(device) | ||||
|         ys = torch.FloatTensor(ys).view(-1, 1).to(device) | ||||
|         preds = model(inputs) | ||||
|         loss = loss_fn(preds, ys) | ||||
|         preds = preds.view(-1).cpu().numpy() | ||||
|     return preds, loss.item() | ||||
| @@ -91,6 +91,8 @@ class SuperSequential(SuperModule): | ||||
|     def abstract_search_space(self): | ||||
|         root_node = spaces.VirtualNode(id(self)) | ||||
|         for index, module in enumerate(self): | ||||
|             if not isinstance(module, SuperModule): | ||||
|                 continue | ||||
|             space = module.abstract_search_space | ||||
|             if not spaces.is_determined(space): | ||||
|                 root_node.append(str(index), space) | ||||
| @@ -98,9 +100,9 @@ class SuperSequential(SuperModule): | ||||
|  | ||||
|     def apply_candidate(self, abstract_child: spaces.VirtualNode): | ||||
|         super(SuperSequential, self).apply_candidate(abstract_child) | ||||
|         for index in range(len(self)): | ||||
|         for index, module in enumerate(self): | ||||
|             if str(index) in abstract_child: | ||||
|                 self.__getitem__(index).apply_candidate(abstract_child[str(index)]) | ||||
|                 module.apply_candidate(abstract_child[str(index)]) | ||||
|  | ||||
|     def forward_candidate(self, input): | ||||
|         return self.forward_raw(input) | ||||
|   | ||||
| @@ -9,6 +9,7 @@ from .super_module import SuperModule | ||||
| from .super_container import SuperSequential | ||||
| from .super_linear import SuperLinear | ||||
| from .super_linear import SuperMLPv1, SuperMLPv2 | ||||
| from .super_norm import SuperSimpleNorm | ||||
| from .super_norm import SuperLayerNorm1D | ||||
| from .super_attention import SuperAttention | ||||
| from .super_transformer import SuperTransformerEncoderLayer | ||||
|   | ||||
| @@ -3,6 +3,7 @@ | ||||
| ##################################################### | ||||
|  | ||||
| import abc | ||||
| import warnings | ||||
| from typing import Optional, Union, Callable | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| @@ -45,6 +46,17 @@ class SuperModule(abc.ABC, nn.Module): | ||||
|  | ||||
|         self.apply(_reset_super_run) | ||||
|  | ||||
|     def add_module(self, name: str, module: Optional[torch.nn.Module]) -> None: | ||||
|         if not isinstance(module, SuperModule): | ||||
|             warnings.warn( | ||||
|                 "Add {:} module, which is not SuperModule, into {:}".format( | ||||
|                     name, self.__class__.__name__ | ||||
|                 ) | ||||
|                 + "\n" | ||||
|                 + "It may cause some functions invalid." | ||||
|             ) | ||||
|         super(SuperModule, self).add_module(name, module) | ||||
|  | ||||
|     def apply_verbose(self, verbose): | ||||
|         def _reset_verbose(m): | ||||
|             if isinstance(m, SuperModule): | ||||
|   | ||||
| @@ -82,3 +82,43 @@ class SuperLayerNorm1D(SuperModule): | ||||
|                 elementwise_affine=self._elementwise_affine, | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class SuperSimpleNorm(SuperModule): | ||||
|     """Super simple normalization.""" | ||||
|  | ||||
|     def __init__(self, mean, std, inplace=False) -> None: | ||||
|         super(SuperSimpleNorm, self).__init__() | ||||
|         self._mean = mean | ||||
|         self._std = std | ||||
|         self._inplace = inplace | ||||
|  | ||||
|     @property | ||||
|     def abstract_search_space(self): | ||||
|         return spaces.VirtualNode(id(self)) | ||||
|  | ||||
|     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         # check inputs -> | ||||
|         return self.forward_raw(input) | ||||
|  | ||||
|     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         if not self._inplace: | ||||
|             tensor = input.clone() | ||||
|         else: | ||||
|             tensor = input | ||||
|         mean = torch.as_tensor(self._mean, dtype=tensor.dtype, device=tensor.device) | ||||
|         std = torch.as_tensor(self._std, dtype=tensor.dtype, device=tensor.device) | ||||
|         if (std == 0).any(): | ||||
|             raise ValueError( | ||||
|                 "std evaluated to zero after conversion to {}, leading to division by zero.".format( | ||||
|                     dtype | ||||
|                 ) | ||||
|             ) | ||||
|         while mean.ndim < tensor.ndim: | ||||
|             mean, std = torch.unsqueeze(mean, dim=0), torch.unsqueeze(std, dim=0) | ||||
|         return tensor.sub_(mean).div_(std) | ||||
|  | ||||
|     def extra_repr(self) -> str: | ||||
|         return "mean={mean}, std={mean}, inplace={inplace}".format( | ||||
|             mean=self._mean, std=self._std, inplace=self._inplace | ||||
|         ) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user