Add super/norm layers in xcore
This commit is contained in:
		
							
								
								
									
										212
									
								
								exps/LFNA/lfna-v1.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										212
									
								
								exps/LFNA/lfna-v1.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,212 @@ | |||||||
|  | ##################################################### | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||||
|  | ##################################################### | ||||||
|  | # python exps/LFNA/lfna-v1.py | ||||||
|  | ##################################################### | ||||||
|  | import sys, time, copy, torch, random, argparse | ||||||
|  | from tqdm import tqdm | ||||||
|  | from copy import deepcopy | ||||||
|  | from pathlib import Path | ||||||
|  |  | ||||||
|  | lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | ||||||
|  | if str(lib_dir) not in sys.path: | ||||||
|  |     sys.path.insert(0, str(lib_dir)) | ||||||
|  | from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint | ||||||
|  | from log_utils import time_string | ||||||
|  | from log_utils import AverageMeter, convert_secs2time | ||||||
|  |  | ||||||
|  | from utils import split_str2indexes | ||||||
|  |  | ||||||
|  | from procedures.advanced_main import basic_train_fn, basic_eval_fn | ||||||
|  | from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric | ||||||
|  | from datasets.synthetic_core import get_synthetic_env | ||||||
|  | from models.xcore import get_model | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Population: | ||||||
|  |     def __init__(self): | ||||||
|  |         self._time2model = dict() | ||||||
|  |  | ||||||
|  |     def append(self, timestamp, model): | ||||||
|  |         if timestamp in self._time2model: | ||||||
|  |             raise ValueError("This timestamp has been added.") | ||||||
|  |         self._time2model[timestamp] = model | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def main(args): | ||||||
|  |     prepare_seed(args.rand_seed) | ||||||
|  |     logger = prepare_logger(args) | ||||||
|  |  | ||||||
|  |     cache_path = (logger.path(None) / ".." / "env-info.pth").resolve() | ||||||
|  |     if cache_path.exists(): | ||||||
|  |         env_info = torch.load(cache_path) | ||||||
|  |     else: | ||||||
|  |         env_info = dict() | ||||||
|  |         dynamic_env = get_synthetic_env() | ||||||
|  |         env_info["total"] = len(dynamic_env) | ||||||
|  |         for idx, (timestamp, (_allx, _ally)) in enumerate(tqdm(dynamic_env)): | ||||||
|  |             env_info["{:}-timestamp".format(idx)] = timestamp | ||||||
|  |             env_info["{:}-x".format(idx)] = _allx | ||||||
|  |             env_info["{:}-y".format(idx)] = _ally | ||||||
|  |         env_info["dynamic_env"] = dynamic_env | ||||||
|  |         torch.save(env_info, cache_path) | ||||||
|  |  | ||||||
|  |     total_time = env_info["total"] | ||||||
|  |     for i in range(total_time): | ||||||
|  |         for xkey in ("timestamp", "x", "y"): | ||||||
|  |             nkey = "{:}-{:}".format(i, xkey) | ||||||
|  |             assert nkey in env_info, "{:} no in {:}".format(nkey, list(env_info.keys())) | ||||||
|  |     train_time_bar = total_time // 2 | ||||||
|  |     base_model = get_model( | ||||||
|  |         dict(model_type="simple_mlp"), | ||||||
|  |         act_cls="leaky_relu", | ||||||
|  |         norm_cls="simple_learn_norm", | ||||||
|  |         mean=0, | ||||||
|  |         std=1, | ||||||
|  |         input_dim=1, | ||||||
|  |         output_dim=1, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     w_container = base_model.named_parameters_buffers() | ||||||
|  |     print("There are {:} weights.".format(w_container.numel())) | ||||||
|  |  | ||||||
|  |     pool = Population() | ||||||
|  |     pool.append(0, w_container) | ||||||
|  |  | ||||||
|  |     # LFNA meta-training | ||||||
|  |     per_epoch_time, start_time = AverageMeter(), time.time() | ||||||
|  |     for iepoch in range(args.epochs): | ||||||
|  |         import pdb | ||||||
|  |  | ||||||
|  |         pdb.set_trace() | ||||||
|  |         print("-") | ||||||
|  |  | ||||||
|  |     for i, idx in enumerate(to_evaluate_indexes): | ||||||
|  |  | ||||||
|  |         need_time = "Time Left: {:}".format( | ||||||
|  |             convert_secs2time( | ||||||
|  |                 per_timestamp_time.avg * (len(to_evaluate_indexes) - i), True | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |         logger.log( | ||||||
|  |             "[{:}]".format(time_string()) | ||||||
|  |             + " [{:04d}/{:04d}][{:04d}]".format(i, len(to_evaluate_indexes), idx) | ||||||
|  |             + " " | ||||||
|  |             + need_time | ||||||
|  |         ) | ||||||
|  |         # train the same data | ||||||
|  |         assert idx != 0 | ||||||
|  |         historical_x = env_info["{:}-x".format(idx)] | ||||||
|  |         historical_y = env_info["{:}-y".format(idx)] | ||||||
|  |         # build model | ||||||
|  |         mean, std = historical_x.mean().item(), historical_x.std().item() | ||||||
|  |         model_kwargs = dict(input_dim=1, output_dim=1, mean=mean, std=std) | ||||||
|  |         model = get_model(dict(model_type="simple_mlp"), **model_kwargs) | ||||||
|  |         # build optimizer | ||||||
|  |         optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) | ||||||
|  |         criterion = torch.nn.MSELoss() | ||||||
|  |         lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | ||||||
|  |             optimizer, | ||||||
|  |             milestones=[ | ||||||
|  |                 int(args.epochs * 0.25), | ||||||
|  |                 int(args.epochs * 0.5), | ||||||
|  |                 int(args.epochs * 0.75), | ||||||
|  |             ], | ||||||
|  |             gamma=0.3, | ||||||
|  |         ) | ||||||
|  |         train_metric = MSEMetric() | ||||||
|  |         best_loss, best_param = None, None | ||||||
|  |         for _iepoch in range(args.epochs): | ||||||
|  |             preds = model(historical_x) | ||||||
|  |             optimizer.zero_grad() | ||||||
|  |             loss = criterion(preds, historical_y) | ||||||
|  |             loss.backward() | ||||||
|  |             optimizer.step() | ||||||
|  |             lr_scheduler.step() | ||||||
|  |             # save best | ||||||
|  |             if best_loss is None or best_loss > loss.item(): | ||||||
|  |                 best_loss = loss.item() | ||||||
|  |                 best_param = copy.deepcopy(model.state_dict()) | ||||||
|  |         model.load_state_dict(best_param) | ||||||
|  |         with torch.no_grad(): | ||||||
|  |             train_metric(preds, historical_y) | ||||||
|  |         train_results = train_metric.get_info() | ||||||
|  |  | ||||||
|  |         metric = ComposeMetric(MSEMetric(), SaveMetric()) | ||||||
|  |         eval_dataset = torch.utils.data.TensorDataset( | ||||||
|  |             env_info["{:}-x".format(idx)], env_info["{:}-y".format(idx)] | ||||||
|  |         ) | ||||||
|  |         eval_loader = torch.utils.data.DataLoader( | ||||||
|  |             eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0 | ||||||
|  |         ) | ||||||
|  |         results = basic_eval_fn(eval_loader, model, metric, logger) | ||||||
|  |         log_str = ( | ||||||
|  |             "[{:}]".format(time_string()) | ||||||
|  |             + " [{:04d}/{:04d}]".format(idx, env_info["total"]) | ||||||
|  |             + " train-mse: {:.5f}, eval-mse: {:.5f}".format( | ||||||
|  |                 train_results["mse"], results["mse"] | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |         logger.log(log_str) | ||||||
|  |  | ||||||
|  |         save_path = logger.path(None) / "{:04d}-{:04d}.pth".format( | ||||||
|  |             idx, env_info["total"] | ||||||
|  |         ) | ||||||
|  |         save_checkpoint( | ||||||
|  |             { | ||||||
|  |                 "model_state_dict": model.state_dict(), | ||||||
|  |                 "model": model, | ||||||
|  |                 "index": idx, | ||||||
|  |                 "timestamp": env_info["{:}-timestamp".format(idx)], | ||||||
|  |             }, | ||||||
|  |             save_path, | ||||||
|  |             logger, | ||||||
|  |         ) | ||||||
|  |         logger.log("") | ||||||
|  |  | ||||||
|  |         per_timestamp_time.update(time.time() - start_time) | ||||||
|  |         start_time = time.time() | ||||||
|  |  | ||||||
|  |     logger.log("-" * 200 + "\n") | ||||||
|  |     logger.close() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     parser = argparse.ArgumentParser("Use the data in the past.") | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--save_dir", | ||||||
|  |         type=str, | ||||||
|  |         default="./outputs/lfna-synthetic/lfna-v1", | ||||||
|  |         help="The checkpoint directory.", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--init_lr", | ||||||
|  |         type=float, | ||||||
|  |         default=0.1, | ||||||
|  |         help="The initial learning rate for the optimizer (default is Adam)", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--batch_size", | ||||||
|  |         type=int, | ||||||
|  |         default=512, | ||||||
|  |         help="The batch size", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--epochs", | ||||||
|  |         type=int, | ||||||
|  |         default=1000, | ||||||
|  |         help="The total number of epochs.", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--workers", | ||||||
|  |         type=int, | ||||||
|  |         default=4, | ||||||
|  |         help="The number of data loading workers (default: 4)", | ||||||
|  |     ) | ||||||
|  |     # Random Seed | ||||||
|  |     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | ||||||
|  |     args = parser.parse_args() | ||||||
|  |     if args.rand_seed is None or args.rand_seed < 0: | ||||||
|  |         args.rand_seed = random.randint(1, 100000) | ||||||
|  |     assert args.save_dir is not None, "The save dir argument can not be None" | ||||||
|  |     main(args) | ||||||
| @@ -10,21 +10,26 @@ __all__ = ["get_model"] | |||||||
|  |  | ||||||
|  |  | ||||||
| from xlayers.super_core import SuperSequential | from xlayers.super_core import SuperSequential | ||||||
| from xlayers.super_core import SuperSimpleNorm |  | ||||||
| from xlayers.super_core import SuperLeakyReLU |  | ||||||
| from xlayers.super_core import SuperLinear | from xlayers.super_core import SuperLinear | ||||||
|  | from xlayers.super_core import super_name2norm | ||||||
|  | from xlayers.super_core import super_name2activation | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_model(config: Dict[Text, Any], **kwargs): | def get_model(config: Dict[Text, Any], **kwargs): | ||||||
|     model_type = config.get("model_type", "simple_mlp") |     model_type = config.get("model_type", "simple_mlp") | ||||||
|     if model_type == "simple_mlp": |     if model_type == "simple_mlp": | ||||||
|  |         act_cls = super_name2activation[kwargs["act_cls"]] | ||||||
|  |         norm_cls = super_name2norm[kwargs["norm_cls"]] | ||||||
|  |         mean, std = kwargs.get("mean", None), kwargs.get("std", None) | ||||||
|  |         hidden_dim1 = kwargs.get("hidden_dim1", 200) | ||||||
|  |         hidden_dim2 = kwargs.get("hidden_dim2", 100) | ||||||
|         model = SuperSequential( |         model = SuperSequential( | ||||||
|             SuperSimpleNorm(kwargs["mean"], kwargs["std"]), |             norm_cls(mean=mean, std=std), | ||||||
|             SuperLinear(kwargs["input_dim"], 200), |             SuperLinear(kwargs["input_dim"], hidden_dim1), | ||||||
|             SuperLeakyReLU(), |             act_cls(), | ||||||
|             SuperLinear(200, 100), |             SuperLinear(hidden_dim1, hidden_dim2), | ||||||
|             SuperLeakyReLU(), |             act_cls(), | ||||||
|             SuperLinear(100, kwargs["output_dim"]), |             SuperLinear(hidden_dim2, kwargs["output_dim"]), | ||||||
|         ) |         ) | ||||||
|     else: |     else: | ||||||
|         raise TypeError("Unkonwn model type: {:}".format(model_type)) |         raise TypeError("Unkonwn model type: {:}".format(model_type)) | ||||||
|   | |||||||
| @@ -9,13 +9,27 @@ from .super_module import SuperModule | |||||||
| from .super_container import SuperSequential | from .super_container import SuperSequential | ||||||
| from .super_linear import SuperLinear | from .super_linear import SuperLinear | ||||||
| from .super_linear import SuperMLPv1, SuperMLPv2 | from .super_linear import SuperMLPv1, SuperMLPv2 | ||||||
|  |  | ||||||
| from .super_norm import SuperSimpleNorm | from .super_norm import SuperSimpleNorm | ||||||
| from .super_norm import SuperLayerNorm1D | from .super_norm import SuperLayerNorm1D | ||||||
|  | from .super_norm import SuperSimpleLearnableNorm | ||||||
|  | from .super_norm import SuperIdentity | ||||||
|  |  | ||||||
|  | super_name2norm = { | ||||||
|  |     "simple_norm": SuperSimpleNorm, | ||||||
|  |     "simple_learn_norm": SuperSimpleLearnableNorm, | ||||||
|  |     "layer_norm_1d": SuperLayerNorm1D, | ||||||
|  |     "identity": SuperIdentity, | ||||||
|  | } | ||||||
|  |  | ||||||
| from .super_attention import SuperAttention | from .super_attention import SuperAttention | ||||||
| from .super_transformer import SuperTransformerEncoderLayer | from .super_transformer import SuperTransformerEncoderLayer | ||||||
|  |  | ||||||
| from .super_activations import SuperReLU | from .super_activations import SuperReLU | ||||||
| from .super_activations import SuperLeakyReLU | from .super_activations import SuperLeakyReLU | ||||||
|  |  | ||||||
|  | super_name2activation = {"relu": SuperReLU, "leaky_relu": SuperLeakyReLU} | ||||||
|  |  | ||||||
|  |  | ||||||
| from .super_trade_stem import SuperAlphaEBDv1 | from .super_trade_stem import SuperAlphaEBDv1 | ||||||
| from .super_positional_embedding import SuperPositionalEncoder | from .super_positional_embedding import SuperPositionalEncoder | ||||||
|   | |||||||
| @@ -30,6 +30,45 @@ class SuperRunMode(Enum): | |||||||
|     Default = "fullmodel" |     Default = "fullmodel" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TensorContainer: | ||||||
|  |     """A class to maintain both parameters and buffers for a model.""" | ||||||
|  |  | ||||||
|  |     def __init__(self): | ||||||
|  |         self._names = [] | ||||||
|  |         self._tensors = [] | ||||||
|  |         self._param_or_buffers = [] | ||||||
|  |         self._name2index = dict() | ||||||
|  |  | ||||||
|  |     def append(self, name, tensor, param_or_buffer): | ||||||
|  |         if not isinstance(tensor, torch.Tensor): | ||||||
|  |             raise TypeError( | ||||||
|  |                 "The input tensor must be torch.Tensor instead of {:}".format( | ||||||
|  |                     type(tensor) | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  |         self._names.append(name) | ||||||
|  |         self._tensors.append(tensor) | ||||||
|  |         self._param_or_buffers.append(param_or_buffer) | ||||||
|  |         assert name not in self._name2index, "The [{:}] has already been added.".format( | ||||||
|  |             name | ||||||
|  |         ) | ||||||
|  |         self._name2index[name] = len(self._names) - 1 | ||||||
|  |  | ||||||
|  |     def numel(self): | ||||||
|  |         total = 0 | ||||||
|  |         for tensor in self._tensors: | ||||||
|  |             total += tensor.numel() | ||||||
|  |         return total | ||||||
|  |  | ||||||
|  |     def __len__(self): | ||||||
|  |         return len(self._names) | ||||||
|  |  | ||||||
|  |     def __repr__(self): | ||||||
|  |         return "{name}({num} tensors)".format( | ||||||
|  |             name=self.__class__.__name__, num=len(self) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
| class SuperModule(abc.ABC, nn.Module): | class SuperModule(abc.ABC, nn.Module): | ||||||
|     """This class equips the nn.Module class with the ability to apply AutoDL.""" |     """This class equips the nn.Module class with the ability to apply AutoDL.""" | ||||||
|  |  | ||||||
| @@ -71,6 +110,14 @@ class SuperModule(abc.ABC, nn.Module): | |||||||
|             ) |             ) | ||||||
|         self._abstract_child = abstract_child |         self._abstract_child = abstract_child | ||||||
|  |  | ||||||
|  |     def named_parameters_buffers(self): | ||||||
|  |         container = TensorContainer() | ||||||
|  |         for name, param in self.named_parameters(): | ||||||
|  |             container.append(name, param, True) | ||||||
|  |         for name, buf in self.named_buffers(): | ||||||
|  |             container.append(name, buf, False) | ||||||
|  |         return container | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def abstract_search_space(self): |     def abstract_search_space(self): | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|   | |||||||
| @@ -89,8 +89,8 @@ class SuperSimpleNorm(SuperModule): | |||||||
|  |  | ||||||
|     def __init__(self, mean, std, inplace=False) -> None: |     def __init__(self, mean, std, inplace=False) -> None: | ||||||
|         super(SuperSimpleNorm, self).__init__() |         super(SuperSimpleNorm, self).__init__() | ||||||
|         self._mean = mean |         self.register_buffer("_mean", torch.tensor(mean, dtype=torch.float)) | ||||||
|         self._std = std |         self.register_buffer("_std", torch.tensor(std, dtype=torch.float)) | ||||||
|         self._inplace = inplace |         self._inplace = inplace | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
| @@ -111,7 +111,7 @@ class SuperSimpleNorm(SuperModule): | |||||||
|         if (std == 0).any(): |         if (std == 0).any(): | ||||||
|             raise ValueError( |             raise ValueError( | ||||||
|                 "std evaluated to zero after conversion to {}, leading to division by zero.".format( |                 "std evaluated to zero after conversion to {}, leading to division by zero.".format( | ||||||
|                     dtype |                     tensor.dtype | ||||||
|                 ) |                 ) | ||||||
|             ) |             ) | ||||||
|         while mean.ndim < tensor.ndim: |         while mean.ndim < tensor.ndim: | ||||||
| @@ -119,6 +119,75 @@ class SuperSimpleNorm(SuperModule): | |||||||
|         return tensor.sub_(mean).div_(std) |         return tensor.sub_(mean).div_(std) | ||||||
|  |  | ||||||
|     def extra_repr(self) -> str: |     def extra_repr(self) -> str: | ||||||
|         return "mean={mean}, std={mean}, inplace={inplace}".format( |         return "mean={mean}, std={std}, inplace={inplace}".format( | ||||||
|             mean=self._mean, std=self._std, inplace=self._inplace |             mean=self._mean.item(), std=self._std.item(), inplace=self._inplace | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SuperSimpleLearnableNorm(SuperModule): | ||||||
|  |     """Super simple normalization.""" | ||||||
|  |  | ||||||
|  |     def __init__(self, mean=0, std=1, eps=1e-6, inplace=False) -> None: | ||||||
|  |         super(SuperSimpleLearnableNorm, self).__init__() | ||||||
|  |         self.register_parameter( | ||||||
|  |             "_mean", nn.Parameter(torch.tensor(mean, dtype=torch.float)) | ||||||
|  |         ) | ||||||
|  |         self.register_parameter( | ||||||
|  |             "_std", nn.Parameter(torch.tensor(std, dtype=torch.float)) | ||||||
|  |         ) | ||||||
|  |         self._eps = eps | ||||||
|  |         self._inplace = inplace | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def abstract_search_space(self): | ||||||
|  |         return spaces.VirtualNode(id(self)) | ||||||
|  |  | ||||||
|  |     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|  |         # check inputs -> | ||||||
|  |         return self.forward_raw(input) | ||||||
|  |  | ||||||
|  |     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|  |         if not self._inplace: | ||||||
|  |             tensor = input.clone() | ||||||
|  |         else: | ||||||
|  |             tensor = input | ||||||
|  |         mean, std = ( | ||||||
|  |             self._mean.to(tensor.device), | ||||||
|  |             torch.abs(self._std.to(tensor.device)) + self._eps, | ||||||
|  |         ) | ||||||
|  |         if (std == 0).any(): | ||||||
|  |             raise ValueError("std leads to division by zero.") | ||||||
|  |         while mean.ndim < tensor.ndim: | ||||||
|  |             mean, std = torch.unsqueeze(mean, dim=0), torch.unsqueeze(std, dim=0) | ||||||
|  |         return tensor.sub_(mean).div_(std) | ||||||
|  |  | ||||||
|  |     def extra_repr(self) -> str: | ||||||
|  |         return "mean={mean}, std={std}, inplace={inplace}".format( | ||||||
|  |             mean=self._mean.item(), std=self._std.item(), inplace=self._inplace | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SuperIdentity(SuperModule): | ||||||
|  |     """Super identity mapping layer.""" | ||||||
|  |  | ||||||
|  |     def __init__(self, inplace=False, **kwargs) -> None: | ||||||
|  |         super(SuperIdentity, self).__init__() | ||||||
|  |         self._inplace = inplace | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def abstract_search_space(self): | ||||||
|  |         return spaces.VirtualNode(id(self)) | ||||||
|  |  | ||||||
|  |     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|  |         # check inputs -> | ||||||
|  |         return self.forward_raw(input) | ||||||
|  |  | ||||||
|  |     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|  |         if not self._inplace: | ||||||
|  |             tensor = input.clone() | ||||||
|  |         else: | ||||||
|  |             tensor = input | ||||||
|  |         return tensor | ||||||
|  |  | ||||||
|  |     def extra_repr(self) -> str: | ||||||
|  |         return "inplace={inplace}".format(inplace=self._inplace) | ||||||
|   | |||||||
| @@ -51,3 +51,35 @@ class TestSuperSimpleNorm(unittest.TestCase): | |||||||
|         output_shape = (20, abstract_child["1"]["_out_features"].value) |         output_shape = (20, abstract_child["1"]["_out_features"].value) | ||||||
|         outputs = model(inputs) |         outputs = model(inputs) | ||||||
|         self.assertEqual(tuple(outputs.shape), output_shape) |         self.assertEqual(tuple(outputs.shape), output_shape) | ||||||
|  |  | ||||||
|  |     def test_super_simple_learn_norm(self): | ||||||
|  |         out_features = spaces.Categorical(12, 24, 36) | ||||||
|  |         bias = spaces.Categorical(True, False) | ||||||
|  |         model = super_core.SuperSequential( | ||||||
|  |             super_core.SuperSimpleLearnableNorm(), | ||||||
|  |             super_core.SuperIdentity(), | ||||||
|  |             super_core.SuperLinear(10, out_features, bias=bias), | ||||||
|  |         ) | ||||||
|  |         print("The simple super module is:\n{:}".format(model)) | ||||||
|  |         model.apply_verbose(True) | ||||||
|  |  | ||||||
|  |         print(model.super_run_type) | ||||||
|  |         self.assertTrue(model[1].bias) | ||||||
|  |  | ||||||
|  |         inputs = torch.rand(20, 10) | ||||||
|  |         print("Input shape: {:}".format(inputs.shape)) | ||||||
|  |         outputs = model(inputs) | ||||||
|  |         self.assertEqual(tuple(outputs.shape), (20, 36)) | ||||||
|  |  | ||||||
|  |         abstract_space = model.abstract_search_space | ||||||
|  |         abstract_space.clean_last() | ||||||
|  |         abstract_child = abstract_space.random() | ||||||
|  |         print("The abstract searc space:\n{:}".format(abstract_space)) | ||||||
|  |         print("The abstract child program:\n{:}".format(abstract_child)) | ||||||
|  |  | ||||||
|  |         model.set_super_run_type(super_core.SuperRunMode.Candidate) | ||||||
|  |         model.apply_candidate(abstract_child) | ||||||
|  |  | ||||||
|  |         output_shape = (20, abstract_child["1"]["_out_features"].value) | ||||||
|  |         outputs = model(inputs) | ||||||
|  |         self.assertEqual(tuple(outputs.shape), output_shape) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user