Update LFNA
This commit is contained in:
		| @@ -93,6 +93,38 @@ def epoch_evaluate(loader, meta_model, base_model, criterion, device, logger): | ||||
|     return loss_meter | ||||
|  | ||||
|  | ||||
| def online_evaluate(env, meta_model, base_model, criterion, args, logger): | ||||
|     logger.log("Online evaluate: {:}".format(env)) | ||||
|     for idx, (timestamp, (future_x, future_y)) in enumerate(env): | ||||
|         future_time = timestamp.item() | ||||
|         time_seqs = [ | ||||
|             future_time - iseq * env.timestamp_interval | ||||
|             for iseq in range(args.seq_length * 2) | ||||
|         ] | ||||
|         time_seqs.reverse() | ||||
|         with torch.no_grad(): | ||||
|             meta_model.eval() | ||||
|             base_model.eval() | ||||
|             time_seqs = torch.Tensor(time_seqs).view(1, -1).to(args.device) | ||||
|             [seq_containers], _ = meta_model(time_seqs, None) | ||||
|             future_container = seq_containers[-2] | ||||
|             _, (future_x, future_y) = env(time_seqs[0, -2].item()) | ||||
|             future_x, future_y = future_x.to(args.device), future_y.to(args.device) | ||||
|             future_y_hat = base_model.forward_with_container(future_x, future_container) | ||||
|             future_loss = criterion(future_y_hat, future_y) | ||||
|             logger.log( | ||||
|                 "[ONLINE] [{:03d}/{:03d}] loss={:.4f}".format( | ||||
|                     idx, len(env), future_loss.item() | ||||
|                 ) | ||||
|             ) | ||||
|         import pdb | ||||
|  | ||||
|         pdb.set_trace() | ||||
|         for iseq in range(args.seq_length): | ||||
|             time_seqs.append(future_time - iseq * eval_env.timestamp_interval) | ||||
|         print("-") | ||||
|  | ||||
|  | ||||
| def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): | ||||
|     base_model.train() | ||||
|     meta_model.train() | ||||
| @@ -176,7 +208,7 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): | ||||
|             ) | ||||
|             + ", batch={:}".format(len(total_meta_v1_losses)) | ||||
|             + ", success={:}, best={:.4f}".format(success, -best_score) | ||||
|             + ", LS={:}/{:}".format(last_success_epoch - iepoch, early_stop_thresh) | ||||
|             + ", LS={:}/{:}".format(iepoch - last_success_epoch, early_stop_thresh) | ||||
|             + ", {:}".format(left_time) | ||||
|         ) | ||||
|         if success: | ||||
| @@ -194,77 +226,6 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): | ||||
|     logger.log("Save the best model into {:}".format(final_best_name)) | ||||
|  | ||||
|  | ||||
| def pretrain_v1(base_model, meta_model, criterion, xenv, args, logger): | ||||
|     base_model.train() | ||||
|     meta_model.train() | ||||
|     optimizer = torch.optim.Adam( | ||||
|         meta_model.parameters(), | ||||
|         lr=args.lr, | ||||
|         weight_decay=args.weight_decay, | ||||
|         amsgrad=True, | ||||
|     ) | ||||
|     logger.log("Pre-train the meta-model's embeddings") | ||||
|     logger.log("Using the optimizer: {:}".format(optimizer)) | ||||
|  | ||||
|     meta_model.set_best_dir(logger.path(None) / "ckps-pretrain-v1") | ||||
|     meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed)) | ||||
|     per_epoch_time, start_time = AverageMeter(), time.time() | ||||
|     last_success_epoch, early_stop_thresh = 0, args.pretrain_early_stop_thresh | ||||
|     for iepoch in range(args.epochs): | ||||
|         left_time = "Time Left: {:}".format( | ||||
|             convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) | ||||
|         ) | ||||
|         losses = [] | ||||
|         optimizer.zero_grad() | ||||
|         for ibatch in range(args.meta_batch): | ||||
|             rand_index = random.randint(0, meta_model.meta_length - xenv.seq_length - 1) | ||||
|             timestamps = meta_model.meta_timestamps[ | ||||
|                 rand_index : rand_index + xenv.seq_length | ||||
|             ] | ||||
|  | ||||
|             seq_timestamps, (seq_inputs, seq_targets) = xenv.seq_call(timestamps) | ||||
|             time_embeds = meta_model.super_meta_embed[ | ||||
|                 rand_index : rand_index + xenv.seq_length | ||||
|             ] | ||||
|             [seq_containers], time_embeds = meta_model( | ||||
|                 None, torch.unsqueeze(time_embeds, dim=0) | ||||
|             ) | ||||
|             seq_inputs, seq_targets = seq_inputs.to(args.device), seq_targets.to( | ||||
|                 args.device | ||||
|             ) | ||||
|             for container, inputs, targets in zip( | ||||
|                 seq_containers, seq_inputs, seq_targets | ||||
|             ): | ||||
|                 predictions = base_model.forward_with_container(inputs, container) | ||||
|                 loss = criterion(predictions, targets) | ||||
|                 losses.append(loss) | ||||
|         final_loss = torch.stack(losses).mean() | ||||
|         final_loss.backward() | ||||
|         optimizer.step() | ||||
|         # success | ||||
|         success, best_score = meta_model.save_best(-final_loss.item()) | ||||
|         logger.log( | ||||
|             "{:} [Pre-V1 {:04d}/{:}] loss : {:.5f}".format( | ||||
|                 time_string(), | ||||
|                 iepoch, | ||||
|                 args.epochs, | ||||
|                 final_loss.item(), | ||||
|             ) | ||||
|             + ", batch={:}".format(len(losses)) | ||||
|             + ", success={:}, best={:.4f}".format(success, -best_score) | ||||
|             + ", LS={:}/{:}".format(last_success_epoch - iepoch, early_stop_thresh) | ||||
|             + " {:}".format(left_time) | ||||
|         ) | ||||
|         if success: | ||||
|             last_success_epoch = iepoch | ||||
|         if iepoch - last_success_epoch >= early_stop_thresh: | ||||
|             logger.log("Early stop the pre-training at {:}".format(iepoch)) | ||||
|             break | ||||
|         per_epoch_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|     meta_model.load_best() | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
|     logger, env_info, model_kwargs = lfna_setup(args) | ||||
|     train_env = get_synthetic_env(mode="train", version=args.env_version) | ||||
| @@ -290,7 +251,7 @@ def main(args): | ||||
|  | ||||
|     batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge) | ||||
|     train_env.reset_max_seq_length(args.seq_length) | ||||
|     valid_env.reset_max_seq_length(args.seq_length) | ||||
|     # valid_env.reset_max_seq_length(args.seq_length) | ||||
|     valid_env_loader = torch.utils.data.DataLoader( | ||||
|         valid_env, | ||||
|         batch_size=args.meta_batch, | ||||
| @@ -306,6 +267,11 @@ def main(args): | ||||
|     ) | ||||
|     pretrain_v2(base_model, meta_model, criterion, train_env, args, logger) | ||||
|  | ||||
|     # try to evaluate once | ||||
|     online_evaluate(valid_env, meta_model, base_model, criterion, args, logger) | ||||
|     import pdb | ||||
|  | ||||
|     pdb.set_trace() | ||||
|     optimizer = torch.optim.Adam( | ||||
|         meta_model.get_parameters(True, True, False),  # fix hypernet | ||||
|         lr=args.lr, | ||||
| @@ -558,7 +524,7 @@ if __name__ == "__main__": | ||||
|     parser.add_argument( | ||||
|         "--pretrain_early_stop_thresh", | ||||
|         type=int, | ||||
|         default=200, | ||||
|         default=300, | ||||
|         help="The #epochs for early stop.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|   | ||||
| @@ -22,6 +22,7 @@ class LFNA_Meta(super_core.SuperModule): | ||||
|         meta_timestamps, | ||||
|         mha_depth: int = 2, | ||||
|         dropout: float = 0.1, | ||||
|         thresh: float = 0.05, | ||||
|     ): | ||||
|         super(LFNA_Meta, self).__init__() | ||||
|         self._shape_container = shape_container | ||||
| @@ -30,6 +31,7 @@ class LFNA_Meta(super_core.SuperModule): | ||||
|         for ilayer in range(self._num_layers): | ||||
|             self._numel_per_layer.append(shape_container[ilayer].numel()) | ||||
|         self._raw_meta_timestamps = meta_timestamps | ||||
|         self._thresh = thresh | ||||
|  | ||||
|         self.register_parameter( | ||||
|             "_super_layer_embed", | ||||
| @@ -168,7 +170,14 @@ class LFNA_Meta(super_core.SuperModule): | ||||
|         timestamp_k_embed = self._tscalar_embed(meta_timestamps.view(1, -1)) | ||||
|         timestamp_v_embed = meta_embeds.unsqueeze(dim=0) | ||||
|         # create the mask | ||||
|         mask = torch.unsqueeze(timestamps, dim=-1) <= meta_timestamps.view(1, 1, -1) | ||||
|         mask = ( | ||||
|             torch.unsqueeze(timestamps, dim=-1) <= meta_timestamps.view(1, 1, -1) | ||||
|         ) | ( | ||||
|             torch.abs( | ||||
|                 torch.unsqueeze(timestamps, dim=-1) - meta_timestamps.view(1, 1, -1) | ||||
|             ) | ||||
|             > self._thresh | ||||
|         ) | ||||
|         timestamp_embeds = self._trans_att( | ||||
|             timestamp_q_embed, timestamp_k_embed, timestamp_v_embed, mask | ||||
|         ) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user