| 
									
										
										
										
											2021-03-03 13:57:48 +00:00
										 |  |  | ##################################################### | 
					
						
							| 
									
										
										
										
											2021-03-04 13:55:48 +00:00
										 |  |  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 # | 
					
						
							| 
									
										
										
										
											2021-03-03 13:57:48 +00:00
										 |  |  | ##################################################### | 
					
						
							|  |  |  | # Refer to: | 
					
						
							|  |  |  | # - https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.ipynb | 
					
						
							|  |  |  | # - https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.py | 
					
						
							| 
									
										
										
										
											2021-03-06 21:35:26 -08:00
										 |  |  | # python exps/trading/workflow_tt.py --gpu 1 --market csi300 | 
					
						
							| 
									
										
										
										
											2021-03-03 13:57:48 +00:00
										 |  |  | ##################################################### | 
					
						
							| 
									
										
										
										
											2021-03-04 13:55:48 +00:00
										 |  |  | import sys, argparse | 
					
						
							| 
									
										
										
										
											2021-03-03 13:57:48 +00:00
										 |  |  | from pathlib import Path | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | 
					
						
							|  |  |  | if str(lib_dir) not in sys.path: | 
					
						
							|  |  |  |     sys.path.insert(0, str(lib_dir)) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-06 06:38:34 -08:00
										 |  |  | from procedures.q_exps import update_gpu | 
					
						
							|  |  |  | from procedures.q_exps import update_market | 
					
						
							|  |  |  | from procedures.q_exps import run_exp | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-03 13:57:48 +00:00
										 |  |  | import qlib | 
					
						
							|  |  |  | from qlib.config import C | 
					
						
							|  |  |  | from qlib.config import REG_CN | 
					
						
							| 
									
										
										
										
											2021-03-04 13:55:48 +00:00
										 |  |  | from qlib.utils import init_instance_by_config | 
					
						
							| 
									
										
										
										
											2021-03-03 13:57:48 +00:00
										 |  |  | from qlib.workflow import R | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def main(xargs): | 
					
						
							|  |  |  |     dataset_config = { | 
					
						
							|  |  |  |         "class": "DatasetH", | 
					
						
							|  |  |  |         "module_path": "qlib.data.dataset", | 
					
						
							|  |  |  |         "kwargs": { | 
					
						
							|  |  |  |             "handler": { | 
					
						
							|  |  |  |                 "class": "Alpha360", | 
					
						
							|  |  |  |                 "module_path": "qlib.contrib.data.handler", | 
					
						
							|  |  |  |                 "kwargs": { | 
					
						
							|  |  |  |                     "start_time": "2008-01-01", | 
					
						
							|  |  |  |                     "end_time": "2020-08-01", | 
					
						
							|  |  |  |                     "fit_start_time": "2008-01-01", | 
					
						
							|  |  |  |                     "fit_end_time": "2014-12-31", | 
					
						
							|  |  |  |                     "instruments": xargs.market, | 
					
						
							|  |  |  |                     "infer_processors": [ | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |                         { | 
					
						
							|  |  |  |                             "class": "RobustZScoreNorm", | 
					
						
							|  |  |  |                             "kwargs": {"fields_group": "feature", "clip_outlier": True}, | 
					
						
							|  |  |  |                         }, | 
					
						
							| 
									
										
										
										
											2021-03-03 13:57:48 +00:00
										 |  |  |                         {"class": "Fillna", "kwargs": {"fields_group": "feature"}}, | 
					
						
							|  |  |  |                     ], | 
					
						
							|  |  |  |                     "learn_processors": [ | 
					
						
							|  |  |  |                         {"class": "DropnaLabel"}, | 
					
						
							|  |  |  |                         {"class": "CSRankNorm", "kwargs": {"fields_group": "label"}}, | 
					
						
							|  |  |  |                     ], | 
					
						
							|  |  |  |                     "label": ["Ref($close, -2) / Ref($close, -1) - 1"], | 
					
						
							|  |  |  |                 }, | 
					
						
							|  |  |  |             }, | 
					
						
							|  |  |  |             "segments": { | 
					
						
							|  |  |  |                 "train": ("2008-01-01", "2014-12-31"), | 
					
						
							|  |  |  |                 "valid": ("2015-01-01", "2016-12-31"), | 
					
						
							|  |  |  |                 "test": ("2017-01-01", "2020-08-01"), | 
					
						
							|  |  |  |             }, | 
					
						
							|  |  |  |         }, | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     model_config = { | 
					
						
							|  |  |  |         "class": "QuantTransformer", | 
					
						
							|  |  |  |         "module_path": "trade_models", | 
					
						
							|  |  |  |         "kwargs": { | 
					
						
							| 
									
										
										
										
											2021-03-06 21:35:26 -08:00
										 |  |  |             "net_config": None, | 
					
						
							|  |  |  |             "opt_config": None, | 
					
						
							| 
									
										
										
										
											2021-03-03 13:57:48 +00:00
										 |  |  |             "GPU": "0", | 
					
						
							|  |  |  |             "metric": "loss", | 
					
						
							|  |  |  |         }, | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-04 13:55:48 +00:00
										 |  |  |     port_analysis_config = { | 
					
						
							|  |  |  |         "strategy": { | 
					
						
							|  |  |  |             "class": "TopkDropoutStrategy", | 
					
						
							|  |  |  |             "module_path": "qlib.contrib.strategy.strategy", | 
					
						
							|  |  |  |             "kwargs": { | 
					
						
							|  |  |  |                 "topk": 50, | 
					
						
							|  |  |  |                 "n_drop": 5, | 
					
						
							|  |  |  |             }, | 
					
						
							|  |  |  |         }, | 
					
						
							|  |  |  |         "backtest": { | 
					
						
							|  |  |  |             "verbose": False, | 
					
						
							|  |  |  |             "limit_threshold": 0.095, | 
					
						
							|  |  |  |             "account": 100000000, | 
					
						
							|  |  |  |             "benchmark": "SH000300", | 
					
						
							|  |  |  |             "deal_price": "close", | 
					
						
							|  |  |  |             "open_cost": 0.0005, | 
					
						
							|  |  |  |             "close_cost": 0.0015, | 
					
						
							|  |  |  |             "min_cost": 5, | 
					
						
							|  |  |  |         }, | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     record_config = [ | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |         { | 
					
						
							|  |  |  |             "class": "SignalRecord", | 
					
						
							|  |  |  |             "module_path": "qlib.workflow.record_temp", | 
					
						
							|  |  |  |             "kwargs": dict(), | 
					
						
							|  |  |  |         }, | 
					
						
							| 
									
										
										
										
											2021-03-04 13:55:48 +00:00
										 |  |  |         { | 
					
						
							|  |  |  |             "class": "SigAnaRecord", | 
					
						
							|  |  |  |             "module_path": "qlib.workflow.record_temp", | 
					
						
							|  |  |  |             "kwargs": dict(ana_long_short=False, ann_scaler=252), | 
					
						
							|  |  |  |         }, | 
					
						
							|  |  |  |         { | 
					
						
							|  |  |  |             "class": "PortAnaRecord", | 
					
						
							|  |  |  |             "module_path": "qlib.workflow.record_temp", | 
					
						
							|  |  |  |             "kwargs": dict(config=port_analysis_config), | 
					
						
							|  |  |  |         }, | 
					
						
							|  |  |  |     ] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-06 06:38:34 -08:00
										 |  |  |     provider_uri = "~/.qlib/qlib_data/cn_data" | 
					
						
							|  |  |  |     qlib.init(provider_uri=provider_uri, region=REG_CN) | 
					
						
							| 
									
										
										
										
											2021-03-04 13:55:48 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-06 21:35:26 -08:00
										 |  |  |     save_dir = "{:}-{:}".format(xargs.save_dir, xargs.market) | 
					
						
							| 
									
										
										
										
											2021-03-06 06:38:34 -08:00
										 |  |  |     dataset = init_instance_by_config(dataset_config) | 
					
						
							|  |  |  |     for irun in range(xargs.times): | 
					
						
							|  |  |  |         xmodel_config = model_config.copy() | 
					
						
							| 
									
										
										
										
											2021-03-06 21:35:26 -08:00
										 |  |  |         xmodel_config = update_gpu(xmodel_config, xargs.gpu) | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |         task_config = dict( | 
					
						
							|  |  |  |             model=xmodel_config, dataset=dataset_config, record=record_config | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-03-06 21:35:26 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |         run_exp( | 
					
						
							|  |  |  |             task_config, | 
					
						
							|  |  |  |             dataset, | 
					
						
							|  |  |  |             xargs.name, | 
					
						
							|  |  |  |             "recorder-{:02d}-{:02d}".format(irun, xargs.times), | 
					
						
							|  |  |  |             save_dir, | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-03-03 13:57:48 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | if __name__ == "__main__": | 
					
						
							|  |  |  |     parser = argparse.ArgumentParser("Vanilla Transformable Transformer") | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--save_dir", | 
					
						
							|  |  |  |         type=str, | 
					
						
							|  |  |  |         default="./outputs/vtt-runs", | 
					
						
							|  |  |  |         help="The checkpoint directory.", | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--name", type=str, default="Transformer", help="The experiment name." | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-03-06 06:38:34 -08:00
										 |  |  |     parser.add_argument("--times", type=int, default=10, help="The repeated run times.") | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--gpu", type=int, default=0, help="The GPU ID used for train / test." | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--market", type=str, default="all", help="The market indicator." | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-03-03 13:57:48 +00:00
										 |  |  |     args = parser.parse_args() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     main(args) |