diff --git a/exps/LFNA/vis-synthetic.py b/exps/LFNA/vis-synthetic.py index 395d760..649b890 100644 --- a/exps/LFNA/vis-synthetic.py +++ b/exps/LFNA/vis-synthetic.py @@ -1,7 +1,8 @@ ##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 # ############################################################################ -# python exps/LFNA/vis-synthetic.py # +# python exps/LFNA/vis-synthetic.py --env_version v1 # +# python exps/LFNA/vis-synthetic.py --env_version v2 # ############################################################################ import os, sys, copy, random import torch @@ -223,7 +224,9 @@ def visualize_env(save_dir, version): def compare_algs(save_dir, version, alg_dir="./outputs/lfna-synthetic"): save_dir = Path(str(save_dir)) - save_dir.mkdir(parents=True, exist_ok=True) + for substr in ("pdf", "png"): + sub_save_dir = save_dir / substr + sub_save_dir.mkdir(parents=True, exist_ok=True) dpi, width, height = 30, 3200, 2000 figsize = width / float(dpi), height / float(dpi) @@ -235,10 +238,10 @@ def compare_algs(save_dir, version, alg_dir="./outputs/lfna-synthetic"): alg_name2dir = OrderedDict() alg_name2dir["Optimal"] = "use-same-timestamp" - alg_name2dir["Supervised Learning (History Data)"] = "use-all-past-data" - alg_name2dir["MAML"] = "use-maml-s1" - alg_name2dir["LFNA (fix init)"] = "lfna-fix-init" - alg_name2dir["LFNA (debug)"] = "lfna-debug" + # alg_name2dir["Supervised Learning (History Data)"] = "use-all-past-data" + # alg_name2dir["MAML"] = "use-maml-s1" + # alg_name2dir["LFNA (fix init)"] = "lfna-fix-init" + alg_name2dir["LFNA (debug)"] = "lfna-tall-hpnet" alg_name2all_containers = OrderedDict() if version == "v1": poststr = "v1-d16" @@ -246,15 +249,16 @@ def compare_algs(save_dir, version, alg_dir="./outputs/lfna-synthetic"): raise ValueError("Invalid version: {:}".format(version)) for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()): ckp_path = Path(alg_dir) / "{:}-{:}".format(xdir, poststr) / "final-ckp.pth" - xdata = torch.load(ckp_path) + xdata = torch.load(ckp_path, map_location="cpu") alg_name2all_containers[alg] = xdata["w_container_per_epoch"] # load the basic model model = get_model( - dict(model_type="simple_mlp"), - act_cls="leaky_relu", - norm_cls="identity", + dict(model_type="norm_mlp"), input_dim=1, output_dim=1, + hidden_dims=[16] * 2, + act_cls="gelu", + norm_cls="layer_norm_1d", ) alg2xs, alg2ys = defaultdict(list), defaultdict(list) @@ -331,13 +335,14 @@ def compare_algs(save_dir, version, alg_dir="./outputs/lfna-synthetic"): cur_ax.set_ylim(0, 10) cur_ax.legend(loc=1, fontsize=LegendFontsize) - save_path = save_dir / "v{:}-{:05d}".format(version, idx) - fig.savefig(str(save_path) + ".pdf", dpi=dpi, bbox_inches="tight", format="pdf") - fig.savefig(str(save_path) + ".png", dpi=dpi, bbox_inches="tight", format="png") + pdf_save_path = save_dir / "pdf" / "v{:}-{:05d}.pdf".format(version, idx) + fig.savefig(str(pdf_save_path), dpi=dpi, bbox_inches="tight", format="pdf") + png_save_path = save_dir / "png" / "v{:}-{:05d}.png".format(version, idx) + fig.savefig(str(png_save_path), dpi=dpi, bbox_inches="tight", format="png") plt.close("all") save_dir = save_dir.resolve() base_cmd = "ffmpeg -y -i {xdir}/v{ver}-%05d.png -vf scale={w}:{h} -pix_fmt yuv420p -vb 5000k".format( - xdir=save_dir, w=width, h=height, ver=version + xdir=save_dir / "png", w=width, h=height, ver=version ) os.system( "{:} {xdir}/com-alg-{ver}.mp4".format(base_cmd, xdir=save_dir, ver=version) @@ -356,9 +361,15 @@ if __name__ == "__main__": default="./outputs/vis-synthetic", help="The save directory.", ) + parser.add_argument( + "--env_version", + type=str, + required=True, + help="The synthetic enviornment version.", + ) args = parser.parse_args() # visualize_env(os.path.join(args.save_dir, "vis-env"), "v1") # visualize_env(os.path.join(args.save_dir, "vis-env"), "v2") - compare_algs(os.path.join(args.save_dir, "compare-alg-v2"), "v1") + compare_algs(os.path.join(args.save_dir, "compare-alg"), args.env_version) # compare_cl(os.path.join(args.save_dir, "compare-cl"))