{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "afraid-minutes",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The root path: /Users/xuanyidong/Desktop/AutoDL-Projects\n",
      "The library path: /Users/xuanyidong/Desktop/AutoDL-Projects/lib\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[70363:MainThread](2021-04-12 13:25:01,065) INFO - qlib.Initialization - [config.py:276] - default_conf: client.\n",
      "[70363:MainThread](2021-04-12 13:25:01,069) WARNING - qlib.Initialization - [config.py:291] - redis connection failed(host=127.0.0.1 port=6379), cache will not be used!\n",
      "[70363:MainThread](2021-04-12 13:25:01,085) INFO - qlib.Initialization - [__init__.py:46] - qlib successfully initialized based on client settings.\n",
      "[70363:MainThread](2021-04-12 13:25:01,092) INFO - qlib.Initialization - [__init__.py:47] - data_path=/Users/xuanyidong/.qlib/qlib_data/cn_data\n"
     ]
    }
   ],
   "source": [
    "#\n",
    "# Exhaustive Search Results\n",
    "#\n",
    "import os\n",
    "import re\n",
    "import sys\n",
    "import qlib\n",
    "import pprint\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from pathlib import Path\n",
    "\n",
    "__file__ = os.path.dirname(os.path.realpath(\"__file__\"))\n",
    "root_dir = (Path(__file__).parent / \"..\").resolve()\n",
    "lib_dir = (root_dir / \"lib\").resolve()\n",
    "print(\"The root path: {:}\".format(root_dir))\n",
    "print(\"The library path: {:}\".format(lib_dir))\n",
    "assert lib_dir.exists(), \"{:} does not exist\".format(lib_dir)\n",
    "if str(lib_dir) not in sys.path:\n",
    "    sys.path.insert(0, str(lib_dir))\n",
    "\n",
    "import qlib\n",
    "from qlib import config as qconfig\n",
    "from qlib.workflow import R\n",
    "qlib.init(provider_uri='~/.qlib/qlib_data/cn_data', region=qconfig.REG_CN)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "hidden-exemption",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.qlib_utils import QResult"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "continental-drain",
   "metadata": {},
   "outputs": [],
   "source": [
    "def filter_finished(recorders):\n",
    "    returned_recorders = dict()\n",
    "    not_finished = 0\n",
    "    for key, recorder in recorders.items():\n",
    "        if recorder.status == \"FINISHED\":\n",
    "            returned_recorders[key] = recorder\n",
    "        else:\n",
    "            not_finished += 1\n",
    "    return returned_recorders, not_finished\n",
    "\n",
    "def query_info(save_dir, verbose, name_filter, key_map):\n",
    "    if isinstance(save_dir, list):\n",
    "        results = []\n",
    "        for x in save_dir:\n",
    "            x = query_info(x, verbose, name_filter, key_map)\n",
    "            results.extend(x)\n",
    "        return results\n",
    "    # Here, the save_dir must be a string\n",
    "    R.set_uri(str(save_dir))\n",
    "    experiments = R.list_experiments()\n",
    "\n",
    "    if verbose:\n",
    "        print(\"There are {:} experiments.\".format(len(experiments)))\n",
    "    qresults = []\n",
    "    for idx, (key, experiment) in enumerate(experiments.items()):\n",
    "        if experiment.id == \"0\":\n",
    "            continue\n",
    "        if name_filter is not None and re.fullmatch(name_filter, experiment.name) is None:\n",
    "            continue\n",
    "        recorders = experiment.list_recorders()\n",
    "        recorders, not_finished = filter_finished(recorders)\n",
    "        if verbose:\n",
    "            print(\n",
    "                \"====>>>> {:02d}/{:02d}-th experiment {:9s} has {:02d}/{:02d} finished recorders.\".format(\n",
    "                    idx + 1,\n",
    "                    len(experiments),\n",
    "                    experiment.name,\n",
    "                    len(recorders),\n",
    "                    len(recorders) + not_finished,\n",
    "                )\n",
    "            )\n",
    "        result = QResult(experiment.name)\n",
    "        for recorder_id, recorder in recorders.items():\n",
    "            result.update(recorder.list_metrics(), key_map)\n",
    "            result.append_path(\n",
    "                os.path.join(recorder.uri, recorder.experiment_id, recorder.id)\n",
    "            )\n",
    "        if not len(result):\n",
    "            print(\"There are no valid recorders for {:}\".format(experiment))\n",
    "            continue\n",
    "        else:\n",
    "            if verbose:\n",
    "                print(\n",
    "                    \"There are {:} valid recorders for {:}\".format(\n",
    "                        len(recorders), experiment.name\n",
    "                    )\n",
    "                )\n",
    "        qresults.append(result)\n",
    "    return qresults"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "filled-multiple",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[70363:MainThread](2021-04-12 13:25:01,647) INFO - qlib.workflow - [expm.py:290] - <mlflow.tracking.client.MlflowClient object at 0x7fa920e56820>\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[PosixPath('/Users/xuanyidong/Desktop/AutoDL-Projects/outputs/qlib-baselines-csi300')]\n"
     ]
    }
   ],
   "source": [
    "paths = [root_dir / 'outputs' / 'qlib-baselines-csi300']\n",
    "paths = [path.resolve() for path in paths]\n",
    "print(paths)\n",
    "\n",
    "key_map = dict()\n",
    "for xset in (\"train\", \"valid\", \"test\"):\n",
    "    key_map[\"{:}-mean-IC\".format(xset)] = \"IC ({:})\".format(xset)\n",
    "    key_map[\"{:}-mean-ICIR\".format(xset)] = \"ICIR ({:})\".format(xset)\n",
    "\n",
    "qresults = query_info(paths, False, 'TSF-.*', key_map)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "intimate-approval",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib\n",
    "from matplotlib import cm\n",
    "matplotlib.use(\"agg\")\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.ticker as ticker"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "supreme-basis",
   "metadata": {},
   "outputs": [],
   "source": [
    "def vis_dropouts(qresults, basenames, name2suffix, save_path):\n",
    "    save_dir = (save_path / '..').resolve()\n",
    "    save_dir.mkdir(parents=True, exist_ok=True)\n",
    "    print('There are {:} qlib-results'.format(len(qresults)))\n",
    "    \n",
    "    name2qresult = dict()\n",
    "    for qresult in qresults:\n",
    "        name2qresult[qresult.name] = qresult\n",
    "    # sort architectures\n",
    "    accuracies = []\n",
    "    for basename in basenames:\n",
    "        qresult = name2qresult[basename + '-drop0_0']\n",
    "        accuracies.append(qresult['ICIR (train)'])\n",
    "    sorted_basenames = sorted(basenames, key=lambda x: accuracies[basenames.index(x)])\n",
    "    \n",
    "    dpi, width, height = 200, 4000, 2000\n",
    "    figsize = width / float(dpi), height / float(dpi)\n",
    "    LabelSize, LegendFontsize = 22, 22\n",
    "    font_gap = 5\n",
    "    colors = ['k', 'r']\n",
    "    markers = ['*', 'o']\n",
    "    \n",
    "    fig = plt.figure(figsize=figsize)\n",
    "    \n",
    "    def plot_ax(cur_ax, train_or_test):\n",
    "        for idx, (legend, suffix) in enumerate(name2suffix.items()):\n",
    "            x_values = list(range(len(sorted_basenames)))\n",
    "            y_values = []\n",
    "            for i, name in enumerate(sorted_basenames):\n",
    "                name = '{:}{:}'.format(name, suffix)\n",
    "                qresult = name2qresult[name]\n",
    "                if train_or_test:\n",
    "                    value = qresult['IC (train)']\n",
    "                else:\n",
    "                    value = qresult['IC (valid)']\n",
    "                y_values.append(value)\n",
    "            cur_ax.plot(x_values, y_values, c=colors[idx])\n",
    "            cur_ax.scatter(x_values, y_values,\n",
    "                           marker=markers[idx], s=3, c=colors[idx], alpha=0.9,\n",
    "                           label=legend)\n",
    "        cur_ax.set_yticks(np.arange(4, 11, 2))\n",
    "        cur_ax.set_xlabel(\"sorted architectures\", fontsize=LabelSize)\n",
    "        cur_ax.set_ylabel(\"{:} IC (%)\".format('training' if train_or_test else 'validation'), fontsize=LabelSize)\n",
    "        for tick in cur_ax.xaxis.get_major_ticks():\n",
    "            tick.label.set_fontsize(LabelSize - font_gap)\n",
    "        for tick in cur_ax.yaxis.get_major_ticks():\n",
    "            tick.label.set_fontsize(LabelSize - font_gap)\n",
    "        cur_ax.legend(loc=4, fontsize=LegendFontsize)\n",
    "    ax = fig.add_subplot(1, 2, 1)\n",
    "    plot_ax(ax, True)\n",
    "    ax = fig.add_subplot(1, 2, 2)\n",
    "    plot_ax(ax, False)\n",
    "    # fig.tight_layout()\n",
    "    # plt.subplots_adjust(wspace=0.05)#, hspace=0.4)\n",
    "    fig.savefig(save_path, dpi=dpi, bbox_inches=\"tight\", format=\"pdf\")\n",
    "    plt.close(\"all\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "shared-envelope",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'TSF-3x48', 'TSF-2x64', 'TSF-2x12', 'TSF-8x48', 'TSF-6x32', 'TSF-4x48', 'TSF-8x6', 'TSF-4x6', 'TSF-2x32', 'TSF-5x12', 'TSF-5x64', 'TSF-1x64', 'TSF-2x24', 'TSF-8x24', 'TSF-4x12', 'TSF-6x12', 'TSF-1x32', 'TSF-5x32', 'TSF-3x24', 'TSF-8x12', 'TSF-5x48', 'TSF-6x64', 'TSF-7x64', 'TSF-7x48', 'TSF-1x6', 'TSF-2x48', 'TSF-7x24', 'TSF-3x32', 'TSF-1x24', 'TSF-4x64', 'TSF-3x12', 'TSF-8x64', 'TSF-4x32', 'TSF-5x6', 'TSF-7x6', 'TSF-7x12', 'TSF-3x6', 'TSF-4x24', 'TSF-6x48', 'TSF-6x6', 'TSF-1x48', 'TSF-1x12', 'TSF-7x32', 'TSF-5x24', 'TSF-2x6', 'TSF-6x24', 'TSF-3x64', 'TSF-8x32'}\n",
      "The Desktop is at: /Users/xuanyidong/Desktop\n",
      "There are 104 qlib-results\n"
     ]
    }
   ],
   "source": [
    "# Visualization\n",
    "names = [qresult.name for qresult in qresults]\n",
    "base_names = set()\n",
    "for name in names:\n",
    "    base_name = name.split('-drop')[0]\n",
    "    base_names.add(base_name)\n",
    "print(base_names)\n",
    "# filter\n",
    "filtered_base_names = set()\n",
    "for base_name in base_names:\n",
    "    if (base_name + '-drop0_0') in names and (base_name + '-drop0.1_0') in names:\n",
    "        filtered_base_names.add(base_name)\n",
    "    else:\n",
    "        print('Cannot find all names for {:}'.format(base_name))\n",
    "# print(filtered_base_names)\n",
    "home_dir = Path.home()\n",
    "desktop_dir = home_dir / 'Desktop'\n",
    "print('The Desktop is at: {:}'.format(desktop_dir))\n",
    "\n",
    "vis_dropouts(qresults, list(filtered_base_names),\n",
    "             {'No-dropout': '-drop0_0',\n",
    "              'Ratio=0.1' : '-drop0.1_0'},\n",
    "             desktop_dir / 'es_csi300_drop.pdf')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}