xautodl/notebooks/Q/workflow-demo.py

76 lines
2.4 KiB
Python
Raw Normal View History

2021-03-03 04:52:10 +01:00
import os
import sys
2021-03-03 14:57:48 +01:00
import ruamel.yaml as yaml
2021-03-03 04:52:10 +01:00
import pprint
import numpy as np
import pandas as pd
2021-03-03 14:57:48 +01:00
from pathlib import Path
2021-03-03 04:52:10 +01:00
2021-03-03 14:57:48 +01:00
import qlib
2021-03-03 04:52:10 +01:00
from qlib import config as qconfig
from qlib.utils import init_instance_by_config
2021-03-03 14:57:48 +01:00
from qlib.workflow import R
from qlib.data.dataset import DatasetH
from qlib.data.dataset.handler import DataHandlerLP
2021-03-03 04:52:10 +01:00
2021-03-03 14:57:48 +01:00
qlib.init(provider_uri="~/.qlib/qlib_data/cn_data", region=qconfig.REG_CN)
2021-03-03 04:52:10 +01:00
dataset_config = {
2021-03-03 14:57:48 +01:00
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha360",
"module_path": "qlib.contrib.data.handler",
2021-03-03 04:52:10 +01:00
"kwargs": {
2021-03-03 14:57:48 +01:00
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": "csi300",
"infer_processors": [
{"class": "RobustZScoreNorm", "kwargs": {"fields_group": "feature", "clip_outlier": True}},
{"class": "Fillna", "kwargs": {"fields_group": "feature"}},
],
"learn_processors": [
{"class": "DropnaLabel"},
{"class": "CSRankNorm", "kwargs": {"fields_group": "label"}},
],
"label": ["Ref($close, -2) / Ref($close, -1) - 1"]
2021-03-03 04:52:10 +01:00
},
2021-03-03 14:57:48 +01:00
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
},
},
}
2021-03-03 04:52:10 +01:00
if __name__ == "__main__":
2021-03-03 14:57:48 +01:00
qlib_root_dir = (Path(__file__).parent / '..' / '..' / '.latent-data' / 'qlib').resolve()
demo_yaml_path = qlib_root_dir / 'examples' / 'benchmarks' / 'GRU' / 'workflow_config_gru_Alpha360.yaml'
print('Demo-workflow-yaml: {:}'.format(demo_yaml_path))
with open(demo_yaml_path, 'r') as fp:
config = yaml.safe_load(fp)
pprint.pprint(config['task']['dataset'])
dataset = init_instance_by_config(dataset_config)
pprint.pprint(dataset_config)
pprint.pprint(dataset)
df_train, df_valid, df_test = dataset.prepare(
["train", "valid", "test"],
col_set=["feature", "label"],
data_key=DataHandlerLP.DK_L,
)
x_train, y_train = df_train["feature"], df_train["label"]
import pdb
pdb.set_trace()
print("Complete")