Update LFNA test
This commit is contained in:
		| @@ -1,7 +1,8 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||
| ##################################################### | ||||
| # python exps/LFNA/lfna-test-hpnet.py --env_version v1 --hidden_dim 16 --layer_dim 32 --epochs 50000 | ||||
| # python exps/LFNA/lfna-test-hpnet.py --env_version v1 --hidden_dim 16 --layer_dim 32 --epochs 500000 --init_lr 0.02 | ||||
| # python exps/LFNA/lfna-test-hpnet.py --env_version v1 --hidden_dim 16 --layer_dim 32 --epochs 500000 --init_lr 0.02 --device cuda | ||||
| ##################################################### | ||||
| import sys, time, copy, torch, random, argparse | ||||
| from tqdm import tqdm | ||||
| @@ -37,19 +38,31 @@ def main(args): | ||||
|     model = model.to(args.device) | ||||
|     criterion = torch.nn.MSELoss() | ||||
|  | ||||
|     logger.log("There are {:} weights.".format(model.get_w_container().numel())) | ||||
|  | ||||
|     shape_container = model.get_w_container().to_shape_container() | ||||
|     hypernet = HyperNet(shape_container, args.layer_dim, args.task_dim) | ||||
|     hypernet = hypernet.to(args.device) | ||||
|  | ||||
|     logger.log( | ||||
|         "{:} There are {:} weights in the base-model.".format( | ||||
|             time_string(), model.numel() | ||||
|         ) | ||||
|     ) | ||||
|     logger.log( | ||||
|         "{:} There are {:} weights in the meta-model.".format( | ||||
|             time_string(), hypernet.numel() | ||||
|         ) | ||||
|     ) | ||||
|     # task_embed = torch.nn.Parameter(torch.Tensor(env_info["total"], args.task_dim)) | ||||
|     total_bar = 16 | ||||
|     total_bar = 100 | ||||
|     task_embeds = [] | ||||
|     for i in range(total_bar): | ||||
|         tensor = torch.Tensor(1, args.task_dim).to(args.device) | ||||
|         task_embeds.append(torch.nn.Parameter(tensor)) | ||||
|     for task_embed in task_embeds: | ||||
|         trunc_normal_(task_embed, std=0.02) | ||||
|     for i in range(total_bar): | ||||
|         env_info["{:}-x".format(i)] = env_info["{:}-x".format(i)].to(args.device) | ||||
|         env_info["{:}-y".format(i)] = env_info["{:}-y".format(i)].to(args.device) | ||||
|  | ||||
|     model.train() | ||||
|     hypernet.train() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user