138 lines
365 KiB
Plaintext
138 lines
365 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 1,
|
||
|
"id": "filled-multiple",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"The root path: /Users/xuanyidong/Desktop/AutoDL-Projects\n",
|
||
|
"The library path: /Users/xuanyidong/Desktop/AutoDL-Projects/lib\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"import os, sys\n",
|
||
|
"import torch\n",
|
||
|
"from pathlib import Path\n",
|
||
|
"import numpy as np\n",
|
||
|
"import matplotlib\n",
|
||
|
"import matplotlib.pyplot as plt\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"__file__ = os.path.dirname(os.path.realpath(\"__file__\"))\n",
|
||
|
"root_dir = (Path(__file__).parent / \"..\").resolve()\n",
|
||
|
"lib_dir = (root_dir / \"lib\").resolve()\n",
|
||
|
"print(\"The root path: {:}\".format(root_dir))\n",
|
||
|
"print(\"The library path: {:}\".format(lib_dir))\n",
|
||
|
"assert lib_dir.exists(), \"{:} does not exist\".format(lib_dir)\n",
|
||
|
"if str(lib_dir) not in sys.path:\n",
|
||
|
" sys.path.insert(0, str(lib_dir))\n",
|
||
|
"\n",
|
||
|
"from datasets.synthetic_example import create_example_v1"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"id": "consistent-transition",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"/Users/xuanyidong/Desktop/AutoDL-Projects/lib/datasets/synthetic_env.py:63: RuntimeWarning: covariance is not positive-semidefinite.\n",
|
||
|
" dataset = np.random.multivariate_normal(\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAEd8AAAjqCAYAAAB+2cVeAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOz9cYyc950m+H1/VV3FLjUlliJzrL2UpdL4LE1mhbju0DMRMJdcMnvHAQJkLdTdzOYwQEJng6tdBEJ2JA//yaWQK2ADRJa8GyhIpvdysW4xEyQz2TeWgUywxGFvkMsFujMFtBMtbiytxyW5ElOibTWlblarqrt++UNim61u0iqJ5FvF/nyAgrqfIoVnKXq23n7f93lTzjkAAAAAAAAAAAAAAAAAAAAAAOAkqZRdAAAAAAAAAAAAAAAAAAAAAAAA7jbjOwAAAAAAAAAAAAAAAAAAAAAAnDjGdwAAAAAAAAAAAAAAAAAAAAAAOHGM7wAAAAAAAAAAAAAAAAAAAAAAcOIY3wEAAAAAAAAAAAAAAAAAAAAA4MRZKbsAt/aFL3wht9vtsmsAAAAAAAAAAAAAAAAAAAAAACyd11577ac557PHvWd8Z8G12+24dOlS2TUAAAAAAAAAAAAAAAAAAAAAAJZOSumtm71XuZtFAAAAAAAAAAAAAAAAAAAAAABgERjfAQAAAAAAAAAAAAAAAAAAAADgxDG+AwAAAAAAAAAAAAAAAAAAAADAiWN8BwAAAAAAAAAAAAAAAAAAAACAE8f4DgAAAAAAAAAAAAAAAAAAAAAAJ47xHQAAAAAAAAAAAAAAAAAAAAAAThzjOwAAAAAAAAAAAAAAAAAAAAAAnDjGdwAAAAAAAAAAAAAAAAAAAAAAOHGM7wAAAAAAAAAAAAAAAAAAAAAAcOIY3wEAAAAAAAAAAAAAAAAAAAAA4MQxvgMAAAAAAAAAAAAAAAAAAAAAwIljfAcAAAAAAAAAAAAAAAAAAAAAgBPH+A4AAAAAAAAAAAAAAAAAAAAAACeO8R0AAAAAAAAAAAAAAAAAAAAAAE4c4zsAAAAAAAAAAAAAAAAAAAAAAJw4xncAAAAAAAAAAAAAAAAAAAAAADhxjO8AAAAAAAAAAAAAAAAAAAAAAHDiGN8BAAAAAAAAAAAAAAAAAAAAAODEMb4DAAAAAAAAAAAAAAAAAAAAAMCJY3wHAAAAAAAAAAAAAAAAAAAAAIATZ6XsAvealNJaRHwtIroR8dcj4l+Kj/6cfxIR/7+IeDUi/pOI+E9zzttl9QQAAAAAAAAAAAAAAAAAAAAAOMmM79xGKaV/OyL+1/HR4M4nffnj138zIv7w49cLd68dAAAAAAAAAAAAAAAAAAAAAADXGd+5TVJK34yIb3wifjcihhGxHREPRcSvRcSpu9sMAAAAAAAAAAAAAAAAAAAAAIBPMr5zG6SU/v04PLzzTyOiHxHfyznnG37dSkT8axHx34+PBnkAAAAAAAAAAAAAAAAAAAAAACiB8Z3PKaX0VET8ezdE/8uc8//suF+bc96LiL/4+AUAAAAAAAAAAAAAAAAAAAAAQEkqZRdYZimlFBH/Qfziz/H/drPhHQAAAAAAAAAAAAAAAAAAAAAAFofxnc/nb0TEkx9/nSPiD0rsAgAAAAAAAAAAAAAAAAAAAADAp2R85/P5H9/w9V/knN8srQkAAAAAAAAAAAAAAAAAAAAAAJ+a8Z3P59+44es/L60FAAAAAAAAAAAAAAAAAAAAAABzMb7zGaWU/uWIeOiG6NWP8381pfS/Syn9IKW0k1K6+vHX/2FK6d8spy0AAAAAAAAAAAAAAAAAAAAAADdaKbvAEvuvf+L7v0opPR8Rz8XRUaMHIuLxiPgfpZT+k4j4d3LO79yFjgAAAAAAAAAAAAAAAAAAAAAAHOOTIzF8eg/d8PVeRPx7EfGH8Ys/0zci4j+JiO9HxOyGX/vfiYj/PKX0xZv9i1NK/25K6VJK6dKVK1dub2sAAAAAAAAAAAAAAAAAAAAAAIzvfA5nbvh6JSL+7sdf/78i4smc8xM559/OOXciohUR/+cbfv2jEfEnN/sX55z/Uc55Pee8fvbs2dtcGwAAAAAAAAAAAAAAAAAAAAAA4zuf3eox2WsR8W/knP/5jWHO+ScR8e9ExP/xhvhvpJT+jTvYDwAAAAAAAAAAAAAAAAAAAACAmzC+89ntHJP9T3LO4+N+cc45R8QzEXHthvhv34liAAAAAAAAAAAAAAAAAAAAAADcmvGdz277E9+/mXP+z2/1G3LOP4+IP78h+m/d9lYAAAAAAAAAAAAAAAAAAAAAAPxSxnc+u59+4vvXPuXvu/HX/UsppdXb1AcAAAAAAAAAAAAAAAAAAAAAgE/J+M5n919+4vuffcrf98lf9+Bt6AIAAAAAAAAAAAAAAAAAAAAAwByM73x2/yIiJjd8f+pT/r7VT3w/vj11AAAAAAAAAAAAAAAAAAAAAAD4tIzvfEY5572I+H/eEP3qp/ytj93w9YcRcfW2lQIAAAAAAAAAAAAAAAAAAAAA4FMxvvP5/JMbvn4qpXT/p/g9/+YNX//nOed8mzsBAAAAAAAAS6ooiuh0OtFsNqPT6URRFGVXAgAAAAAAAAAAALhnGd/5fP40Ij74+Ov7IuJ/eqtfnFL670XEkzdE37kztQAAAAAAAIBlUxRF9Hq9GI1G0Wg0YjQaRa/XM8ADAAAAAAAAAAAAcIcY3/kccs4/jYhv3hD9z1NKXzvu16aUvhoR/+EN0ZWI+Ed3sB4AAAAAAACwRAaDQeSco16vR0op6vV65JxjMBiUXQ0AAIC7pCiK6HQ60Ww2o9PpGGQFAAAAAACAO2yl7AL3gG9GxH83Ip6KiHpE/F9TSv+XiCgi4v8bEf+ViPidiPjbH78fETGLiP9Bznnn7tcFAAAAAAAAFtFwOIxGo3Eoq9VqMRwOyykEAADAXVUURfR6vcg5R6PRiNFoFL1eLyIiut1uye0AAAAAAADg3pRyzmV3WHoppS9ExMWI+Fc+xS//MCL+ds75Tz7Nv3t9fT1funTp89QDAAAAAAAAlkCn04nRaBT1ev0gm0wm0Wq1YnNzs7xiAAAA3BWOCwEAAAAAAODOSCm9lnNeP+69yt0ucy/KOf80Iv4bEfG/iIgrN/lls4j4bkSsf9rhHQAAAAAAAODk6Pf7kVKKyWQSOeeYTCaRUop+v192NQAAAO6C4XAYtVrtUFar1WI4HJZTCAAAAAAAAE6AlbIL3CtyztOI+PdTSn8/Iv61iPiXI+JsROxExI8j4v+Rc/5ZiRUBAAAAAACABdbtdiMiYjAYxHA4jHa7Hf1+/yAHAADg3tZut2M0GkW9Xj/IptNptNvt8koBAAAAAADAPS7lnMvuwC2sr6/nS5culV0DAAAAAAAAAAAAgDuoKIro9XqRc45arRbT6TRSSrGxsWGYFQAAAAAAAD6HlNJrOef1496r3O0yAAAAAAAAAAAAAMBh3W43NjY2otVqxXg8jlarZXgHAAAAAAAA7rCUcy67A7ewvr6eL126VHYNAAAAAAAAAAAAAAAAAAAAAIClk1J6Lee8ftx7lbtdBgAAAAAAAAAAAABYTkVRRKfTiWazGZ1OJ4qiKLsSAAAAAAAAfGbGdwAAAAAAAAAAAABgASz6sE1RFNHr9WI0GkWj0YjRaBS9Xm/hegIAAAAAAMCnlXLOZXfgFtbX1/OlS5fKrgEAAAAAAAAAAADAHXR92CbnHLVaLabTaaSUYmNjI7rdbtn1IiKi0+nEaDSKer1+kE0mk2i1WrG5uVleMQAAAAAAALiFlNJrOef1496r3O0yAAAAAAAAAAAAAMBhg8Egcs5Rr9cjpRT1ej1yzjEYDMqudmA4HEatVjuU1Wq1GA6H5RQCAAAAAACAz8n4DgAAAAAAAAAAAACUbBmGbdrtdkyn00PZdDqNdrtdTiEAAAAAAAD4nIzvAAAAAAAAAAAAAEDJlmHYpt/vR0opJpNJ5JxjMplESin6/X7Z1QAAAAAAAOAzMb4DAAAAAAAAAAAAACVbhmGbbrcbGxsb0Wq1YjweR6vVio2Njeh2u2VXAwAAAAA
|
||
|
"text/plain": [
|
||
|
"<Figure size 5760x2880 with 2 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {
|
||
|
"needs_background": "light"
|
||
|
},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"def visualize_env():\n",
|
||
|
" \n",
|
||
|
" dpi, width, height = 10, 800, 400\n",
|
||
|
" figsize = width / float(dpi), height / float(dpi)\n",
|
||
|
" LabelSize, LegendFontsize, font_gap = 40, 40, 5\n",
|
||
|
"\n",
|
||
|
" fig = plt.figure(figsize=figsize)\n",
|
||
|
"\n",
|
||
|
" dynamic_env, function = create_example_v1(num_per_task=250)\n",
|
||
|
" \n",
|
||
|
" timeaxis, xaxis, yaxis = [], [], []\n",
|
||
|
" for timestamp, dataset in dynamic_env:\n",
|
||
|
" num = dataset.shape[0]\n",
|
||
|
" timeaxis.append(torch.zeros(num) + timestamp)\n",
|
||
|
" xaxis.append(dataset[:,0])\n",
|
||
|
" # compute the ground truth\n",
|
||
|
" function.set_timestamp(timestamp)\n",
|
||
|
" yaxis.append(function(dataset[:,0]))\n",
|
||
|
" \n",
|
||
|
" timeaxis = torch.cat(timeaxis).numpy()\n",
|
||
|
" # import pdb; pdb.set_trace()\n",
|
||
|
" xaxis = torch.cat(xaxis).numpy()\n",
|
||
|
" yaxis = torch.cat(yaxis).numpy()\n",
|
||
|
"\n",
|
||
|
" cur_ax = fig.add_subplot(2, 1, 1)\n",
|
||
|
" cur_ax.scatter(timeaxis, xaxis, color=\"k\", linestyle=\"-\", alpha=0.9, label=None)\n",
|
||
|
" cur_ax.set_xlabel(\"Time\", fontsize=LabelSize)\n",
|
||
|
" cur_ax.set_ylabel(\"X\", rotation=0, fontsize=LabelSize)\n",
|
||
|
" for tick in cur_ax.xaxis.get_major_ticks():\n",
|
||
|
" tick.label.set_fontsize(LabelSize - font_gap)\n",
|
||
|
" tick.label.set_rotation(10)\n",
|
||
|
" for tick in cur_ax.yaxis.get_major_ticks():\n",
|
||
|
" tick.label.set_fontsize(LabelSize - font_gap)\n",
|
||
|
" \n",
|
||
|
" cur_ax = fig.add_subplot(2, 1, 2)\n",
|
||
|
" cur_ax.scatter(timeaxis, yaxis, color=\"k\", linestyle=\"-\", alpha=0.9, label=None)\n",
|
||
|
" cur_ax.set_xlabel(\"Time\", fontsize=LabelSize)\n",
|
||
|
" cur_ax.set_ylabel(\"Y\", rotation=0, fontsize=LabelSize)\n",
|
||
|
" for tick in cur_ax.xaxis.get_major_ticks():\n",
|
||
|
" tick.label.set_fontsize(LabelSize - font_gap)\n",
|
||
|
" tick.label.set_rotation(10)\n",
|
||
|
" for tick in cur_ax.yaxis.get_major_ticks():\n",
|
||
|
" tick.label.set_fontsize(LabelSize - font_gap)\n",
|
||
|
" plt.show()\n",
|
||
|
"\n",
|
||
|
"visualize_env()"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.8.8"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 5
|
||
|
}
|