Update super-activation layers
This commit is contained in:
		| @@ -25,6 +25,7 @@ from xlayers import super_core | ||||
| 
 | ||||
| 
 | ||||
| from lfna_utils import lfna_setup, train_model, TimeData | ||||
| from lfna_models import HyperNet | ||||
| 
 | ||||
| 
 | ||||
| class LFNAmlp: | ||||
| @@ -77,17 +78,40 @@ def main(args): | ||||
|             nkey = "{:}-{:}".format(i, xkey) | ||||
|             assert nkey in env_info, "{:} no in {:}".format(nkey, list(env_info.keys())) | ||||
|     train_time_bar = total_time // 2 | ||||
|     network = get_model(dict(model_type="simple_mlp"), **model_kwargs) | ||||
| 
 | ||||
|     criterion = torch.nn.MSELoss() | ||||
|     logger.log("There are {:} weights.".format(network.get_w_container().numel())) | ||||
|     logger.log("There are {:} weights.".format(model.get_w_container().numel())) | ||||
| 
 | ||||
|     adaptor = LFNAmlp(args.meta_seq, (200, 200), "leaky_relu", criterion) | ||||
| 
 | ||||
|     # pre-train the model | ||||
|     init_dataset = TimeData(0, env_info["0-x"], env_info["0-y"]) | ||||
|     init_loss = train_model(network, init_dataset, args.init_lr, args.epochs) | ||||
|     dataset = init_dataset = TimeData(0, env_info["0-x"], env_info["0-y"]) | ||||
| 
 | ||||
|     shape_container = model.get_w_container().to_shape_container() | ||||
|     hypernet = HyperNet(shape_container, 16) | ||||
| 
 | ||||
|     optimizer = torch.optim.Adam(hypernet.parameters(), lr=args.init_lr, amsgrad=True) | ||||
| 
 | ||||
|     best_loss, best_param = None, None | ||||
|     for _iepoch in range(args.epochs): | ||||
|         container = hypernet(None) | ||||
| 
 | ||||
|         preds = model.forward_with_container(dataset.x, container) | ||||
|         optimizer.zero_grad() | ||||
|         loss = criterion(preds, dataset.y) | ||||
|         loss.backward() | ||||
|         optimizer.step() | ||||
|         # save best | ||||
|         if best_loss is None or best_loss > loss.item(): | ||||
|             best_loss = loss.item() | ||||
|             best_param = copy.deepcopy(model.state_dict()) | ||||
|     print("hyper-net : best={:.4f}".format(best_loss)) | ||||
| 
 | ||||
|     init_loss = train_model(model, init_dataset, args.init_lr, args.epochs) | ||||
|     logger.log("The pre-training loss is {:.4f}".format(init_loss)) | ||||
|     import pdb | ||||
| 
 | ||||
|     pdb.set_trace() | ||||
| 
 | ||||
|     all_past_containers = [] | ||||
|     ground_truth_path = ( | ||||
							
								
								
									
										50
									
								
								exps/LFNA/backup/lfna_models.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								exps/LFNA/backup/lfna_models.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,50 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||
| ##################################################### | ||||
| import copy | ||||
| import torch | ||||
|  | ||||
| from xlayers import super_core | ||||
| from xlayers import trunc_normal_ | ||||
| from models.xcore import get_model | ||||
|  | ||||
|  | ||||
| class HyperNet(super_core.SuperModule): | ||||
|     def __init__(self, shape_container, input_embeding, return_container=True): | ||||
|         super(HyperNet, self).__init__() | ||||
|         self._shape_container = shape_container | ||||
|         self._num_layers = len(shape_container) | ||||
|         self._numel_per_layer = [] | ||||
|         for ilayer in range(self._num_layers): | ||||
|             self._numel_per_layer.append(shape_container[ilayer].numel()) | ||||
|  | ||||
|         self.register_parameter( | ||||
|             "_super_layer_embed", | ||||
|             torch.nn.Parameter(torch.Tensor(self._num_layers, input_embeding)), | ||||
|         ) | ||||
|         trunc_normal_(self._super_layer_embed, std=0.02) | ||||
|  | ||||
|         model_kwargs = dict( | ||||
|             input_dim=input_embeding, | ||||
|             output_dim=max(self._numel_per_layer), | ||||
|             hidden_dim=input_embeding * 4, | ||||
|             act_cls="sigmoid", | ||||
|             norm_cls="identity", | ||||
|         ) | ||||
|         self._generator = get_model(dict(model_type="simple_mlp"), **model_kwargs) | ||||
|         self._return_container = return_container | ||||
|         print("generator: {:}".format(self._generator)) | ||||
|  | ||||
|     def forward_raw(self, input): | ||||
|         weights = self._generator(self._super_layer_embed) | ||||
|         if self._return_container: | ||||
|             weights = torch.split(weights, 1) | ||||
|             return self._shape_container.translate(weights) | ||||
|         else: | ||||
|             return weights | ||||
|  | ||||
|     def forward_candidate(self, input): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def extra_repr(self) -> str: | ||||
|         return "(_super_layer_embed): {:}".format(list(self._super_layer_embed.shape)) | ||||
		Reference in New Issue
	
	Block a user