Upgrade lfna debug
This commit is contained in:
		| @@ -1,7 +1,7 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||
| ##################################################### | ||||
| # python exps/LFNA/lfna-tall-hpnet.py --env_version v1 --hidden_dim 16 | ||||
| # python exps/LFNA/lfna-tall-hpnet.py --env_version v1 --hidden_dim 16 --epochs 100000 --meta_batch 16 | ||||
| ##################################################### | ||||
| import sys, time, copy, torch, random, argparse | ||||
| from tqdm import tqdm | ||||
| @@ -42,7 +42,7 @@ def main(args): | ||||
|     hypernet = HyperNet(shape_container, args.hidden_dim, args.task_dim) | ||||
|     total_bar = env_info["total"] - 1 | ||||
|     task_embeds = [] | ||||
|     for i in range(total_bar): | ||||
|     for i in range(env_info["total"]): | ||||
|         task_embeds.append(torch.nn.Parameter(torch.Tensor(1, args.task_dim))) | ||||
|     for task_embed in task_embeds: | ||||
|         trunc_normal_(task_embed, std=0.02) | ||||
| @@ -97,7 +97,7 @@ def main(args): | ||||
|         if iepoch % 200 == 0: | ||||
|             logger.log( | ||||
|                 head_str | ||||
|                 + "meta-loss: {:.4f} ({:.4f}) :: lr={:.5f}, batch={:}, limit={:}".format( | ||||
|                 + " meta-loss: {:.4f} ({:.4f}) :: lr={:.5f}, batch={:}, limit={:}".format( | ||||
|                     loss_meter.avg, | ||||
|                     loss_meter.val, | ||||
|                     min(lr_scheduler.get_last_lr()), | ||||
| @@ -109,7 +109,7 @@ def main(args): | ||||
|             save_checkpoint( | ||||
|                 { | ||||
|                     "hypernet": hypernet.state_dict(), | ||||
|                     "task_embed": task_embed, | ||||
|                     "task_embeds": task_embeds, | ||||
|                     "lr_scheduler": lr_scheduler.state_dict(), | ||||
|                     "iepoch": iepoch, | ||||
|                 }, | ||||
| @@ -122,6 +122,25 @@ def main(args): | ||||
|  | ||||
|     print(model) | ||||
|     print(hypernet) | ||||
|     w_container_per_epoch = dict() | ||||
|     for idx in range(0, env_info["total"]): | ||||
|         future_time = env_info["{:}-timestamp".format(idx)] | ||||
|         future_x = env_info["{:}-x".format(idx)] | ||||
|         future_y = env_info["{:}-y".format(idx)] | ||||
|         future_container = hypernet(task_embeds[idx]) | ||||
|         w_container_per_epoch[idx] = future_container.no_grad_clone() | ||||
|         with torch.no_grad(): | ||||
|             future_y_hat = model.forward_with_container( | ||||
|                 future_x, w_container_per_epoch[idx] | ||||
|             ) | ||||
|             future_loss = criterion(future_y_hat, future_y) | ||||
|         logger.log("meta-test: [{:03d}] -> loss={:.4f}".format(idx, future_loss.item())) | ||||
|  | ||||
|     save_checkpoint( | ||||
|         {"w_container_per_epoch": w_container_per_epoch}, | ||||
|         logger.path(None) / "final-ckp.pth", | ||||
|         logger, | ||||
|     ) | ||||
|  | ||||
|     logger.log("-" * 200 + "\n") | ||||
|     logger.close() | ||||
|   | ||||
| @@ -34,17 +34,20 @@ def main(args): | ||||
|     logger, env_info, model_kwargs = lfna_setup(args) | ||||
|     dynamic_env = env_info["dynamic_env"] | ||||
|     model = get_model(dict(model_type="simple_mlp"), **model_kwargs) | ||||
|     model = model.to(args.device) | ||||
|     criterion = torch.nn.MSELoss() | ||||
|  | ||||
|     logger.log("There are {:} weights.".format(model.get_w_container().numel())) | ||||
|  | ||||
|     shape_container = model.get_w_container().to_shape_container() | ||||
|     hypernet = HyperNet(shape_container, args.hidden_dim, args.task_dim) | ||||
|     hypernet = hypernet.to(args.device) | ||||
|     # task_embed = torch.nn.Parameter(torch.Tensor(env_info["total"], args.task_dim)) | ||||
|     total_bar = 10 | ||||
|     task_embeds = [] | ||||
|     for i in range(total_bar): | ||||
|         task_embeds.append(torch.nn.Parameter(torch.Tensor(1, args.task_dim))) | ||||
|         tensor = torch.Tensor(1, args.task_dim).to(args.device) | ||||
|         task_embeds.append(torch.nn.Parameter(tensor)) | ||||
|     for task_embed in task_embeds: | ||||
|         trunc_normal_(task_embed, std=0.02) | ||||
|  | ||||
| @@ -79,8 +82,8 @@ def main(args): | ||||
|             # cur_time = random.randint(0, total_bar) | ||||
|             cur_task_embed = task_embeds[cur_time] | ||||
|             cur_container = hypernet(cur_task_embed) | ||||
|             cur_x = env_info["{:}-x".format(cur_time)] | ||||
|             cur_y = env_info["{:}-y".format(cur_time)] | ||||
|             cur_x = env_info["{:}-x".format(cur_time)].to(args.device) | ||||
|             cur_y = env_info["{:}-y".format(cur_time)].to(args.device) | ||||
|             cur_dataset = TimeData(cur_time, cur_x, cur_y) | ||||
|  | ||||
|             preds = model.forward_with_container(cur_dataset.x, cur_container) | ||||
| @@ -98,7 +101,7 @@ def main(args): | ||||
|         if iepoch % 200 == 0: | ||||
|             logger.log( | ||||
|                 head_str | ||||
|                 + "meta-loss: {:.4f} ({:.4f}) :: lr={:.5f}, batch={:}".format( | ||||
|                 + " meta-loss: {:.4f} ({:.4f}) :: lr={:.5f}, batch={:}".format( | ||||
|                     loss_meter.avg, | ||||
|                     loss_meter.val, | ||||
|                     min(lr_scheduler.get_last_lr()), | ||||
| @@ -166,6 +169,12 @@ if __name__ == "__main__": | ||||
|         default=2000, | ||||
|         help="The total number of epochs.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--device", | ||||
|         type=str, | ||||
|         default="cpu", | ||||
|         help="", | ||||
|     ) | ||||
|     # Random Seed | ||||
|     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | ||||
|     args = parser.parse_args() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user