Remove unnecessary model in GMOA
This commit is contained in:
		| @@ -337,11 +337,11 @@ if __name__ == "__main__": | |||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--refine_lr", |         "--refine_lr", | ||||||
|         type=float, |         type=float, | ||||||
|         default=0.002, |         default=0.001, | ||||||
|         help="The learning rate for the optimizer, during refine", |         help="The learning rate for the optimizer, during refine", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--refine_epochs", type=int, default=100, help="The final refine #epochs." |         "--refine_epochs", type=int, default=150, help="The final refine #epochs." | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--early_stop_thresh", |         "--early_stop_thresh", | ||||||
|   | |||||||
| @@ -19,7 +19,6 @@ class MetaModelV1(super_core.SuperModule): | |||||||
|         layer_dim, |         layer_dim, | ||||||
|         time_dim, |         time_dim, | ||||||
|         meta_timestamps, |         meta_timestamps, | ||||||
|         mha_depth: int = 2, |  | ||||||
|         dropout: float = 0.1, |         dropout: float = 0.1, | ||||||
|         seq_length: int = 10, |         seq_length: int = 10, | ||||||
|         interval: float = None, |         interval: float = None, | ||||||
| @@ -69,22 +68,6 @@ class MetaModelV1(super_core.SuperModule): | |||||||
|             attn_drop=None, |             attn_drop=None, | ||||||
|             proj_drop=dropout, |             proj_drop=dropout, | ||||||
|         ) |         ) | ||||||
|         layers = [] |  | ||||||
|         for ilayer in range(mha_depth): |  | ||||||
|             layers.append( |  | ||||||
|                 super_core.SuperTransformerEncoderLayer( |  | ||||||
|                     time_dim * 2, |  | ||||||
|                     4, |  | ||||||
|                     True, |  | ||||||
|                     4, |  | ||||||
|                     dropout, |  | ||||||
|                     norm_affine=False, |  | ||||||
|                     order=super_core.LayerOrder.PostNorm, |  | ||||||
|                     use_mask=True, |  | ||||||
|                 ) |  | ||||||
|             ) |  | ||||||
|         layers.append(super_core.SuperLinear(time_dim * 2, time_dim)) |  | ||||||
|         self._meta_corrector = super_core.SuperSequential(*layers) |  | ||||||
|  |  | ||||||
|         model_kwargs = dict( |         model_kwargs = dict( | ||||||
|             config=dict(model_type="dual_norm_mlp"), |             config=dict(model_type="dual_norm_mlp"), | ||||||
| @@ -103,13 +86,12 @@ class MetaModelV1(super_core.SuperModule): | |||||||
|             std=0.02, |             std=0.02, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def get_parameters(self, time_embed, meta_corrector, generator): |     def get_parameters(self, time_embed, attention, generator): | ||||||
|         parameters = [] |         parameters = [] | ||||||
|         if time_embed: |         if time_embed: | ||||||
|             parameters.append(self._super_meta_embed) |             parameters.append(self._super_meta_embed) | ||||||
|         if meta_corrector: |         if attention: | ||||||
|             parameters.extend(list(self._trans_att.parameters())) |             parameters.extend(list(self._trans_att.parameters())) | ||||||
|             parameters.extend(list(self._meta_corrector.parameters())) |  | ||||||
|         if generator: |         if generator: | ||||||
|             parameters.append(self._super_layer_embed) |             parameters.append(self._super_layer_embed) | ||||||
|             parameters.extend(list(self._generator.parameters())) |             parameters.extend(list(self._generator.parameters())) | ||||||
| @@ -199,13 +181,7 @@ class MetaModelV1(super_core.SuperModule): | |||||||
|             timestamp_v_embed, |             timestamp_v_embed, | ||||||
|             mask, |             mask, | ||||||
|         ) |         ) | ||||||
|         relative_timestamps = timestamps - timestamps[:, :1] |         return timestamp_embeds | ||||||
|         relative_pos_embeds = self._tscalar_embed(relative_timestamps) |  | ||||||
|         init_timestamp_embeds = torch.cat( |  | ||||||
|             (timestamp_embeds, relative_pos_embeds), dim=-1 |  | ||||||
|         ) |  | ||||||
|         corrected_embeds = self._meta_corrector(init_timestamp_embeds) |  | ||||||
|         return corrected_embeds |  | ||||||
|  |  | ||||||
|     def forward_raw(self, timestamps, time_embeds, get_seq_last): |     def forward_raw(self, timestamps, time_embeds, get_seq_last): | ||||||
|         if time_embeds is None: |         if time_embeds is None: | ||||||
| @@ -264,9 +240,8 @@ class MetaModelV1(super_core.SuperModule): | |||||||
|         x, y = x.to(self._meta_timestamps.device), y.to(self._meta_timestamps.device) |         x, y = x.to(self._meta_timestamps.device), y.to(self._meta_timestamps.device) | ||||||
|         with torch.set_grad_enabled(True): |         with torch.set_grad_enabled(True): | ||||||
|             new_param = self.create_meta_embed() |             new_param = self.create_meta_embed() | ||||||
|             optimizer = torch.optim.Adam( |  | ||||||
|                 [new_param], lr=lr, weight_decay=1e-5, amsgrad=True |             optimizer = torch.optim.Adam([new_param], lr=lr, weight_decay=1e-5, amsgrad=True) | ||||||
|             ) |  | ||||||
|             timestamp = torch.Tensor([timestamp]).to(new_param.device) |             timestamp = torch.Tensor([timestamp]).to(new_param.device) | ||||||
|             self.replace_append_learnt(timestamp, new_param) |             self.replace_append_learnt(timestamp, new_param) | ||||||
|             self.train() |             self.train() | ||||||
| @@ -297,7 +272,7 @@ class MetaModelV1(super_core.SuperModule): | |||||||
|         with torch.no_grad(): |         with torch.no_grad(): | ||||||
|             self.replace_append_learnt(None, None) |             self.replace_append_learnt(None, None) | ||||||
|             self.append_fixed(timestamp, best_new_param) |             self.append_fixed(timestamp, best_new_param) | ||||||
|         return True, meta_loss.item() |         return True, best_loss | ||||||
|  |  | ||||||
|     def extra_repr(self) -> str: |     def extra_repr(self) -> str: | ||||||
|         return "(_super_layer_embed): {:}, (_super_meta_embed): {:}, (_meta_timestamps): {:}".format( |         return "(_super_layer_embed): {:}, (_super_meta_embed): {:}, (_meta_timestamps): {:}".format( | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user