diff --git a/exps/GeMOSA/meta_model_ablation.py b/exps/GeMOSA/meta_model_ablation.py index c904b75..f2e856a 100644 --- a/exps/GeMOSA/meta_model_ablation.py +++ b/exps/GeMOSA/meta_model_ablation.py @@ -237,7 +237,7 @@ class MetaModel_TraditionalAtt(super_core.SuperModule): for iepoch in range(epochs): optimizer.zero_grad() time_embed = self.gen_time_embed(timestamp.view(1)) - match_loss = criterion(new_param, time_embed) + match_loss = F.l1_loss(new_param, time_embed) [container] = self.gen_model(new_param.view(1, -1)) y_hat = base_model.forward_with_container(x, container)