##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 # ##################################################### # Refer to: # - https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.ipynb # - https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.py # python exps/trading/workflow_tt.py --gpu 1 --market csi300 ##################################################### import sys, argparse from pathlib import Path lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) from procedures.q_exps import update_gpu from procedures.q_exps import update_market from procedures.q_exps import run_exp import qlib from qlib.config import C from qlib.config import REG_CN from qlib.utils import init_instance_by_config from qlib.workflow import R def main(xargs): dataset_config = { "class": "DatasetH", "module_path": "qlib.data.dataset", "kwargs": { "handler": { "class": "Alpha360", "module_path": "qlib.contrib.data.handler", "kwargs": { "start_time": "2008-01-01", "end_time": "2020-08-01", "fit_start_time": "2008-01-01", "fit_end_time": "2014-12-31", "instruments": xargs.market, "infer_processors": [ {"class": "RobustZScoreNorm", "kwargs": {"fields_group": "feature", "clip_outlier": True}}, {"class": "Fillna", "kwargs": {"fields_group": "feature"}}, ], "learn_processors": [ {"class": "DropnaLabel"}, {"class": "CSRankNorm", "kwargs": {"fields_group": "label"}}, ], "label": ["Ref($close, -2) / Ref($close, -1) - 1"], }, }, "segments": { "train": ("2008-01-01", "2014-12-31"), "valid": ("2015-01-01", "2016-12-31"), "test": ("2017-01-01", "2020-08-01"), }, }, } model_config = { "class": "QuantTransformer", "module_path": "trade_models", "kwargs": { "net_config": None, "opt_config": None, "GPU": "0", "metric": "loss", }, } port_analysis_config = { "strategy": { "class": "TopkDropoutStrategy", "module_path": "qlib.contrib.strategy.strategy", "kwargs": { "topk": 50, "n_drop": 5, }, }, "backtest": { "verbose": False, "limit_threshold": 0.095, "account": 100000000, "benchmark": "SH000300", "deal_price": "close", "open_cost": 0.0005, "close_cost": 0.0015, "min_cost": 5, }, } record_config = [ {"class": "SignalRecord", "module_path": "qlib.workflow.record_temp", "kwargs": dict()}, { "class": "SigAnaRecord", "module_path": "qlib.workflow.record_temp", "kwargs": dict(ana_long_short=False, ann_scaler=252), }, { "class": "PortAnaRecord", "module_path": "qlib.workflow.record_temp", "kwargs": dict(config=port_analysis_config), }, ] provider_uri = "~/.qlib/qlib_data/cn_data" qlib.init(provider_uri=provider_uri, region=REG_CN) save_dir = "{:}-{:}".format(xargs.save_dir, xargs.market) dataset = init_instance_by_config(dataset_config) for irun in range(xargs.times): xmodel_config = model_config.copy() xmodel_config = update_gpu(xmodel_config, xargs.gpu) task_config = dict(model=xmodel_config, dataset=dataset_config, record=record_config) run_exp(task_config, dataset, xargs.name, "recorder-{:02d}-{:02d}".format(irun, xargs.times), save_dir) if __name__ == "__main__": parser = argparse.ArgumentParser("Vanilla Transformable Transformer") parser.add_argument("--save_dir", type=str, default="./outputs/vtt-runs", help="The checkpoint directory.") parser.add_argument("--name", type=str, default="Transformer", help="The experiment name.") parser.add_argument("--times", type=int, default=10, help="The repeated run times.") parser.add_argument("--gpu", type=int, default=0, help="The GPU ID used for train / test.") parser.add_argument("--market", type=str, default="all", help="The market indicator.") args = parser.parse_args() main(args)