xautodl/exps/LFNA/vis-synthetic.py

376 lines
14 KiB
Python
Raw Normal View History

2021-04-23 07:15:39 +02:00
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 #
############################################################################
2021-05-13 11:43:38 +02:00
# python exps/LFNA/vis-synthetic.py --env_version v1 #
# python exps/LFNA/vis-synthetic.py --env_version v2 #
############################################################################
import os, sys, copy, random
2021-04-23 07:15:39 +02:00
import torch
import numpy as np
import argparse
2021-04-29 17:37:50 +02:00
from collections import OrderedDict, defaultdict
2021-04-23 07:15:39 +02:00
from pathlib import Path
from tqdm import tqdm
from pprint import pprint
import matplotlib
from matplotlib import cm
matplotlib.use("agg")
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
2021-05-07 08:27:15 +02:00
from models.xcore import get_model
2021-04-28 17:56:25 +02:00
from datasets.synthetic_core import get_synthetic_env
from utils.temp_sync import optimize_fn, evaluate_fn
2021-04-29 17:37:50 +02:00
from procedures.metric_utils import MSEMetric
def plot_scatter(cur_ax, xs, ys, color, alpha, linewidths, label=None):
cur_ax.scatter([-100], [-100], color=color, linewidths=linewidths, label=label)
cur_ax.scatter(xs, ys, color=color, alpha=alpha, linewidths=1.5, label=None)
2021-04-23 07:15:39 +02:00
2021-04-26 14:16:38 +02:00
def draw_multi_fig(save_dir, timestamp, scatter_list, wh, fig_title=None):
2021-04-23 07:15:39 +02:00
save_path = save_dir / "{:04d}".format(timestamp)
# print('Plot the figure at timestamp-{:} into {:}'.format(timestamp, save_path))
2021-04-26 14:16:38 +02:00
dpi, width, height = 40, wh[0], wh[1]
2021-04-23 07:15:39 +02:00
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize, font_gap = 80, 80, 5
fig = plt.figure(figsize=figsize)
2021-04-26 08:06:51 +02:00
if fig_title is not None:
2021-04-26 18:10:39 +02:00
fig.suptitle(
fig_title, fontsize=LegendFontsize, fontweight="bold", x=0.5, y=0.92
)
2021-04-23 07:15:39 +02:00
2021-04-26 08:06:51 +02:00
for idx, scatter_dict in enumerate(scatter_list):
cur_ax = fig.add_subplot(len(scatter_list), 1, idx + 1)
2021-04-29 17:37:50 +02:00
plot_scatter(
cur_ax,
2021-04-23 07:15:39 +02:00
scatter_dict["xaxis"],
scatter_dict["yaxis"],
2021-04-29 17:37:50 +02:00
scatter_dict["color"],
scatter_dict["alpha"],
scatter_dict["linewidths"],
scatter_dict["label"],
2021-04-23 07:15:39 +02:00
)
2021-04-26 08:06:51 +02:00
cur_ax.set_xlabel("X", fontsize=LabelSize)
2021-04-29 17:37:50 +02:00
cur_ax.set_ylabel("Y", rotation=0, fontsize=LabelSize)
2021-04-26 08:06:51 +02:00
cur_ax.set_xlim(scatter_dict["xlim"][0], scatter_dict["xlim"][1])
cur_ax.set_ylim(scatter_dict["ylim"][0], scatter_dict["ylim"][1])
for tick in cur_ax.xaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap)
tick.label.set_rotation(10)
for tick in cur_ax.yaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap)
2021-04-26 15:16:08 +02:00
cur_ax.legend(loc=1, fontsize=LegendFontsize)
2021-04-23 07:15:39 +02:00
fig.savefig(str(save_path) + ".pdf", dpi=dpi, bbox_inches="tight", format="pdf")
fig.savefig(str(save_path) + ".png", dpi=dpi, bbox_inches="tight", format="png")
plt.close("all")
2021-04-26 15:16:08 +02:00
def find_min(cur, others):
if cur is None:
2021-04-26 15:44:03 +02:00
return float(others)
2021-04-26 15:16:08 +02:00
else:
2021-04-26 15:44:03 +02:00
return float(min(cur, others))
2021-04-26 15:16:08 +02:00
def find_max(cur, others):
if cur is None:
return float(others.max())
else:
2021-04-26 15:44:03 +02:00
return float(max(cur, others))
2021-04-26 15:16:08 +02:00
2021-04-26 08:06:51 +02:00
def compare_cl(save_dir):
2021-04-23 07:15:39 +02:00
save_dir = Path(str(save_dir))
save_dir.mkdir(parents=True, exist_ok=True)
2021-04-27 14:09:37 +02:00
dynamic_env, cl_function = create_example_v1(
2021-04-26 15:44:03 +02:00
# timestamp_config=dict(num=200, min_timestamp=-1, max_timestamp=1.0),
2021-04-26 18:10:39 +02:00
timestamp_config=dict(num=200),
2021-04-26 15:16:08 +02:00
num_per_task=1000,
)
models = dict()
2021-04-23 11:14:49 +02:00
2021-04-26 08:06:51 +02:00
cl_function.set_timestamp(0)
2021-04-26 15:16:08 +02:00
cl_xaxis_min = None
cl_xaxis_max = None
all_data = OrderedDict()
2021-04-26 08:06:51 +02:00
for idx, (timestamp, dataset) in enumerate(tqdm(dynamic_env, ncols=50)):
2021-04-27 14:09:37 +02:00
xaxis_all = dataset[0][:, 0].numpy()
yaxis_all = dataset[1][:, 0].numpy()
2021-04-26 15:16:08 +02:00
current_data = dict()
current_data["lfna_xaxis_all"] = xaxis_all
current_data["lfna_yaxis_all"] = yaxis_all
# compute cl-min
2021-04-26 15:44:03 +02:00
cl_xaxis_min = find_min(cl_xaxis_min, xaxis_all.mean() - xaxis_all.std())
2021-04-26 18:10:39 +02:00
cl_xaxis_max = find_max(cl_xaxis_max, xaxis_all.mean() + xaxis_all.std())
2021-04-26 15:16:08 +02:00
all_data[timestamp] = current_data
2021-04-26 18:10:39 +02:00
global_cl_xaxis_all = np.arange(cl_xaxis_min, cl_xaxis_max, step=0.1)
global_cl_yaxis_all = cl_function.noise_call(global_cl_xaxis_all)
2021-04-26 15:16:08 +02:00
for idx, (timestamp, xdata) in enumerate(tqdm(all_data.items(), ncols=50)):
2021-04-23 07:15:39 +02:00
scatter_list = []
scatter_list.append(
{
2021-04-26 15:16:08 +02:00
"xaxis": xdata["lfna_xaxis_all"],
"yaxis": xdata["lfna_yaxis_all"],
2021-04-23 07:15:39 +02:00
"color": "k",
2021-04-29 17:37:50 +02:00
"linewidths": 15,
2021-04-23 07:15:39 +02:00
"alpha": 0.99,
2021-04-26 08:06:51 +02:00
"xlim": (-6, 6),
"ylim": (-40, 40),
"label": "LFNA",
2021-04-23 07:15:39 +02:00
}
)
2021-04-26 18:10:39 +02:00
cur_cl_xaxis_min = cl_xaxis_min
cur_cl_xaxis_max = cl_xaxis_min + (cl_xaxis_max - cl_xaxis_min) * (
idx + 1
) / len(all_data)
cl_xaxis_all = np.arange(cur_cl_xaxis_min, cur_cl_xaxis_max, step=0.01)
cl_yaxis_all = cl_function.noise_call(cl_xaxis_all, std=0.2)
2021-04-26 15:16:08 +02:00
scatter_list.append(
{
2021-04-26 08:06:51 +02:00
"xaxis": cl_xaxis_all,
"yaxis": cl_yaxis_all,
2021-04-26 18:10:39 +02:00
"color": "k",
2021-04-29 17:37:50 +02:00
"linewidths": 15,
2021-04-26 18:10:39 +02:00
"xlim": (round(cl_xaxis_min, 1), round(cl_xaxis_max, 1)),
2021-04-26 18:20:59 +02:00
"ylim": (-20, 6),
2021-04-26 08:06:51 +02:00
"alpha": 0.99,
"label": "Continual Learning",
}
)
2021-04-23 11:14:49 +02:00
2021-04-26 08:06:51 +02:00
draw_multi_fig(
2021-04-26 15:16:08 +02:00
save_dir,
2021-04-26 15:44:03 +02:00
idx,
2021-04-26 15:16:08 +02:00
scatter_list,
2021-04-26 18:10:39 +02:00
wh=(2200, 1800),
2021-04-26 15:44:03 +02:00
fig_title="Timestamp={:03d}".format(idx),
2021-04-26 08:06:51 +02:00
)
2021-04-23 07:15:39 +02:00
print("Save all figures into {:}".format(save_dir))
save_dir = save_dir.resolve()
2021-04-26 15:16:08 +02:00
base_cmd = (
2021-04-26 18:10:39 +02:00
"ffmpeg -y -i {xdir}/%04d.png -vf fps=1 -vf scale=2200:1800 -vb 5000k".format(
2021-04-26 15:16:08 +02:00
xdir=save_dir
)
2021-04-23 07:15:39 +02:00
)
2021-04-27 14:09:37 +02:00
video_cmd = "{:} -pix_fmt yuv420p {xdir}/compare-cl.mp4".format(
base_cmd, xdir=save_dir
)
2021-04-26 18:10:39 +02:00
print(video_cmd + "\n")
os.system(video_cmd)
2021-04-29 17:39:51 +02:00
os.system(
"{:} -pix_fmt yuv420p {xdir}/compare-cl.webm".format(base_cmd, xdir=save_dir)
)
2021-04-23 07:15:39 +02:00
2021-05-09 13:05:07 +02:00
def visualize_env(save_dir, version):
2021-04-28 17:56:25 +02:00
save_dir = Path(str(save_dir))
save_dir.mkdir(parents=True, exist_ok=True)
2021-05-09 13:05:07 +02:00
dynamic_env = get_synthetic_env(version=version)
2021-04-28 17:56:25 +02:00
min_t, max_t = dynamic_env.min_timestamp, dynamic_env.max_timestamp
for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)):
dpi, width, height = 30, 1800, 1400
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize, font_gap = 80, 80, 5
fig = plt.figure(figsize=figsize)
cur_ax = fig.add_subplot(1, 1, 1)
allx, ally = allx[:, 0].numpy(), ally[:, 0].numpy()
2021-04-29 17:37:50 +02:00
plot_scatter(cur_ax, allx, ally, "k", 0.99, 15, "timestamp={:05d}".format(idx))
2021-04-28 17:56:25 +02:00
cur_ax.set_xlabel("X", fontsize=LabelSize)
cur_ax.set_ylabel("Y", rotation=0, fontsize=LabelSize)
for tick in cur_ax.xaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap)
tick.label.set_rotation(10)
for tick in cur_ax.yaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap)
2021-05-09 13:05:07 +02:00
if version == "v1":
cur_ax.set_xlim(-2, 2)
2021-05-09 13:23:18 +02:00
cur_ax.set_ylim(-8, 8)
2021-05-09 13:05:07 +02:00
elif version == "v2":
cur_ax.set_xlim(-10, 10)
cur_ax.set_ylim(-60, 60)
2021-04-28 17:56:25 +02:00
cur_ax.legend(loc=1, fontsize=LegendFontsize)
2021-05-09 13:05:07 +02:00
save_path = save_dir / "v{:}-{:05d}".format(version, idx)
2021-04-28 17:56:25 +02:00
fig.savefig(str(save_path) + ".pdf", dpi=dpi, bbox_inches="tight", format="pdf")
fig.savefig(str(save_path) + ".png", dpi=dpi, bbox_inches="tight", format="png")
plt.close("all")
save_dir = save_dir.resolve()
2021-05-09 17:36:55 +02:00
base_cmd = "ffmpeg -y -i {xdir}/v{version}-%05d.png -vf scale=1800:1400 -pix_fmt yuv420p -vb 5000k".format(
xdir=save_dir, version=version
2021-04-28 17:56:25 +02:00
)
2021-05-09 17:36:55 +02:00
print(base_cmd)
2021-05-09 13:05:07 +02:00
os.system("{:} {xdir}/env-{ver}.mp4".format(base_cmd, xdir=save_dir, ver=version))
os.system("{:} {xdir}/env-{ver}.webm".format(base_cmd, xdir=save_dir, ver=version))
2021-04-29 13:48:21 +02:00
def compare_algs(save_dir, version, alg_dir="./outputs/lfna-synthetic"):
2021-04-29 17:37:50 +02:00
save_dir = Path(str(save_dir))
2021-05-13 11:43:38 +02:00
for substr in ("pdf", "png"):
2021-05-13 15:33:34 +02:00
sub_save_dir = save_dir / substr
sub_save_dir.mkdir(parents=True, exist_ok=True)
2021-04-29 17:37:50 +02:00
dpi, width, height = 30, 3200, 2000
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize, font_gap = 80, 80, 5
cache_path = Path(alg_dir) / "env-{:}-info.pth".format(version)
2021-04-29 17:37:50 +02:00
assert cache_path.exists(), "{:} does not exist".format(cache_path)
env_info = torch.load(cache_path)
alg_name2dir = OrderedDict()
alg_name2dir["Optimal"] = "use-same-timestamp"
2021-05-13 11:43:38 +02:00
# alg_name2dir["Supervised Learning (History Data)"] = "use-all-past-data"
# alg_name2dir["MAML"] = "use-maml-s1"
# alg_name2dir["LFNA (fix init)"] = "lfna-fix-init"
alg_name2dir["LFNA (debug)"] = "lfna-tall-hpnet"
2021-05-07 08:27:15 +02:00
alg_name2all_containers = OrderedDict()
if version == "v1":
poststr = "v1-d16"
else:
raise ValueError("Invalid version: {:}".format(version))
2021-05-07 08:27:15 +02:00
for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()):
ckp_path = Path(alg_dir) / "{:}-{:}".format(xdir, poststr) / "final-ckp.pth"
2021-05-13 11:43:38 +02:00
xdata = torch.load(ckp_path, map_location="cpu")
2021-05-07 08:27:15 +02:00
alg_name2all_containers[alg] = xdata["w_container_per_epoch"]
# load the basic model
model = get_model(
2021-05-13 11:43:38 +02:00
dict(model_type="norm_mlp"),
2021-05-07 08:27:15 +02:00
input_dim=1,
output_dim=1,
2021-05-13 11:43:38 +02:00
hidden_dims=[16] * 2,
act_cls="gelu",
norm_cls="layer_norm_1d",
2021-05-07 08:27:15 +02:00
)
2021-04-29 17:37:50 +02:00
alg2xs, alg2ys = defaultdict(list), defaultdict(list)
2021-05-10 08:14:06 +02:00
colors = ["r", "g", "b", "m", "y"]
2021-04-29 17:37:50 +02:00
dynamic_env = env_info["dynamic_env"]
min_t, max_t = dynamic_env.min_timestamp, dynamic_env.max_timestamp
linewidths = 10
for idx, (timestamp, (ori_allx, ori_ally)) in enumerate(
tqdm(dynamic_env, ncols=50)
):
if idx == 0:
continue
fig = plt.figure(figsize=figsize)
cur_ax = fig.add_subplot(2, 1, 1)
# the data
allx, ally = ori_allx[:, 0].numpy(), ori_ally[:, 0].numpy()
plot_scatter(cur_ax, allx, ally, "k", 0.99, linewidths, "Raw Data")
for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()):
with torch.no_grad():
2021-05-07 08:27:15 +02:00
# predicts = ckp_data["model"](ori_allx)
predicts = model.forward_with_container(
ori_allx, alg_name2all_containers[alg][idx]
)
2021-04-29 17:37:50 +02:00
predicts = predicts.cpu()
# keep data
metric = MSEMetric()
metric(predicts, ori_ally)
predicts = predicts.view(-1).numpy()
alg2xs[alg].append(idx)
2021-04-29 17:39:51 +02:00
alg2ys[alg].append(metric.get_info()["mse"])
2021-04-29 17:37:50 +02:00
plot_scatter(cur_ax, allx, predicts, colors[idx_alg], 0.99, linewidths, alg)
2021-04-29 13:48:21 +02:00
cur_ax.set_xlabel("X", fontsize=LabelSize)
cur_ax.set_ylabel("Y", rotation=0, fontsize=LabelSize)
for tick in cur_ax.xaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap)
tick.label.set_rotation(10)
for tick in cur_ax.yaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap)
if version == "v1":
cur_ax.set_xlim(-2, 2)
cur_ax.set_ylim(-8, 8)
elif version == "v2":
cur_ax.set_xlim(-10, 10)
cur_ax.set_ylim(-60, 60)
2021-04-29 13:48:21 +02:00
cur_ax.legend(loc=1, fontsize=LegendFontsize)
2021-04-29 17:37:50 +02:00
# the trajectory data
cur_ax = fig.add_subplot(2, 1, 2)
for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()):
# plot_scatter(cur_ax, alg2xs[alg], alg2ys[alg], olors[idx_alg], 0.99, linewidths, alg)
2021-04-29 17:39:51 +02:00
cur_ax.plot(
alg2xs[alg],
alg2ys[alg],
color=colors[idx_alg],
linestyle="-",
linewidth=5,
label=alg,
)
2021-04-29 17:37:50 +02:00
cur_ax.legend(loc=1, fontsize=LegendFontsize)
cur_ax.set_xlabel("Timestamp", fontsize=LabelSize)
cur_ax.set_ylabel("MSE", fontsize=LabelSize)
for tick in cur_ax.xaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap)
tick.label.set_rotation(10)
for tick in cur_ax.yaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap)
cur_ax.set_xlim(1, len(dynamic_env))
cur_ax.set_ylim(0, 10)
cur_ax.legend(loc=1, fontsize=LegendFontsize)
2021-05-13 11:43:38 +02:00
pdf_save_path = save_dir / "pdf" / "v{:}-{:05d}.pdf".format(version, idx)
fig.savefig(str(pdf_save_path), dpi=dpi, bbox_inches="tight", format="pdf")
png_save_path = save_dir / "png" / "v{:}-{:05d}.png".format(version, idx)
fig.savefig(str(png_save_path), dpi=dpi, bbox_inches="tight", format="png")
2021-04-29 13:48:21 +02:00
plt.close("all")
save_dir = save_dir.resolve()
base_cmd = "ffmpeg -y -i {xdir}/v{ver}-%05d.png -vf scale={w}:{h} -pix_fmt yuv420p -vb 5000k".format(
2021-05-13 11:43:38 +02:00
xdir=save_dir / "png", w=width, h=height, ver=version
)
os.system(
"{:} {xdir}/com-alg-{ver}.mp4".format(base_cmd, xdir=save_dir, ver=version)
)
os.system(
"{:} {xdir}/com-alg-{ver}.webm".format(base_cmd, xdir=save_dir, ver=version)
2021-04-29 13:48:21 +02:00
)
2021-04-28 17:56:25 +02:00
2021-04-23 07:15:39 +02:00
if __name__ == "__main__":
2021-04-26 08:06:51 +02:00
parser = argparse.ArgumentParser("Visualize synthetic data.")
2021-04-23 07:15:39 +02:00
parser.add_argument(
"--save_dir",
type=str,
default="./outputs/vis-synthetic",
help="The save directory.",
)
2021-05-13 11:43:38 +02:00
parser.add_argument(
"--env_version",
type=str,
required=True,
help="The synthetic enviornment version.",
)
2021-04-23 07:15:39 +02:00
args = parser.parse_args()
# visualize_env(os.path.join(args.save_dir, "vis-env"), "v1")
# visualize_env(os.path.join(args.save_dir, "vis-env"), "v2")
2021-05-13 11:43:38 +02:00
compare_algs(os.path.join(args.save_dir, "compare-alg"), args.env_version)
2021-04-29 13:48:21 +02:00
# compare_cl(os.path.join(args.save_dir, "compare-cl"))