Update visualization codes
This commit is contained in:
		
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -133,3 +133,4 @@ outputs | |||||||
|  |  | ||||||
| pytest_cache | pytest_cache | ||||||
| *.pkl | *.pkl | ||||||
|  | *.pth | ||||||
|   | |||||||
| @@ -64,7 +64,7 @@ def extend_transformer_settings(alg2configs, name): | |||||||
|     config = copy.deepcopy(alg2configs[name]) |     config = copy.deepcopy(alg2configs[name]) | ||||||
|     for i in range(1, 9): |     for i in range(1, 9): | ||||||
|         for j in (6, 12, 24, 32, 48, 64): |         for j in (6, 12, 24, 32, 48, 64): | ||||||
|             for k1 in (0, 0.1, 0.2, 0.3): |             for k1 in (0, 0.05, 0.1, 0.2, 0.3): | ||||||
|                 for k2 in (0, 0.1): |                 for k2 in (0, 0.1): | ||||||
|                     alg2configs[ |                     alg2configs[ | ||||||
|                         name + "-{:}x{:}-drop{:}_{:}".format(i, j, k1, k2) |                         name + "-{:}x{:}-drop{:}_{:}".format(i, j, k1, k2) | ||||||
|   | |||||||
| @@ -22,6 +22,7 @@ from qlib.workflow import R | |||||||
|  |  | ||||||
| from utils.qlib_utils import QResult | from utils.qlib_utils import QResult | ||||||
|  |  | ||||||
|  |  | ||||||
| def compare_results( | def compare_results( | ||||||
|     heads, values, names, space=10, separate="& ", verbose=True, sort_key=False |     heads, values, names, space=10, separate="& ", verbose=True, sort_key=False | ||||||
| ): | ): | ||||||
| @@ -69,7 +70,10 @@ def query_info(save_dir, verbose, name_filter, key_map): | |||||||
|     for idx, (key, experiment) in enumerate(experiments.items()): |     for idx, (key, experiment) in enumerate(experiments.items()): | ||||||
|         if experiment.id == "0": |         if experiment.id == "0": | ||||||
|             continue |             continue | ||||||
|         if name_filter is not None and re.fullmatch(name_filter, experiment.name) is None: |         if ( | ||||||
|  |             name_filter is not None | ||||||
|  |             and re.fullmatch(name_filter, experiment.name) is None | ||||||
|  |         ): | ||||||
|             continue |             continue | ||||||
|         recorders = experiment.list_recorders() |         recorders = experiment.list_recorders() | ||||||
|         recorders, not_finished = filter_finished(recorders) |         recorders, not_finished = filter_finished(recorders) | ||||||
|   | |||||||
| @@ -1,3 +1,4 @@ | |||||||
|  | import os | ||||||
| import numpy as np | import numpy as np | ||||||
| from typing import List, Text | from typing import List, Text | ||||||
| from collections import defaultdict, OrderedDict | from collections import defaultdict, OrderedDict | ||||||
| @@ -10,6 +11,7 @@ class QResult: | |||||||
|         self._result = defaultdict(list) |         self._result = defaultdict(list) | ||||||
|         self._name = name |         self._name = name | ||||||
|         self._recorder_paths = [] |         self._recorder_paths = [] | ||||||
|  |         self._date2ICs = [] | ||||||
|  |  | ||||||
|     def append(self, key, value): |     def append(self, key, value): | ||||||
|         self._result[key].append(value) |         self._result[key].append(value) | ||||||
| @@ -17,6 +19,25 @@ class QResult: | |||||||
|     def append_path(self, xpath): |     def append_path(self, xpath): | ||||||
|         self._recorder_paths.append(xpath) |         self._recorder_paths.append(xpath) | ||||||
|  |  | ||||||
|  |     def append_date2ICs(self, date2IC): | ||||||
|  |         if self._date2ICs:  # not empty | ||||||
|  |             keys = sorted(list(date2IC.keys())) | ||||||
|  |             pre_keys = sorted(list(self._date2ICs[0].keys())) | ||||||
|  |             assert len(keys) == len(pre_keys) | ||||||
|  |             for i, (x, y) in enumerate(zip(keys, pre_keys)): | ||||||
|  |                 assert x == y, "[{:}] {:} vs {:}".format(i, x, y) | ||||||
|  |         self._date2ICs.append(date2IC) | ||||||
|  |  | ||||||
|  |     def find_all_dates(self): | ||||||
|  |         dates = self._date2ICs[-1].keys() | ||||||
|  |         return sorted(list(dates)) | ||||||
|  |  | ||||||
|  |     def get_IC_by_date(self, date, scale=1.0): | ||||||
|  |         values = [] | ||||||
|  |         for date2IC in self._date2ICs: | ||||||
|  |             values.append(date2IC[date] * scale) | ||||||
|  |         return float(np.mean(values)), float(np.std(values)) | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def name(self): |     def name(self): | ||||||
|         return self._name |         return self._name | ||||||
|   | |||||||
| @@ -18,10 +18,10 @@ | |||||||
|      "name": "stderr", |      "name": "stderr", | ||||||
|      "output_type": "stream", |      "output_type": "stream", | ||||||
|      "text": [ |      "text": [ | ||||||
|       "[68147:MainThread](2021-04-12 13:09:24,409) INFO - qlib.Initialization - [config.py:276] - default_conf: client.\n", |       "[70148:MainThread](2021-04-12 13:23:30,262) INFO - qlib.Initialization - [config.py:276] - default_conf: client.\n", | ||||||
|       "[68147:MainThread](2021-04-12 13:09:24,411) WARNING - qlib.Initialization - [config.py:291] - redis connection failed(host=127.0.0.1 port=6379), cache will not be used!\n", |       "[70148:MainThread](2021-04-12 13:23:30,266) WARNING - qlib.Initialization - [config.py:291] - redis connection failed(host=127.0.0.1 port=6379), cache will not be used!\n", | ||||||
|       "[68147:MainThread](2021-04-12 13:09:24,414) INFO - qlib.Initialization - [__init__.py:46] - qlib successfully initialized based on client settings.\n", |       "[70148:MainThread](2021-04-12 13:23:30,269) INFO - qlib.Initialization - [__init__.py:46] - qlib successfully initialized based on client settings.\n", | ||||||
|       "[68147:MainThread](2021-04-12 13:09:24,417) INFO - qlib.Initialization - [__init__.py:47] - data_path=/Users/xuanyidong/.qlib/qlib_data/cn_data\n" |       "[70148:MainThread](2021-04-12 13:23:30,271) INFO - qlib.Initialization - [__init__.py:47] - data_path=/Users/xuanyidong/.qlib/qlib_data/cn_data\n" | ||||||
|      ] |      ] | ||||||
|     } |     } | ||||||
|    ], |    ], | ||||||
| @@ -142,7 +142,7 @@ | |||||||
|      "name": "stderr", |      "name": "stderr", | ||||||
|      "output_type": "stream", |      "output_type": "stream", | ||||||
|      "text": [ |      "text": [ | ||||||
|       "[68147:MainThread](2021-04-12 13:09:25,066) INFO - qlib.workflow - [expm.py:290] - <mlflow.tracking.client.MlflowClient object at 0x7fd449277a30>\n" |       "[70148:MainThread](2021-04-12 13:23:31,137) INFO - qlib.workflow - [expm.py:290] - <mlflow.tracking.client.MlflowClient object at 0x7f8c4a47efa0>\n" | ||||||
|      ] |      ] | ||||||
|     }, |     }, | ||||||
|     { |     { | ||||||
| @@ -233,6 +233,7 @@ | |||||||
|     "            cmap=cm.Spectral, linewidth=0.2, antialiased=True)\n", |     "            cmap=cm.Spectral, linewidth=0.2, antialiased=True)\n", | ||||||
|     "        cur_ax.set_xticks(raw_depths)\n", |     "        cur_ax.set_xticks(raw_depths)\n", | ||||||
|     "        cur_ax.set_yticks(raw_channels)\n", |     "        cur_ax.set_yticks(raw_channels)\n", | ||||||
|  |     "        cur_ax.set_zticks(np.arange(4, 11, 2))\n", | ||||||
|     "        cur_ax.set_xlabel(\"#depth\", fontsize=LabelSize)\n", |     "        cur_ax.set_xlabel(\"#depth\", fontsize=LabelSize)\n", | ||||||
|     "        cur_ax.set_ylabel(\"#channels\", fontsize=LabelSize)\n", |     "        cur_ax.set_ylabel(\"#channels\", fontsize=LabelSize)\n", | ||||||
|     "        cur_ax.set_zlabel(\"{:} IC (%)\".format('training' if train_or_test else 'validation'), fontsize=LabelSize)\n", |     "        cur_ax.set_zlabel(\"{:} IC (%)\".format('training' if train_or_test else 'validation'), fontsize=LabelSize)\n", | ||||||
|   | |||||||
| @@ -18,10 +18,10 @@ | |||||||
|      "name": "stderr", |      "name": "stderr", | ||||||
|      "output_type": "stream", |      "output_type": "stream", | ||||||
|      "text": [ |      "text": [ | ||||||
|       "[64660:MainThread](2021-04-11 23:57:38,079) INFO - qlib.Initialization - [config.py:276] - default_conf: client.\n", |       "[70363:MainThread](2021-04-12 13:25:01,065) INFO - qlib.Initialization - [config.py:276] - default_conf: client.\n", | ||||||
|       "[64660:MainThread](2021-04-11 23:57:38,081) WARNING - qlib.Initialization - [config.py:291] - redis connection failed(host=127.0.0.1 port=6379), cache will not be used!\n", |       "[70363:MainThread](2021-04-12 13:25:01,069) WARNING - qlib.Initialization - [config.py:291] - redis connection failed(host=127.0.0.1 port=6379), cache will not be used!\n", | ||||||
|       "[64660:MainThread](2021-04-11 23:57:38,083) INFO - qlib.Initialization - [__init__.py:46] - qlib successfully initialized based on client settings.\n", |       "[70363:MainThread](2021-04-12 13:25:01,085) INFO - qlib.Initialization - [__init__.py:46] - qlib successfully initialized based on client settings.\n", | ||||||
|       "[64660:MainThread](2021-04-11 23:57:38,084) INFO - qlib.Initialization - [__init__.py:47] - data_path=/Users/xuanyidong/.qlib/qlib_data/cn_data\n" |       "[70363:MainThread](2021-04-12 13:25:01,092) INFO - qlib.Initialization - [__init__.py:47] - data_path=/Users/xuanyidong/.qlib/qlib_data/cn_data\n" | ||||||
|      ] |      ] | ||||||
|     } |     } | ||||||
|    ], |    ], | ||||||
| @@ -142,7 +142,7 @@ | |||||||
|      "name": "stderr", |      "name": "stderr", | ||||||
|      "output_type": "stream", |      "output_type": "stream", | ||||||
|      "text": [ |      "text": [ | ||||||
|       "[64660:MainThread](2021-04-11 23:57:38,469) INFO - qlib.workflow - [expm.py:290] - <mlflow.tracking.client.MlflowClient object at 0x7fba2bc7df70>\n" |       "[70363:MainThread](2021-04-12 13:25:01,647) INFO - qlib.workflow - [expm.py:290] - <mlflow.tracking.client.MlflowClient object at 0x7fa920e56820>\n" | ||||||
|      ] |      ] | ||||||
|     }, |     }, | ||||||
|     { |     { | ||||||
| @@ -182,7 +182,7 @@ | |||||||
|   }, |   }, | ||||||
|   { |   { | ||||||
|    "cell_type": "code", |    "cell_type": "code", | ||||||
|    "execution_count": 27, |    "execution_count": 8, | ||||||
|    "id": "supreme-basis", |    "id": "supreme-basis", | ||||||
|    "metadata": {}, |    "metadata": {}, | ||||||
|    "outputs": [], |    "outputs": [], | ||||||
| @@ -204,7 +204,7 @@ | |||||||
|     "    \n", |     "    \n", | ||||||
|     "    dpi, width, height = 200, 4000, 2000\n", |     "    dpi, width, height = 200, 4000, 2000\n", | ||||||
|     "    figsize = width / float(dpi), height / float(dpi)\n", |     "    figsize = width / float(dpi), height / float(dpi)\n", | ||||||
|     "    LabelSize, LegendFontsize = 22, 18\n", |     "    LabelSize, LegendFontsize = 22, 22\n", | ||||||
|     "    font_gap = 5\n", |     "    font_gap = 5\n", | ||||||
|     "    colors = ['k', 'r']\n", |     "    colors = ['k', 'r']\n", | ||||||
|     "    markers = ['*', 'o']\n", |     "    markers = ['*', 'o']\n", | ||||||
| @@ -227,6 +227,7 @@ | |||||||
|     "            cur_ax.scatter(x_values, y_values,\n", |     "            cur_ax.scatter(x_values, y_values,\n", | ||||||
|     "                           marker=markers[idx], s=3, c=colors[idx], alpha=0.9,\n", |     "                           marker=markers[idx], s=3, c=colors[idx], alpha=0.9,\n", | ||||||
|     "                           label=legend)\n", |     "                           label=legend)\n", | ||||||
|  |     "        cur_ax.set_yticks(np.arange(4, 11, 2))\n", | ||||||
|     "        cur_ax.set_xlabel(\"sorted architectures\", fontsize=LabelSize)\n", |     "        cur_ax.set_xlabel(\"sorted architectures\", fontsize=LabelSize)\n", | ||||||
|     "        cur_ax.set_ylabel(\"{:} IC (%)\".format('training' if train_or_test else 'validation'), fontsize=LabelSize)\n", |     "        cur_ax.set_ylabel(\"{:} IC (%)\".format('training' if train_or_test else 'validation'), fontsize=LabelSize)\n", | ||||||
|     "        for tick in cur_ax.xaxis.get_major_ticks():\n", |     "        for tick in cur_ax.xaxis.get_major_ticks():\n", | ||||||
| @@ -246,7 +247,7 @@ | |||||||
|   }, |   }, | ||||||
|   { |   { | ||||||
|    "cell_type": "code", |    "cell_type": "code", | ||||||
|    "execution_count": 28, |    "execution_count": 9, | ||||||
|    "id": "shared-envelope", |    "id": "shared-envelope", | ||||||
|    "metadata": {}, |    "metadata": {}, | ||||||
|    "outputs": [ |    "outputs": [ | ||||||
| @@ -254,7 +255,7 @@ | |||||||
|      "name": "stdout", |      "name": "stdout", | ||||||
|      "output_type": "stream", |      "output_type": "stream", | ||||||
|      "text": [ |      "text": [ | ||||||
|       "{'TSF-8x6', 'TSF-6x6', 'TSF-4x24', 'TSF-3x32', 'TSF-7x6', 'TSF-4x12', 'TSF-2x12', 'TSF-1x24', 'TSF-1x32', 'TSF-6x32', 'TSF-7x48', 'TSF-4x6', 'TSF-5x32', 'TSF-6x24', 'TSF-8x24', 'TSF-5x6', 'TSF-3x24', 'TSF-6x12', 'TSF-3x12', 'TSF-5x64', 'TSF-5x12', 'TSF-7x32', 'TSF-6x48', 'TSF-3x64', 'TSF-5x48', 'TSF-7x24', 'TSF-4x32', 'TSF-4x64', 'TSF-2x64', 'TSF-8x12', 'TSF-7x64', 'TSF-3x6', 'TSF-1x6', 'TSF-8x64', 'TSF-2x6', 'TSF-6x64', 'TSF-7x12', 'TSF-2x24', 'TSF-8x48', 'TSF-1x64', 'TSF-4x48', 'TSF-8x32', 'TSF-2x48', 'TSF-1x12', 'TSF-5x24', 'TSF-3x48', 'TSF-2x32', 'TSF-1x48'}\n", |       "{'TSF-3x48', 'TSF-2x64', 'TSF-2x12', 'TSF-8x48', 'TSF-6x32', 'TSF-4x48', 'TSF-8x6', 'TSF-4x6', 'TSF-2x32', 'TSF-5x12', 'TSF-5x64', 'TSF-1x64', 'TSF-2x24', 'TSF-8x24', 'TSF-4x12', 'TSF-6x12', 'TSF-1x32', 'TSF-5x32', 'TSF-3x24', 'TSF-8x12', 'TSF-5x48', 'TSF-6x64', 'TSF-7x64', 'TSF-7x48', 'TSF-1x6', 'TSF-2x48', 'TSF-7x24', 'TSF-3x32', 'TSF-1x24', 'TSF-4x64', 'TSF-3x12', 'TSF-8x64', 'TSF-4x32', 'TSF-5x6', 'TSF-7x6', 'TSF-7x12', 'TSF-3x6', 'TSF-4x24', 'TSF-6x48', 'TSF-6x6', 'TSF-1x48', 'TSF-1x12', 'TSF-7x32', 'TSF-5x24', 'TSF-2x6', 'TSF-6x24', 'TSF-3x64', 'TSF-8x32'}\n", | ||||||
|       "The Desktop is at: /Users/xuanyidong/Desktop\n", |       "The Desktop is at: /Users/xuanyidong/Desktop\n", | ||||||
|       "There are 104 qlib-results\n" |       "There are 104 qlib-results\n" | ||||||
|      ] |      ] | ||||||
|   | |||||||
							
								
								
									
										208
									
								
								notebooks/TOT/Time-Curve.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										208
									
								
								notebooks/TOT/Time-Curve.ipynb
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,208 @@ | |||||||
|  | { | ||||||
|  |  "cells": [ | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": 1, | ||||||
|  |    "id": "afraid-minutes", | ||||||
|  |    "metadata": {}, | ||||||
|  |    "outputs": [ | ||||||
|  |     { | ||||||
|  |      "name": "stdout", | ||||||
|  |      "output_type": "stream", | ||||||
|  |      "text": [ | ||||||
|  |       "The root path: /Users/xuanyidong/Desktop/AutoDL-Projects\n", | ||||||
|  |       "The library path: /Users/xuanyidong/Desktop/AutoDL-Projects/lib\n" | ||||||
|  |      ] | ||||||
|  |     } | ||||||
|  |    ], | ||||||
|  |    "source": [ | ||||||
|  |     "import os\n", | ||||||
|  |     "import re\n", | ||||||
|  |     "import sys\n", | ||||||
|  |     "import torch\n", | ||||||
|  |     "import pprint\n", | ||||||
|  |     "import numpy as np\n", | ||||||
|  |     "import pandas as pd\n", | ||||||
|  |     "from pathlib import Path\n", | ||||||
|  |     "from scipy.interpolate import make_interp_spline\n", | ||||||
|  |     "\n", | ||||||
|  |     "__file__ = os.path.dirname(os.path.realpath(\"__file__\"))\n", | ||||||
|  |     "root_dir = (Path(__file__).parent / \"..\").resolve()\n", | ||||||
|  |     "lib_dir = (root_dir / \"lib\").resolve()\n", | ||||||
|  |     "print(\"The root path: {:}\".format(root_dir))\n", | ||||||
|  |     "print(\"The library path: {:}\".format(lib_dir))\n", | ||||||
|  |     "assert lib_dir.exists(), \"{:} does not exist\".format(lib_dir)\n", | ||||||
|  |     "if str(lib_dir) not in sys.path:\n", | ||||||
|  |     "    sys.path.insert(0, str(lib_dir))\n", | ||||||
|  |     "from utils.qlib_utils import QResult" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": 2, | ||||||
|  |    "id": "continental-drain", | ||||||
|  |    "metadata": {}, | ||||||
|  |    "outputs": [ | ||||||
|  |     { | ||||||
|  |      "name": "stdout", | ||||||
|  |      "output_type": "stream", | ||||||
|  |      "text": [ | ||||||
|  |       "TSF-2x24-drop0_0s2013-01-01\n", | ||||||
|  |       "TSF-2x24-drop0_0s2012-01-01\n", | ||||||
|  |       "TSF-2x24-drop0_0s2008-01-01\n", | ||||||
|  |       "TSF-2x24-drop0_0s2009-01-01\n", | ||||||
|  |       "TSF-2x24-drop0_0s2010-01-01\n", | ||||||
|  |       "TSF-2x24-drop0_0s2011-01-01\n", | ||||||
|  |       "TSF-2x24-drop0_0s2008-07-01\n", | ||||||
|  |       "TSF-2x24-drop0_0s2009-07-01\n", | ||||||
|  |       "There are 3011 dates\n", | ||||||
|  |       "Dates: 2008-01-02 2008-01-03\n" | ||||||
|  |      ] | ||||||
|  |     } | ||||||
|  |    ], | ||||||
|  |    "source": [ | ||||||
|  |     "qresults = torch.load(os.path.join(root_dir, 'notebooks', 'TOT', 'temp-time-x.pth'))\n", | ||||||
|  |     "for qresult in qresults:\n", | ||||||
|  |     "    print(qresult.name)\n", | ||||||
|  |     "all_dates = set()\n", | ||||||
|  |     "for qresult in qresults:\n", | ||||||
|  |     "    dates = qresult.find_all_dates()\n", | ||||||
|  |     "    for date in dates:\n", | ||||||
|  |     "        all_dates.add(date)\n", | ||||||
|  |     "all_dates = sorted(list(all_dates))\n", | ||||||
|  |     "print('There are {:} dates'.format(len(all_dates)))\n", | ||||||
|  |     "print('Dates: {:} {:}'.format(all_dates[0], all_dates[1]))" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": 3, | ||||||
|  |    "id": "intimate-approval", | ||||||
|  |    "metadata": {}, | ||||||
|  |    "outputs": [], | ||||||
|  |    "source": [ | ||||||
|  |     "import matplotlib\n", | ||||||
|  |     "from matplotlib import cm\n", | ||||||
|  |     "matplotlib.use(\"agg\")\n", | ||||||
|  |     "import matplotlib.pyplot as plt\n", | ||||||
|  |     "import matplotlib.ticker as ticker" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": 6, | ||||||
|  |    "id": "supreme-basis", | ||||||
|  |    "metadata": {}, | ||||||
|  |    "outputs": [], | ||||||
|  |    "source": [ | ||||||
|  |     "def vis_time_curve(qresults, dates, use_original, save_path):\n", | ||||||
|  |     "    save_dir = (save_path / '..').resolve()\n", | ||||||
|  |     "    save_dir.mkdir(parents=True, exist_ok=True)\n", | ||||||
|  |     "    print('There are {:} qlib-results'.format(len(qresults)))\n", | ||||||
|  |     "    \n", | ||||||
|  |     "    dpi, width, height = 200, 5000, 2000\n", | ||||||
|  |     "    figsize = width / float(dpi), height / float(dpi)\n", | ||||||
|  |     "    LabelSize, LegendFontsize = 22, 12\n", | ||||||
|  |     "    font_gap = 5\n", | ||||||
|  |     "    linestyles = ['-', '--']\n", | ||||||
|  |     "    colors = ['k', 'r']\n", | ||||||
|  |     "    \n", | ||||||
|  |     "    fig = plt.figure(figsize=figsize)\n", | ||||||
|  |     "    cur_ax = fig.add_subplot(1, 1, 1)\n", | ||||||
|  |     "    for idx, qresult in enumerate(qresults):\n", | ||||||
|  |     "        print('Visualize [{:}] -- {:}'.format(idx, qresult.name))\n", | ||||||
|  |     "        x_axis, y_axis = [], []\n", | ||||||
|  |     "        for idate, date in enumerate(dates):\n", | ||||||
|  |     "            if date in qresult._date2ICs[-1]:\n", | ||||||
|  |     "                mean, std = qresult.get_IC_by_date(date, 100)\n", | ||||||
|  |     "                if not np.isnan(mean):\n", | ||||||
|  |     "                    x_axis.append(idate)\n", | ||||||
|  |     "                    y_axis.append(mean)\n", | ||||||
|  |     "        x_axis, y_axis = np.array(x_axis), np.array(y_axis)\n", | ||||||
|  |     "        if use_original:\n", | ||||||
|  |     "            cur_ax.plot(x_axis, y_axis, linewidth=1, color=colors[idx], linestyle=linestyles[idx])\n", | ||||||
|  |     "        else:\n", | ||||||
|  |     "            xnew = np.linspace(x_axis.min(), x_axis.max(), 200)\n", | ||||||
|  |     "            spl = make_interp_spline(x_axis, y_axis, k=5)\n", | ||||||
|  |     "            ynew = spl(xnew)\n", | ||||||
|  |     "            cur_ax.plot(xnew, ynew, linewidth=2, color=colors[idx], linestyle=linestyles[idx])\n", | ||||||
|  |     "        \n", | ||||||
|  |     "    for tick in cur_ax.xaxis.get_major_ticks():\n", | ||||||
|  |     "        tick.label.set_fontsize(LabelSize - font_gap)\n", | ||||||
|  |     "    for tick in cur_ax.yaxis.get_major_ticks():\n", | ||||||
|  |     "        tick.label.set_fontsize(LabelSize - font_gap)\n", | ||||||
|  |     "    cur_ax.set_ylabel(\"IC (%)\", fontsize=LabelSize)\n", | ||||||
|  |     "    fig.savefig(save_path, dpi=dpi, bbox_inches=\"tight\", format=\"pdf\")\n", | ||||||
|  |     "    plt.close(\"all\")" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": 7, | ||||||
|  |    "id": "shared-envelope", | ||||||
|  |    "metadata": {}, | ||||||
|  |    "outputs": [ | ||||||
|  |     { | ||||||
|  |      "name": "stdout", | ||||||
|  |      "output_type": "stream", | ||||||
|  |      "text": [ | ||||||
|  |       "The Desktop is at: /Users/xuanyidong/Desktop\n", | ||||||
|  |       "There are 2 qlib-results\n", | ||||||
|  |       "Visualize [0] -- TSF-2x24-drop0_0s2008-01-01\n", | ||||||
|  |       "Visualize [1] -- TSF-2x24-drop0_0s2009-07-01\n", | ||||||
|  |       "There are 2 qlib-results\n", | ||||||
|  |       "Visualize [0] -- TSF-2x24-drop0_0s2008-01-01\n", | ||||||
|  |       "Visualize [1] -- TSF-2x24-drop0_0s2009-07-01\n" | ||||||
|  |      ] | ||||||
|  |     } | ||||||
|  |    ], | ||||||
|  |    "source": [ | ||||||
|  |     "# Visualization\n", | ||||||
|  |     "home_dir = Path.home()\n", | ||||||
|  |     "desktop_dir = home_dir / 'Desktop'\n", | ||||||
|  |     "print('The Desktop is at: {:}'.format(desktop_dir))\n", | ||||||
|  |     "\n", | ||||||
|  |     "vis_time_curve(\n", | ||||||
|  |     "    (qresults[2], qresults[-1]),\n", | ||||||
|  |     "    all_dates,\n", | ||||||
|  |     "    True,\n", | ||||||
|  |     "    desktop_dir / 'es_csi300_time_curve.pdf')\n", | ||||||
|  |     "\n", | ||||||
|  |     "vis_time_curve(\n", | ||||||
|  |     "    (qresults[2], qresults[-1]),\n", | ||||||
|  |     "    all_dates,\n", | ||||||
|  |     "    False,\n", | ||||||
|  |     "    desktop_dir / 'es_csi300_time_curve-inter.pdf')" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": null, | ||||||
|  |    "id": "exempt-stable", | ||||||
|  |    "metadata": {}, | ||||||
|  |    "outputs": [], | ||||||
|  |    "source": [] | ||||||
|  |   } | ||||||
|  |  ], | ||||||
|  |  "metadata": { | ||||||
|  |   "kernelspec": { | ||||||
|  |    "display_name": "Python 3", | ||||||
|  |    "language": "python", | ||||||
|  |    "name": "python3" | ||||||
|  |   }, | ||||||
|  |   "language_info": { | ||||||
|  |    "codemirror_mode": { | ||||||
|  |     "name": "ipython", | ||||||
|  |     "version": 3 | ||||||
|  |    }, | ||||||
|  |    "file_extension": ".py", | ||||||
|  |    "mimetype": "text/x-python", | ||||||
|  |    "name": "python", | ||||||
|  |    "nbconvert_exporter": "python", | ||||||
|  |    "pygments_lexer": "ipython3", | ||||||
|  |    "version": "3.8.8" | ||||||
|  |   } | ||||||
|  |  }, | ||||||
|  |  "nbformat": 4, | ||||||
|  |  "nbformat_minor": 5 | ||||||
|  | } | ||||||
							
								
								
									
										128
									
								
								notebooks/TOT/synthetic.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										128
									
								
								notebooks/TOT/synthetic.ipynb
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,128 @@ | |||||||
|  | { | ||||||
|  |  "cells": [ | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": 1, | ||||||
|  |    "id": "filled-multiple", | ||||||
|  |    "metadata": {}, | ||||||
|  |    "outputs": [], | ||||||
|  |    "source": [ | ||||||
|  |     "#\n", | ||||||
|  |     "# %matplotlib notebook\n", | ||||||
|  |     "from pathlib import Path\n", | ||||||
|  |     "import numpy as np\n", | ||||||
|  |     "import matplotlib\n", | ||||||
|  |     "from matplotlib import cm\n", | ||||||
|  |     "matplotlib.use(\"agg\")\n", | ||||||
|  |     "import matplotlib.pyplot as plt\n", | ||||||
|  |     "import matplotlib.ticker as ticker" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": 2, | ||||||
|  |    "id": "supreme-basis", | ||||||
|  |    "metadata": {}, | ||||||
|  |    "outputs": [], | ||||||
|  |    "source": [ | ||||||
|  |     "def visualize_syn(save_path):\n", | ||||||
|  |     "    save_dir = (save_path / '..').resolve()\n", | ||||||
|  |     "    save_dir.mkdir(parents=True, exist_ok=True)\n", | ||||||
|  |     "    \n", | ||||||
|  |     "    dpi, width, height = 50, 2000, 1000\n", | ||||||
|  |     "    figsize = width / float(dpi), height / float(dpi)\n", | ||||||
|  |     "    LabelSize, font_gap = 30, 4\n", | ||||||
|  |     "    \n", | ||||||
|  |     "    fig = plt.figure(figsize=figsize)\n", | ||||||
|  |     "    \n", | ||||||
|  |     "    times = np.arange(0, np.pi * 100, 0.1)\n", | ||||||
|  |     "    num = len(times)\n", | ||||||
|  |     "    x = []\n", | ||||||
|  |     "    for i in range(num):\n", | ||||||
|  |     "        scale = (i + 1.) / num * 4\n", | ||||||
|  |     "        value = times[i] * scale\n", | ||||||
|  |     "        x.append(np.sin(value) * (1.3 - scale))\n", | ||||||
|  |     "    x = np.array(x)\n", | ||||||
|  |     "    y = np.cos( x * x - 0.3 * x )\n", | ||||||
|  |     "    \n", | ||||||
|  |     "    cur_ax = fig.add_subplot(2, 1, 1)\n", | ||||||
|  |     "    cur_ax.plot(times, x)\n", | ||||||
|  |     "    cur_ax.set_xlabel(\"time\", fontsize=LabelSize)\n", | ||||||
|  |     "    cur_ax.set_ylabel(\"x\", fontsize=LabelSize)\n", | ||||||
|  |     "    for tick in cur_ax.xaxis.get_major_ticks():\n", | ||||||
|  |     "        tick.label.set_fontsize(LabelSize - font_gap)\n", | ||||||
|  |     "        tick.label.set_rotation(30)\n", | ||||||
|  |     "    for tick in cur_ax.yaxis.get_major_ticks():\n", | ||||||
|  |     "        tick.label.set_fontsize(LabelSize - font_gap)\n", | ||||||
|  |     "        \n", | ||||||
|  |     "    \n", | ||||||
|  |     "    cur_ax = fig.add_subplot(2, 1, 2)\n", | ||||||
|  |     "    cur_ax.plot(times, y)\n", | ||||||
|  |     "    cur_ax.set_xlabel(\"time\", fontsize=LabelSize)\n", | ||||||
|  |     "    cur_ax.set_ylabel(\"f(x)\", fontsize=LabelSize)\n", | ||||||
|  |     "    for tick in cur_ax.xaxis.get_major_ticks():\n", | ||||||
|  |     "        tick.label.set_fontsize(LabelSize - font_gap)\n", | ||||||
|  |     "        tick.label.set_rotation(30)\n", | ||||||
|  |     "    for tick in cur_ax.yaxis.get_major_ticks():\n", | ||||||
|  |     "        tick.label.set_fontsize(LabelSize - font_gap)\n", | ||||||
|  |     "        \n", | ||||||
|  |     "    # fig.tight_layout()\n", | ||||||
|  |     "    # plt.subplots_adjust(wspace=0.05)#, hspace=0.4)\n", | ||||||
|  |     "    fig.savefig(save_path, dpi=dpi, bbox_inches=\"tight\", format=\"pdf\")\n", | ||||||
|  |     "    plt.close(\"all\")\n", | ||||||
|  |     "    # plt.show()" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": 3, | ||||||
|  |    "id": "shared-envelope", | ||||||
|  |    "metadata": {}, | ||||||
|  |    "outputs": [ | ||||||
|  |     { | ||||||
|  |      "name": "stdout", | ||||||
|  |      "output_type": "stream", | ||||||
|  |      "text": [ | ||||||
|  |       "The Desktop is at: /Users/xuanyidong/Desktop\n" | ||||||
|  |      ] | ||||||
|  |     } | ||||||
|  |    ], | ||||||
|  |    "source": [ | ||||||
|  |     "# Visualization\n", | ||||||
|  |     "home_dir = Path.home()\n", | ||||||
|  |     "desktop_dir = home_dir / 'Desktop'\n", | ||||||
|  |     "print('The Desktop is at: {:}'.format(desktop_dir))\n", | ||||||
|  |     "visualize_syn(desktop_dir / 'tot-synthetic-v0.pdf')" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": null, | ||||||
|  |    "id": "romantic-ordinance", | ||||||
|  |    "metadata": {}, | ||||||
|  |    "outputs": [], | ||||||
|  |    "source": [] | ||||||
|  |   } | ||||||
|  |  ], | ||||||
|  |  "metadata": { | ||||||
|  |   "kernelspec": { | ||||||
|  |    "display_name": "Python 3", | ||||||
|  |    "language": "python", | ||||||
|  |    "name": "python3" | ||||||
|  |   }, | ||||||
|  |   "language_info": { | ||||||
|  |    "codemirror_mode": { | ||||||
|  |     "name": "ipython", | ||||||
|  |     "version": 3 | ||||||
|  |    }, | ||||||
|  |    "file_extension": ".py", | ||||||
|  |    "mimetype": "text/x-python", | ||||||
|  |    "name": "python", | ||||||
|  |    "nbconvert_exporter": "python", | ||||||
|  |    "pygments_lexer": "ipython3", | ||||||
|  |    "version": "3.8.8" | ||||||
|  |   } | ||||||
|  |  }, | ||||||
|  |  "nbformat": 4, | ||||||
|  |  "nbformat_minor": 5 | ||||||
|  | } | ||||||
							
								
								
									
										123
									
								
								notebooks/TOT/time-curve.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										123
									
								
								notebooks/TOT/time-curve.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,123 @@ | |||||||
|  | import os | ||||||
|  | import re | ||||||
|  | import sys | ||||||
|  | import torch | ||||||
|  | import qlib | ||||||
|  | import pprint | ||||||
|  | from collections import OrderedDict | ||||||
|  | import numpy as np | ||||||
|  | import pandas as pd | ||||||
|  |  | ||||||
|  | from pathlib import Path | ||||||
|  |  | ||||||
|  | # __file__ = os.path.dirname(os.path.realpath("__file__")) | ||||||
|  | note_dir = Path(__file__).parent.resolve() | ||||||
|  | root_dir = (Path(__file__).parent / ".." / "..").resolve() | ||||||
|  | lib_dir = (root_dir / "lib").resolve() | ||||||
|  | print("The root path: {:}".format(root_dir)) | ||||||
|  | print("The library path: {:}".format(lib_dir)) | ||||||
|  | assert lib_dir.exists(), "{:} does not exist".format(lib_dir) | ||||||
|  | if str(lib_dir) not in sys.path: | ||||||
|  |     sys.path.insert(0, str(lib_dir)) | ||||||
|  |  | ||||||
|  | import qlib | ||||||
|  | from qlib import config as qconfig | ||||||
|  | from qlib.workflow import R | ||||||
|  | qlib.init(provider_uri='~/.qlib/qlib_data/cn_data', region=qconfig.REG_CN) | ||||||
|  |  | ||||||
|  | from utils.qlib_utils import QResult | ||||||
|  |  | ||||||
|  | def filter_finished(recorders): | ||||||
|  |     returned_recorders = dict() | ||||||
|  |     not_finished = 0 | ||||||
|  |     for key, recorder in recorders.items(): | ||||||
|  |         if recorder.status == "FINISHED": | ||||||
|  |             returned_recorders[key] = recorder | ||||||
|  |         else: | ||||||
|  |             not_finished += 1 | ||||||
|  |     return returned_recorders, not_finished | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def add_to_dict(xdict, timestamp, value): | ||||||
|  |     date = timestamp.date().strftime("%Y-%m-%d") | ||||||
|  |     if date in xdict: | ||||||
|  |       raise ValueError("This date [{:}] is already in the dict".format(date)) | ||||||
|  |     xdict[date] = value | ||||||
|  |  | ||||||
|  | def query_info(save_dir, verbose, name_filter, key_map): | ||||||
|  |     if isinstance(save_dir, list): | ||||||
|  |         results = [] | ||||||
|  |         for x in save_dir: | ||||||
|  |             x = query_info(x, verbose, name_filter, key_map) | ||||||
|  |             results.extend(x) | ||||||
|  |         return results | ||||||
|  |     # Here, the save_dir must be a string | ||||||
|  |     R.set_uri(str(save_dir)) | ||||||
|  |     experiments = R.list_experiments() | ||||||
|  |  | ||||||
|  |     if verbose: | ||||||
|  |         print("There are {:} experiments.".format(len(experiments))) | ||||||
|  |     qresults = [] | ||||||
|  |     for idx, (key, experiment) in enumerate(experiments.items()): | ||||||
|  |         if experiment.id == "0": | ||||||
|  |             continue | ||||||
|  |         if name_filter is not None and re.fullmatch(name_filter, experiment.name) is None: | ||||||
|  |             continue | ||||||
|  |         recorders = experiment.list_recorders() | ||||||
|  |         recorders, not_finished = filter_finished(recorders) | ||||||
|  |         if verbose: | ||||||
|  |             print( | ||||||
|  |                 "====>>>> {:02d}/{:02d}-th experiment {:9s} has {:02d}/{:02d} finished recorders.".format( | ||||||
|  |                     idx + 1, | ||||||
|  |                     len(experiments), | ||||||
|  |                     experiment.name, | ||||||
|  |                     len(recorders), | ||||||
|  |                     len(recorders) + not_finished, | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  |         result = QResult(experiment.name) | ||||||
|  |         for recorder_id, recorder in recorders.items(): | ||||||
|  |             file_names = ['results-train.pkl', 'results-valid.pkl', 'results-test.pkl'] | ||||||
|  |             date2IC = OrderedDict() | ||||||
|  |             for file_name in file_names: | ||||||
|  |                 xtemp = recorder.load_object(file_name)['all-IC'] | ||||||
|  |                 timestamps, values = xtemp.index.tolist(), xtemp.tolist() | ||||||
|  |                 for timestamp, value in zip(timestamps, values): | ||||||
|  |                     add_to_dict(date2IC, timestamp, value) | ||||||
|  |             result.update(recorder.list_metrics(), key_map) | ||||||
|  |             result.append_path( | ||||||
|  |                 os.path.join(recorder.uri, recorder.experiment_id, recorder.id) | ||||||
|  |             ) | ||||||
|  |             result.append_date2ICs(date2IC) | ||||||
|  |         if not len(result): | ||||||
|  |             print("There are no valid recorders for {:}".format(experiment)) | ||||||
|  |             continue | ||||||
|  |         else: | ||||||
|  |             if verbose: | ||||||
|  |                 print( | ||||||
|  |                     "There are {:} valid recorders for {:}".format( | ||||||
|  |                         len(recorders), experiment.name | ||||||
|  |                     ) | ||||||
|  |                 ) | ||||||
|  |         qresults.append(result) | ||||||
|  |     return qresults | ||||||
|  |  | ||||||
|  |  | ||||||
|  | ## | ||||||
|  | paths = [root_dir / 'outputs' / 'qlib-baselines-csi300'] | ||||||
|  | paths = [path.resolve() for path in paths] | ||||||
|  | print(paths) | ||||||
|  |  | ||||||
|  | key_map = dict() | ||||||
|  | for xset in ("train", "valid", "test"): | ||||||
|  |     key_map["{:}-mean-IC".format(xset)] = "IC ({:})".format(xset) | ||||||
|  |     key_map["{:}-mean-ICIR".format(xset)] = "ICIR ({:})".format(xset) | ||||||
|  | qresults = query_info(paths, False, 'TSF-2x24-drop0_0s.*-.*-01', key_map) | ||||||
|  | print('Find {:} results'.format(len(qresults))) | ||||||
|  | times = [] | ||||||
|  | for qresult in qresults: | ||||||
|  |     times.append(qresult.name.split('0_0s')[-1]) | ||||||
|  | print(times) | ||||||
|  | save_path = os.path.join(note_dir, 'temp-time-x.pth') | ||||||
|  | torch.save(qresults, save_path) | ||||||
|  | print(save_path) | ||||||
		Reference in New Issue
	
	Block a user