diff --git a/exps/GeMOSA/lfna_meta_model.py b/exps/GeMOSA/lfna_meta_model.py index 9c69c83..c36e88b 100644 --- a/exps/GeMOSA/lfna_meta_model.py +++ b/exps/GeMOSA/lfna_meta_model.py @@ -46,9 +46,9 @@ class MetaModelV1(super_core.SuperModule): ) self.register_buffer("_meta_timestamps", torch.Tensor(meta_timestamps)) # register a time difference buffer - time_interval = [-i * self._interval for i in range(self._seq_length)] - time_interval.reverse() - self.register_buffer("_time_interval", torch.Tensor(time_interval)) + # time_interval = [-i * self._interval for i in range(self._seq_length)] + # time_interval.reverse() + # self.register_buffer("_time_interval", torch.Tensor(time_interval)) self._time_embed_dim = time_dim self._append_meta_embed = dict(fixed=None, learnt=None) self._append_meta_timestamps = dict(fixed=None, learnt=None) @@ -161,7 +161,8 @@ class MetaModelV1(super_core.SuperModule): def _obtain_time_embed(self, timestamps): # timestamps is a batch of sequence of timestamps - batch, seq = timestamps.shape + # batch, seq = timestamps.shape + timestamps = timestamps.view(-1, 1) meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed timestamp_v_embed = meta_embeds.unsqueeze(dim=0) timestamp_qk_att_embed = self._tscalar_embed( @@ -185,9 +186,9 @@ class MetaModelV1(super_core.SuperModule): def forward_raw(self, timestamps, time_embeds, tembed_only=False): if time_embeds is None: - time_seq = timestamps.view(-1, 1) + self._time_interval.view(1, -1) - B, S = time_seq.shape - time_embeds = self._obtain_time_embed(time_seq) + # time_seq = timestamps.view(-1, 1) + self._time_interval.view(1, -1) + [B] = timestamps.shape + time_embeds = self._obtain_time_embed(timestamps) else: # use the hyper-net only time_seq = None B, _ = time_embeds.shape