Update LFNA
This commit is contained in:
		| @@ -93,6 +93,67 @@ def epoch_evaluate(loader, meta_model, base_model, criterion, device, logger): | ||||
|     return loss_meter | ||||
|  | ||||
|  | ||||
| def pretrain(base_model, meta_model, criterion, xenv, args, logger): | ||||
|     optimizer = torch.optim.Adam( | ||||
|         meta_model.parameters(), | ||||
|         lr=args.lr, | ||||
|         weight_decay=args.weight_decay, | ||||
|         amsgrad=True, | ||||
|     ) | ||||
|  | ||||
|     meta_model.set_best_dir(logger.path(None) / "checkpoint-pretrain") | ||||
|     for iepoch in range(args.epochs): | ||||
|         total_meta_losses, total_match_losses = [], [] | ||||
|         for ibatch in range(args.meta_batch): | ||||
|             rand_index = random.randint(0, meta_model.meta_length - xenv.seq_length - 1) | ||||
|             timestamps = meta_model.meta_timestamps[ | ||||
|                 rand_index : rand_index + xenv.seq_length | ||||
|             ] | ||||
|  | ||||
|             seq_timestamps, (seq_inputs, seq_targets) = xenv.seq_call(timestamps) | ||||
|             [seq_containers], time_embeds = meta_model( | ||||
|                 torch.unsqueeze(timestamps, dim=0) | ||||
|             ) | ||||
|             # performance loss | ||||
|             losses = [] | ||||
|             seq_inputs, seq_targets = seq_inputs.to(args.device), seq_targets.to( | ||||
|                 args.device | ||||
|             ) | ||||
|             for container, inputs, targets in zip( | ||||
|                 seq_containers, seq_inputs, seq_targets | ||||
|             ): | ||||
|                 predictions = base_model.forward_with_container(inputs, container) | ||||
|                 loss = criterion(predictions, targets) | ||||
|                 losses.append(loss) | ||||
|             meta_loss = torch.stack(losses).mean() | ||||
|             match_loss = criterion( | ||||
|                 torch.squeeze(time_embeds, dim=0), | ||||
|                 meta_model.super_meta_embed[rand_index : rand_index + xenv.seq_length], | ||||
|             ) | ||||
|             # batch_loss = meta_loss + match_loss * 0.1 | ||||
|             # total_losses.append(batch_loss) | ||||
|             total_meta_losses.append(meta_loss) | ||||
|             total_match_losses.append(match_loss) | ||||
|         final_meta_loss = torch.stack(total_meta_losses).mean() | ||||
|         final_match_loss = torch.stack(total_match_losses).mean() | ||||
|         total_loss = final_meta_loss + final_match_loss | ||||
|         total_loss.backward() | ||||
|         optimizer.step() | ||||
|         # success | ||||
|         success, best_score = meta_model.save_best(-total_loss.item()) | ||||
|         logger.log( | ||||
|             "{:} [{:04d}/{:}] loss : {:.5f} = {:.5f} + {:.5f} (match)".format( | ||||
|                 time_string(), | ||||
|                 iepoch, | ||||
|                 args.epochs, | ||||
|                 total_loss.item(), | ||||
|                 final_meta_loss.item(), | ||||
|                 final_match_loss.item(), | ||||
|             ) | ||||
|             + ", batch={:}".format(len(total_meta_losses)) | ||||
|         ) | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
|     logger, env_info, model_kwargs = lfna_setup(args) | ||||
|     train_env = get_synthetic_env(mode="train", version=args.env_version) | ||||
| @@ -148,6 +209,8 @@ def main(args): | ||||
|     logger.log("The scheduler is\n{:}".format(lr_scheduler)) | ||||
|     logger.log("Per epoch iterations = {:}".format(len(train_env_loader))) | ||||
|  | ||||
|     pretrain(base_model, meta_model, criterion, train_env, args, logger) | ||||
|  | ||||
|     if logger.path("model").exists(): | ||||
|         ckp_data = torch.load(logger.path("model")) | ||||
|         base_model.load_state_dict(ckp_data["base_model"]) | ||||
| @@ -345,7 +408,7 @@ if __name__ == "__main__": | ||||
|     parser.add_argument( | ||||
|         "--lr", | ||||
|         type=float, | ||||
|         default=0.005, | ||||
|         default=0.002, | ||||
|         help="The initial learning rate for the optimizer (default is Adam)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|   | ||||
| @@ -63,7 +63,7 @@ class LFNA_Meta(super_core.SuperModule): | ||||
|         for ilayer in range(mha_depth): | ||||
|             layers.append( | ||||
|                 super_core.SuperTransformerEncoderLayer( | ||||
|                     time_embedding, | ||||
|                     time_embedding * 2, | ||||
|                     4, | ||||
|                     True, | ||||
|                     4, | ||||
| @@ -72,7 +72,7 @@ class LFNA_Meta(super_core.SuperModule): | ||||
|                     order=super_core.LayerOrder.PostNorm, | ||||
|                 ) | ||||
|             ) | ||||
|         layers.append(super_core.SuperLinear(time_embedding, time_embedding)) | ||||
|         layers.append(super_core.SuperLinear(time_embedding * 2, time_embedding)) | ||||
|         self.meta_corrector = super_core.SuperSequential(*layers) | ||||
|  | ||||
|         model_kwargs = dict( | ||||
| @@ -95,6 +95,7 @@ class LFNA_Meta(super_core.SuperModule): | ||||
|  | ||||
|     @property | ||||
|     def meta_timestamps(self): | ||||
|         with torch.no_grad(): | ||||
|             meta_timestamps = [self._meta_timestamps] | ||||
|             for key in ("fixed", "learnt"): | ||||
|                 if self._append_meta_timestamps[key] is not None: | ||||
| @@ -125,6 +126,10 @@ class LFNA_Meta(super_core.SuperModule): | ||||
|         self._append_meta_timestamps["learnt"] = timestamp | ||||
|         self._append_meta_embed["learnt"] = meta_embed | ||||
|  | ||||
|     @property | ||||
|     def meta_length(self): | ||||
|         return self.meta_timestamps.numel() | ||||
|  | ||||
|     def append_fixed(self, timestamp, meta_embed): | ||||
|         with torch.no_grad(): | ||||
|             device = self._super_meta_embed.device | ||||
| @@ -152,15 +157,18 @@ class LFNA_Meta(super_core.SuperModule): | ||||
|         timestamp_embeds = self._trans_att( | ||||
|             timestamp_q_embed, timestamp_k_embed, timestamp_v_embed | ||||
|         ) | ||||
|         corrected_embeds = self.meta_corrector(timestamp_embeds) | ||||
|         # relative_timestamps = timestamps - timestamps[:, :1] | ||||
|         # relative_pos_embeds = self._tscalar_embed(relative_timestamps) | ||||
|         init_timestamp_embeds = torch.cat((timestamp_q_embed, timestamp_embeds), dim=-1) | ||||
|         corrected_embeds = self.meta_corrector(init_timestamp_embeds) | ||||
|         return corrected_embeds | ||||
|  | ||||
|     def forward_raw(self, timestamps): | ||||
|         batch, seq = timestamps.shape | ||||
|         meta_embed = self._obtain_time_embed(timestamps) | ||||
|         time_embed = self._obtain_time_embed(timestamps) | ||||
|         # create joint embed | ||||
|         num_layer, _ = self._super_layer_embed.shape | ||||
|         meta_embed = meta_embed.view(batch, seq, 1, -1).expand(-1, -1, num_layer, -1) | ||||
|         meta_embed = time_embed.view(batch, seq, 1, -1).expand(-1, -1, num_layer, -1) | ||||
|         layer_embed = self._super_layer_embed.view(1, 1, num_layer, -1).expand( | ||||
|             batch, seq, -1, -1 | ||||
|         ) | ||||
| @@ -173,7 +181,7 @@ class LFNA_Meta(super_core.SuperModule): | ||||
|                 weights = torch.split(weights.squeeze(0), 1) | ||||
|                 seq_containers.append(self._shape_container.translate(weights)) | ||||
|             batch_containers.append(seq_containers) | ||||
|         return batch_containers | ||||
|         return batch_containers, time_embed | ||||
|  | ||||
|     def forward_candidate(self, input): | ||||
|         raise NotImplementedError | ||||
|   | ||||
| @@ -68,6 +68,10 @@ class SyntheticDEnv(data.Dataset): | ||||
|         self._oracle_map = None | ||||
|         self._seq_length = None | ||||
|  | ||||
|     @property | ||||
|     def seq_length(self): | ||||
|         return self._seq_length | ||||
|  | ||||
|     @property | ||||
|     def min_timestamp(self): | ||||
|         return self._timestamp_generator.min_timestamp | ||||
| @@ -125,6 +129,14 @@ class SyntheticDEnv(data.Dataset): | ||||
|                 timestamp + i * self.timestamp_interval + noise | ||||
|                 for i in range(self._seq_length) | ||||
|             ] | ||||
|             # xdata = [self.__call__(timestamp) for timestamp in timestamps] | ||||
|             # return zip_sequence(xdata) | ||||
|             return self.seq_call(timestamps) | ||||
|  | ||||
|     def seq_call(self, timestamps): | ||||
|         with torch.no_grad(): | ||||
|             if isinstance(timestamps, torch.Tensor): | ||||
|                 timestamps = timestamps.cpu().tolist() | ||||
|             xdata = [self.__call__(timestamp) for timestamp in timestamps] | ||||
|             return zip_sequence(xdata) | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user