130 lines
4.2 KiB
Python
130 lines
4.2 KiB
Python
import os
|
|
import re
|
|
import sys
|
|
import torch
|
|
import qlib
|
|
import pprint
|
|
from collections import OrderedDict
|
|
import numpy as np
|
|
import pandas as pd
|
|
|
|
from pathlib import Path
|
|
|
|
# __file__ = os.path.dirname(os.path.realpath("__file__"))
|
|
note_dir = Path(__file__).parent.resolve()
|
|
root_dir = (Path(__file__).parent / ".." / "..").resolve()
|
|
lib_dir = (root_dir / "lib").resolve()
|
|
print("The root path: {:}".format(root_dir))
|
|
print("The library path: {:}".format(lib_dir))
|
|
assert lib_dir.exists(), "{:} does not exist".format(lib_dir)
|
|
if str(lib_dir) not in sys.path:
|
|
sys.path.insert(0, str(lib_dir))
|
|
|
|
import qlib
|
|
from qlib import config as qconfig
|
|
from qlib.workflow import R
|
|
|
|
qlib.init(provider_uri="~/.qlib/qlib_data/cn_data", region=qconfig.REG_CN)
|
|
|
|
from utils.qlib_utils import QResult
|
|
|
|
|
|
def filter_finished(recorders):
|
|
returned_recorders = dict()
|
|
not_finished = 0
|
|
for key, recorder in recorders.items():
|
|
if recorder.status == "FINISHED":
|
|
returned_recorders[key] = recorder
|
|
else:
|
|
not_finished += 1
|
|
return returned_recorders, not_finished
|
|
|
|
|
|
def add_to_dict(xdict, timestamp, value):
|
|
date = timestamp.date().strftime("%Y-%m-%d")
|
|
if date in xdict:
|
|
raise ValueError("This date [{:}] is already in the dict".format(date))
|
|
xdict[date] = value
|
|
|
|
|
|
def query_info(save_dir, verbose, name_filter, key_map):
|
|
if isinstance(save_dir, list):
|
|
results = []
|
|
for x in save_dir:
|
|
x = query_info(x, verbose, name_filter, key_map)
|
|
results.extend(x)
|
|
return results
|
|
# Here, the save_dir must be a string
|
|
R.set_uri(str(save_dir))
|
|
experiments = R.list_experiments()
|
|
|
|
if verbose:
|
|
print("There are {:} experiments.".format(len(experiments)))
|
|
qresults = []
|
|
for idx, (key, experiment) in enumerate(experiments.items()):
|
|
if experiment.id == "0":
|
|
continue
|
|
if (
|
|
name_filter is not None
|
|
and re.fullmatch(name_filter, experiment.name) is None
|
|
):
|
|
continue
|
|
recorders = experiment.list_recorders()
|
|
recorders, not_finished = filter_finished(recorders)
|
|
if verbose:
|
|
print(
|
|
"====>>>> {:02d}/{:02d}-th experiment {:9s} has {:02d}/{:02d} finished recorders.".format(
|
|
idx + 1,
|
|
len(experiments),
|
|
experiment.name,
|
|
len(recorders),
|
|
len(recorders) + not_finished,
|
|
)
|
|
)
|
|
result = QResult(experiment.name)
|
|
for recorder_id, recorder in recorders.items():
|
|
file_names = ["results-train.pkl", "results-valid.pkl", "results-test.pkl"]
|
|
date2IC = OrderedDict()
|
|
for file_name in file_names:
|
|
xtemp = recorder.load_object(file_name)["all-IC"]
|
|
timestamps, values = xtemp.index.tolist(), xtemp.tolist()
|
|
for timestamp, value in zip(timestamps, values):
|
|
add_to_dict(date2IC, timestamp, value)
|
|
result.update(recorder.list_metrics(), key_map)
|
|
result.append_path(
|
|
os.path.join(recorder.uri, recorder.experiment_id, recorder.id)
|
|
)
|
|
result.append_date2ICs(date2IC)
|
|
if not len(result):
|
|
print("There are no valid recorders for {:}".format(experiment))
|
|
continue
|
|
else:
|
|
if verbose:
|
|
print(
|
|
"There are {:} valid recorders for {:}".format(
|
|
len(recorders), experiment.name
|
|
)
|
|
)
|
|
qresults.append(result)
|
|
return qresults
|
|
|
|
|
|
##
|
|
paths = [root_dir / "outputs" / "qlib-baselines-csi300"]
|
|
paths = [path.resolve() for path in paths]
|
|
print(paths)
|
|
|
|
key_map = dict()
|
|
for xset in ("train", "valid", "test"):
|
|
key_map["{:}-mean-IC".format(xset)] = "IC ({:})".format(xset)
|
|
key_map["{:}-mean-ICIR".format(xset)] = "ICIR ({:})".format(xset)
|
|
qresults = query_info(paths, False, "TSF-2x24-drop0_0s.*-.*-01", key_map)
|
|
print("Find {:} results".format(len(qresults)))
|
|
times = []
|
|
for qresult in qresults:
|
|
times.append(qresult.name.split("0_0s")[-1])
|
|
print(times)
|
|
save_path = os.path.join(note_dir, "temp-time-x.pth")
|
|
torch.save(qresults, save_path)
|
|
print(save_path)
|