Update Q workflow
This commit is contained in:
parent
e329b78cf4
commit
192c25eb42
@ -1,12 +1,12 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 #
|
||||
#####################################################
|
||||
# Refer to:
|
||||
# - https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.ipynb
|
||||
# - https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.py
|
||||
# python exps/trading/workflow_tt.py
|
||||
#####################################################
|
||||
import sys, site, argparse
|
||||
import sys, argparse
|
||||
from pathlib import Path
|
||||
|
||||
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
|
||||
@ -15,19 +15,11 @@ if str(lib_dir) not in sys.path:
|
||||
|
||||
import qlib
|
||||
from qlib.config import C
|
||||
import pandas as pd
|
||||
from qlib.config import REG_CN
|
||||
from qlib.contrib.model.gbdt import LGBModel
|
||||
from qlib.contrib.data.handler import Alpha158
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
from qlib.utils import exists_qlib_data, init_instance_by_config
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
|
||||
from qlib.utils import flatten_dict
|
||||
from qlib.log import set_log_basic_config
|
||||
|
||||
|
||||
def main(xargs):
|
||||
@ -73,13 +65,51 @@ def main(xargs):
|
||||
},
|
||||
}
|
||||
|
||||
task = {"model": model_config, "dataset": dataset_config}
|
||||
port_analysis_config = {
|
||||
"strategy": {
|
||||
"class": "TopkDropoutStrategy",
|
||||
"module_path": "qlib.contrib.strategy.strategy",
|
||||
"kwargs": {
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
},
|
||||
},
|
||||
"backtest": {
|
||||
"verbose": False,
|
||||
"limit_threshold": 0.095,
|
||||
"account": 100000000,
|
||||
"benchmark": "SH000300",
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
},
|
||||
}
|
||||
|
||||
record_config = [
|
||||
{"class": "SignalRecord", "module_path": "qlib.workflow.record_temp", "kwargs": dict()},
|
||||
{
|
||||
"class": "SigAnaRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
"kwargs": dict(ana_long_short=False, ann_scaler=252),
|
||||
},
|
||||
{
|
||||
"class": "PortAnaRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
"kwargs": dict(config=port_analysis_config),
|
||||
},
|
||||
]
|
||||
|
||||
task = dict(model=model_config, dataset=dataset_config, record=record_config)
|
||||
|
||||
model = init_instance_by_config(model_config)
|
||||
dataset = init_instance_by_config(dataset_config)
|
||||
|
||||
# start exp to train model
|
||||
with R.start(experiment_name="train_tt_model"):
|
||||
set_log_basic_config(R.get_recorder().root_uri / 'log.log')
|
||||
|
||||
model = init_instance_by_config(model_config)
|
||||
dataset = init_instance_by_config(dataset_config)
|
||||
|
||||
R.log_params(**flatten_dict(task))
|
||||
model.fit(dataset)
|
||||
R.save_objects(trained_model=model)
|
||||
@ -87,14 +117,19 @@ def main(xargs):
|
||||
# prediction
|
||||
recorder = R.get_recorder()
|
||||
print(recorder)
|
||||
sr = SignalRecord(model, dataset, recorder)
|
||||
sr.generate()
|
||||
|
||||
# backtest. If users want to use backtest based on their own prediction,
|
||||
# please refer to https://qlib.readthedocs.io/en/latest/component/recorder.html#record-template.
|
||||
par = PortAnaRecord(recorder, port_analysis_config)
|
||||
par.generate()
|
||||
|
||||
for record in task["record"]:
|
||||
record = record.copy()
|
||||
if record["class"] == "SignalRecord":
|
||||
srconf = {"model": model, "dataset": dataset, "recorder": recorder}
|
||||
record["kwargs"].update(srconf)
|
||||
sr = init_instance_by_config(record)
|
||||
sr.generate()
|
||||
else:
|
||||
rconf = {"recorder": recorder}
|
||||
record["kwargs"].update(rconf)
|
||||
ar = init_instance_by_config(record)
|
||||
ar.generate()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -41,8 +41,8 @@ class QuantTransformer(Model):
|
||||
def __init__(
|
||||
self,
|
||||
d_feat=6,
|
||||
hidden_size=64,
|
||||
num_layers=2,
|
||||
hidden_size=48,
|
||||
depth=5,
|
||||
dropout=0.0,
|
||||
n_epochs=200,
|
||||
lr=0.001,
|
||||
@ -62,7 +62,7 @@ class QuantTransformer(Model):
|
||||
# set hyper-parameters.
|
||||
self.d_feat = d_feat
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.depth = depth
|
||||
self.dropout = dropout
|
||||
self.n_epochs = n_epochs
|
||||
self.lr = lr
|
||||
@ -79,7 +79,7 @@ class QuantTransformer(Model):
|
||||
"Transformer parameters setting:"
|
||||
"\nd_feat : {}"
|
||||
"\nhidden_size : {}"
|
||||
"\nnum_layers : {}"
|
||||
"\ndepth : {}"
|
||||
"\ndropout : {}"
|
||||
"\nn_epochs : {}"
|
||||
"\nlr : {}"
|
||||
@ -93,7 +93,7 @@ class QuantTransformer(Model):
|
||||
"\nseed : {}".format(
|
||||
d_feat,
|
||||
hidden_size,
|
||||
num_layers,
|
||||
depth,
|
||||
dropout,
|
||||
n_epochs,
|
||||
lr,
|
||||
@ -112,7 +112,9 @@ class QuantTransformer(Model):
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
self.model = TransformerModel(d_feat=self.d_feat)
|
||||
self.model = TransformerModel(d_feat=self.d_feat,
|
||||
embed_dim=self.hidden_size,
|
||||
depth=self.depth)
|
||||
self.logger.info('model: {:}'.format(self.model))
|
||||
self.logger.info('model size: {:.3f} MB'.format(count_parameters_in_MB(self.model)))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user