xautodl/notebooks/TOT/time-curve.py
2021-05-26 01:53:44 -07:00

130 lines
4.2 KiB
Python

import os
import re
import sys
import torch
import qlib
import pprint
from collections import OrderedDict
import numpy as np
import pandas as pd
from pathlib import Path
# __file__ = os.path.dirname(os.path.realpath("__file__"))
note_dir = Path(__file__).parent.resolve()
root_dir = (Path(__file__).parent / ".." / "..").resolve()
lib_dir = (root_dir / "lib").resolve()
print("The root path: {:}".format(root_dir))
print("The library path: {:}".format(lib_dir))
assert lib_dir.exists(), "{:} does not exist".format(lib_dir)
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
import qlib
from qlib import config as qconfig
from qlib.workflow import R
qlib.init(provider_uri="~/.qlib/qlib_data/cn_data", region=qconfig.REG_CN)
from utils.qlib_utils import QResult
def filter_finished(recorders):
returned_recorders = dict()
not_finished = 0
for key, recorder in recorders.items():
if recorder.status == "FINISHED":
returned_recorders[key] = recorder
else:
not_finished += 1
return returned_recorders, not_finished
def add_to_dict(xdict, timestamp, value):
date = timestamp.date().strftime("%Y-%m-%d")
if date in xdict:
raise ValueError("This date [{:}] is already in the dict".format(date))
xdict[date] = value
def query_info(save_dir, verbose, name_filter, key_map):
if isinstance(save_dir, list):
results = []
for x in save_dir:
x = query_info(x, verbose, name_filter, key_map)
results.extend(x)
return results
# Here, the save_dir must be a string
R.set_uri(str(save_dir))
experiments = R.list_experiments()
if verbose:
print("There are {:} experiments.".format(len(experiments)))
qresults = []
for idx, (key, experiment) in enumerate(experiments.items()):
if experiment.id == "0":
continue
if (
name_filter is not None
and re.fullmatch(name_filter, experiment.name) is None
):
continue
recorders = experiment.list_recorders()
recorders, not_finished = filter_finished(recorders)
if verbose:
print(
"====>>>> {:02d}/{:02d}-th experiment {:9s} has {:02d}/{:02d} finished recorders.".format(
idx + 1,
len(experiments),
experiment.name,
len(recorders),
len(recorders) + not_finished,
)
)
result = QResult(experiment.name)
for recorder_id, recorder in recorders.items():
file_names = ["results-train.pkl", "results-valid.pkl", "results-test.pkl"]
date2IC = OrderedDict()
for file_name in file_names:
xtemp = recorder.load_object(file_name)["all-IC"]
timestamps, values = xtemp.index.tolist(), xtemp.tolist()
for timestamp, value in zip(timestamps, values):
add_to_dict(date2IC, timestamp, value)
result.update(recorder.list_metrics(), key_map)
result.append_path(
os.path.join(recorder.uri, recorder.experiment_id, recorder.id)
)
result.append_date2ICs(date2IC)
if not len(result):
print("There are no valid recorders for {:}".format(experiment))
continue
else:
if verbose:
print(
"There are {:} valid recorders for {:}".format(
len(recorders), experiment.name
)
)
qresults.append(result)
return qresults
##
paths = [root_dir / "outputs" / "qlib-baselines-csi300"]
paths = [path.resolve() for path in paths]
print(paths)
key_map = dict()
for xset in ("train", "valid", "test"):
key_map["{:}-mean-IC".format(xset)] = "IC ({:})".format(xset)
key_map["{:}-mean-ICIR".format(xset)] = "ICIR ({:})".format(xset)
qresults = query_info(paths, False, "TSF-2x24-drop0_0s.*-.*-01", key_map)
print("Find {:} results".format(len(qresults)))
times = []
for qresult in qresults:
times.append(qresult.name.split("0_0s")[-1])
print(times)
save_path = os.path.join(note_dir, "temp-time-x.pth")
torch.save(qresults, save_path)
print(save_path)