{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "filled-multiple", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The root path: /Users/xuanyidong/Desktop/AutoDL-Projects\n", "The library path: /Users/xuanyidong/Desktop/AutoDL-Projects/lib\n" ] } ], "source": [ "import os, sys\n", "import torch\n", "from pathlib import Path\n", "import numpy as np\n", "import matplotlib\n", "from matplotlib import cm\n", "matplotlib.use(\"agg\")\n", "import matplotlib.pyplot as plt\n", "import matplotlib.ticker as ticker\n", "\n", "\n", "__file__ = os.path.dirname(os.path.realpath(\"__file__\"))\n", "root_dir = (Path(__file__).parent / \"..\").resolve()\n", "lib_dir = (root_dir / \"lib\").resolve()\n", "print(\"The root path: {:}\".format(root_dir))\n", "print(\"The library path: {:}\".format(lib_dir))\n", "assert lib_dir.exists(), \"{:} does not exist\".format(lib_dir)\n", "if str(lib_dir) not in sys.path:\n", " sys.path.insert(0, str(lib_dir))\n", "\n", "from datasets import SynAdaptiveEnv\n", "from xlayers.super_core import SuperMLPv1" ] }, { "cell_type": "code", "execution_count": 2, "id": "supreme-basis", "metadata": {}, "outputs": [], "source": [ "def optimize_fn(xs, ys, test_sets):\n", " xs = torch.FloatTensor(xs).view(-1, 1)\n", " ys = torch.FloatTensor(ys).view(-1, 1)\n", " \n", " model = SuperMLPv1(1, 10, 1, torch.nn.ReLU)\n", " optimizer = torch.optim.Adam(\n", " model.parameters(),\n", " lr=0.01, weight_decay=1e-4, amsgrad=True\n", " )\n", " for _iter in range(100):\n", " preds = model(ys)\n", "\n", " optimizer.zero_grad()\n", " loss = torch.nn.functional.mse_loss(preds, ys)\n", " loss.backward()\n", " optimizer.step()\n", " \n", " with torch.no_grad():\n", " answers = []\n", " for test_set in test_sets:\n", " test_set = torch.FloatTensor(test_set).view(-1, 1)\n", " preds = model(test_set).view(-1).numpy()\n", " answers.append(preds.tolist())\n", " return answers\n", "\n", "def f(x):\n", " return np.cos( 0.5 * x + 0.)\n", "\n", "def get_data(mode):\n", " dataset = SynAdaptiveEnv(mode=mode)\n", " times, xs, ys = [], [], []\n", " for i, (_, t, x) in enumerate(dataset):\n", " times.append(t)\n", " xs.append(x)\n", " dataset.set_transform(f)\n", " for i, (_, _, y) in enumerate(dataset):\n", " ys.append(y)\n", " return times, xs, ys\n", "\n", "def visualize_syn(save_path):\n", " save_dir = (save_path / '..').resolve()\n", " save_dir.mkdir(parents=True, exist_ok=True)\n", " \n", " dpi, width, height = 40, 2000, 900\n", " figsize = width / float(dpi), height / float(dpi)\n", " LabelSize, LegendFontsize, font_gap = 40, 40, 5\n", " \n", " fig = plt.figure(figsize=figsize)\n", " \n", " times, xs, ys = get_data(None)\n", " \n", " def draw_ax(cur_ax, xaxis, yaxis, xlabel, ylabel,\n", " alpha=0.1, color='k', linestyle='-', legend=None, plot_only=False):\n", " if legend is not None:\n", " cur_ax.plot(xaxis[:1], yaxis[:1], color=color, label=legend)\n", " cur_ax.plot(xaxis, yaxis, color=color, linestyle=linestyle, alpha=alpha, label=None)\n", " if not plot_only:\n", " cur_ax.set_xlabel(xlabel, fontsize=LabelSize)\n", " cur_ax.set_ylabel(ylabel, rotation=0, fontsize=LabelSize)\n", " for tick in cur_ax.xaxis.get_major_ticks():\n", " tick.label.set_fontsize(LabelSize - font_gap)\n", " tick.label.set_rotation(10)\n", " for tick in cur_ax.yaxis.get_major_ticks():\n", " tick.label.set_fontsize(LabelSize - font_gap)\n", " \n", " cur_ax = fig.add_subplot(2, 1, 1)\n", " draw_ax(cur_ax, times, xs, \"time\", \"x\", alpha=1.0, legend=None)\n", "\n", " cur_ax = fig.add_subplot(2, 1, 2)\n", " draw_ax(cur_ax, times, ys, \"time\", \"y\", alpha=0.1, legend=\"ground truth\")\n", " \n", " train_times, train_xs, train_ys = get_data(\"train\")\n", " draw_ax(cur_ax, train_times, train_ys, None, None, alpha=1.0, color='r', legend=None, plot_only=True)\n", " \n", " valid_times, valid_xs, valid_ys = get_data(\"valid\")\n", " draw_ax(cur_ax, valid_times, valid_ys, None, None, alpha=1.0, color='g', legend=None, plot_only=True)\n", " \n", " test_times, test_xs, test_ys = get_data(\"test\")\n", " draw_ax(cur_ax, test_times, test_ys, None, None, alpha=1.0, color='b', legend=None, plot_only=True)\n", " \n", " # optimize MLP models\n", " [train_preds, valid_preds, test_preds] = optimize_fn(train_xs, train_ys, [train_xs, valid_xs, test_xs])\n", " draw_ax(cur_ax, train_times, train_preds, None, None,\n", " alpha=1.0, linestyle='--', color='r', legend=\"MLP\", plot_only=True)\n", " draw_ax(cur_ax, valid_times, valid_preds, None, None,\n", " alpha=1.0, linestyle='--', color='g', legend=None, plot_only=True)\n", " draw_ax(cur_ax, test_times, test_preds, None, None,\n", " alpha=1.0, linestyle='--', color='b', legend=None, plot_only=True)\n", "\n", " plt.legend(loc=1, fontsize=LegendFontsize)\n", "\n", " fig.savefig(save_path, dpi=dpi, bbox_inches=\"tight\", format=\"pdf\")\n", " plt.close(\"all\")\n", " # plt.show()" ] }, { "cell_type": "code", "execution_count": 3, "id": "shared-envelope", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The Desktop is at: /Users/xuanyidong/Desktop\n" ] } ], "source": [ "# Visualization\n", "home_dir = Path.home()\n", "desktop_dir = home_dir / 'Desktop'\n", "print('The Desktop is at: {:}'.format(desktop_dir))\n", "visualize_syn(desktop_dir / 'tot-synthetic-v0.pdf')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" } }, "nbformat": 4, "nbformat_minor": 5 }