Update LFNA -- refine
This commit is contained in:
		| @@ -93,8 +93,10 @@ def epoch_evaluate(loader, meta_model, base_model, criterion, device, logger): | |||||||
|     return loss_meter |     return loss_meter | ||||||
|  |  | ||||||
|  |  | ||||||
| def online_evaluate(env, meta_model, base_model, criterion, args, logger): | def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=False): | ||||||
|     logger.log("Online evaluate: {:}".format(env)) |     logger.log("Online evaluate: {:}".format(env)) | ||||||
|  |     loss_meter = AverageMeter() | ||||||
|  |     w_containers = dict() | ||||||
|     for idx, (future_time, (future_x, future_y)) in enumerate(env): |     for idx, (future_time, (future_x, future_y)) in enumerate(env): | ||||||
|         with torch.no_grad(): |         with torch.no_grad(): | ||||||
|             meta_model.eval() |             meta_model.eval() | ||||||
| @@ -102,9 +104,12 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger): | |||||||
|             _, [future_container], time_embeds = meta_model( |             _, [future_container], time_embeds = meta_model( | ||||||
|                 future_time.to(args.device).view(1, 1), None, True |                 future_time.to(args.device).view(1, 1), None, True | ||||||
|             ) |             ) | ||||||
|  |             if save: | ||||||
|  |                 w_containers[idx] = future_container.no_grad_clone() | ||||||
|             future_x, future_y = future_x.to(args.device), future_y.to(args.device) |             future_x, future_y = future_x.to(args.device), future_y.to(args.device) | ||||||
|             future_y_hat = base_model.forward_with_container(future_x, future_container) |             future_y_hat = base_model.forward_with_container(future_x, future_container) | ||||||
|             future_loss = criterion(future_y_hat, future_y) |             future_loss = criterion(future_y_hat, future_y) | ||||||
|  |             loss_meter.update(future_loss.item()) | ||||||
|         refine, post_refine_loss = meta_model.adapt( |         refine, post_refine_loss = meta_model.adapt( | ||||||
|             base_model, |             base_model, | ||||||
|             criterion, |             criterion, | ||||||
| @@ -123,6 +128,7 @@ def online_evaluate(env, meta_model, base_model, criterion, args, logger): | |||||||
|         ) |         ) | ||||||
|     meta_model.clear_fixed() |     meta_model.clear_fixed() | ||||||
|     meta_model.clear_learnt() |     meta_model.clear_learnt() | ||||||
|  |     return w_containers, loss_meter | ||||||
|  |  | ||||||
|  |  | ||||||
| def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): | def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger): | ||||||
| @@ -219,8 +225,10 @@ def main(args): | |||||||
|     logger, env_info, model_kwargs = lfna_setup(args) |     logger, env_info, model_kwargs = lfna_setup(args) | ||||||
|     train_env = get_synthetic_env(mode="train", version=args.env_version) |     train_env = get_synthetic_env(mode="train", version=args.env_version) | ||||||
|     valid_env = get_synthetic_env(mode="valid", version=args.env_version) |     valid_env = get_synthetic_env(mode="valid", version=args.env_version) | ||||||
|     logger.log("training enviornment: {:}".format(train_env)) |     all_env = get_synthetic_env(mode=None, version=args.env_version) | ||||||
|     logger.log("validation enviornment: {:}".format(valid_env)) |     logger.log("The training enviornment: {:}".format(train_env)) | ||||||
|  |     logger.log("The validation enviornment: {:}".format(valid_env)) | ||||||
|  |     logger.log("The total enviornment: {:}".format(all_env)) | ||||||
|  |  | ||||||
|     base_model = get_model(**model_kwargs) |     base_model = get_model(**model_kwargs) | ||||||
|     base_model = base_model.to(args.device) |     base_model = base_model.to(args.device) | ||||||
| @@ -249,10 +257,20 @@ def main(args): | |||||||
|     pretrain_v2(base_model, meta_model, criterion, train_env, args, logger) |     pretrain_v2(base_model, meta_model, criterion, train_env, args, logger) | ||||||
|  |  | ||||||
|     # try to evaluate once |     # try to evaluate once | ||||||
|     online_evaluate(train_env, meta_model, base_model, criterion, args, logger) |     # online_evaluate(train_env, meta_model, base_model, criterion, args, logger) | ||||||
|     online_evaluate(valid_env, meta_model, base_model, criterion, args, logger) |     # online_evaluate(valid_env, meta_model, base_model, criterion, args, logger) | ||||||
|  |     w_containers, loss_meter = online_evaluate( | ||||||
|  |         all_env, meta_model, base_model, criterion, args, logger, True | ||||||
|  |     ) | ||||||
|  |     logger.log("In this enviornment, the loss-meter is {:}".format(loss_meter)) | ||||||
|  |  | ||||||
|     pdb.set_trace() |     save_checkpoint( | ||||||
|  |         {"w_containers": w_containers}, | ||||||
|  |         logger.path(None) / "final-ckp.pth", | ||||||
|  |         logger, | ||||||
|  |     ) | ||||||
|  |     return | ||||||
|  |     """ | ||||||
|     optimizer = torch.optim.Adam( |     optimizer = torch.optim.Adam( | ||||||
|         meta_model.get_parameters(True, True, False),  # fix hypernet |         meta_model.get_parameters(True, True, False),  # fix hypernet | ||||||
|         lr=args.lr, |         lr=args.lr, | ||||||
| @@ -364,7 +382,6 @@ def main(args): | |||||||
|     # meta-test |     # meta-test | ||||||
|     meta_model.load_best() |     meta_model.load_best() | ||||||
|     eval_env = env_info["dynamic_env"] |     eval_env = env_info["dynamic_env"] | ||||||
|     w_container_per_epoch = dict() |  | ||||||
|     for idx in range(args.seq_length, len(eval_env)): |     for idx in range(args.seq_length, len(eval_env)): | ||||||
|         # build-timestamp |         # build-timestamp | ||||||
|         future_time = env_info["{:}-timestamp".format(idx)].item() |         future_time = env_info["{:}-timestamp".format(idx)].item() | ||||||
| @@ -424,6 +441,7 @@ def main(args): | |||||||
|         logger.path(None) / "final-ckp.pth", |         logger.path(None) / "final-ckp.pth", | ||||||
|         logger, |         logger, | ||||||
|     ) |     ) | ||||||
|  |     """ | ||||||
|  |  | ||||||
|     logger.log("-" * 200 + "\n") |     logger.log("-" * 200 + "\n") | ||||||
|     logger.close() |     logger.close() | ||||||
| @@ -494,7 +512,7 @@ if __name__ == "__main__": | |||||||
|         help="The learning rate for the optimizer, during refine", |         help="The learning rate for the optimizer, during refine", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--refine_epochs", type=int, default=40, help="The final refine #epochs." |         "--refine_epochs", type=int, default=50, help="The final refine #epochs." | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--early_stop_thresh", |         "--early_stop_thresh", | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user