Update Q workflow

This commit is contained in:
D-X-Y 2021-03-04 13:55:48 +00:00
parent e329b78cf4
commit 192c25eb42
2 changed files with 65 additions and 28 deletions

View File

@ -1,12 +1,12 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 #
#####################################################
# Refer to:
# - https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.ipynb
# - https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.py
# python exps/trading/workflow_tt.py
#####################################################
import sys, site, argparse
import sys, argparse
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
@ -15,19 +15,11 @@ if str(lib_dir) not in sys.path:
import qlib
from qlib.config import C
import pandas as pd
from qlib.config import REG_CN
from qlib.contrib.model.gbdt import LGBModel
from qlib.contrib.data.handler import Alpha158
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
from qlib.contrib.evaluate import (
backtest as normal_backtest,
risk_analysis,
)
from qlib.utils import exists_qlib_data, init_instance_by_config
from qlib.utils import init_instance_by_config
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
from qlib.utils import flatten_dict
from qlib.log import set_log_basic_config
def main(xargs):
@ -73,13 +65,51 @@ def main(xargs):
},
}
task = {"model": model_config, "dataset": dataset_config}
port_analysis_config = {
"strategy": {
"class": "TopkDropoutStrategy",
"module_path": "qlib.contrib.strategy.strategy",
"kwargs": {
"topk": 50,
"n_drop": 5,
},
},
"backtest": {
"verbose": False,
"limit_threshold": 0.095,
"account": 100000000,
"benchmark": "SH000300",
"deal_price": "close",
"open_cost": 0.0005,
"close_cost": 0.0015,
"min_cost": 5,
},
}
record_config = [
{"class": "SignalRecord", "module_path": "qlib.workflow.record_temp", "kwargs": dict()},
{
"class": "SigAnaRecord",
"module_path": "qlib.workflow.record_temp",
"kwargs": dict(ana_long_short=False, ann_scaler=252),
},
{
"class": "PortAnaRecord",
"module_path": "qlib.workflow.record_temp",
"kwargs": dict(config=port_analysis_config),
},
]
task = dict(model=model_config, dataset=dataset_config, record=record_config)
model = init_instance_by_config(model_config)
dataset = init_instance_by_config(dataset_config)
# start exp to train model
with R.start(experiment_name="train_tt_model"):
set_log_basic_config(R.get_recorder().root_uri / 'log.log')
model = init_instance_by_config(model_config)
dataset = init_instance_by_config(dataset_config)
R.log_params(**flatten_dict(task))
model.fit(dataset)
R.save_objects(trained_model=model)
@ -87,14 +117,19 @@ def main(xargs):
# prediction
recorder = R.get_recorder()
print(recorder)
sr = SignalRecord(model, dataset, recorder)
sr.generate()
# backtest. If users want to use backtest based on their own prediction,
# please refer to https://qlib.readthedocs.io/en/latest/component/recorder.html#record-template.
par = PortAnaRecord(recorder, port_analysis_config)
par.generate()
for record in task["record"]:
record = record.copy()
if record["class"] == "SignalRecord":
srconf = {"model": model, "dataset": dataset, "recorder": recorder}
record["kwargs"].update(srconf)
sr = init_instance_by_config(record)
sr.generate()
else:
rconf = {"recorder": recorder}
record["kwargs"].update(rconf)
ar = init_instance_by_config(record)
ar.generate()
if __name__ == "__main__":

View File

@ -41,8 +41,8 @@ class QuantTransformer(Model):
def __init__(
self,
d_feat=6,
hidden_size=64,
num_layers=2,
hidden_size=48,
depth=5,
dropout=0.0,
n_epochs=200,
lr=0.001,
@ -62,7 +62,7 @@ class QuantTransformer(Model):
# set hyper-parameters.
self.d_feat = d_feat
self.hidden_size = hidden_size
self.num_layers = num_layers
self.depth = depth
self.dropout = dropout
self.n_epochs = n_epochs
self.lr = lr
@ -79,7 +79,7 @@ class QuantTransformer(Model):
"Transformer parameters setting:"
"\nd_feat : {}"
"\nhidden_size : {}"
"\nnum_layers : {}"
"\ndepth : {}"
"\ndropout : {}"
"\nn_epochs : {}"
"\nlr : {}"
@ -93,7 +93,7 @@ class QuantTransformer(Model):
"\nseed : {}".format(
d_feat,
hidden_size,
num_layers,
depth,
dropout,
n_epochs,
lr,
@ -112,7 +112,9 @@ class QuantTransformer(Model):
np.random.seed(self.seed)
torch.manual_seed(self.seed)
self.model = TransformerModel(d_feat=self.d_feat)
self.model = TransformerModel(d_feat=self.d_feat,
embed_dim=self.hidden_size,
depth=self.depth)
self.logger.info('model: {:}'.format(self.model))
self.logger.info('model size: {:.3f} MB'.format(count_parameters_in_MB(self.model)))