Update vis
This commit is contained in:
		| @@ -37,7 +37,9 @@ def draw_multi_fig(save_dir, timestamp, scatter_list, wh, fig_title=None): | |||||||
|  |  | ||||||
|     fig = plt.figure(figsize=figsize) |     fig = plt.figure(figsize=figsize) | ||||||
|     if fig_title is not None: |     if fig_title is not None: | ||||||
|         fig.suptitle(fig_title, fontsize=LegendFontsize) |         fig.suptitle( | ||||||
|  |             fig_title, fontsize=LegendFontsize, fontweight="bold", x=0.5, y=0.92 | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     for idx, scatter_dict in enumerate(scatter_list): |     for idx, scatter_dict in enumerate(scatter_list): | ||||||
|         cur_ax = fig.add_subplot(len(scatter_list), 1, idx + 1) |         cur_ax = fig.add_subplot(len(scatter_list), 1, idx + 1) | ||||||
| @@ -83,7 +85,7 @@ def compare_cl(save_dir): | |||||||
|     save_dir.mkdir(parents=True, exist_ok=True) |     save_dir.mkdir(parents=True, exist_ok=True) | ||||||
|     dynamic_env, function = create_example_v1( |     dynamic_env, function = create_example_v1( | ||||||
|         # timestamp_config=dict(num=200, min_timestamp=-1, max_timestamp=1.0), |         # timestamp_config=dict(num=200, min_timestamp=-1, max_timestamp=1.0), | ||||||
|         timestamp_config=None, |         timestamp_config=dict(num=200), | ||||||
|         num_per_task=1000, |         num_per_task=1000, | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
| @@ -107,16 +109,18 @@ def compare_cl(save_dir): | |||||||
|  |  | ||||||
|         # compute cl-min |         # compute cl-min | ||||||
|         cl_xaxis_min = find_min(cl_xaxis_min, xaxis_all.mean() - xaxis_all.std()) |         cl_xaxis_min = find_min(cl_xaxis_min, xaxis_all.mean() - xaxis_all.std()) | ||||||
|         cl_xaxis_max = ( |         cl_xaxis_max = find_max(cl_xaxis_max, xaxis_all.mean() + xaxis_all.std()) | ||||||
|             find_max(cl_xaxis_max, xaxis_all.mean() + xaxis_all.std()) + idx * 0.1 |         """ | ||||||
|         ) |  | ||||||
|         cl_xaxis_all = np.arange(cl_xaxis_min, cl_xaxis_max, step=0.05) |         cl_xaxis_all = np.arange(cl_xaxis_min, cl_xaxis_max, step=0.05) | ||||||
|  |  | ||||||
|         cl_yaxis_all = cl_function.noise_call(cl_xaxis_all) |         cl_yaxis_all = cl_function.noise_call(cl_xaxis_all) | ||||||
|         current_data["cl_xaxis_all"] = cl_xaxis_all |         current_data["cl_xaxis_all"] = cl_xaxis_all | ||||||
|         current_data["cl_yaxis_all"] = cl_yaxis_all |         current_data["cl_yaxis_all"] = cl_yaxis_all | ||||||
|  |         """ | ||||||
|         all_data[timestamp] = current_data |         all_data[timestamp] = current_data | ||||||
|  |  | ||||||
|  |     global_cl_xaxis_all = np.arange(cl_xaxis_min, cl_xaxis_max, step=0.1) | ||||||
|  |     global_cl_yaxis_all = cl_function.noise_call(global_cl_xaxis_all) | ||||||
|  |  | ||||||
|     for idx, (timestamp, xdata) in enumerate(tqdm(all_data.items(), ncols=50)): |     for idx, (timestamp, xdata) in enumerate(tqdm(all_data.items(), ncols=50)): | ||||||
|         scatter_list = [] |         scatter_list = [] | ||||||
|         scatter_list.append( |         scatter_list.append( | ||||||
| @@ -124,7 +128,7 @@ def compare_cl(save_dir): | |||||||
|                 "xaxis": xdata["lfna_xaxis_all"], |                 "xaxis": xdata["lfna_xaxis_all"], | ||||||
|                 "yaxis": xdata["lfna_yaxis_all"], |                 "yaxis": xdata["lfna_yaxis_all"], | ||||||
|                 "color": "k", |                 "color": "k", | ||||||
|                 "s": 10, |                 "s": 12, | ||||||
|                 "alpha": 0.99, |                 "alpha": 0.99, | ||||||
|                 "xlim": (-6, 6), |                 "xlim": (-6, 6), | ||||||
|                 "ylim": (-40, 40), |                 "ylim": (-40, 40), | ||||||
| @@ -132,17 +136,21 @@ def compare_cl(save_dir): | |||||||
|             } |             } | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         cl_xaxis_all = current_data["cl_xaxis_all"] |         cur_cl_xaxis_min = cl_xaxis_min | ||||||
|         cl_yaxis_all = current_data["cl_yaxis_all"] |         cur_cl_xaxis_max = cl_xaxis_min + (cl_xaxis_max - cl_xaxis_min) * ( | ||||||
|  |             idx + 1 | ||||||
|  |         ) / len(all_data) | ||||||
|  |         cl_xaxis_all = np.arange(cur_cl_xaxis_min, cur_cl_xaxis_max, step=0.01) | ||||||
|  |         cl_yaxis_all = cl_function.noise_call(cl_xaxis_all, std=0.2) | ||||||
|  |  | ||||||
|         scatter_list.append( |         scatter_list.append( | ||||||
|             { |             { | ||||||
|                 "xaxis": cl_xaxis_all, |                 "xaxis": cl_xaxis_all, | ||||||
|                 "yaxis": cl_yaxis_all, |                 "yaxis": cl_yaxis_all, | ||||||
|                 "color": "r", |                 "color": "k", | ||||||
|                 "s": 10, |                 "s": 12, | ||||||
|                 "xlim": (round(cl_xaxis_all.min(), 1), round(cl_xaxis_all.max(), 1)), |                 "xlim": (round(cl_xaxis_min, 1), round(cl_xaxis_max, 1)), | ||||||
|                 "ylim": (round(cl_xaxis_all.min(), 1), round(cl_yaxis_all.max(), 1)), |                 "ylim": (-18, 2), | ||||||
|                 "alpha": 0.99, |                 "alpha": 0.99, | ||||||
|                 "label": "Continual Learning", |                 "label": "Continual Learning", | ||||||
|             } |             } | ||||||
| @@ -152,18 +160,20 @@ def compare_cl(save_dir): | |||||||
|             save_dir, |             save_dir, | ||||||
|             idx, |             idx, | ||||||
|             scatter_list, |             scatter_list, | ||||||
|             wh=(2000, 1300), |             wh=(2200, 1800), | ||||||
|             fig_title="Timestamp={:03d}".format(idx), |             fig_title="Timestamp={:03d}".format(idx), | ||||||
|         ) |         ) | ||||||
|     print("Save all figures into {:}".format(save_dir)) |     print("Save all figures into {:}".format(save_dir)) | ||||||
|     save_dir = save_dir.resolve() |     save_dir = save_dir.resolve() | ||||||
|     base_cmd = ( |     base_cmd = ( | ||||||
|         "ffmpeg -y -i {xdir}/%04d.png -vf fps=2 -vf scale=2000:1300 -vb 5000k".format( |         "ffmpeg -y -i {xdir}/%04d.png -vf fps=1 -vf scale=2200:1800 -vb 5000k".format( | ||||||
|             xdir=save_dir |             xdir=save_dir | ||||||
|         ) |         ) | ||||||
|     ) |     ) | ||||||
|     os.system("{:} -pix_fmt yuv420p {xdir}/vis.mp4".format(base_cmd, xdir=save_dir)) |     video_cmd = "{:} -pix_fmt yuv420p {xdir}/vis.mp4".format(base_cmd, xdir=save_dir) | ||||||
|     os.system("{:} -c:a libvorbis {xdir}/vis.webm".format(base_cmd, xdir=save_dir)) |     print(video_cmd + "\n") | ||||||
|  |     os.system(video_cmd) | ||||||
|  |     # os.system("{:} {xdir}/vis.webm".format(base_cmd, xdir=save_dir)) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user