diff --git a/exps/trading/workflow_tt.py b/exps/trading/workflow_tt.py index f49ef78..5249ed9 100644 --- a/exps/trading/workflow_tt.py +++ b/exps/trading/workflow_tt.py @@ -1,12 +1,12 @@ ##################################################### -# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 # ##################################################### # Refer to: # - https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.ipynb # - https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.py # python exps/trading/workflow_tt.py ##################################################### -import sys, site, argparse +import sys, argparse from pathlib import Path lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() @@ -15,19 +15,11 @@ if str(lib_dir) not in sys.path: import qlib from qlib.config import C -import pandas as pd from qlib.config import REG_CN -from qlib.contrib.model.gbdt import LGBModel -from qlib.contrib.data.handler import Alpha158 -from qlib.contrib.strategy.strategy import TopkDropoutStrategy -from qlib.contrib.evaluate import ( - backtest as normal_backtest, - risk_analysis, -) -from qlib.utils import exists_qlib_data, init_instance_by_config +from qlib.utils import init_instance_by_config from qlib.workflow import R -from qlib.workflow.record_temp import SignalRecord, PortAnaRecord from qlib.utils import flatten_dict +from qlib.log import set_log_basic_config def main(xargs): @@ -73,13 +65,51 @@ def main(xargs): }, } - task = {"model": model_config, "dataset": dataset_config} + port_analysis_config = { + "strategy": { + "class": "TopkDropoutStrategy", + "module_path": "qlib.contrib.strategy.strategy", + "kwargs": { + "topk": 50, + "n_drop": 5, + }, + }, + "backtest": { + "verbose": False, + "limit_threshold": 0.095, + "account": 100000000, + "benchmark": "SH000300", + "deal_price": "close", + "open_cost": 0.0005, + "close_cost": 0.0015, + "min_cost": 5, + }, + } + + record_config = [ + {"class": "SignalRecord", "module_path": "qlib.workflow.record_temp", "kwargs": dict()}, + { + "class": "SigAnaRecord", + "module_path": "qlib.workflow.record_temp", + "kwargs": dict(ana_long_short=False, ann_scaler=252), + }, + { + "class": "PortAnaRecord", + "module_path": "qlib.workflow.record_temp", + "kwargs": dict(config=port_analysis_config), + }, + ] + + task = dict(model=model_config, dataset=dataset_config, record=record_config) - model = init_instance_by_config(model_config) - dataset = init_instance_by_config(dataset_config) # start exp to train model with R.start(experiment_name="train_tt_model"): + set_log_basic_config(R.get_recorder().root_uri / 'log.log') + + model = init_instance_by_config(model_config) + dataset = init_instance_by_config(dataset_config) + R.log_params(**flatten_dict(task)) model.fit(dataset) R.save_objects(trained_model=model) @@ -87,14 +117,19 @@ def main(xargs): # prediction recorder = R.get_recorder() print(recorder) - sr = SignalRecord(model, dataset, recorder) - sr.generate() - # backtest. If users want to use backtest based on their own prediction, - # please refer to https://qlib.readthedocs.io/en/latest/component/recorder.html#record-template. - par = PortAnaRecord(recorder, port_analysis_config) - par.generate() - + for record in task["record"]: + record = record.copy() + if record["class"] == "SignalRecord": + srconf = {"model": model, "dataset": dataset, "recorder": recorder} + record["kwargs"].update(srconf) + sr = init_instance_by_config(record) + sr.generate() + else: + rconf = {"recorder": recorder} + record["kwargs"].update(rconf) + ar = init_instance_by_config(record) + ar.generate() if __name__ == "__main__": diff --git a/lib/trade_models/quant_transformer.py b/lib/trade_models/quant_transformer.py index eda6c97..dae4ad7 100755 --- a/lib/trade_models/quant_transformer.py +++ b/lib/trade_models/quant_transformer.py @@ -41,8 +41,8 @@ class QuantTransformer(Model): def __init__( self, d_feat=6, - hidden_size=64, - num_layers=2, + hidden_size=48, + depth=5, dropout=0.0, n_epochs=200, lr=0.001, @@ -62,7 +62,7 @@ class QuantTransformer(Model): # set hyper-parameters. self.d_feat = d_feat self.hidden_size = hidden_size - self.num_layers = num_layers + self.depth = depth self.dropout = dropout self.n_epochs = n_epochs self.lr = lr @@ -79,7 +79,7 @@ class QuantTransformer(Model): "Transformer parameters setting:" "\nd_feat : {}" "\nhidden_size : {}" - "\nnum_layers : {}" + "\ndepth : {}" "\ndropout : {}" "\nn_epochs : {}" "\nlr : {}" @@ -93,7 +93,7 @@ class QuantTransformer(Model): "\nseed : {}".format( d_feat, hidden_size, - num_layers, + depth, dropout, n_epochs, lr, @@ -112,7 +112,9 @@ class QuantTransformer(Model): np.random.seed(self.seed) torch.manual_seed(self.seed) - self.model = TransformerModel(d_feat=self.d_feat) + self.model = TransformerModel(d_feat=self.d_feat, + embed_dim=self.hidden_size, + depth=self.depth) self.logger.info('model: {:}'.format(self.model)) self.logger.info('model size: {:.3f} MB'.format(count_parameters_in_MB(self.model)))