Fix bugs
This commit is contained in:
		| @@ -46,9 +46,9 @@ class MetaModelV1(super_core.SuperModule): | ||||
|         ) | ||||
|         self.register_buffer("_meta_timestamps", torch.Tensor(meta_timestamps)) | ||||
|         # register a time difference buffer | ||||
|         time_interval = [-i * self._interval for i in range(self._seq_length)] | ||||
|         time_interval.reverse() | ||||
|         self.register_buffer("_time_interval", torch.Tensor(time_interval)) | ||||
|         # time_interval = [-i * self._interval for i in range(self._seq_length)] | ||||
|         # time_interval.reverse() | ||||
|         # self.register_buffer("_time_interval", torch.Tensor(time_interval)) | ||||
|         self._time_embed_dim = time_dim | ||||
|         self._append_meta_embed = dict(fixed=None, learnt=None) | ||||
|         self._append_meta_timestamps = dict(fixed=None, learnt=None) | ||||
| @@ -161,7 +161,8 @@ class MetaModelV1(super_core.SuperModule): | ||||
|  | ||||
|     def _obtain_time_embed(self, timestamps): | ||||
|         # timestamps is a batch of sequence of timestamps | ||||
|         batch, seq = timestamps.shape | ||||
|         # batch, seq = timestamps.shape | ||||
|         timestamps = timestamps.view(-1, 1) | ||||
|         meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed | ||||
|         timestamp_v_embed = meta_embeds.unsqueeze(dim=0) | ||||
|         timestamp_qk_att_embed = self._tscalar_embed( | ||||
| @@ -185,9 +186,9 @@ class MetaModelV1(super_core.SuperModule): | ||||
|  | ||||
|     def forward_raw(self, timestamps, time_embeds, tembed_only=False): | ||||
|         if time_embeds is None: | ||||
|             time_seq = timestamps.view(-1, 1) + self._time_interval.view(1, -1) | ||||
|             B, S = time_seq.shape | ||||
|             time_embeds = self._obtain_time_embed(time_seq) | ||||
|             # time_seq = timestamps.view(-1, 1) + self._time_interval.view(1, -1) | ||||
|             [B] = timestamps.shape | ||||
|             time_embeds = self._obtain_time_embed(timestamps) | ||||
|         else:  # use the hyper-net only | ||||
|             time_seq = None | ||||
|             B, _ = time_embeds.shape | ||||
|   | ||||
		Reference in New Issue
	
	Block a user