add autodl
This commit is contained in:
		
							
								
								
									
										118
									
								
								AutoDL-Projects/notebooks/NATS-Bench/BayesOpt.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										118
									
								
								AutoDL-Projects/notebooks/NATS-Bench/BayesOpt.ipynb
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,118 @@ | ||||
| { | ||||
|  "cells": [ | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 8, | ||||
|    "id": "german-madonna", | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "# Implementation for \"A Tutorial on Bayesian Optimization\"\n", | ||||
|     "import numpy as np\n", | ||||
|     "\n", | ||||
|     "def get_data():\n", | ||||
|     "    return np.random.random(2) * 10\n", | ||||
|     "\n", | ||||
|     "def f(x):\n", | ||||
|     "    return float(np.power((x[0] * 3 - x[1]), 3) - np.exp(x[1]) + np.power(x[0], 2))" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 12, | ||||
|    "id": "broke-citizenship", | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "# Kernels typically have the property that points closer in the input space are more strongly correlated\n", | ||||
|     "# i.e., if |x1 - x2| < |x1 - x3|, then sigma(x1, x2) > sigma(x1, x3).\n", | ||||
|     "# the commonly used and simple kernel is the power exponential or Gaussian kernel:\n", | ||||
|     "def sigma0(x1, x2, alpha0=1, alpha=[1,1]):\n", | ||||
|     "    \"\"\"alpha could be a vector\"\"\"\n", | ||||
|     "    power = np.array(alpha, dtype=np.float32) * np.power(np.array(x1)-np.array(x2), 2)\n", | ||||
|     "    return alpha0 * np.exp( -np.sum(power) )\n", | ||||
|     "\n", | ||||
|     "# the most common choice for the mean function is a constant value\n", | ||||
|     "def mu0(x, mu):\n", | ||||
|     "    return mu" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 13, | ||||
|    "id": "aerial-carnival", | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "K = 5\n", | ||||
|     "X = np.array([get_data() for i in range(K)])\n", | ||||
|     "mu = np.mean(X, axis=0)\n", | ||||
|     "mu0_over_K = [mu0(x, mu) for x in X]" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 14, | ||||
|    "id": "polished-discussion", | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "sigma0_over_KK = []\n", | ||||
|     "for i in range(K):\n", | ||||
|     "    sigma0_over_KK.append(np.array([sigma0(X[i], X[j]) for j in range(K)]))\n", | ||||
|     "sigma0_over_KK = np.array(sigma0_over_KK)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 16, | ||||
|    "id": "comic-jesus", | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "(20, 20)\n", | ||||
|       "1.1038803861344952e-06\n", | ||||
|       "1.1038803861344952e-06\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "print(sigma0_over_KK.shape)\n", | ||||
|     "print(sigma0_over_KK[1][2])\n", | ||||
|     "print(sigma0_over_KK[2][1])" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "id": "statistical-wrist", | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [] | ||||
|   } | ||||
|  ], | ||||
|  "metadata": { | ||||
|   "kernelspec": { | ||||
|    "display_name": "Python 3", | ||||
|    "language": "python", | ||||
|    "name": "python3" | ||||
|   }, | ||||
|   "language_info": { | ||||
|    "codemirror_mode": { | ||||
|     "name": "ipython", | ||||
|     "version": 3 | ||||
|    }, | ||||
|    "file_extension": ".py", | ||||
|    "mimetype": "text/x-python", | ||||
|    "name": "python", | ||||
|    "nbconvert_exporter": "python", | ||||
|    "pygments_lexer": "ipython3", | ||||
|    "version": "3.8.8" | ||||
|   } | ||||
|  }, | ||||
|  "nbformat": 4, | ||||
|  "nbformat_minor": 5 | ||||
| } | ||||
							
								
								
									
										88
									
								
								AutoDL-Projects/notebooks/NATS-Bench/find-largest.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										88
									
								
								AutoDL-Projects/notebooks/NATS-Bench/find-largest.ipynb
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,88 @@ | ||||
| { | ||||
|  "cells": [ | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 1, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "[2021-03-27 06:46:38] Try to use the default NATS-Bench (topology) path from fast_mode=True and path=None.\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "from nats_bench import create\n", | ||||
|     "from pprint import pprint\n", | ||||
|     "# Create the API for tologoy search space\n", | ||||
|     "api = create(None, 'tss', fast_mode=True, verbose=False)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 2, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "{'test-accuracy': 22.39999992879232,\n", | ||||
|       " 'test-all-time': 7.7054752962929856,\n", | ||||
|       " 'test-loss': 3.1626377182006835,\n", | ||||
|       " 'test-per-time': 0.6421229413577488,\n", | ||||
|       " 'train-accuracy': 21.68885959195242,\n", | ||||
|       " 'train-all-time': 1260.0195466594694,\n", | ||||
|       " 'train-loss': 3.1863493608815463,\n", | ||||
|       " 'train-per-time': 105.00162888828912,\n", | ||||
|       " 'valid-accuracy': 23.266666631062826,\n", | ||||
|       " 'valid-all-time': 7.7054752962929856,\n", | ||||
|       " 'valid-loss': 3.1219845104217527,\n", | ||||
|       " 'valid-per-time': 0.6421229413577488,\n", | ||||
|       " 'valtest-accuracy': 22.833333323160808,\n", | ||||
|       " 'valtest-all-time': 15.410950592585971,\n", | ||||
|       " 'valtest-loss': 3.142311067581177,\n", | ||||
|       " 'valtest-per-time': 1.2842458827154977}\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "largest_candidate_tss = '|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|nor_conv_3x3~0|nor_conv_3x3~1|nor_conv_3x3~2|'\n", | ||||
|     "\n", | ||||
|     "arch_index = api.query_index_by_arch(largest_candidate_tss)\n", | ||||
|     "info = api.get_more_info(arch_index, 'ImageNet16-120', hp='12', is_random=False)\n", | ||||
|     "pprint(info)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [] | ||||
|   } | ||||
|  ], | ||||
|  "metadata": { | ||||
|   "kernelspec": { | ||||
|    "display_name": "Python 3", | ||||
|    "language": "python", | ||||
|    "name": "python3" | ||||
|   }, | ||||
|   "language_info": { | ||||
|    "codemirror_mode": { | ||||
|     "name": "ipython", | ||||
|     "version": 3 | ||||
|    }, | ||||
|    "file_extension": ".py", | ||||
|    "mimetype": "text/x-python", | ||||
|    "name": "python", | ||||
|    "nbconvert_exporter": "python", | ||||
|    "pygments_lexer": "ipython3", | ||||
|    "version": "3.8.8" | ||||
|   } | ||||
|  }, | ||||
|  "nbformat": 4, | ||||
|  "nbformat_minor": 4 | ||||
| } | ||||
							
								
								
									
										91
									
								
								AutoDL-Projects/notebooks/NATS-Bench/issue-96.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										91
									
								
								AutoDL-Projects/notebooks/NATS-Bench/issue-96.ipynb
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,91 @@ | ||||
| { | ||||
|  "cells": [ | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 1, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "[2021-03-01 12:28:12] Try to use the default NATS-Bench (topology) path from fast_mode=True and path=None.\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "from nats_bench import create\n", | ||||
|     "import numpy as np\n", | ||||
|     "\n", | ||||
|     "def get_correlation(A, B):\n", | ||||
|     "    return float(np.corrcoef(A, B)[0,1])\n", | ||||
|     "\n", | ||||
|     "# Create the API for tologoy search space\n", | ||||
|     "api = create(None, 'tss', fast_mode=True, verbose=False)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 2, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "There are 15625 architectures on the topology search space\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "print('There are {:} architectures on the topology search space'.format(len(api)))\n", | ||||
|     "accuracies_12, accuracies_200 = [], []\n", | ||||
|     "for i, arch in enumerate(api):\n", | ||||
|     "    info_a = api.get_more_info(i, dataset='cifar10-valid', hp='12', is_random=False)\n", | ||||
|     "    accuracies_12.append(info_a['valid-accuracy'])\n", | ||||
|     "\n", | ||||
|     "    info_b = api.get_more_info(i, dataset='cifar10-valid', hp='200', is_random=False)\n", | ||||
|     "    accuracies_200.append(info_b['test-accuracy'])" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 4, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "[CIFAR-10] The correlation between 12-epoch validation accuracy and 200-epoch test accuracy is: 91.18%\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "correlation = get_correlation(accuracies_12, accuracies_200)\n", | ||||
|     "print('[CIFAR-10] The correlation between 12-epoch validation accuracy and 200-epoch test accuracy is: {:.2f}%'.format(correlation * 100))" | ||||
|    ] | ||||
|   } | ||||
|  ], | ||||
|  "metadata": { | ||||
|   "kernelspec": { | ||||
|    "display_name": "Python 3", | ||||
|    "language": "python", | ||||
|    "name": "python3" | ||||
|   }, | ||||
|   "language_info": { | ||||
|    "codemirror_mode": { | ||||
|     "name": "ipython", | ||||
|     "version": 3 | ||||
|    }, | ||||
|    "file_extension": ".py", | ||||
|    "mimetype": "text/x-python", | ||||
|    "name": "python", | ||||
|    "nbconvert_exporter": "python", | ||||
|    "pygments_lexer": "ipython3", | ||||
|    "version": "3.8.3" | ||||
|   } | ||||
|  }, | ||||
|  "nbformat": 4, | ||||
|  "nbformat_minor": 4 | ||||
| } | ||||
							
								
								
									
										86
									
								
								AutoDL-Projects/notebooks/NATS-Bench/issue-97.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										86
									
								
								AutoDL-Projects/notebooks/NATS-Bench/issue-97.ipynb
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,86 @@ | ||||
| { | ||||
|  "cells": [ | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 1, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "[2021-03-09 08:44:19] Try to use the default NATS-Bench (size) path from fast_mode=True and path=None.\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "from nats_bench import create\n", | ||||
|     "import numpy as np\n", | ||||
|     "\n", | ||||
|     "# Create the API for size search space\n", | ||||
|     "api = create(None, 'sss', fast_mode=True, verbose=False)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 2, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "There are 32768 architectures on the size search space\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "print('There are {:} architectures on the size search space'.format(len(api)))\n", | ||||
|     "\n", | ||||
|     "c2acc = dict()\n", | ||||
|     "for index in range(len(api)):\n", | ||||
|     "    info = api.get_more_info(index, 'cifar10', hp='90')\n", | ||||
|     "    config = api.get_net_config(index, 'cifar10')\n", | ||||
|     "    c2acc[config['channels']] = info['test-accuracy']" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 4, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "91.08546417236329\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "print(np.mean(list(c2acc.values())))" | ||||
|    ] | ||||
|   } | ||||
|  ], | ||||
|  "metadata": { | ||||
|   "kernelspec": { | ||||
|    "display_name": "Python 3", | ||||
|    "language": "python", | ||||
|    "name": "python3" | ||||
|   }, | ||||
|   "language_info": { | ||||
|    "codemirror_mode": { | ||||
|     "name": "ipython", | ||||
|     "version": 3 | ||||
|    }, | ||||
|    "file_extension": ".py", | ||||
|    "mimetype": "text/x-python", | ||||
|    "name": "python", | ||||
|    "nbconvert_exporter": "python", | ||||
|    "pygments_lexer": "ipython3", | ||||
|    "version": "3.8.3" | ||||
|   } | ||||
|  }, | ||||
|  "nbformat": 4, | ||||
|  "nbformat_minor": 4 | ||||
| } | ||||
							
								
								
									
										274
									
								
								AutoDL-Projects/notebooks/Q/qlib-data-play.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										274
									
								
								AutoDL-Projects/notebooks/Q/qlib-data-play.ipynb
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,274 @@ | ||||
| { | ||||
|  "cells": [ | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 1, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stderr", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "[82189:MainThread](2021-03-02 21:02:54,241) INFO - qlib.Initialization - [config.py:276] - default_conf: client.\n", | ||||
|       "[82189:MainThread](2021-03-02 21:02:54,255) WARNING - qlib.Initialization - [config.py:291] - redis connection failed(host=127.0.0.1 port=6379), cache will not be used!\n", | ||||
|       "[82189:MainThread](2021-03-02 21:02:54,828) INFO - qlib.Initialization - [__init__.py:46] - qlib successfully initialized based on client settings.\n", | ||||
|       "[82189:MainThread](2021-03-02 21:02:54,829) INFO - qlib.Initialization - [__init__.py:47] - data_path=/Users/xuanyidong/.qlib/qlib_data/cn_data\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "import os\n", | ||||
|     "import sys\n", | ||||
|     "import qlib\n", | ||||
|     "import pprint\n", | ||||
|     "import numpy as np\n", | ||||
|     "import pandas as pd\n", | ||||
|     "qlib.init(provider_uri='~/.qlib/qlib_data/cn_data')\n", | ||||
|     "\n", | ||||
|     "from qlib.config import C\n", | ||||
|     "from qlib.data import D\n", | ||||
|     "from qlib.data.data import DatasetD, ExpressionD, Inst, Cal, FeatureD\n", | ||||
|     "from qlib.data.cache import H\n", | ||||
|     "from qlib.data.filter import NameDFilter\n", | ||||
|     "from qlib.utils import code_to_fname, read_bin" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 3, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "<class 'pandas.core.frame.DataFrame'>\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "nameDFilter = NameDFilter(name_rule_re='SH[0-9]{4}55')\n", | ||||
|     "instruments_config = D.instruments(market='csi300', filter_pipe=[nameDFilter])\n", | ||||
|     "instruments = D.list_instruments(instruments=instruments_config,\n", | ||||
|     "                                 start_time='2015-01-01',\n", | ||||
|     "                                 end_time='2016-02-15',\n", | ||||
|     "                                 as_list=True)\n", | ||||
|     "\n", | ||||
|     "fields = ['$close', '$volume', 'Ref($close, 1)', 'Mean($close, 3)', '$high-$low']\n", | ||||
|     "features = D.features(instruments_config, fields, start_time='2010-01-01', end_time='2017-12-31', freq='day')\n", | ||||
|     "print(type(features))" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 9, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "                         $close     $volume  Ref($close, 1)  Mean($close, 3)  \\\n", | ||||
|       "instrument datetime                                                            \n", | ||||
|       "SH600655   2010-01-04  8.934296  47799352.0        8.667867         8.691138   \n", | ||||
|       "           2010-01-05  8.889880  29791234.0        8.934296         8.830681   \n", | ||||
|       "           2010-01-06  8.845468  29002874.0        8.889880         8.889881   \n", | ||||
|       "           2010-01-07  8.553690  38189440.0        8.845468         8.763013   \n", | ||||
|       "           2010-01-08  8.645658  23417642.0        8.553690         8.681605   \n", | ||||
|       "...                         ...         ...             ...              ...   \n", | ||||
|       "SH601555   2017-12-25  1.393481  80615584.0        1.406559         1.408012   \n", | ||||
|       "           2017-12-26  1.406559  64259856.0        1.393481         1.402200   \n", | ||||
|       "           2017-12-27  1.400747  58551256.0        1.406559         1.400262   \n", | ||||
|       "           2017-12-28  1.412371  96204872.0        1.400747         1.406559   \n", | ||||
|       "           2017-12-29  1.412371  52801024.0        1.412371         1.408496   \n", | ||||
|       "\n", | ||||
|       "                       $high-$low  \n", | ||||
|       "instrument datetime                \n", | ||||
|       "SH600655   2010-01-04    0.412291  \n", | ||||
|       "           2010-01-05    0.203006  \n", | ||||
|       "           2010-01-06    0.250560  \n", | ||||
|       "           2010-01-07    0.412291  \n", | ||||
|       "           2010-01-08    0.275964  \n", | ||||
|       "...                           ...  \n", | ||||
|       "SH601555   2017-12-25    0.020343  \n", | ||||
|       "           2017-12-26    0.018890  \n", | ||||
|       "           2017-12-27    0.017437  \n", | ||||
|       "           2017-12-28    0.045045  \n", | ||||
|       "           2017-12-29    0.013078  \n", | ||||
|       "\n", | ||||
|       "[2867 rows x 5 columns]\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "print(features)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 2, | ||||
|    "metadata": { | ||||
|     "scrolled": true | ||||
|    }, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "<class 'qlib.data.data.LocalProvider'>\n", | ||||
|       "<class 'qlib.config.QlibConfig'>\n", | ||||
|       "LocalProvider\n", | ||||
|       "Wrapper(provider=<qlib.data.data.LocalProvider object at 0x7ff5601cb370>)\n", | ||||
|       "Wrapper(provider=<qlib.data.data.LocalDatasetProvider object at 0x7ff5601c3b80>)\n", | ||||
|       "<qlib.data.data.LocalDatasetProvider object at 0x7ff5601c3b80>\n", | ||||
|       "LocalDatasetProvider\n", | ||||
|       "--\n", | ||||
|       "Wrapper(provider=<qlib.data.data.LocalInstrumentProvider object at 0x7ff55fb73340>)\n", | ||||
|       "<qlib.data.data.LocalInstrumentProvider object at 0x7ff55fb73340>\n", | ||||
|       "default_disk_cache: 1\n", | ||||
|       "ExpressionD: Wrapper(provider=<qlib.data.data.LocalExpressionProvider object at 0x7ff5601c3bb0>)\n", | ||||
|       "FeatureD   : <qlib.data.data.LocalFeatureProvider object at 0x7ff55fb84430>\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "# Provider:\n", | ||||
|     "print(type(D._provider))\n", | ||||
|     "print(type(C))\n", | ||||
|     "print(C.provider)\n", | ||||
|     "print(D)\n", | ||||
|     "\n", | ||||
|     "# DatasetD Provider\n", | ||||
|     "print(DatasetD)\n", | ||||
|     "print(DatasetD._provider)\n", | ||||
|     "print(C.dataset_provider)\n", | ||||
|     "\n", | ||||
|     "print('--')\n", | ||||
|     "print(Inst)\n", | ||||
|     "print(Inst._provider)\n", | ||||
|     "\n", | ||||
|     "# Default Disk Cache\n", | ||||
|     "print('default_disk_cache: {:}'.format(C.default_disk_cache))\n", | ||||
|     "print('ExpressionD: {:}'.format(ExpressionD))\n", | ||||
|     "print('FeatureD   : {:}'.format(FeatureD._provider))" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 4, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "ename": "NameError", | ||||
|      "evalue": "name 'pprint' is not defined", | ||||
|      "output_type": "error", | ||||
|      "traceback": [ | ||||
|       "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | ||||
|       "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)", | ||||
|       "\u001b[0;32m<ipython-input-4-76544a8bb578>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mpprint\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minstruments_config\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0minstruments_d\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mDatasetD\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_provider\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_instruments_d\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minstruments_config\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfreq\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'day'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mpprint\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0mpprint\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minstruments_d\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | ||||
|       "\u001b[0;31mNameError\u001b[0m: name 'pprint' is not defined" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "pprint.pprint(instruments_config)\n", | ||||
|     "instruments_d = DatasetD._provider.get_instruments_d(instruments_config, freq='day')\n", | ||||
|     "pprint.pprint(instruments_d)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 19, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "2012-12-31 00:00:00 -> 2019-01-18 00:00:00\n", | ||||
|       "<PandasArray>\n", | ||||
|       "[1.1059314, 1.0935822, 1.1059314, 1.0922102, 1.0839773, 1.0839773, 1.0181155,\n", | ||||
|       " 1.0730004, 1.0867218,  1.068884,\n", | ||||
|       " ...\n", | ||||
|       " 1.1163876, 1.1208236, 1.1119517, 1.0986437, 1.1075157, 1.0971651, 1.1149089,\n", | ||||
|       "  1.083857,  1.083857, 1.0956864]\n", | ||||
|       "Length: 1439, dtype: float32\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "instrument, field, freq = 'SH601555', '$close', 'day'\n", | ||||
|     "all_dates = D.calendar(start_time='2011-12-31', end_time='2019-02-10', freq=freq)\n", | ||||
|     "start_time, end_time = all_dates[0], all_dates[-11]\n", | ||||
|     "print(str(start_time) + ' -> ' + str(end_time))\n", | ||||
|     "obj = ExpressionD.expression(instrument, field, start_time, end_time, freq)\n", | ||||
|     "print(obj.array)\n", | ||||
|     "\n", | ||||
|     "# expression = ExpressionD.get_expression_instance(field)\n", | ||||
|     "# start_time = pd.Timestamp(start_time)\n", | ||||
|     "# end_time = pd.Timestamp(end_time)\n", | ||||
|     "# _, _, start_index, end_index = Cal.locate_index(start_time, end_time, freq='day', future=False)\n", | ||||
|     "# print(start_index)\n", | ||||
|     "# print(end_index)\n", | ||||
|     "\n", | ||||
|     "# fname = code_to_fname(instrument)\n", | ||||
|     "# uri_data = FeatureD._uri_data.format(instrument.lower(), field[1:], freq)\n", | ||||
|     "# print(uri_data)\n", | ||||
|     "# # series = read_bin(uri_data, start_index, end_index)\n", | ||||
|     "# series = read_bin(uri_data, 2850, 2870)\n", | ||||
|     "# print(series)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 3, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "Wrapper(provider=<qlib.data.data.LocalProvider object at 0x7ff5601cb370>)\n", | ||||
|       "Wrapper(provider=<qlib.data.data.LocalInstrumentProvider object at 0x7ff55fb73340>)\n", | ||||
|       "Wrapper(provider=<qlib.data.data.LocalExpressionProvider object at 0x7ff5601c3bb0>)\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "from qlib.data import D\n", | ||||
|     "from qlib.data.data import ExpressionD, Inst\n", | ||||
|     "print(D)\n", | ||||
|     "print(Inst)\n", | ||||
|     "print(ExpressionD)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [] | ||||
|   } | ||||
|  ], | ||||
|  "metadata": { | ||||
|   "kernelspec": { | ||||
|    "display_name": "Python 3", | ||||
|    "language": "python", | ||||
|    "name": "python3" | ||||
|   }, | ||||
|   "language_info": { | ||||
|    "codemirror_mode": { | ||||
|     "name": "ipython", | ||||
|     "version": 3 | ||||
|    }, | ||||
|    "file_extension": ".py", | ||||
|    "mimetype": "text/x-python", | ||||
|    "name": "python", | ||||
|    "nbconvert_exporter": "python", | ||||
|    "pygments_lexer": "ipython3", | ||||
|    "version": "3.8.3" | ||||
|   } | ||||
|  }, | ||||
|  "nbformat": 4, | ||||
|  "nbformat_minor": 4 | ||||
| } | ||||
							
								
								
									
										162
									
								
								AutoDL-Projects/notebooks/Q/workflow-test.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										162
									
								
								AutoDL-Projects/notebooks/Q/workflow-test.ipynb
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,162 @@ | ||||
| { | ||||
|  "cells": [ | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "library path: /Users/xuanyidong/Desktop/XAutoDL/lib\n" | ||||
|      ] | ||||
|     }, | ||||
|     { | ||||
|      "name": "stderr", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "[61704:MainThread](2021-03-22 13:56:38,104) INFO - qlib.Initialization - [config.py:276] - default_conf: client.\n", | ||||
|       "[61704:MainThread](2021-03-22 13:56:38,106) WARNING - qlib.Initialization - [config.py:291] - redis connection failed(host=127.0.0.1 port=6379), cache will not be used!\n", | ||||
|       "[61704:MainThread](2021-03-22 13:56:38,680) INFO - qlib.Initialization - [__init__.py:46] - qlib successfully initialized based on client settings.\n", | ||||
|       "[61704:MainThread](2021-03-22 13:56:38,681) INFO - qlib.Initialization - [__init__.py:47] - data_path=/Users/xuanyidong/.qlib/qlib_data/cn_data\n" | ||||
|      ] | ||||
|     }, | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "{'class': 'DatasetH',\n", | ||||
|       " 'kwargs': {'handler': {'class': 'Alpha158',\n", | ||||
|       "                        'kwargs': {'end_time': '2020-08-01',\n", | ||||
|       "                                   'fit_end_time': '2014-12-31',\n", | ||||
|       "                                   'fit_start_time': '2008-01-01',\n", | ||||
|       "                                   'instruments': 'csi100',\n", | ||||
|       "                                   'start_time': '2008-01-01'},\n", | ||||
|       "                        'module_path': 'qlib.contrib.data.handler'},\n", | ||||
|       "            'segments': {'test': ('2017-01-01', '2020-08-01'),\n", | ||||
|       "                         'train': ('2008-01-01', '2014-12-31'),\n", | ||||
|       "                         'valid': ('2015-01-01', '2016-12-31')}},\n", | ||||
|       " 'module_path': 'qlib.data.dataset'}\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "import os\n", | ||||
|     "import sys\n", | ||||
|     "import qlib\n", | ||||
|     "import pprint\n", | ||||
|     "import numpy as np\n", | ||||
|     "import pandas as pd\n", | ||||
|     "\n", | ||||
|     "from pathlib import Path\n", | ||||
|     "\n", | ||||
|     "__file__ = os.path.dirname(os.path.realpath(\"__file__\"))\n", | ||||
|     "\n", | ||||
|     "lib_dir = (Path(__file__).parent / \"..\" / \"lib\").resolve()\n", | ||||
|     "print(\"library path: {:}\".format(lib_dir))\n", | ||||
|     "assert lib_dir.exists(), \"{:} does not exist\".format(lib_dir)\n", | ||||
|     "if str(lib_dir) not in sys.path:\n", | ||||
|     "    sys.path.insert(0, str(lib_dir))\n", | ||||
|     "\n", | ||||
|     "from qlib import config as qconfig\n", | ||||
|     "from qlib.utils import init_instance_by_config\n", | ||||
|     "\n", | ||||
|     "qlib.init(provider_uri='~/.qlib/qlib_data/cn_data', region=qconfig.REG_CN)\n", | ||||
|     "\n", | ||||
|     "dataset_config = {\n", | ||||
|     "            \"class\": \"DatasetH\",\n", | ||||
|     "            \"module_path\": \"qlib.data.dataset\",\n", | ||||
|     "            \"kwargs\": {\n", | ||||
|     "                \"handler\": {\n", | ||||
|     "                    \"class\": \"Alpha158\",\n", | ||||
|     "                    \"module_path\": \"qlib.contrib.data.handler\",\n", | ||||
|     "                    \"kwargs\": {\n", | ||||
|     "                        \"start_time\": \"2008-01-01\",\n", | ||||
|     "                        \"end_time\": \"2020-08-01\",\n", | ||||
|     "                        \"fit_start_time\": \"2008-01-01\",\n", | ||||
|     "                        \"fit_end_time\": \"2014-12-31\",\n", | ||||
|     "                        \"instruments\": \"csi100\",\n", | ||||
|     "                    },\n", | ||||
|     "                },\n", | ||||
|     "                \"segments\": {\n", | ||||
|     "                    \"train\": (\"2008-01-01\", \"2014-12-31\"),\n", | ||||
|     "                    \"valid\": (\"2015-01-01\", \"2016-12-31\"),\n", | ||||
|     "                    \"test\": (\"2017-01-01\", \"2020-08-01\"),\n", | ||||
|     "                },\n", | ||||
|     "            },\n", | ||||
|     "        }\n", | ||||
|     "pprint.pprint(dataset_config)\n", | ||||
|     "dataset = init_instance_by_config(dataset_config)\n", | ||||
|     "\n", | ||||
|     "df_train, df_valid, df_test = dataset.prepare(\n", | ||||
|     "            [\"train\", \"valid\", \"test\"],\n", | ||||
|     "            col_set=[\"feature\", \"label\"],\n", | ||||
|     "            data_key=DataHandlerLP.DK_L,\n", | ||||
|     "        )" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 2, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "{'class': 'DatasetH',\n", | ||||
|       " 'kwargs': {'handler': {'class': 'Alpha158',\n", | ||||
|       "                        'kwargs': {'end_time': '2020-08-01',\n", | ||||
|       "                                   'fit_end_time': '2014-12-31',\n", | ||||
|       "                                   'fit_start_time': '2008-01-01',\n", | ||||
|       "                                   'instruments': 'csi300',\n", | ||||
|       "                                   'start_time': '2008-01-01'},\n", | ||||
|       "                        'module_path': 'qlib.contrib.data.handler'},\n", | ||||
|       "            'segments': {'test': ('2017-01-01', '2020-08-01'),\n", | ||||
|       "                         'train': ('2008-01-01', '2014-12-31'),\n", | ||||
|       "                         'valid': ('2015-01-01', '2016-12-31')}},\n", | ||||
|       " 'module_path': 'qlib.data.dataset'}\n" | ||||
|      ] | ||||
|     }, | ||||
|     { | ||||
|      "name": "stderr", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "[95290:MainThread](2021-03-03 12:18:43,481) INFO - qlib.timer - [log.py:81] - Time cost: 237.911s | Loading data Done\n", | ||||
|       "[95290:MainThread](2021-03-03 12:18:45,080) INFO - qlib.timer - [log.py:81] - Time cost: 0.465s | DropnaLabel Done\n", | ||||
|       "[95290:MainThread](2021-03-03 12:18:51,572) INFO - qlib.timer - [log.py:81] - Time cost: 6.491s | CSZScoreNorm Done\n", | ||||
|       "[95290:MainThread](2021-03-03 12:18:51,573) INFO - qlib.timer - [log.py:81] - Time cost: 8.090s | fit & process data Done\n", | ||||
|       "[95290:MainThread](2021-03-03 12:18:51,573) INFO - qlib.timer - [log.py:81] - Time cost: 246.003s | Init data Done\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "from trade_models.transformations import get_transformer\n", | ||||
|     "\n", | ||||
|     "model = get_transformer(None)" | ||||
|    ] | ||||
|   } | ||||
|  ], | ||||
|  "metadata": { | ||||
|   "kernelspec": { | ||||
|    "display_name": "Python 3", | ||||
|    "language": "python", | ||||
|    "name": "python3" | ||||
|   }, | ||||
|   "language_info": { | ||||
|    "codemirror_mode": { | ||||
|     "name": "ipython", | ||||
|     "version": 3 | ||||
|    }, | ||||
|    "file_extension": ".py", | ||||
|    "mimetype": "text/x-python", | ||||
|    "name": "python", | ||||
|    "nbconvert_exporter": "python", | ||||
|    "pygments_lexer": "ipython3", | ||||
|    "version": "3.8.8" | ||||
|   } | ||||
|  }, | ||||
|  "nbformat": 4, | ||||
|  "nbformat_minor": 4 | ||||
| } | ||||
							
								
								
									
										311
									
								
								AutoDL-Projects/notebooks/TOT/ES-Model-DC.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										311
									
								
								AutoDL-Projects/notebooks/TOT/ES-Model-DC.ipynb
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,311 @@ | ||||
| { | ||||
|  "cells": [ | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 1, | ||||
|    "id": "afraid-minutes", | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "The root path: /Users/xuanyidong/Desktop/AutoDL-Projects\n", | ||||
|       "The library path: /Users/xuanyidong/Desktop/AutoDL-Projects/lib\n" | ||||
|      ] | ||||
|     }, | ||||
|     { | ||||
|      "name": "stderr", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "[70148:MainThread](2021-04-12 13:23:30,262) INFO - qlib.Initialization - [config.py:276] - default_conf: client.\n", | ||||
|       "[70148:MainThread](2021-04-12 13:23:30,266) WARNING - qlib.Initialization - [config.py:291] - redis connection failed(host=127.0.0.1 port=6379), cache will not be used!\n", | ||||
|       "[70148:MainThread](2021-04-12 13:23:30,269) INFO - qlib.Initialization - [__init__.py:46] - qlib successfully initialized based on client settings.\n", | ||||
|       "[70148:MainThread](2021-04-12 13:23:30,271) INFO - qlib.Initialization - [__init__.py:47] - data_path=/Users/xuanyidong/.qlib/qlib_data/cn_data\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "#\n", | ||||
|     "# Exhaustive Search Results\n", | ||||
|     "#\n", | ||||
|     "import os\n", | ||||
|     "import re\n", | ||||
|     "import sys\n", | ||||
|     "import qlib\n", | ||||
|     "import pprint\n", | ||||
|     "import numpy as np\n", | ||||
|     "import pandas as pd\n", | ||||
|     "\n", | ||||
|     "from pathlib import Path\n", | ||||
|     "\n", | ||||
|     "__file__ = os.path.dirname(os.path.realpath(\"__file__\"))\n", | ||||
|     "root_dir = (Path(__file__).parent / \"..\").resolve()\n", | ||||
|     "lib_dir = (root_dir / \"lib\").resolve()\n", | ||||
|     "print(\"The root path: {:}\".format(root_dir))\n", | ||||
|     "print(\"The library path: {:}\".format(lib_dir))\n", | ||||
|     "assert lib_dir.exists(), \"{:} does not exist\".format(lib_dir)\n", | ||||
|     "if str(lib_dir) not in sys.path:\n", | ||||
|     "    sys.path.insert(0, str(lib_dir))\n", | ||||
|     "\n", | ||||
|     "import qlib\n", | ||||
|     "from qlib import config as qconfig\n", | ||||
|     "from qlib.workflow import R\n", | ||||
|     "qlib.init(provider_uri='~/.qlib/qlib_data/cn_data', region=qconfig.REG_CN)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 2, | ||||
|    "id": "hidden-exemption", | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "from utils.qlib_utils import QResult" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 3, | ||||
|    "id": "continental-drain", | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "def filter_finished(recorders):\n", | ||||
|     "    returned_recorders = dict()\n", | ||||
|     "    not_finished = 0\n", | ||||
|     "    for key, recorder in recorders.items():\n", | ||||
|     "        if recorder.status == \"FINISHED\":\n", | ||||
|     "            returned_recorders[key] = recorder\n", | ||||
|     "        else:\n", | ||||
|     "            not_finished += 1\n", | ||||
|     "    return returned_recorders, not_finished\n", | ||||
|     "\n", | ||||
|     "def query_info(save_dir, verbose, name_filter, key_map):\n", | ||||
|     "    if isinstance(save_dir, list):\n", | ||||
|     "        results = []\n", | ||||
|     "        for x in save_dir:\n", | ||||
|     "            x = query_info(x, verbose, name_filter, key_map)\n", | ||||
|     "            results.extend(x)\n", | ||||
|     "        return results\n", | ||||
|     "    # Here, the save_dir must be a string\n", | ||||
|     "    R.set_uri(str(save_dir))\n", | ||||
|     "    experiments = R.list_experiments()\n", | ||||
|     "\n", | ||||
|     "    if verbose:\n", | ||||
|     "        print(\"There are {:} experiments.\".format(len(experiments)))\n", | ||||
|     "    qresults = []\n", | ||||
|     "    for idx, (key, experiment) in enumerate(experiments.items()):\n", | ||||
|     "        if experiment.id == \"0\":\n", | ||||
|     "            continue\n", | ||||
|     "        if name_filter is not None and re.fullmatch(name_filter, experiment.name) is None:\n", | ||||
|     "            continue\n", | ||||
|     "        recorders = experiment.list_recorders()\n", | ||||
|     "        recorders, not_finished = filter_finished(recorders)\n", | ||||
|     "        if verbose:\n", | ||||
|     "            print(\n", | ||||
|     "                \"====>>>> {:02d}/{:02d}-th experiment {:9s} has {:02d}/{:02d} finished recorders.\".format(\n", | ||||
|     "                    idx + 1,\n", | ||||
|     "                    len(experiments),\n", | ||||
|     "                    experiment.name,\n", | ||||
|     "                    len(recorders),\n", | ||||
|     "                    len(recorders) + not_finished,\n", | ||||
|     "                )\n", | ||||
|     "            )\n", | ||||
|     "        result = QResult(experiment.name)\n", | ||||
|     "        for recorder_id, recorder in recorders.items():\n", | ||||
|     "            result.update(recorder.list_metrics(), key_map)\n", | ||||
|     "            result.append_path(\n", | ||||
|     "                os.path.join(recorder.uri, recorder.experiment_id, recorder.id)\n", | ||||
|     "            )\n", | ||||
|     "        if not len(result):\n", | ||||
|     "            print(\"There are no valid recorders for {:}\".format(experiment))\n", | ||||
|     "            continue\n", | ||||
|     "        else:\n", | ||||
|     "            if verbose:\n", | ||||
|     "                print(\n", | ||||
|     "                    \"There are {:} valid recorders for {:}\".format(\n", | ||||
|     "                        len(recorders), experiment.name\n", | ||||
|     "                    )\n", | ||||
|     "                )\n", | ||||
|     "        qresults.append(result)\n", | ||||
|     "    return qresults" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 4, | ||||
|    "id": "filled-multiple", | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stderr", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "[70148:MainThread](2021-04-12 13:23:31,137) INFO - qlib.workflow - [expm.py:290] - <mlflow.tracking.client.MlflowClient object at 0x7f8c4a47efa0>\n" | ||||
|      ] | ||||
|     }, | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "[PosixPath('/Users/xuanyidong/Desktop/AutoDL-Projects/outputs/qlib-baselines-csi300')]\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "paths = [root_dir / 'outputs' / 'qlib-baselines-csi300']\n", | ||||
|     "paths = [path.resolve() for path in paths]\n", | ||||
|     "print(paths)\n", | ||||
|     "\n", | ||||
|     "key_map = dict()\n", | ||||
|     "for xset in (\"train\", \"valid\", \"test\"):\n", | ||||
|     "    key_map[\"{:}-mean-IC\".format(xset)] = \"IC ({:})\".format(xset)\n", | ||||
|     "    key_map[\"{:}-mean-ICIR\".format(xset)] = \"ICIR ({:})\".format(xset)\n", | ||||
|     "qresults = query_info(paths, False, 'TSF-.*-drop0_0', key_map)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 5, | ||||
|    "id": "intimate-approval", | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "import matplotlib\n", | ||||
|     "from matplotlib import cm\n", | ||||
|     "matplotlib.use(\"agg\")\n", | ||||
|     "import matplotlib.pyplot as plt\n", | ||||
|     "import matplotlib.ticker as ticker" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 6, | ||||
|    "id": "supreme-basis", | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "def vis_depth_channel(qresults, save_path):\n", | ||||
|     "    save_dir = (save_path / '..').resolve()\n", | ||||
|     "    save_dir.mkdir(parents=True, exist_ok=True)\n", | ||||
|     "    print('There are {:} qlib-results'.format(len(qresults)))\n", | ||||
|     "    \n", | ||||
|     "    dpi, width, height = 200, 4000, 2000\n", | ||||
|     "    figsize = width / float(dpi), height / float(dpi)\n", | ||||
|     "    LabelSize, LegendFontsize = 22, 12\n", | ||||
|     "    font_gap = 5\n", | ||||
|     "    \n", | ||||
|     "    fig = plt.figure(figsize=figsize)\n", | ||||
|     "    # fig, axs = plt.subplots(1, 2, figsize=figsize, projection='3d')\n", | ||||
|     "    \n", | ||||
|     "    def plot_ax(cur_ax, train_or_test):\n", | ||||
|     "        depths, channels = [], []\n", | ||||
|     "        ic_values, xmaps = [], dict()\n", | ||||
|     "        for qresult in qresults:\n", | ||||
|     "            name = qresult.name.split('-')[1]\n", | ||||
|     "            depths.append(float(name.split('x')[0]))\n", | ||||
|     "            channels.append(float(name.split('x')[1]))\n", | ||||
|     "            if train_or_test:\n", | ||||
|     "                ic_values.append(qresult['IC (train)'])\n", | ||||
|     "            else:\n", | ||||
|     "                ic_values.append(qresult['IC (valid)'])\n", | ||||
|     "            xmaps[(depths[-1], channels[-1])] = ic_values[-1]\n", | ||||
|     "        # cur_ax.scatter(depths, channels, ic_values, marker='o', c=\"tab:orange\")\n", | ||||
|     "        raw_depths = np.arange(1, 9, dtype=np.int32)\n", | ||||
|     "        raw_channels = np.array([6, 12, 24, 32, 48, 64], dtype=np.int32)\n", | ||||
|     "        depths, channels = np.meshgrid(raw_depths, raw_channels)\n", | ||||
|     "        ic_values = np.sin(depths)  # initialize\n", | ||||
|     "        # print(ic_values.shape)\n", | ||||
|     "        num_x, num_y = ic_values.shape\n", | ||||
|     "        for i in range(num_x):\n", | ||||
|     "            for j in range(num_y):\n", | ||||
|     "                xkey = (int(depths[i][j]), int(channels[i][j]))\n", | ||||
|     "                if xkey not in xmaps:\n", | ||||
|     "                    raise ValueError(\"Did not find {:}\".format(xkey))\n", | ||||
|     "                ic_values[i][j] = xmaps[xkey]\n", | ||||
|     "        #print(sorted(list(xmaps.keys())))\n", | ||||
|     "        #surf = cur_ax.plot_surface(\n", | ||||
|     "        #    np.array(depths), np.array(channels), np.array(ic_values),\n", | ||||
|     "        #    cmap=cm.coolwarm, linewidth=0, antialiased=False)\n", | ||||
|     "        surf = cur_ax.plot_surface(\n", | ||||
|     "            depths, channels, ic_values,\n", | ||||
|     "            cmap=cm.Spectral, linewidth=0.2, antialiased=True)\n", | ||||
|     "        cur_ax.set_xticks(raw_depths)\n", | ||||
|     "        cur_ax.set_yticks(raw_channels)\n", | ||||
|     "        cur_ax.set_zticks(np.arange(4, 11, 2))\n", | ||||
|     "        cur_ax.set_xlabel(\"#depth\", fontsize=LabelSize)\n", | ||||
|     "        cur_ax.set_ylabel(\"#channels\", fontsize=LabelSize)\n", | ||||
|     "        cur_ax.set_zlabel(\"{:} IC (%)\".format('training' if train_or_test else 'validation'), fontsize=LabelSize)\n", | ||||
|     "        for tick in cur_ax.xaxis.get_major_ticks():\n", | ||||
|     "            tick.label.set_fontsize(LabelSize - font_gap)\n", | ||||
|     "        for tick in cur_ax.yaxis.get_major_ticks():\n", | ||||
|     "            tick.label.set_fontsize(LabelSize - font_gap)\n", | ||||
|     "        for tick in cur_ax.zaxis.get_major_ticks():\n", | ||||
|     "            tick.label.set_fontsize(LabelSize - font_gap)\n", | ||||
|     "        # Add a color bar which maps values to colors.\n", | ||||
|     "#         cax = fig.add_axes([cur_ax.get_position().x1 + 0.01,\n", | ||||
|     "#                             cur_ax.get_position().y0,\n", | ||||
|     "#                             0.01,\n", | ||||
|     "#                             cur_ax.get_position().height * 0.9])\n", | ||||
|     "        # fig.colorbar(surf, cax=cax)\n", | ||||
|     "        # fig.colorbar(surf, shrink=0.5, aspect=5)\n", | ||||
|     "        # import pdb; pdb.set_trace()\n", | ||||
|     "        # ax1.legend(loc=4, fontsize=LegendFontsize)\n", | ||||
|     "    ax = fig.add_subplot(1, 2, 1, projection='3d')\n", | ||||
|     "    plot_ax(ax, True)\n", | ||||
|     "    ax = fig.add_subplot(1, 2, 2, projection='3d')\n", | ||||
|     "    plot_ax(ax, False)\n", | ||||
|     "    # fig.tight_layout()\n", | ||||
|     "    plt.subplots_adjust(wspace=0.05)#, hspace=0.4)\n", | ||||
|     "    fig.savefig(save_path, dpi=dpi, bbox_inches=\"tight\", format=\"pdf\")\n", | ||||
|     "    plt.close(\"all\")" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 7, | ||||
|    "id": "shared-envelope", | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "The Desktop is at: /Users/xuanyidong/Desktop\n", | ||||
|       "There are 48 qlib-results\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "# Visualization\n", | ||||
|     "home_dir = Path.home()\n", | ||||
|     "desktop_dir = home_dir / 'Desktop'\n", | ||||
|     "print('The Desktop is at: {:}'.format(desktop_dir))\n", | ||||
|     "\n", | ||||
|     "vis_depth_channel(qresults, desktop_dir / 'es_csi300_d_vs_c.pdf')" | ||||
|    ] | ||||
|   } | ||||
|  ], | ||||
|  "metadata": { | ||||
|   "kernelspec": { | ||||
|    "display_name": "Python 3", | ||||
|    "language": "python", | ||||
|    "name": "python3" | ||||
|   }, | ||||
|   "language_info": { | ||||
|    "codemirror_mode": { | ||||
|     "name": "ipython", | ||||
|     "version": 3 | ||||
|    }, | ||||
|    "file_extension": ".py", | ||||
|    "mimetype": "text/x-python", | ||||
|    "name": "python", | ||||
|    "nbconvert_exporter": "python", | ||||
|    "pygments_lexer": "ipython3", | ||||
|    "version": "3.8.8" | ||||
|   } | ||||
|  }, | ||||
|  "nbformat": 4, | ||||
|  "nbformat_minor": 5 | ||||
| } | ||||
							
								
								
									
										312
									
								
								AutoDL-Projects/notebooks/TOT/ES-Model-Drop.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										312
									
								
								AutoDL-Projects/notebooks/TOT/ES-Model-Drop.ipynb
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,312 @@ | ||||
| { | ||||
|  "cells": [ | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 1, | ||||
|    "id": "afraid-minutes", | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "The root path: /Users/xuanyidong/Desktop/AutoDL-Projects\n", | ||||
|       "The library path: /Users/xuanyidong/Desktop/AutoDL-Projects/lib\n" | ||||
|      ] | ||||
|     }, | ||||
|     { | ||||
|      "name": "stderr", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "[70363:MainThread](2021-04-12 13:25:01,065) INFO - qlib.Initialization - [config.py:276] - default_conf: client.\n", | ||||
|       "[70363:MainThread](2021-04-12 13:25:01,069) WARNING - qlib.Initialization - [config.py:291] - redis connection failed(host=127.0.0.1 port=6379), cache will not be used!\n", | ||||
|       "[70363:MainThread](2021-04-12 13:25:01,085) INFO - qlib.Initialization - [__init__.py:46] - qlib successfully initialized based on client settings.\n", | ||||
|       "[70363:MainThread](2021-04-12 13:25:01,092) INFO - qlib.Initialization - [__init__.py:47] - data_path=/Users/xuanyidong/.qlib/qlib_data/cn_data\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "#\n", | ||||
|     "# Exhaustive Search Results\n", | ||||
|     "#\n", | ||||
|     "import os\n", | ||||
|     "import re\n", | ||||
|     "import sys\n", | ||||
|     "import qlib\n", | ||||
|     "import pprint\n", | ||||
|     "import numpy as np\n", | ||||
|     "import pandas as pd\n", | ||||
|     "\n", | ||||
|     "from pathlib import Path\n", | ||||
|     "\n", | ||||
|     "__file__ = os.path.dirname(os.path.realpath(\"__file__\"))\n", | ||||
|     "root_dir = (Path(__file__).parent / \"..\").resolve()\n", | ||||
|     "lib_dir = (root_dir / \"lib\").resolve()\n", | ||||
|     "print(\"The root path: {:}\".format(root_dir))\n", | ||||
|     "print(\"The library path: {:}\".format(lib_dir))\n", | ||||
|     "assert lib_dir.exists(), \"{:} does not exist\".format(lib_dir)\n", | ||||
|     "if str(lib_dir) not in sys.path:\n", | ||||
|     "    sys.path.insert(0, str(lib_dir))\n", | ||||
|     "\n", | ||||
|     "import qlib\n", | ||||
|     "from qlib import config as qconfig\n", | ||||
|     "from qlib.workflow import R\n", | ||||
|     "qlib.init(provider_uri='~/.qlib/qlib_data/cn_data', region=qconfig.REG_CN)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 2, | ||||
|    "id": "hidden-exemption", | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "from utils.qlib_utils import QResult" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 3, | ||||
|    "id": "continental-drain", | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "def filter_finished(recorders):\n", | ||||
|     "    returned_recorders = dict()\n", | ||||
|     "    not_finished = 0\n", | ||||
|     "    for key, recorder in recorders.items():\n", | ||||
|     "        if recorder.status == \"FINISHED\":\n", | ||||
|     "            returned_recorders[key] = recorder\n", | ||||
|     "        else:\n", | ||||
|     "            not_finished += 1\n", | ||||
|     "    return returned_recorders, not_finished\n", | ||||
|     "\n", | ||||
|     "def query_info(save_dir, verbose, name_filter, key_map):\n", | ||||
|     "    if isinstance(save_dir, list):\n", | ||||
|     "        results = []\n", | ||||
|     "        for x in save_dir:\n", | ||||
|     "            x = query_info(x, verbose, name_filter, key_map)\n", | ||||
|     "            results.extend(x)\n", | ||||
|     "        return results\n", | ||||
|     "    # Here, the save_dir must be a string\n", | ||||
|     "    R.set_uri(str(save_dir))\n", | ||||
|     "    experiments = R.list_experiments()\n", | ||||
|     "\n", | ||||
|     "    if verbose:\n", | ||||
|     "        print(\"There are {:} experiments.\".format(len(experiments)))\n", | ||||
|     "    qresults = []\n", | ||||
|     "    for idx, (key, experiment) in enumerate(experiments.items()):\n", | ||||
|     "        if experiment.id == \"0\":\n", | ||||
|     "            continue\n", | ||||
|     "        if name_filter is not None and re.fullmatch(name_filter, experiment.name) is None:\n", | ||||
|     "            continue\n", | ||||
|     "        recorders = experiment.list_recorders()\n", | ||||
|     "        recorders, not_finished = filter_finished(recorders)\n", | ||||
|     "        if verbose:\n", | ||||
|     "            print(\n", | ||||
|     "                \"====>>>> {:02d}/{:02d}-th experiment {:9s} has {:02d}/{:02d} finished recorders.\".format(\n", | ||||
|     "                    idx + 1,\n", | ||||
|     "                    len(experiments),\n", | ||||
|     "                    experiment.name,\n", | ||||
|     "                    len(recorders),\n", | ||||
|     "                    len(recorders) + not_finished,\n", | ||||
|     "                )\n", | ||||
|     "            )\n", | ||||
|     "        result = QResult(experiment.name)\n", | ||||
|     "        for recorder_id, recorder in recorders.items():\n", | ||||
|     "            result.update(recorder.list_metrics(), key_map)\n", | ||||
|     "            result.append_path(\n", | ||||
|     "                os.path.join(recorder.uri, recorder.experiment_id, recorder.id)\n", | ||||
|     "            )\n", | ||||
|     "        if not len(result):\n", | ||||
|     "            print(\"There are no valid recorders for {:}\".format(experiment))\n", | ||||
|     "            continue\n", | ||||
|     "        else:\n", | ||||
|     "            if verbose:\n", | ||||
|     "                print(\n", | ||||
|     "                    \"There are {:} valid recorders for {:}\".format(\n", | ||||
|     "                        len(recorders), experiment.name\n", | ||||
|     "                    )\n", | ||||
|     "                )\n", | ||||
|     "        qresults.append(result)\n", | ||||
|     "    return qresults" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 4, | ||||
|    "id": "filled-multiple", | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stderr", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "[70363:MainThread](2021-04-12 13:25:01,647) INFO - qlib.workflow - [expm.py:290] - <mlflow.tracking.client.MlflowClient object at 0x7fa920e56820>\n" | ||||
|      ] | ||||
|     }, | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "[PosixPath('/Users/xuanyidong/Desktop/AutoDL-Projects/outputs/qlib-baselines-csi300')]\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "paths = [root_dir / 'outputs' / 'qlib-baselines-csi300']\n", | ||||
|     "paths = [path.resolve() for path in paths]\n", | ||||
|     "print(paths)\n", | ||||
|     "\n", | ||||
|     "key_map = dict()\n", | ||||
|     "for xset in (\"train\", \"valid\", \"test\"):\n", | ||||
|     "    key_map[\"{:}-mean-IC\".format(xset)] = \"IC ({:})\".format(xset)\n", | ||||
|     "    key_map[\"{:}-mean-ICIR\".format(xset)] = \"ICIR ({:})\".format(xset)\n", | ||||
|     "\n", | ||||
|     "qresults = query_info(paths, False, 'TSF-.*', key_map)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 5, | ||||
|    "id": "intimate-approval", | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "import matplotlib\n", | ||||
|     "from matplotlib import cm\n", | ||||
|     "matplotlib.use(\"agg\")\n", | ||||
|     "import matplotlib.pyplot as plt\n", | ||||
|     "import matplotlib.ticker as ticker" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 8, | ||||
|    "id": "supreme-basis", | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "def vis_dropouts(qresults, basenames, name2suffix, save_path):\n", | ||||
|     "    save_dir = (save_path / '..').resolve()\n", | ||||
|     "    save_dir.mkdir(parents=True, exist_ok=True)\n", | ||||
|     "    print('There are {:} qlib-results'.format(len(qresults)))\n", | ||||
|     "    \n", | ||||
|     "    name2qresult = dict()\n", | ||||
|     "    for qresult in qresults:\n", | ||||
|     "        name2qresult[qresult.name] = qresult\n", | ||||
|     "    # sort architectures\n", | ||||
|     "    accuracies = []\n", | ||||
|     "    for basename in basenames:\n", | ||||
|     "        qresult = name2qresult[basename + '-drop0_0']\n", | ||||
|     "        accuracies.append(qresult['ICIR (train)'])\n", | ||||
|     "    sorted_basenames = sorted(basenames, key=lambda x: accuracies[basenames.index(x)])\n", | ||||
|     "    \n", | ||||
|     "    dpi, width, height = 200, 4000, 2000\n", | ||||
|     "    figsize = width / float(dpi), height / float(dpi)\n", | ||||
|     "    LabelSize, LegendFontsize = 22, 22\n", | ||||
|     "    font_gap = 5\n", | ||||
|     "    colors = ['k', 'r']\n", | ||||
|     "    markers = ['*', 'o']\n", | ||||
|     "    \n", | ||||
|     "    fig = plt.figure(figsize=figsize)\n", | ||||
|     "    \n", | ||||
|     "    def plot_ax(cur_ax, train_or_test):\n", | ||||
|     "        for idx, (legend, suffix) in enumerate(name2suffix.items()):\n", | ||||
|     "            x_values = list(range(len(sorted_basenames)))\n", | ||||
|     "            y_values = []\n", | ||||
|     "            for i, name in enumerate(sorted_basenames):\n", | ||||
|     "                name = '{:}{:}'.format(name, suffix)\n", | ||||
|     "                qresult = name2qresult[name]\n", | ||||
|     "                if train_or_test:\n", | ||||
|     "                    value = qresult['IC (train)']\n", | ||||
|     "                else:\n", | ||||
|     "                    value = qresult['IC (valid)']\n", | ||||
|     "                y_values.append(value)\n", | ||||
|     "            cur_ax.plot(x_values, y_values, c=colors[idx])\n", | ||||
|     "            cur_ax.scatter(x_values, y_values,\n", | ||||
|     "                           marker=markers[idx], s=3, c=colors[idx], alpha=0.9,\n", | ||||
|     "                           label=legend)\n", | ||||
|     "        cur_ax.set_yticks(np.arange(4, 11, 2))\n", | ||||
|     "        cur_ax.set_xlabel(\"sorted architectures\", fontsize=LabelSize)\n", | ||||
|     "        cur_ax.set_ylabel(\"{:} IC (%)\".format('training' if train_or_test else 'validation'), fontsize=LabelSize)\n", | ||||
|     "        for tick in cur_ax.xaxis.get_major_ticks():\n", | ||||
|     "            tick.label.set_fontsize(LabelSize - font_gap)\n", | ||||
|     "        for tick in cur_ax.yaxis.get_major_ticks():\n", | ||||
|     "            tick.label.set_fontsize(LabelSize - font_gap)\n", | ||||
|     "        cur_ax.legend(loc=4, fontsize=LegendFontsize)\n", | ||||
|     "    ax = fig.add_subplot(1, 2, 1)\n", | ||||
|     "    plot_ax(ax, True)\n", | ||||
|     "    ax = fig.add_subplot(1, 2, 2)\n", | ||||
|     "    plot_ax(ax, False)\n", | ||||
|     "    # fig.tight_layout()\n", | ||||
|     "    # plt.subplots_adjust(wspace=0.05)#, hspace=0.4)\n", | ||||
|     "    fig.savefig(save_path, dpi=dpi, bbox_inches=\"tight\", format=\"pdf\")\n", | ||||
|     "    plt.close(\"all\")" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 9, | ||||
|    "id": "shared-envelope", | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "{'TSF-3x48', 'TSF-2x64', 'TSF-2x12', 'TSF-8x48', 'TSF-6x32', 'TSF-4x48', 'TSF-8x6', 'TSF-4x6', 'TSF-2x32', 'TSF-5x12', 'TSF-5x64', 'TSF-1x64', 'TSF-2x24', 'TSF-8x24', 'TSF-4x12', 'TSF-6x12', 'TSF-1x32', 'TSF-5x32', 'TSF-3x24', 'TSF-8x12', 'TSF-5x48', 'TSF-6x64', 'TSF-7x64', 'TSF-7x48', 'TSF-1x6', 'TSF-2x48', 'TSF-7x24', 'TSF-3x32', 'TSF-1x24', 'TSF-4x64', 'TSF-3x12', 'TSF-8x64', 'TSF-4x32', 'TSF-5x6', 'TSF-7x6', 'TSF-7x12', 'TSF-3x6', 'TSF-4x24', 'TSF-6x48', 'TSF-6x6', 'TSF-1x48', 'TSF-1x12', 'TSF-7x32', 'TSF-5x24', 'TSF-2x6', 'TSF-6x24', 'TSF-3x64', 'TSF-8x32'}\n", | ||||
|       "The Desktop is at: /Users/xuanyidong/Desktop\n", | ||||
|       "There are 104 qlib-results\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "# Visualization\n", | ||||
|     "names = [qresult.name for qresult in qresults]\n", | ||||
|     "base_names = set()\n", | ||||
|     "for name in names:\n", | ||||
|     "    base_name = name.split('-drop')[0]\n", | ||||
|     "    base_names.add(base_name)\n", | ||||
|     "print(base_names)\n", | ||||
|     "# filter\n", | ||||
|     "filtered_base_names = set()\n", | ||||
|     "for base_name in base_names:\n", | ||||
|     "    if (base_name + '-drop0_0') in names and (base_name + '-drop0.1_0') in names:\n", | ||||
|     "        filtered_base_names.add(base_name)\n", | ||||
|     "    else:\n", | ||||
|     "        print('Cannot find all names for {:}'.format(base_name))\n", | ||||
|     "# print(filtered_base_names)\n", | ||||
|     "home_dir = Path.home()\n", | ||||
|     "desktop_dir = home_dir / 'Desktop'\n", | ||||
|     "print('The Desktop is at: {:}'.format(desktop_dir))\n", | ||||
|     "\n", | ||||
|     "vis_dropouts(qresults, list(filtered_base_names),\n", | ||||
|     "             {'No-dropout': '-drop0_0',\n", | ||||
|     "              'Ratio=0.1' : '-drop0.1_0'},\n", | ||||
|     "             desktop_dir / 'es_csi300_drop.pdf')" | ||||
|    ] | ||||
|   } | ||||
|  ], | ||||
|  "metadata": { | ||||
|   "kernelspec": { | ||||
|    "display_name": "Python 3", | ||||
|    "language": "python", | ||||
|    "name": "python3" | ||||
|   }, | ||||
|   "language_info": { | ||||
|    "codemirror_mode": { | ||||
|     "name": "ipython", | ||||
|     "version": 3 | ||||
|    }, | ||||
|    "file_extension": ".py", | ||||
|    "mimetype": "text/x-python", | ||||
|    "name": "python", | ||||
|    "nbconvert_exporter": "python", | ||||
|    "pygments_lexer": "ipython3", | ||||
|    "version": "3.8.8" | ||||
|   } | ||||
|  }, | ||||
|  "nbformat": 4, | ||||
|  "nbformat_minor": 5 | ||||
| } | ||||
							
								
								
									
										208
									
								
								AutoDL-Projects/notebooks/TOT/Time-Curve.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										208
									
								
								AutoDL-Projects/notebooks/TOT/Time-Curve.ipynb
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,208 @@ | ||||
| { | ||||
|  "cells": [ | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 1, | ||||
|    "id": "afraid-minutes", | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "The root path: /Users/xuanyidong/Desktop/AutoDL-Projects\n", | ||||
|       "The library path: /Users/xuanyidong/Desktop/AutoDL-Projects/lib\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "import os\n", | ||||
|     "import re\n", | ||||
|     "import sys\n", | ||||
|     "import torch\n", | ||||
|     "import pprint\n", | ||||
|     "import numpy as np\n", | ||||
|     "import pandas as pd\n", | ||||
|     "from pathlib import Path\n", | ||||
|     "from scipy.interpolate import make_interp_spline\n", | ||||
|     "\n", | ||||
|     "__file__ = os.path.dirname(os.path.realpath(\"__file__\"))\n", | ||||
|     "root_dir = (Path(__file__).parent / \"..\").resolve()\n", | ||||
|     "lib_dir = (root_dir / \"lib\").resolve()\n", | ||||
|     "print(\"The root path: {:}\".format(root_dir))\n", | ||||
|     "print(\"The library path: {:}\".format(lib_dir))\n", | ||||
|     "assert lib_dir.exists(), \"{:} does not exist\".format(lib_dir)\n", | ||||
|     "if str(lib_dir) not in sys.path:\n", | ||||
|     "    sys.path.insert(0, str(lib_dir))\n", | ||||
|     "from utils.qlib_utils import QResult" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 2, | ||||
|    "id": "continental-drain", | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "TSF-2x24-drop0_0s2013-01-01\n", | ||||
|       "TSF-2x24-drop0_0s2012-01-01\n", | ||||
|       "TSF-2x24-drop0_0s2008-01-01\n", | ||||
|       "TSF-2x24-drop0_0s2009-01-01\n", | ||||
|       "TSF-2x24-drop0_0s2010-01-01\n", | ||||
|       "TSF-2x24-drop0_0s2011-01-01\n", | ||||
|       "TSF-2x24-drop0_0s2008-07-01\n", | ||||
|       "TSF-2x24-drop0_0s2009-07-01\n", | ||||
|       "There are 3011 dates\n", | ||||
|       "Dates: 2008-01-02 2008-01-03\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "qresults = torch.load(os.path.join(root_dir, 'notebooks', 'TOT', 'temp-time-x.pth'))\n", | ||||
|     "for qresult in qresults:\n", | ||||
|     "    print(qresult.name)\n", | ||||
|     "all_dates = set()\n", | ||||
|     "for qresult in qresults:\n", | ||||
|     "    dates = qresult.find_all_dates()\n", | ||||
|     "    for date in dates:\n", | ||||
|     "        all_dates.add(date)\n", | ||||
|     "all_dates = sorted(list(all_dates))\n", | ||||
|     "print('There are {:} dates'.format(len(all_dates)))\n", | ||||
|     "print('Dates: {:} {:}'.format(all_dates[0], all_dates[1]))" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 3, | ||||
|    "id": "intimate-approval", | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "import matplotlib\n", | ||||
|     "from matplotlib import cm\n", | ||||
|     "matplotlib.use(\"agg\")\n", | ||||
|     "import matplotlib.pyplot as plt\n", | ||||
|     "import matplotlib.ticker as ticker" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 6, | ||||
|    "id": "supreme-basis", | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "def vis_time_curve(qresults, dates, use_original, save_path):\n", | ||||
|     "    save_dir = (save_path / '..').resolve()\n", | ||||
|     "    save_dir.mkdir(parents=True, exist_ok=True)\n", | ||||
|     "    print('There are {:} qlib-results'.format(len(qresults)))\n", | ||||
|     "    \n", | ||||
|     "    dpi, width, height = 200, 5000, 2000\n", | ||||
|     "    figsize = width / float(dpi), height / float(dpi)\n", | ||||
|     "    LabelSize, LegendFontsize = 22, 12\n", | ||||
|     "    font_gap = 5\n", | ||||
|     "    linestyles = ['-', '--']\n", | ||||
|     "    colors = ['k', 'r']\n", | ||||
|     "    \n", | ||||
|     "    fig = plt.figure(figsize=figsize)\n", | ||||
|     "    cur_ax = fig.add_subplot(1, 1, 1)\n", | ||||
|     "    for idx, qresult in enumerate(qresults):\n", | ||||
|     "        print('Visualize [{:}] -- {:}'.format(idx, qresult.name))\n", | ||||
|     "        x_axis, y_axis = [], []\n", | ||||
|     "        for idate, date in enumerate(dates):\n", | ||||
|     "            if date in qresult._date2ICs[-1]:\n", | ||||
|     "                mean, std = qresult.get_IC_by_date(date, 100)\n", | ||||
|     "                if not np.isnan(mean):\n", | ||||
|     "                    x_axis.append(idate)\n", | ||||
|     "                    y_axis.append(mean)\n", | ||||
|     "        x_axis, y_axis = np.array(x_axis), np.array(y_axis)\n", | ||||
|     "        if use_original:\n", | ||||
|     "            cur_ax.plot(x_axis, y_axis, linewidth=1, color=colors[idx], linestyle=linestyles[idx])\n", | ||||
|     "        else:\n", | ||||
|     "            xnew = np.linspace(x_axis.min(), x_axis.max(), 200)\n", | ||||
|     "            spl = make_interp_spline(x_axis, y_axis, k=5)\n", | ||||
|     "            ynew = spl(xnew)\n", | ||||
|     "            cur_ax.plot(xnew, ynew, linewidth=2, color=colors[idx], linestyle=linestyles[idx])\n", | ||||
|     "        \n", | ||||
|     "    for tick in cur_ax.xaxis.get_major_ticks():\n", | ||||
|     "        tick.label.set_fontsize(LabelSize - font_gap)\n", | ||||
|     "    for tick in cur_ax.yaxis.get_major_ticks():\n", | ||||
|     "        tick.label.set_fontsize(LabelSize - font_gap)\n", | ||||
|     "    cur_ax.set_ylabel(\"IC (%)\", fontsize=LabelSize)\n", | ||||
|     "    fig.savefig(save_path, dpi=dpi, bbox_inches=\"tight\", format=\"pdf\")\n", | ||||
|     "    plt.close(\"all\")" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 7, | ||||
|    "id": "shared-envelope", | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "The Desktop is at: /Users/xuanyidong/Desktop\n", | ||||
|       "There are 2 qlib-results\n", | ||||
|       "Visualize [0] -- TSF-2x24-drop0_0s2008-01-01\n", | ||||
|       "Visualize [1] -- TSF-2x24-drop0_0s2009-07-01\n", | ||||
|       "There are 2 qlib-results\n", | ||||
|       "Visualize [0] -- TSF-2x24-drop0_0s2008-01-01\n", | ||||
|       "Visualize [1] -- TSF-2x24-drop0_0s2009-07-01\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "# Visualization\n", | ||||
|     "home_dir = Path.home()\n", | ||||
|     "desktop_dir = home_dir / 'Desktop'\n", | ||||
|     "print('The Desktop is at: {:}'.format(desktop_dir))\n", | ||||
|     "\n", | ||||
|     "vis_time_curve(\n", | ||||
|     "    (qresults[2], qresults[-1]),\n", | ||||
|     "    all_dates,\n", | ||||
|     "    True,\n", | ||||
|     "    desktop_dir / 'es_csi300_time_curve.pdf')\n", | ||||
|     "\n", | ||||
|     "vis_time_curve(\n", | ||||
|     "    (qresults[2], qresults[-1]),\n", | ||||
|     "    all_dates,\n", | ||||
|     "    False,\n", | ||||
|     "    desktop_dir / 'es_csi300_time_curve-inter.pdf')" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "id": "exempt-stable", | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [] | ||||
|   } | ||||
|  ], | ||||
|  "metadata": { | ||||
|   "kernelspec": { | ||||
|    "display_name": "Python 3", | ||||
|    "language": "python", | ||||
|    "name": "python3" | ||||
|   }, | ||||
|   "language_info": { | ||||
|    "codemirror_mode": { | ||||
|     "name": "ipython", | ||||
|     "version": 3 | ||||
|    }, | ||||
|    "file_extension": ".py", | ||||
|    "mimetype": "text/x-python", | ||||
|    "name": "python", | ||||
|    "nbconvert_exporter": "python", | ||||
|    "pygments_lexer": "ipython3", | ||||
|    "version": "3.8.8" | ||||
|   } | ||||
|  }, | ||||
|  "nbformat": 4, | ||||
|  "nbformat_minor": 5 | ||||
| } | ||||
							
								
								
									
										129
									
								
								AutoDL-Projects/notebooks/TOT/time-curve.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										129
									
								
								AutoDL-Projects/notebooks/TOT/time-curve.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,129 @@ | ||||
| import os | ||||
| import re | ||||
| import sys | ||||
| import torch | ||||
| import qlib | ||||
| import pprint | ||||
| from collections import OrderedDict | ||||
| import numpy as np | ||||
| import pandas as pd | ||||
|  | ||||
| from pathlib import Path | ||||
|  | ||||
| # __file__ = os.path.dirname(os.path.realpath("__file__")) | ||||
| note_dir = Path(__file__).parent.resolve() | ||||
| root_dir = (Path(__file__).parent / ".." / "..").resolve() | ||||
| lib_dir = (root_dir / "lib").resolve() | ||||
| print("The root path: {:}".format(root_dir)) | ||||
| print("The library path: {:}".format(lib_dir)) | ||||
| assert lib_dir.exists(), "{:} does not exist".format(lib_dir) | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| import qlib | ||||
| from qlib import config as qconfig | ||||
| from qlib.workflow import R | ||||
|  | ||||
| qlib.init(provider_uri="~/.qlib/qlib_data/cn_data", region=qconfig.REG_CN) | ||||
|  | ||||
| from utils.qlib_utils import QResult | ||||
|  | ||||
|  | ||||
| def filter_finished(recorders): | ||||
|     returned_recorders = dict() | ||||
|     not_finished = 0 | ||||
|     for key, recorder in recorders.items(): | ||||
|         if recorder.status == "FINISHED": | ||||
|             returned_recorders[key] = recorder | ||||
|         else: | ||||
|             not_finished += 1 | ||||
|     return returned_recorders, not_finished | ||||
|  | ||||
|  | ||||
| def add_to_dict(xdict, timestamp, value): | ||||
|     date = timestamp.date().strftime("%Y-%m-%d") | ||||
|     if date in xdict: | ||||
|         raise ValueError("This date [{:}] is already in the dict".format(date)) | ||||
|     xdict[date] = value | ||||
|  | ||||
|  | ||||
| def query_info(save_dir, verbose, name_filter, key_map): | ||||
|     if isinstance(save_dir, list): | ||||
|         results = [] | ||||
|         for x in save_dir: | ||||
|             x = query_info(x, verbose, name_filter, key_map) | ||||
|             results.extend(x) | ||||
|         return results | ||||
|     # Here, the save_dir must be a string | ||||
|     R.set_uri(str(save_dir)) | ||||
|     experiments = R.list_experiments() | ||||
|  | ||||
|     if verbose: | ||||
|         print("There are {:} experiments.".format(len(experiments))) | ||||
|     qresults = [] | ||||
|     for idx, (key, experiment) in enumerate(experiments.items()): | ||||
|         if experiment.id == "0": | ||||
|             continue | ||||
|         if ( | ||||
|             name_filter is not None | ||||
|             and re.fullmatch(name_filter, experiment.name) is None | ||||
|         ): | ||||
|             continue | ||||
|         recorders = experiment.list_recorders() | ||||
|         recorders, not_finished = filter_finished(recorders) | ||||
|         if verbose: | ||||
|             print( | ||||
|                 "====>>>> {:02d}/{:02d}-th experiment {:9s} has {:02d}/{:02d} finished recorders.".format( | ||||
|                     idx + 1, | ||||
|                     len(experiments), | ||||
|                     experiment.name, | ||||
|                     len(recorders), | ||||
|                     len(recorders) + not_finished, | ||||
|                 ) | ||||
|             ) | ||||
|         result = QResult(experiment.name) | ||||
|         for recorder_id, recorder in recorders.items(): | ||||
|             file_names = ["results-train.pkl", "results-valid.pkl", "results-test.pkl"] | ||||
|             date2IC = OrderedDict() | ||||
|             for file_name in file_names: | ||||
|                 xtemp = recorder.load_object(file_name)["all-IC"] | ||||
|                 timestamps, values = xtemp.index.tolist(), xtemp.tolist() | ||||
|                 for timestamp, value in zip(timestamps, values): | ||||
|                     add_to_dict(date2IC, timestamp, value) | ||||
|             result.update(recorder.list_metrics(), key_map) | ||||
|             result.append_path( | ||||
|                 os.path.join(recorder.uri, recorder.experiment_id, recorder.id) | ||||
|             ) | ||||
|             result.append_date2ICs(date2IC) | ||||
|         if not len(result): | ||||
|             print("There are no valid recorders for {:}".format(experiment)) | ||||
|             continue | ||||
|         else: | ||||
|             if verbose: | ||||
|                 print( | ||||
|                     "There are {:} valid recorders for {:}".format( | ||||
|                         len(recorders), experiment.name | ||||
|                     ) | ||||
|                 ) | ||||
|         qresults.append(result) | ||||
|     return qresults | ||||
|  | ||||
|  | ||||
| ## | ||||
| paths = [root_dir / "outputs" / "qlib-baselines-csi300"] | ||||
| paths = [path.resolve() for path in paths] | ||||
| print(paths) | ||||
|  | ||||
| key_map = dict() | ||||
| for xset in ("train", "valid", "test"): | ||||
|     key_map["{:}-mean-IC".format(xset)] = "IC ({:})".format(xset) | ||||
|     key_map["{:}-mean-ICIR".format(xset)] = "ICIR ({:})".format(xset) | ||||
| qresults = query_info(paths, False, "TSF-2x24-drop0_0s.*-.*-01", key_map) | ||||
| print("Find {:} results".format(len(qresults))) | ||||
| times = [] | ||||
| for qresult in qresults: | ||||
|     times.append(qresult.name.split("0_0s")[-1]) | ||||
| print(times) | ||||
| save_path = os.path.join(note_dir, "temp-time-x.pth") | ||||
| torch.save(qresults, save_path) | ||||
| print(save_path) | ||||
| @@ -0,0 +1,102 @@ | ||||
| { | ||||
|  "cells": [ | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 1, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "library path: /Users/xuanyidong/Desktop/XAutoDL/lib\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "#####################################################\n", | ||||
|     "# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #\n", | ||||
|     "#####################################################\n", | ||||
|     "import abc, os, sys\n", | ||||
|     "from pathlib import Path\n", | ||||
|     "\n", | ||||
|     "__file__ = os.path.dirname(os.path.realpath(\"__file__\"))\n", | ||||
|     "\n", | ||||
|     "lib_dir = (Path(__file__).parent / \"..\" / \"lib\").resolve()\n", | ||||
|     "print(\"library path: {:}\".format(lib_dir))\n", | ||||
|     "assert lib_dir.exists(), \"{:} does not exist\".format(lib_dir)\n", | ||||
|     "if str(lib_dir) not in sys.path:\n", | ||||
|     "    sys.path.insert(0, str(lib_dir))" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 2, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "1.7.0\n", | ||||
|       "True\n", | ||||
|       "OrderedDict()\n", | ||||
|       "OrderedDict()\n", | ||||
|       "set()\n", | ||||
|       "OrderedDict()\n", | ||||
|       "OrderedDict()\n", | ||||
|       "OrderedDict()\n", | ||||
|       "OrderedDict()\n", | ||||
|       "OrderedDict()\n", | ||||
|       "OrderedDict()\n" | ||||
|      ] | ||||
|     }, | ||||
|     { | ||||
|      "name": "stderr", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "/Users/xuanyidong/anaconda3/lib/python3.8/site-packages/torch/nn/modules/container.py:551: UserWarning: Setting attributes on ParameterDict is not supported.\n", | ||||
|       "  warnings.warn(\"Setting attributes on ParameterDict is not supported.\")\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "# Test the Linear layer\n", | ||||
|     "import spaces\n", | ||||
|     "import torch\n", | ||||
|     "from xlayers import super_core\n", | ||||
|     "\n", | ||||
|     "print(torch.__version__)\n", | ||||
|     "mlp = super_core.SuperMLPv2(10, 12, 32)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [] | ||||
|   } | ||||
|  ], | ||||
|  "metadata": { | ||||
|   "kernelspec": { | ||||
|    "display_name": "Python 3", | ||||
|    "language": "python", | ||||
|    "name": "python3" | ||||
|   }, | ||||
|   "language_info": { | ||||
|    "codemirror_mode": { | ||||
|     "name": "ipython", | ||||
|     "version": 3 | ||||
|    }, | ||||
|    "file_extension": ".py", | ||||
|    "mimetype": "text/x-python", | ||||
|    "name": "python", | ||||
|    "nbconvert_exporter": "python", | ||||
|    "pygments_lexer": "ipython3", | ||||
|    "version": "3.8.3" | ||||
|   } | ||||
|  }, | ||||
|  "nbformat": 4, | ||||
|  "nbformat_minor": 4 | ||||
| } | ||||
							
								
								
									
										119
									
								
								AutoDL-Projects/notebooks/spaces-xmisc/scheduler.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										119
									
								
								AutoDL-Projects/notebooks/spaces-xmisc/scheduler.ipynb
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										110
									
								
								AutoDL-Projects/notebooks/spaces-xmisc/synthetic-data.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										110
									
								
								AutoDL-Projects/notebooks/spaces-xmisc/synthetic-data.ipynb
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										129
									
								
								AutoDL-Projects/notebooks/spaces-xmisc/synthetic-env.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										129
									
								
								AutoDL-Projects/notebooks/spaces-xmisc/synthetic-env.ipynb
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							| @@ -0,0 +1,152 @@ | ||||
| { | ||||
|  "cells": [ | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 1, | ||||
|    "id": "filled-multiple", | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "The root path: /Users/xuanyidong/Desktop/AutoDL-Projects\n", | ||||
|       "The library path: /Users/xuanyidong/Desktop/AutoDL-Projects/lib\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "import os, sys\n", | ||||
|     "import torch\n", | ||||
|     "from pathlib import Path\n", | ||||
|     "import numpy as np\n", | ||||
|     "import matplotlib\n", | ||||
|     "from matplotlib import cm\n", | ||||
|     "matplotlib.use(\"agg\")\n", | ||||
|     "import matplotlib.pyplot as plt\n", | ||||
|     "import matplotlib.ticker as ticker\n", | ||||
|     "\n", | ||||
|     "\n", | ||||
|     "__file__ = os.path.dirname(os.path.realpath(\"__file__\"))\n", | ||||
|     "root_dir = (Path(__file__).parent / \"..\").resolve()\n", | ||||
|     "lib_dir = (root_dir / \"lib\").resolve()\n", | ||||
|     "print(\"The root path: {:}\".format(root_dir))\n", | ||||
|     "print(\"The library path: {:}\".format(lib_dir))\n", | ||||
|     "assert lib_dir.exists(), \"{:} does not exist\".format(lib_dir)\n", | ||||
|     "if str(lib_dir) not in sys.path:\n", | ||||
|     "    sys.path.insert(0, str(lib_dir))\n", | ||||
|     "\n", | ||||
|     "from datasets import ConstantGenerator, SinGenerator, SyntheticDEnv\n", | ||||
|     "from datasets import DynamicQuadraticFunc\n", | ||||
|     "from datasets.synthetic_example import create_example_v1" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 2, | ||||
|    "id": "detected-second", | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "def draw_fig(save_dir, timestamp, xaxis, yaxis):\n", | ||||
|     "    save_path = save_dir / '{:04d}'.format(timestamp)\n", | ||||
|     "    # print('Plot the figure at timestamp-{:} into {:}'.format(timestamp, save_path))\n", | ||||
|     "    dpi, width, height = 40, 1500, 1500\n", | ||||
|     "    figsize = width / float(dpi), height / float(dpi)\n", | ||||
|     "    LabelSize, LegendFontsize, font_gap = 80, 80, 5\n", | ||||
|     "\n", | ||||
|     "    fig = plt.figure(figsize=figsize)\n", | ||||
|     "    \n", | ||||
|     "    cur_ax = fig.add_subplot(1, 1, 1)\n", | ||||
|     "    cur_ax.scatter(xaxis, yaxis, color=\"k\", s=10, alpha=0.9, label=\"Timestamp={:02d}\".format(timestamp))\n", | ||||
|     "    cur_ax.set_xlabel(\"X\", fontsize=LabelSize)\n", | ||||
|     "    cur_ax.set_ylabel(\"f(X)\", rotation=0, fontsize=LabelSize)\n", | ||||
|     "    cur_ax.set_xlim(-6, 6)\n", | ||||
|     "    cur_ax.set_ylim(-40, 40)\n", | ||||
|     "    for tick in cur_ax.xaxis.get_major_ticks():\n", | ||||
|     "        tick.label.set_fontsize(LabelSize - font_gap)\n", | ||||
|     "        tick.label.set_rotation(10)\n", | ||||
|     "    for tick in cur_ax.yaxis.get_major_ticks():\n", | ||||
|     "        tick.label.set_fontsize(LabelSize - font_gap)\n", | ||||
|     "        \n", | ||||
|     "    plt.legend(loc=1, fontsize=LegendFontsize)\n", | ||||
|     "    fig.savefig(str(save_path) + '.pdf', dpi=dpi, bbox_inches=\"tight\", format=\"pdf\")\n", | ||||
|     "    fig.savefig(str(save_path) + '.png', dpi=dpi, bbox_inches=\"tight\", format=\"png\")\n", | ||||
|     "    plt.close(\"all\")\n", | ||||
|     "\n", | ||||
|     "\n", | ||||
|     "def visualize_env(save_dir):\n", | ||||
|     "    save_dir.mkdir(parents=True, exist_ok=True)\n", | ||||
|     "    dynamic_env, function = create_example_v1(100, num_per_task=500)\n", | ||||
|     "    \n", | ||||
|     "    additional_xaxis = np.arange(-6, 6, 0.1)\n", | ||||
|     "    for timestamp, dataset in dynamic_env:\n", | ||||
|     "        num = dataset.shape[0]\n", | ||||
|     "        # timeaxis = (torch.zeros(num) + timestamp).numpy()\n", | ||||
|     "        xaxis = dataset[:,0].numpy()\n", | ||||
|     "        xaxis = np.concatenate((additional_xaxis, xaxis))\n", | ||||
|     "        # compute the ground truth\n", | ||||
|     "        function.set_timestamp(timestamp)\n", | ||||
|     "        yaxis = function(xaxis)\n", | ||||
|     "        draw_fig(save_dir, timestamp, xaxis, yaxis)\n", | ||||
|     "\n", | ||||
|     "home_dir = Path.home()\n", | ||||
|     "desktop_dir = home_dir / 'Desktop'\n", | ||||
|     "vis_save_dir = desktop_dir / 'vis-synthetic'\n", | ||||
|     "visualize_env(vis_save_dir)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 3, | ||||
|    "id": "rapid-uruguay", | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "ffmpeg -y -i /Users/xuanyidong/Desktop/vis-synthetic/%04d.png -pix_fmt yuv420p -vf fps=2 -vf scale=1000:1000 -vb 5000k /Users/xuanyidong/Desktop/vis-synthetic/vis.mp4\n" | ||||
|      ] | ||||
|     }, | ||||
|     { | ||||
|      "data": { | ||||
|       "text/plain": [ | ||||
|        "0" | ||||
|       ] | ||||
|      }, | ||||
|      "execution_count": 3, | ||||
|      "metadata": {}, | ||||
|      "output_type": "execute_result" | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "# Plot the data\n", | ||||
|     "cmd = 'ffmpeg -y -i {:}/%04d.png -pix_fmt yuv420p -vf fps=2 -vf scale=1000:1000 -vb 5000k {:}/vis.mp4'.format(vis_save_dir, vis_save_dir)\n", | ||||
|     "print(cmd)\n", | ||||
|     "os.system(cmd)" | ||||
|    ] | ||||
|   } | ||||
|  ], | ||||
|  "metadata": { | ||||
|   "kernelspec": { | ||||
|    "display_name": "Python 3", | ||||
|    "language": "python", | ||||
|    "name": "python3" | ||||
|   }, | ||||
|   "language_info": { | ||||
|    "codemirror_mode": { | ||||
|     "name": "ipython", | ||||
|     "version": 3 | ||||
|    }, | ||||
|    "file_extension": ".py", | ||||
|    "mimetype": "text/x-python", | ||||
|    "name": "python", | ||||
|    "nbconvert_exporter": "python", | ||||
|    "pygments_lexer": "ipython3", | ||||
|    "version": "3.8.8" | ||||
|   } | ||||
|  }, | ||||
|  "nbformat": 4, | ||||
|  "nbformat_minor": 5 | ||||
| } | ||||
| @@ -0,0 +1,277 @@ | ||||
| { | ||||
|  "cells": [ | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 1, | ||||
|    "id": "3f754c96", | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "import torch\n", | ||||
|     "from xautodl import spaces\n", | ||||
|     "from xautodl.xlayers import super_core\n", | ||||
|     "\n", | ||||
|     "def _create_stel(input_dim, output_dim, order):\n", | ||||
|     "    return super_core.SuperSequential(\n", | ||||
|     "        super_core.SuperLinear(input_dim, output_dim),\n", | ||||
|     "        super_core.SuperTransformerEncoderLayer(\n", | ||||
|     "            output_dim,\n", | ||||
|     "            num_heads=spaces.Categorical(2, 4, 6),\n", | ||||
|     "            mlp_hidden_multiplier=spaces.Categorical(1, 2, 4),\n", | ||||
|     "            order=order,\n", | ||||
|     "        ),\n", | ||||
|     "    )" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 2, | ||||
|    "id": "81d42f4b", | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "batch, seq_dim, input_dim = 1, 4, 6\n", | ||||
|     "order = super_core.LayerOrder.PreNorm" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 3, | ||||
|    "id": "8056b37c", | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "SuperSequential(\n", | ||||
|       "  (0): SuperSequential(\n", | ||||
|       "    (0): SuperLinear(in_features=6, out_features=Categorical(candidates=[12, 24, 36], default_index=None), bias=True)\n", | ||||
|       "    (1): SuperTransformerEncoderLayer(\n", | ||||
|       "      (norm1): SuperLayerNorm1D(shape=Categorical(candidates=[12, 24, 36], default_index=None), eps=1e-06, elementwise_affine=True)\n", | ||||
|       "      (mha): SuperSelfAttention(\n", | ||||
|       "        input_dim=Categorical(candidates=[12, 24, 36], default_index=None), proj_dim=Categorical(candidates=[12, 24, 36], default_index=None), num_heads=Categorical(candidates=[2, 4, 6], default_index=None), mask=False, infinity=1000000000.0\n", | ||||
|       "        (q_fc): SuperLinear(in_features=Categorical(candidates=[12, 24, 36], default_index=None), out_features=Categorical(candidates=[12, 24, 36], default_index=None), bias=False)\n", | ||||
|       "        (k_fc): SuperLinear(in_features=Categorical(candidates=[12, 24, 36], default_index=None), out_features=Categorical(candidates=[12, 24, 36], default_index=None), bias=False)\n", | ||||
|       "        (v_fc): SuperLinear(in_features=Categorical(candidates=[12, 24, 36], default_index=None), out_features=Categorical(candidates=[12, 24, 36], default_index=None), bias=False)\n", | ||||
|       "        (attn_drop): SuperDrop(p=0.0, dims=[-1, -1, -1, -1], recover=True)\n", | ||||
|       "      )\n", | ||||
|       "      (drop): Dropout(p=0.0, inplace=False)\n", | ||||
|       "      (norm2): SuperLayerNorm1D(shape=Categorical(candidates=[12, 24, 36], default_index=None), eps=1e-06, elementwise_affine=True)\n", | ||||
|       "      (mlp): SuperMLPv2(\n", | ||||
|       "        in_features=Categorical(candidates=[12, 24, 36], default_index=None), hidden_multiplier=Categorical(candidates=[1, 2, 4], default_index=None), out_features=Categorical(candidates=[12, 24, 36], default_index=None), drop=None, fc1 -> act -> drop -> fc2 -> drop,\n", | ||||
|       "        (_params): ParameterDict(\n", | ||||
|       "            (fc1_super_weight): Parameter containing: [torch.FloatTensor of size 144x36]\n", | ||||
|       "            (fc1_super_bias): Parameter containing: [torch.FloatTensor of size 144]\n", | ||||
|       "            (fc2_super_weight): Parameter containing: [torch.FloatTensor of size 36x144]\n", | ||||
|       "            (fc2_super_bias): Parameter containing: [torch.FloatTensor of size 36]\n", | ||||
|       "        )\n", | ||||
|       "        (act): GELU()\n", | ||||
|       "        (drop): Dropout(p=0.0, inplace=False)\n", | ||||
|       "      )\n", | ||||
|       "    )\n", | ||||
|       "  )\n", | ||||
|       "  (1): SuperSequential(\n", | ||||
|       "    (0): SuperLinear(in_features=Categorical(candidates=[12, 24, 36], default_index=None), out_features=Categorical(candidates=[24, 36, 48], default_index=None), bias=True)\n", | ||||
|       "    (1): SuperTransformerEncoderLayer(\n", | ||||
|       "      (norm1): SuperLayerNorm1D(shape=Categorical(candidates=[24, 36, 48], default_index=None), eps=1e-06, elementwise_affine=True)\n", | ||||
|       "      (mha): SuperSelfAttention(\n", | ||||
|       "        input_dim=Categorical(candidates=[24, 36, 48], default_index=None), proj_dim=Categorical(candidates=[24, 36, 48], default_index=None), num_heads=Categorical(candidates=[2, 4, 6], default_index=None), mask=False, infinity=1000000000.0\n", | ||||
|       "        (q_fc): SuperLinear(in_features=Categorical(candidates=[24, 36, 48], default_index=None), out_features=Categorical(candidates=[24, 36, 48], default_index=None), bias=False)\n", | ||||
|       "        (k_fc): SuperLinear(in_features=Categorical(candidates=[24, 36, 48], default_index=None), out_features=Categorical(candidates=[24, 36, 48], default_index=None), bias=False)\n", | ||||
|       "        (v_fc): SuperLinear(in_features=Categorical(candidates=[24, 36, 48], default_index=None), out_features=Categorical(candidates=[24, 36, 48], default_index=None), bias=False)\n", | ||||
|       "        (attn_drop): SuperDrop(p=0.0, dims=[-1, -1, -1, -1], recover=True)\n", | ||||
|       "      )\n", | ||||
|       "      (drop): Dropout(p=0.0, inplace=False)\n", | ||||
|       "      (norm2): SuperLayerNorm1D(shape=Categorical(candidates=[24, 36, 48], default_index=None), eps=1e-06, elementwise_affine=True)\n", | ||||
|       "      (mlp): SuperMLPv2(\n", | ||||
|       "        in_features=Categorical(candidates=[24, 36, 48], default_index=None), hidden_multiplier=Categorical(candidates=[1, 2, 4], default_index=None), out_features=Categorical(candidates=[24, 36, 48], default_index=None), drop=None, fc1 -> act -> drop -> fc2 -> drop,\n", | ||||
|       "        (_params): ParameterDict(\n", | ||||
|       "            (fc1_super_weight): Parameter containing: [torch.FloatTensor of size 192x48]\n", | ||||
|       "            (fc1_super_bias): Parameter containing: [torch.FloatTensor of size 192]\n", | ||||
|       "            (fc2_super_weight): Parameter containing: [torch.FloatTensor of size 48x192]\n", | ||||
|       "            (fc2_super_bias): Parameter containing: [torch.FloatTensor of size 48]\n", | ||||
|       "        )\n", | ||||
|       "        (act): GELU()\n", | ||||
|       "        (drop): Dropout(p=0.0, inplace=False)\n", | ||||
|       "      )\n", | ||||
|       "    )\n", | ||||
|       "  )\n", | ||||
|       "  (2): SuperSequential(\n", | ||||
|       "    (0): SuperLinear(in_features=Categorical(candidates=[24, 36, 48], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), bias=True)\n", | ||||
|       "    (1): SuperTransformerEncoderLayer(\n", | ||||
|       "      (norm1): SuperLayerNorm1D(shape=Categorical(candidates=[36, 72, 100], default_index=None), eps=1e-06, elementwise_affine=True)\n", | ||||
|       "      (mha): SuperSelfAttention(\n", | ||||
|       "        input_dim=Categorical(candidates=[36, 72, 100], default_index=None), proj_dim=Categorical(candidates=[36, 72, 100], default_index=None), num_heads=Categorical(candidates=[2, 4, 6], default_index=None), mask=False, infinity=1000000000.0\n", | ||||
|       "        (q_fc): SuperLinear(in_features=Categorical(candidates=[36, 72, 100], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), bias=False)\n", | ||||
|       "        (k_fc): SuperLinear(in_features=Categorical(candidates=[36, 72, 100], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), bias=False)\n", | ||||
|       "        (v_fc): SuperLinear(in_features=Categorical(candidates=[36, 72, 100], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), bias=False)\n", | ||||
|       "        (attn_drop): SuperDrop(p=0.0, dims=[-1, -1, -1, -1], recover=True)\n", | ||||
|       "      )\n", | ||||
|       "      (drop): Dropout(p=0.0, inplace=False)\n", | ||||
|       "      (norm2): SuperLayerNorm1D(shape=Categorical(candidates=[36, 72, 100], default_index=None), eps=1e-06, elementwise_affine=True)\n", | ||||
|       "      (mlp): SuperMLPv2(\n", | ||||
|       "        in_features=Categorical(candidates=[36, 72, 100], default_index=None), hidden_multiplier=Categorical(candidates=[1, 2, 4], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), drop=None, fc1 -> act -> drop -> fc2 -> drop,\n", | ||||
|       "        (_params): ParameterDict(\n", | ||||
|       "            (fc1_super_weight): Parameter containing: [torch.FloatTensor of size 400x100]\n", | ||||
|       "            (fc1_super_bias): Parameter containing: [torch.FloatTensor of size 400]\n", | ||||
|       "            (fc2_super_weight): Parameter containing: [torch.FloatTensor of size 100x400]\n", | ||||
|       "            (fc2_super_bias): Parameter containing: [torch.FloatTensor of size 100]\n", | ||||
|       "        )\n", | ||||
|       "        (act): GELU()\n", | ||||
|       "        (drop): Dropout(p=0.0, inplace=False)\n", | ||||
|       "      )\n", | ||||
|       "    )\n", | ||||
|       "  )\n", | ||||
|       ")\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "out1_dim = spaces.Categorical(12, 24, 36)\n", | ||||
|     "out2_dim = spaces.Categorical(24, 36, 48)\n", | ||||
|     "out3_dim = spaces.Categorical(36, 72, 100)\n", | ||||
|     "layer1 = _create_stel(input_dim, out1_dim, order)\n", | ||||
|     "layer2 = _create_stel(out1_dim, out2_dim, order)\n", | ||||
|     "layer3 = _create_stel(out2_dim, out3_dim, order)\n", | ||||
|     "model = super_core.SuperSequential(layer1, layer2, layer3)\n", | ||||
|     "print(model)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "id": "4fd53a7c", | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "> \u001b[0;32m/Users/xuanyidong/anaconda3/lib/python3.8/site-packages/xautodl-0.9.9-py3.8.egg/xautodl/xlayers/super_transformer.py\u001b[0m(116)\u001b[0;36mforward_raw\u001b[0;34m()\u001b[0m\n", | ||||
|       "\u001b[0;32m    114 \u001b[0;31m              \u001b[0;32mimport\u001b[0m \u001b[0mpdb\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mpdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | ||||
|       "\u001b[0m\u001b[0;32m    115 \u001b[0;31m            \u001b[0;31m# feed-forward layer -- MLP\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | ||||
|       "\u001b[0m\u001b[0;32m--> 116 \u001b[0;31m            \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | ||||
|       "\u001b[0m\u001b[0;32m    117 \u001b[0;31m            \u001b[0mouts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmlp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | ||||
|       "\u001b[0m\u001b[0;32m    118 \u001b[0;31m        \u001b[0;32melif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_order\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0mLayerOrder\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mPostNorm\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | ||||
|       "\u001b[0m\n", | ||||
|       "ipdb> print(self)\n", | ||||
|       "SuperTransformerEncoderLayer(\n", | ||||
|       "  (norm1): SuperLayerNorm1D(shape=Categorical(candidates=[36, 72, 100], default_index=None), eps=1e-06, elementwise_affine=True)\n", | ||||
|       "  (mha): SuperSelfAttention(\n", | ||||
|       "    input_dim=Categorical(candidates=[36, 72, 100], default_index=None), proj_dim=Categorical(candidates=[36, 72, 100], default_index=None), num_heads=Categorical(candidates=[2, 4, 6], default_index=None), mask=False, infinity=1000000000.0\n", | ||||
|       "    (q_fc): SuperLinear(in_features=Categorical(candidates=[36, 72, 100], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), bias=False)\n", | ||||
|       "    (k_fc): SuperLinear(in_features=Categorical(candidates=[36, 72, 100], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), bias=False)\n", | ||||
|       "    (v_fc): SuperLinear(in_features=Categorical(candidates=[36, 72, 100], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), bias=False)\n", | ||||
|       "    (attn_drop): SuperDrop(p=0.0, dims=[-1, -1, -1, -1], recover=True)\n", | ||||
|       "  )\n", | ||||
|       "  (drop): Dropout(p=0.0, inplace=False)\n", | ||||
|       "  (norm2): SuperLayerNorm1D(shape=Categorical(candidates=[36, 72, 100], default_index=None), eps=1e-06, elementwise_affine=True)\n", | ||||
|       "  (mlp): SuperMLPv2(\n", | ||||
|       "    in_features=Categorical(candidates=[36, 72, 100], default_index=None), hidden_multiplier=Categorical(candidates=[1, 2, 4], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), drop=None, fc1 -> act -> drop -> fc2 -> drop,\n", | ||||
|       "    (_params): ParameterDict(\n", | ||||
|       "        (fc1_super_weight): Parameter containing: [torch.FloatTensor of size 400x100]\n", | ||||
|       "        (fc1_super_bias): Parameter containing: [torch.FloatTensor of size 400]\n", | ||||
|       "        (fc2_super_weight): Parameter containing: [torch.FloatTensor of size 100x400]\n", | ||||
|       "        (fc2_super_bias): Parameter containing: [torch.FloatTensor of size 100]\n", | ||||
|       "    )\n", | ||||
|       "    (act): GELU()\n", | ||||
|       "    (drop): Dropout(p=0.0, inplace=False)\n", | ||||
|       "  )\n", | ||||
|       ")\n", | ||||
|       "ipdb> print(inputs.shape)\n", | ||||
|       "torch.Size([1, 4, 100])\n", | ||||
|       "ipdb> print(x.shape)\n", | ||||
|       "torch.Size([1, 4, 96])\n", | ||||
|       "ipdb> print(self.mha)\n", | ||||
|       "SuperSelfAttention(\n", | ||||
|       "  input_dim=Categorical(candidates=[36, 72, 100], default_index=None), proj_dim=Categorical(candidates=[36, 72, 100], default_index=None), num_heads=Categorical(candidates=[2, 4, 6], default_index=None), mask=False, infinity=1000000000.0\n", | ||||
|       "  (q_fc): SuperLinear(in_features=Categorical(candidates=[36, 72, 100], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), bias=False)\n", | ||||
|       "  (k_fc): SuperLinear(in_features=Categorical(candidates=[36, 72, 100], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), bias=False)\n", | ||||
|       "  (v_fc): SuperLinear(in_features=Categorical(candidates=[36, 72, 100], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), bias=False)\n", | ||||
|       "  (attn_drop): SuperDrop(p=0.0, dims=[-1, -1, -1, -1], recover=True)\n", | ||||
|       ")\n", | ||||
|       "ipdb> print(self.mha.candidate)\n", | ||||
|       "*** AttributeError: 'SuperSelfAttention' object has no attribute 'candidate'\n", | ||||
|       "ipdb> print(self.mha.abstract_candidate)\n", | ||||
|       "*** AttributeError: 'SuperSelfAttention' object has no attribute 'abstract_candidate'\n", | ||||
|       "ipdb> print(self.mha._abstract_child)\n", | ||||
|       "None\n", | ||||
|       "ipdb> print(self.abstract_child)\n", | ||||
|       "None\n", | ||||
|       "ipdb> print(self.abstract_child.abstract_child)\n", | ||||
|       "*** AttributeError: 'NoneType' object has no attribute 'abstract_child'\n", | ||||
|       "ipdb> print(self.mha)\n", | ||||
|       "SuperSelfAttention(\n", | ||||
|       "  input_dim=Categorical(candidates=[36, 72, 100], default_index=None), proj_dim=Categorical(candidates=[36, 72, 100], default_index=None), num_heads=Categorical(candidates=[2, 4, 6], default_index=None), mask=False, infinity=1000000000.0\n", | ||||
|       "  (q_fc): SuperLinear(in_features=Categorical(candidates=[36, 72, 100], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), bias=False)\n", | ||||
|       "  (k_fc): SuperLinear(in_features=Categorical(candidates=[36, 72, 100], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), bias=False)\n", | ||||
|       "  (v_fc): SuperLinear(in_features=Categorical(candidates=[36, 72, 100], default_index=None), out_features=Categorical(candidates=[36, 72, 100], default_index=None), bias=False)\n", | ||||
|       "  (attn_drop): SuperDrop(p=0.0, dims=[-1, -1, -1, -1], recover=True)\n", | ||||
|       ")\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "inputs = torch.rand(batch, seq_dim, input_dim)\n", | ||||
|     "outputs = model(inputs)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "id": "05332b98", | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "abstract_space = model.abstract_search_space\n", | ||||
|     "abstract_space.clean_last()\n", | ||||
|     "abstract_child = abstract_space.random(reuse_last=True)\n", | ||||
|     "# print(\"The abstract child program is:\\n{:}\".format(abstract_child))\n", | ||||
|     "model.enable_candidate()\n", | ||||
|     "model.set_super_run_type(super_core.SuperRunMode.Candidate)\n", | ||||
|     "model.apply_candidate(abstract_child)\n", | ||||
|     "outputs = model(inputs)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "id": "3289f938", | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "print(outputs.shape)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "id": "36951cdf", | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [] | ||||
|   } | ||||
|  ], | ||||
|  "metadata": { | ||||
|   "kernelspec": { | ||||
|    "display_name": "Python 3", | ||||
|    "language": "python", | ||||
|    "name": "python3" | ||||
|   }, | ||||
|   "language_info": { | ||||
|    "codemirror_mode": { | ||||
|     "name": "ipython", | ||||
|     "version": 3 | ||||
|    }, | ||||
|    "file_extension": ".py", | ||||
|    "mimetype": "text/x-python", | ||||
|    "name": "python", | ||||
|    "nbconvert_exporter": "python", | ||||
|    "pygments_lexer": "ipython3", | ||||
|    "version": "3.8.8" | ||||
|   } | ||||
|  }, | ||||
|  "nbformat": 4, | ||||
|  "nbformat_minor": 5 | ||||
| } | ||||
		Reference in New Issue
	
	Block a user