Update LFNA
This commit is contained in:
		| @@ -93,6 +93,67 @@ def epoch_evaluate(loader, meta_model, base_model, criterion, device, logger): | |||||||
|     return loss_meter |     return loss_meter | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def pretrain(base_model, meta_model, criterion, xenv, args, logger): | ||||||
|  |     optimizer = torch.optim.Adam( | ||||||
|  |         meta_model.parameters(), | ||||||
|  |         lr=args.lr, | ||||||
|  |         weight_decay=args.weight_decay, | ||||||
|  |         amsgrad=True, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     meta_model.set_best_dir(logger.path(None) / "checkpoint-pretrain") | ||||||
|  |     for iepoch in range(args.epochs): | ||||||
|  |         total_meta_losses, total_match_losses = [], [] | ||||||
|  |         for ibatch in range(args.meta_batch): | ||||||
|  |             rand_index = random.randint(0, meta_model.meta_length - xenv.seq_length - 1) | ||||||
|  |             timestamps = meta_model.meta_timestamps[ | ||||||
|  |                 rand_index : rand_index + xenv.seq_length | ||||||
|  |             ] | ||||||
|  |  | ||||||
|  |             seq_timestamps, (seq_inputs, seq_targets) = xenv.seq_call(timestamps) | ||||||
|  |             [seq_containers], time_embeds = meta_model( | ||||||
|  |                 torch.unsqueeze(timestamps, dim=0) | ||||||
|  |             ) | ||||||
|  |             # performance loss | ||||||
|  |             losses = [] | ||||||
|  |             seq_inputs, seq_targets = seq_inputs.to(args.device), seq_targets.to( | ||||||
|  |                 args.device | ||||||
|  |             ) | ||||||
|  |             for container, inputs, targets in zip( | ||||||
|  |                 seq_containers, seq_inputs, seq_targets | ||||||
|  |             ): | ||||||
|  |                 predictions = base_model.forward_with_container(inputs, container) | ||||||
|  |                 loss = criterion(predictions, targets) | ||||||
|  |                 losses.append(loss) | ||||||
|  |             meta_loss = torch.stack(losses).mean() | ||||||
|  |             match_loss = criterion( | ||||||
|  |                 torch.squeeze(time_embeds, dim=0), | ||||||
|  |                 meta_model.super_meta_embed[rand_index : rand_index + xenv.seq_length], | ||||||
|  |             ) | ||||||
|  |             # batch_loss = meta_loss + match_loss * 0.1 | ||||||
|  |             # total_losses.append(batch_loss) | ||||||
|  |             total_meta_losses.append(meta_loss) | ||||||
|  |             total_match_losses.append(match_loss) | ||||||
|  |         final_meta_loss = torch.stack(total_meta_losses).mean() | ||||||
|  |         final_match_loss = torch.stack(total_match_losses).mean() | ||||||
|  |         total_loss = final_meta_loss + final_match_loss | ||||||
|  |         total_loss.backward() | ||||||
|  |         optimizer.step() | ||||||
|  |         # success | ||||||
|  |         success, best_score = meta_model.save_best(-total_loss.item()) | ||||||
|  |         logger.log( | ||||||
|  |             "{:} [{:04d}/{:}] loss : {:.5f} = {:.5f} + {:.5f} (match)".format( | ||||||
|  |                 time_string(), | ||||||
|  |                 iepoch, | ||||||
|  |                 args.epochs, | ||||||
|  |                 total_loss.item(), | ||||||
|  |                 final_meta_loss.item(), | ||||||
|  |                 final_match_loss.item(), | ||||||
|  |             ) | ||||||
|  |             + ", batch={:}".format(len(total_meta_losses)) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
| def main(args): | def main(args): | ||||||
|     logger, env_info, model_kwargs = lfna_setup(args) |     logger, env_info, model_kwargs = lfna_setup(args) | ||||||
|     train_env = get_synthetic_env(mode="train", version=args.env_version) |     train_env = get_synthetic_env(mode="train", version=args.env_version) | ||||||
| @@ -148,6 +209,8 @@ def main(args): | |||||||
|     logger.log("The scheduler is\n{:}".format(lr_scheduler)) |     logger.log("The scheduler is\n{:}".format(lr_scheduler)) | ||||||
|     logger.log("Per epoch iterations = {:}".format(len(train_env_loader))) |     logger.log("Per epoch iterations = {:}".format(len(train_env_loader))) | ||||||
|  |  | ||||||
|  |     pretrain(base_model, meta_model, criterion, train_env, args, logger) | ||||||
|  |  | ||||||
|     if logger.path("model").exists(): |     if logger.path("model").exists(): | ||||||
|         ckp_data = torch.load(logger.path("model")) |         ckp_data = torch.load(logger.path("model")) | ||||||
|         base_model.load_state_dict(ckp_data["base_model"]) |         base_model.load_state_dict(ckp_data["base_model"]) | ||||||
| @@ -345,7 +408,7 @@ if __name__ == "__main__": | |||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--lr", |         "--lr", | ||||||
|         type=float, |         type=float, | ||||||
|         default=0.005, |         default=0.002, | ||||||
|         help="The initial learning rate for the optimizer (default is Adam)", |         help="The initial learning rate for the optimizer (default is Adam)", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|   | |||||||
| @@ -63,7 +63,7 @@ class LFNA_Meta(super_core.SuperModule): | |||||||
|         for ilayer in range(mha_depth): |         for ilayer in range(mha_depth): | ||||||
|             layers.append( |             layers.append( | ||||||
|                 super_core.SuperTransformerEncoderLayer( |                 super_core.SuperTransformerEncoderLayer( | ||||||
|                     time_embedding, |                     time_embedding * 2, | ||||||
|                     4, |                     4, | ||||||
|                     True, |                     True, | ||||||
|                     4, |                     4, | ||||||
| @@ -72,7 +72,7 @@ class LFNA_Meta(super_core.SuperModule): | |||||||
|                     order=super_core.LayerOrder.PostNorm, |                     order=super_core.LayerOrder.PostNorm, | ||||||
|                 ) |                 ) | ||||||
|             ) |             ) | ||||||
|         layers.append(super_core.SuperLinear(time_embedding, time_embedding)) |         layers.append(super_core.SuperLinear(time_embedding * 2, time_embedding)) | ||||||
|         self.meta_corrector = super_core.SuperSequential(*layers) |         self.meta_corrector = super_core.SuperSequential(*layers) | ||||||
|  |  | ||||||
|         model_kwargs = dict( |         model_kwargs = dict( | ||||||
| @@ -95,6 +95,7 @@ class LFNA_Meta(super_core.SuperModule): | |||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def meta_timestamps(self): |     def meta_timestamps(self): | ||||||
|  |         with torch.no_grad(): | ||||||
|             meta_timestamps = [self._meta_timestamps] |             meta_timestamps = [self._meta_timestamps] | ||||||
|             for key in ("fixed", "learnt"): |             for key in ("fixed", "learnt"): | ||||||
|                 if self._append_meta_timestamps[key] is not None: |                 if self._append_meta_timestamps[key] is not None: | ||||||
| @@ -125,6 +126,10 @@ class LFNA_Meta(super_core.SuperModule): | |||||||
|         self._append_meta_timestamps["learnt"] = timestamp |         self._append_meta_timestamps["learnt"] = timestamp | ||||||
|         self._append_meta_embed["learnt"] = meta_embed |         self._append_meta_embed["learnt"] = meta_embed | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def meta_length(self): | ||||||
|  |         return self.meta_timestamps.numel() | ||||||
|  |  | ||||||
|     def append_fixed(self, timestamp, meta_embed): |     def append_fixed(self, timestamp, meta_embed): | ||||||
|         with torch.no_grad(): |         with torch.no_grad(): | ||||||
|             device = self._super_meta_embed.device |             device = self._super_meta_embed.device | ||||||
| @@ -152,15 +157,18 @@ class LFNA_Meta(super_core.SuperModule): | |||||||
|         timestamp_embeds = self._trans_att( |         timestamp_embeds = self._trans_att( | ||||||
|             timestamp_q_embed, timestamp_k_embed, timestamp_v_embed |             timestamp_q_embed, timestamp_k_embed, timestamp_v_embed | ||||||
|         ) |         ) | ||||||
|         corrected_embeds = self.meta_corrector(timestamp_embeds) |         # relative_timestamps = timestamps - timestamps[:, :1] | ||||||
|  |         # relative_pos_embeds = self._tscalar_embed(relative_timestamps) | ||||||
|  |         init_timestamp_embeds = torch.cat((timestamp_q_embed, timestamp_embeds), dim=-1) | ||||||
|  |         corrected_embeds = self.meta_corrector(init_timestamp_embeds) | ||||||
|         return corrected_embeds |         return corrected_embeds | ||||||
|  |  | ||||||
|     def forward_raw(self, timestamps): |     def forward_raw(self, timestamps): | ||||||
|         batch, seq = timestamps.shape |         batch, seq = timestamps.shape | ||||||
|         meta_embed = self._obtain_time_embed(timestamps) |         time_embed = self._obtain_time_embed(timestamps) | ||||||
|         # create joint embed |         # create joint embed | ||||||
|         num_layer, _ = self._super_layer_embed.shape |         num_layer, _ = self._super_layer_embed.shape | ||||||
|         meta_embed = meta_embed.view(batch, seq, 1, -1).expand(-1, -1, num_layer, -1) |         meta_embed = time_embed.view(batch, seq, 1, -1).expand(-1, -1, num_layer, -1) | ||||||
|         layer_embed = self._super_layer_embed.view(1, 1, num_layer, -1).expand( |         layer_embed = self._super_layer_embed.view(1, 1, num_layer, -1).expand( | ||||||
|             batch, seq, -1, -1 |             batch, seq, -1, -1 | ||||||
|         ) |         ) | ||||||
| @@ -173,7 +181,7 @@ class LFNA_Meta(super_core.SuperModule): | |||||||
|                 weights = torch.split(weights.squeeze(0), 1) |                 weights = torch.split(weights.squeeze(0), 1) | ||||||
|                 seq_containers.append(self._shape_container.translate(weights)) |                 seq_containers.append(self._shape_container.translate(weights)) | ||||||
|             batch_containers.append(seq_containers) |             batch_containers.append(seq_containers) | ||||||
|         return batch_containers |         return batch_containers, time_embed | ||||||
|  |  | ||||||
|     def forward_candidate(self, input): |     def forward_candidate(self, input): | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|   | |||||||
| @@ -68,6 +68,10 @@ class SyntheticDEnv(data.Dataset): | |||||||
|         self._oracle_map = None |         self._oracle_map = None | ||||||
|         self._seq_length = None |         self._seq_length = None | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def seq_length(self): | ||||||
|  |         return self._seq_length | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def min_timestamp(self): |     def min_timestamp(self): | ||||||
|         return self._timestamp_generator.min_timestamp |         return self._timestamp_generator.min_timestamp | ||||||
| @@ -125,6 +129,14 @@ class SyntheticDEnv(data.Dataset): | |||||||
|                 timestamp + i * self.timestamp_interval + noise |                 timestamp + i * self.timestamp_interval + noise | ||||||
|                 for i in range(self._seq_length) |                 for i in range(self._seq_length) | ||||||
|             ] |             ] | ||||||
|  |             # xdata = [self.__call__(timestamp) for timestamp in timestamps] | ||||||
|  |             # return zip_sequence(xdata) | ||||||
|  |             return self.seq_call(timestamps) | ||||||
|  |  | ||||||
|  |     def seq_call(self, timestamps): | ||||||
|  |         with torch.no_grad(): | ||||||
|  |             if isinstance(timestamps, torch.Tensor): | ||||||
|  |                 timestamps = timestamps.cpu().tolist() | ||||||
|             xdata = [self.__call__(timestamp) for timestamp in timestamps] |             xdata = [self.__call__(timestamp) for timestamp in timestamps] | ||||||
|             return zip_sequence(xdata) |             return zip_sequence(xdata) | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user