Refine LFNA vis codes
This commit is contained in:
parent
17955123a0
commit
3d3a04705f
@ -1,7 +1,8 @@
|
|||||||
#####################################################
|
#####################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 #
|
||||||
############################################################################
|
############################################################################
|
||||||
# python exps/LFNA/vis-synthetic.py #
|
# python exps/LFNA/vis-synthetic.py --env_version v1 #
|
||||||
|
# python exps/LFNA/vis-synthetic.py --env_version v2 #
|
||||||
############################################################################
|
############################################################################
|
||||||
import os, sys, copy, random
|
import os, sys, copy, random
|
||||||
import torch
|
import torch
|
||||||
@ -223,7 +224,9 @@ def visualize_env(save_dir, version):
|
|||||||
|
|
||||||
def compare_algs(save_dir, version, alg_dir="./outputs/lfna-synthetic"):
|
def compare_algs(save_dir, version, alg_dir="./outputs/lfna-synthetic"):
|
||||||
save_dir = Path(str(save_dir))
|
save_dir = Path(str(save_dir))
|
||||||
save_dir.mkdir(parents=True, exist_ok=True)
|
for substr in ("pdf", "png"):
|
||||||
|
sub_save_dir = save_dir / substr
|
||||||
|
sub_save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
dpi, width, height = 30, 3200, 2000
|
dpi, width, height = 30, 3200, 2000
|
||||||
figsize = width / float(dpi), height / float(dpi)
|
figsize = width / float(dpi), height / float(dpi)
|
||||||
@ -235,10 +238,10 @@ def compare_algs(save_dir, version, alg_dir="./outputs/lfna-synthetic"):
|
|||||||
|
|
||||||
alg_name2dir = OrderedDict()
|
alg_name2dir = OrderedDict()
|
||||||
alg_name2dir["Optimal"] = "use-same-timestamp"
|
alg_name2dir["Optimal"] = "use-same-timestamp"
|
||||||
alg_name2dir["Supervised Learning (History Data)"] = "use-all-past-data"
|
# alg_name2dir["Supervised Learning (History Data)"] = "use-all-past-data"
|
||||||
alg_name2dir["MAML"] = "use-maml-s1"
|
# alg_name2dir["MAML"] = "use-maml-s1"
|
||||||
alg_name2dir["LFNA (fix init)"] = "lfna-fix-init"
|
# alg_name2dir["LFNA (fix init)"] = "lfna-fix-init"
|
||||||
alg_name2dir["LFNA (debug)"] = "lfna-debug"
|
alg_name2dir["LFNA (debug)"] = "lfna-tall-hpnet"
|
||||||
alg_name2all_containers = OrderedDict()
|
alg_name2all_containers = OrderedDict()
|
||||||
if version == "v1":
|
if version == "v1":
|
||||||
poststr = "v1-d16"
|
poststr = "v1-d16"
|
||||||
@ -246,15 +249,16 @@ def compare_algs(save_dir, version, alg_dir="./outputs/lfna-synthetic"):
|
|||||||
raise ValueError("Invalid version: {:}".format(version))
|
raise ValueError("Invalid version: {:}".format(version))
|
||||||
for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()):
|
for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()):
|
||||||
ckp_path = Path(alg_dir) / "{:}-{:}".format(xdir, poststr) / "final-ckp.pth"
|
ckp_path = Path(alg_dir) / "{:}-{:}".format(xdir, poststr) / "final-ckp.pth"
|
||||||
xdata = torch.load(ckp_path)
|
xdata = torch.load(ckp_path, map_location="cpu")
|
||||||
alg_name2all_containers[alg] = xdata["w_container_per_epoch"]
|
alg_name2all_containers[alg] = xdata["w_container_per_epoch"]
|
||||||
# load the basic model
|
# load the basic model
|
||||||
model = get_model(
|
model = get_model(
|
||||||
dict(model_type="simple_mlp"),
|
dict(model_type="norm_mlp"),
|
||||||
act_cls="leaky_relu",
|
|
||||||
norm_cls="identity",
|
|
||||||
input_dim=1,
|
input_dim=1,
|
||||||
output_dim=1,
|
output_dim=1,
|
||||||
|
hidden_dims=[16] * 2,
|
||||||
|
act_cls="gelu",
|
||||||
|
norm_cls="layer_norm_1d",
|
||||||
)
|
)
|
||||||
|
|
||||||
alg2xs, alg2ys = defaultdict(list), defaultdict(list)
|
alg2xs, alg2ys = defaultdict(list), defaultdict(list)
|
||||||
@ -331,13 +335,14 @@ def compare_algs(save_dir, version, alg_dir="./outputs/lfna-synthetic"):
|
|||||||
cur_ax.set_ylim(0, 10)
|
cur_ax.set_ylim(0, 10)
|
||||||
cur_ax.legend(loc=1, fontsize=LegendFontsize)
|
cur_ax.legend(loc=1, fontsize=LegendFontsize)
|
||||||
|
|
||||||
save_path = save_dir / "v{:}-{:05d}".format(version, idx)
|
pdf_save_path = save_dir / "pdf" / "v{:}-{:05d}.pdf".format(version, idx)
|
||||||
fig.savefig(str(save_path) + ".pdf", dpi=dpi, bbox_inches="tight", format="pdf")
|
fig.savefig(str(pdf_save_path), dpi=dpi, bbox_inches="tight", format="pdf")
|
||||||
fig.savefig(str(save_path) + ".png", dpi=dpi, bbox_inches="tight", format="png")
|
png_save_path = save_dir / "png" / "v{:}-{:05d}.png".format(version, idx)
|
||||||
|
fig.savefig(str(png_save_path), dpi=dpi, bbox_inches="tight", format="png")
|
||||||
plt.close("all")
|
plt.close("all")
|
||||||
save_dir = save_dir.resolve()
|
save_dir = save_dir.resolve()
|
||||||
base_cmd = "ffmpeg -y -i {xdir}/v{ver}-%05d.png -vf scale={w}:{h} -pix_fmt yuv420p -vb 5000k".format(
|
base_cmd = "ffmpeg -y -i {xdir}/v{ver}-%05d.png -vf scale={w}:{h} -pix_fmt yuv420p -vb 5000k".format(
|
||||||
xdir=save_dir, w=width, h=height, ver=version
|
xdir=save_dir / "png", w=width, h=height, ver=version
|
||||||
)
|
)
|
||||||
os.system(
|
os.system(
|
||||||
"{:} {xdir}/com-alg-{ver}.mp4".format(base_cmd, xdir=save_dir, ver=version)
|
"{:} {xdir}/com-alg-{ver}.mp4".format(base_cmd, xdir=save_dir, ver=version)
|
||||||
@ -356,9 +361,15 @@ if __name__ == "__main__":
|
|||||||
default="./outputs/vis-synthetic",
|
default="./outputs/vis-synthetic",
|
||||||
help="The save directory.",
|
help="The save directory.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--env_version",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The synthetic enviornment version.",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# visualize_env(os.path.join(args.save_dir, "vis-env"), "v1")
|
# visualize_env(os.path.join(args.save_dir, "vis-env"), "v1")
|
||||||
# visualize_env(os.path.join(args.save_dir, "vis-env"), "v2")
|
# visualize_env(os.path.join(args.save_dir, "vis-env"), "v2")
|
||||||
compare_algs(os.path.join(args.save_dir, "compare-alg-v2"), "v1")
|
compare_algs(os.path.join(args.save_dir, "compare-alg"), args.env_version)
|
||||||
# compare_cl(os.path.join(args.save_dir, "compare-cl"))
|
# compare_cl(os.path.join(args.save_dir, "compare-cl"))
|
||||||
|
Loading…
Reference in New Issue
Block a user