##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 # ##################################################### # Refer to: # - https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.ipynb # - https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.py # python exps/trading/workflow_tt.py --gpu 1 --market csi300 ##################################################### import yaml import argparse from xautodl.procedures.q_exps import update_gpu from xautodl.procedures.q_exps import update_market from xautodl.procedures.q_exps import run_exp import qlib from qlib.config import C from qlib.config import REG_CN from qlib.utils import init_instance_by_config from qlib.workflow import R def main(xargs): dataset_config = { "class": "DatasetH", "module_path": "qlib.data.dataset", "kwargs": { "handler": { "class": "Alpha360", "module_path": "qlib.contrib.data.handler", "kwargs": { "start_time": "2008-01-01", "end_time": "2020-08-01", "fit_start_time": "2008-01-01", "fit_end_time": "2014-12-31", "instruments": xargs.market, "infer_processors": [ { "class": "RobustZScoreNorm", "kwargs": {"fields_group": "feature", "clip_outlier": True}, }, {"class": "Fillna", "kwargs": {"fields_group": "feature"}}, ], "learn_processors": [ {"class": "DropnaLabel"}, {"class": "CSRankNorm", "kwargs": {"fields_group": "label"}}, ], "label": ["Ref($close, -2) / Ref($close, -1) - 1"], }, }, "segments": { "train": ("2008-01-01", "2014-12-31"), "valid": ("2015-01-01", "2016-12-31"), "test": ("2017-01-01", "2020-08-01"), }, }, } model_config = { "class": "QuantTransformer", "module_path": "xautodl.trade_models.quant_transformer", "kwargs": { "net_config": None, "opt_config": None, "GPU": "0", "metric": "loss", }, } port_analysis_config = { "strategy": { "class": "TopkDropoutStrategy", "module_path": "qlib.contrib.strategy.strategy", "kwargs": { "topk": 50, "n_drop": 5, }, }, "backtest": { "verbose": False, "limit_threshold": 0.095, "account": 100000000, "benchmark": "SH000300", "deal_price": "close", "open_cost": 0.0005, "close_cost": 0.0015, "min_cost": 5, }, } record_config = [ { "class": "SignalRecord", "module_path": "qlib.workflow.record_temp", "kwargs": dict(), }, { "class": "SigAnaRecord", "module_path": "qlib.workflow.record_temp", "kwargs": dict(ana_long_short=False, ann_scaler=252), }, { "class": "PortAnaRecord", "module_path": "qlib.workflow.record_temp", "kwargs": dict(config=port_analysis_config), }, ] provider_uri = "~/.qlib/qlib_data/cn_data" qlib.init(provider_uri=provider_uri, region=REG_CN) from qlib.utils import init_instance_by_config xconfig = """ model: class: SFM module_path: qlib.contrib.model.pytorch_sfm kwargs: d_feat: 6 hidden_size: 64 output_dim: 32 freq_dim: 25 dropout_W: 0.5 dropout_U: 0.5 n_epochs: 20 lr: 1e-3 batch_size: 1600 early_stop: 20 eval_steps: 5 loss: mse optimizer: adam GPU: 0 """ xconfig = """ model: class: TabnetModel module_path: qlib.contrib.model.pytorch_tabnet kwargs: d_feat: 360 pretrain: True """ xconfig = """ model: class: GRU module_path: qlib.contrib.model.pytorch_gru kwargs: d_feat: 6 hidden_size: 64 num_layers: 4 dropout: 0.0 n_epochs: 200 lr: 0.001 early_stop: 20 batch_size: 800 metric: loss loss: mse GPU: 0 """ xconfig = yaml.safe_load(xconfig) model = init_instance_by_config(xconfig["model"]) from xautodl.utils.flop_benchmark import count_parameters_in_MB # print(count_parameters_in_MB(model.tabnet_model)) import pdb pdb.set_trace() save_dir = "{:}-{:}".format(xargs.save_dir, xargs.market) dataset = init_instance_by_config(dataset_config) for irun in range(xargs.times): xmodel_config = model_config.copy() xmodel_config = update_gpu(xmodel_config, xargs.gpu) task_config = dict( model=xmodel_config, dataset=dataset_config, record=record_config ) run_exp( task_config, dataset, xargs.name, "recorder-{:02d}-{:02d}".format(irun, xargs.times), save_dir, ) if __name__ == "__main__": parser = argparse.ArgumentParser("Vanilla Transformable Transformer") parser.add_argument( "--save_dir", type=str, default="./outputs/vtt-runs", help="The checkpoint directory.", ) parser.add_argument( "--name", type=str, default="Transformer", help="The experiment name." ) parser.add_argument("--times", type=int, default=10, help="The repeated run times.") parser.add_argument( "--gpu", type=int, default=0, help="The GPU ID used for train / test." ) parser.add_argument( "--market", type=str, default="all", help="The market indicator." ) args = parser.parse_args() main(args)