add autodl
This commit is contained in:
		
							
								
								
									
										261
									
								
								AutoDL-Projects/exps/trading/baselines.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										261
									
								
								AutoDL-Projects/exps/trading/baselines.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,261 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 # | ||||
| ##################################################### | ||||
| # python exps/trading/baselines.py --alg MLP        # | ||||
| # python exps/trading/baselines.py --alg GRU        # | ||||
| # python exps/trading/baselines.py --alg LSTM       # | ||||
| # python exps/trading/baselines.py --alg ALSTM      # | ||||
| # python exps/trading/baselines.py --alg NAIVE-V1   # | ||||
| # python exps/trading/baselines.py --alg NAIVE-V2   # | ||||
| #                                                   # | ||||
| # python exps/trading/baselines.py --alg SFM        # | ||||
| # python exps/trading/baselines.py --alg XGBoost    # | ||||
| # python exps/trading/baselines.py --alg LightGBM   # | ||||
| # python exps/trading/baselines.py --alg DoubleE    # | ||||
| # python exps/trading/baselines.py --alg TabNet     # | ||||
| #                                                   ############################# | ||||
| # python exps/trading/baselines.py --alg Transformer | ||||
| # python exps/trading/baselines.py --alg TSF | ||||
| # python exps/trading/baselines.py --alg TSF-2x24-drop0_0 --market csi300 | ||||
| # python exps/trading/baselines.py --alg TSF-6x32-drop0_0 --market csi300 | ||||
| ################################################################################# | ||||
| import sys | ||||
| import copy | ||||
| from datetime import datetime | ||||
| import argparse | ||||
| from collections import OrderedDict | ||||
| from pprint import pprint | ||||
| import ruamel.yaml as yaml | ||||
|  | ||||
| from xautodl.config_utils import arg_str2bool | ||||
| from xautodl.procedures.q_exps import update_gpu | ||||
| from xautodl.procedures.q_exps import update_market | ||||
| from xautodl.procedures.q_exps import run_exp | ||||
|  | ||||
| import qlib | ||||
| from qlib.utils import init_instance_by_config | ||||
| from qlib.workflow import R | ||||
| from qlib.utils import flatten_dict | ||||
|  | ||||
|  | ||||
| def to_drop(config, pos_drop, other_drop): | ||||
|     config = copy.deepcopy(config) | ||||
|     net = config["task"]["model"]["kwargs"]["net_config"] | ||||
|     net["pos_drop"] = pos_drop | ||||
|     net["other_drop"] = other_drop | ||||
|     return config | ||||
|  | ||||
|  | ||||
| def to_layer(config, embed_dim, depth): | ||||
|     config = copy.deepcopy(config) | ||||
|     net = config["task"]["model"]["kwargs"]["net_config"] | ||||
|     net["embed_dim"] = embed_dim | ||||
|     net["num_heads"] = [4] * depth | ||||
|     net["mlp_hidden_multipliers"] = [4] * depth | ||||
|     return config | ||||
|  | ||||
|  | ||||
| def extend_transformer_settings(alg2configs, name): | ||||
|     config = copy.deepcopy(alg2configs[name]) | ||||
|     for i in range(1, 9): | ||||
|         for j in (6, 12, 24, 32, 48, 64): | ||||
|             for k1 in (0, 0.05, 0.1, 0.2, 0.3): | ||||
|                 for k2 in (0, 0.1): | ||||
|                     alg2configs[ | ||||
|                         name + "-{:}x{:}-drop{:}_{:}".format(i, j, k1, k2) | ||||
|                     ] = to_layer(to_drop(config, k1, k2), j, i) | ||||
|     return alg2configs | ||||
|  | ||||
|  | ||||
| def replace_start_time(config, start_time): | ||||
|     config = copy.deepcopy(config) | ||||
|     xtime = datetime.strptime(start_time, "%Y-%m-%d") | ||||
|     config["data_handler_config"]["start_time"] = xtime.date() | ||||
|     config["data_handler_config"]["fit_start_time"] = xtime.date() | ||||
|     config["task"]["dataset"]["kwargs"]["segments"]["train"][0] = xtime.date() | ||||
|     return config | ||||
|  | ||||
|  | ||||
| def extend_train_data(alg2configs, name): | ||||
|     config = copy.deepcopy(alg2configs[name]) | ||||
|     start_times = ( | ||||
|         "2008-01-01", | ||||
|         "2008-07-01", | ||||
|         "2009-01-01", | ||||
|         "2009-07-01", | ||||
|         "2010-01-01", | ||||
|         "2011-01-01", | ||||
|         "2012-01-01", | ||||
|         "2013-01-01", | ||||
|     ) | ||||
|     for start_time in start_times: | ||||
|         config = replace_start_time(config, start_time) | ||||
|         alg2configs[name + "s{:}".format(start_time)] = config | ||||
|     return alg2configs | ||||
|  | ||||
|  | ||||
| def refresh_record(alg2configs): | ||||
|     alg2configs = copy.deepcopy(alg2configs) | ||||
|     for key, config in alg2configs.items(): | ||||
|         xlist = config["task"]["record"] | ||||
|         new_list = [] | ||||
|         for x in xlist: | ||||
|             # remove PortAnaRecord and SignalMseRecord | ||||
|             if x["class"] != "PortAnaRecord" and x["class"] != "SignalMseRecord": | ||||
|                 new_list.append(x) | ||||
|         ## add MultiSegRecord | ||||
|         new_list.append( | ||||
|             { | ||||
|                 "class": "MultiSegRecord", | ||||
|                 "module_path": "qlib.contrib.workflow", | ||||
|                 "generate_kwargs": { | ||||
|                     "segments": {"train": "train", "valid": "valid", "test": "test"}, | ||||
|                     "save": True, | ||||
|                 }, | ||||
|             } | ||||
|         ) | ||||
|         config["task"]["record"] = new_list | ||||
|     return alg2configs | ||||
|  | ||||
|  | ||||
| def retrieve_configs(): | ||||
|     # https://github.com/microsoft/qlib/blob/main/examples/benchmarks/ | ||||
|     config_dir = (lib_dir / ".." / "configs" / "qlib").resolve() | ||||
|     # algorithm to file names | ||||
|     alg2names = OrderedDict() | ||||
|     alg2names["GRU"] = "workflow_config_gru_Alpha360.yaml" | ||||
|     alg2names["LSTM"] = "workflow_config_lstm_Alpha360.yaml" | ||||
|     alg2names["MLP"] = "workflow_config_mlp_Alpha360.yaml" | ||||
|     # A dual-stage attention-based recurrent neural network for time series prediction, IJCAI-2017 | ||||
|     alg2names["ALSTM"] = "workflow_config_alstm_Alpha360.yaml" | ||||
|     # XGBoost: A Scalable Tree Boosting System, KDD-2016 | ||||
|     alg2names["XGBoost"] = "workflow_config_xgboost_Alpha360.yaml" | ||||
|     # LightGBM: A Highly Efficient Gradient Boosting Decision Tree, NeurIPS-2017 | ||||
|     alg2names["LightGBM"] = "workflow_config_lightgbm_Alpha360.yaml" | ||||
|     # State Frequency Memory (SFM): Stock Price Prediction via Discovering Multi-Frequency Trading Patterns, KDD-2017 | ||||
|     alg2names["SFM"] = "workflow_config_sfm_Alpha360.yaml" | ||||
|     # DoubleEnsemble: A New Ensemble Method Based on Sample Reweighting and Feature Selection for Financial Data Analysis, https://arxiv.org/pdf/2010.01265.pdf | ||||
|     alg2names["DoubleE"] = "workflow_config_doubleensemble_Alpha360.yaml" | ||||
|     alg2names["TabNet"] = "workflow_config_TabNet_Alpha360.yaml" | ||||
|     alg2names["NAIVE-V1"] = "workflow_config_naive_v1_Alpha360.yaml" | ||||
|     alg2names["NAIVE-V2"] = "workflow_config_naive_v2_Alpha360.yaml" | ||||
|     alg2names["Transformer"] = "workflow_config_transformer_Alpha360.yaml" | ||||
|     alg2names["TSF"] = "workflow_config_transformer_basic_Alpha360.yaml" | ||||
|  | ||||
|     # find the yaml paths | ||||
|     alg2configs = OrderedDict() | ||||
|     print("Start retrieving the algorithm configurations") | ||||
|     for idx, (alg, name) in enumerate(alg2names.items()): | ||||
|         path = config_dir / name | ||||
|         assert path.exists(), "{:} does not exist.".format(path) | ||||
|         with open(path) as fp: | ||||
|             alg2configs[alg] = yaml.safe_load(fp) | ||||
|         print( | ||||
|             "The {:02d}/{:02d}-th baseline algorithm is {:9s} ({:}).".format( | ||||
|                 idx, len(alg2configs), alg, path | ||||
|             ) | ||||
|         ) | ||||
|     alg2configs = extend_transformer_settings(alg2configs, "TSF") | ||||
|     alg2configs = refresh_record(alg2configs) | ||||
|     # extend the algorithms by different train-data | ||||
|     for name in ("TSF-2x24-drop0_0", "TSF-6x32-drop0_0"): | ||||
|         alg2configs = extend_train_data(alg2configs, name) | ||||
|     print( | ||||
|         "There are {:} algorithms : {:}".format( | ||||
|             len(alg2configs), list(alg2configs.keys()) | ||||
|         ) | ||||
|     ) | ||||
|     return alg2configs | ||||
|  | ||||
|  | ||||
| def main(alg_name, market, config, times, save_dir, gpu): | ||||
|  | ||||
|     pprint("Run {:}".format(alg_name)) | ||||
|     config = update_market(config, market) | ||||
|     config = update_gpu(config, gpu) | ||||
|  | ||||
|     qlib.init(**config.get("qlib_init")) | ||||
|     dataset_config = config.get("task").get("dataset") | ||||
|     dataset = init_instance_by_config(dataset_config) | ||||
|     pprint(dataset_config) | ||||
|     pprint(dataset) | ||||
|  | ||||
|     for irun in range(times): | ||||
|         run_exp( | ||||
|             config.get("task"), | ||||
|             dataset, | ||||
|             alg_name, | ||||
|             "recorder-{:02d}-{:02d}".format(irun, times), | ||||
|             "{:}-{:}".format(save_dir, market), | ||||
|         ) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|  | ||||
|     alg2configs = retrieve_configs() | ||||
|  | ||||
|     parser = argparse.ArgumentParser("Baselines") | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="./outputs/qlib-baselines", | ||||
|         help="The checkpoint directory.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--market", | ||||
|         type=str, | ||||
|         default="all", | ||||
|         choices=["csi100", "csi300", "all"], | ||||
|         help="The market indicator.", | ||||
|     ) | ||||
|     parser.add_argument("--times", type=int, default=5, help="The repeated run times.") | ||||
|     parser.add_argument( | ||||
|         "--shared_dataset", | ||||
|         type=arg_str2bool, | ||||
|         default=False, | ||||
|         help="Whether to share the dataset for all algorithms?", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--gpu", type=int, default=0, help="The GPU ID used for train / test." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--alg", | ||||
|         type=str, | ||||
|         choices=list(alg2configs.keys()), | ||||
|         nargs="+", | ||||
|         required=True, | ||||
|         help="The algorithm name(s).", | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     if len(args.alg) == 1: | ||||
|         main( | ||||
|             args.alg[0], | ||||
|             args.market, | ||||
|             alg2configs[args.alg[0]], | ||||
|             args.times, | ||||
|             args.save_dir, | ||||
|             args.gpu, | ||||
|         ) | ||||
|     elif len(args.alg) > 1: | ||||
|         assert args.shared_dataset, "Must allow share dataset" | ||||
|         pprint(args) | ||||
|         configs = [ | ||||
|             update_gpu(update_market(alg2configs[name], args.market), args.gpu) | ||||
|             for name in args.alg | ||||
|         ] | ||||
|         qlib.init(**configs[0].get("qlib_init")) | ||||
|         dataset_config = configs[0].get("task").get("dataset") | ||||
|         dataset = init_instance_by_config(dataset_config) | ||||
|         pprint(dataset_config) | ||||
|         pprint(dataset) | ||||
|         for alg_name, config in zip(args.alg, configs): | ||||
|             print("Run {:} over {:}".format(alg_name, args.alg)) | ||||
|             for irun in range(args.times): | ||||
|                 run_exp( | ||||
|                     config.get("task"), | ||||
|                     dataset, | ||||
|                     alg_name, | ||||
|                     "recorder-{:02d}-{:02d}".format(irun, args.times), | ||||
|                     "{:}-{:}".format(args.save_dir, args.market), | ||||
|                 ) | ||||
							
								
								
									
										175
									
								
								AutoDL-Projects/exps/trading/organize_results.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										175
									
								
								AutoDL-Projects/exps/trading/organize_results.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,175 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 # | ||||
| ##################################################### | ||||
| # python exps/trading/organize_results.py --save_dir outputs/qlib-baselines-all | ||||
| ##################################################### | ||||
| import os, re, sys, argparse | ||||
| import numpy as np | ||||
| from typing import List, Text | ||||
| from collections import defaultdict, OrderedDict | ||||
| from pprint import pprint | ||||
| from pathlib import Path | ||||
| import ruamel.yaml as yaml | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / ".." / "..").resolve() | ||||
| print("LIB-DIR: {:}".format(lib_dir)) | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| from xautodl.config_utils import arg_str2bool | ||||
| from xautodl.utils.qlib_utils import QResult | ||||
|  | ||||
| import qlib | ||||
| from qlib.config import REG_CN | ||||
| from qlib.workflow import R | ||||
|  | ||||
|  | ||||
| def compare_results( | ||||
|     heads, values, names, space=10, separate="& ", verbose=True, sort_key=False | ||||
| ): | ||||
|     for idx, x in enumerate(heads): | ||||
|         assert x == heads[0], "[{:}] \n{:}\nvs\n{:}".format(idx, x, heads[0]) | ||||
|     new_head = QResult.full_str("Name", space) + separate + heads[0] | ||||
|     info_str_dict = dict(head=new_head, lines=[]) | ||||
|     for name, value in zip(names, values): | ||||
|         xline = QResult.full_str(name, space) + separate + value | ||||
|         info_str_dict["lines"].append(xline) | ||||
|     if verbose: | ||||
|         print("\nThere are {:} algorithms.".format(len(values))) | ||||
|         print(info_str_dict["head"]) | ||||
|         if sort_key: | ||||
|             lines = sorted( | ||||
|                 list(zip(values, info_str_dict["lines"])), | ||||
|                 key=lambda x: float(x[0].split(" ")[0]), | ||||
|             ) | ||||
|             lines = [x[1] for x in lines] | ||||
|         else: | ||||
|             lines = info_str_dict["lines"] | ||||
|         for xline in lines: | ||||
|             print(xline + "\\\\") | ||||
|     return info_str_dict | ||||
|  | ||||
|  | ||||
| def filter_finished(recorders): | ||||
|     returned_recorders = dict() | ||||
|     not_finished = 0 | ||||
|     for key, recorder in recorders.items(): | ||||
|         if recorder.status == "FINISHED": | ||||
|             returned_recorders[key] = recorder | ||||
|         else: | ||||
|             not_finished += 1 | ||||
|     return returned_recorders, not_finished | ||||
|  | ||||
|  | ||||
| def query_info(save_dir, verbose, name_filter, key_map): | ||||
|     R.set_uri(save_dir) | ||||
|     experiments = R.list_experiments() | ||||
|  | ||||
|     if verbose: | ||||
|         print("There are {:} experiments.".format(len(experiments))) | ||||
|     qresults = [] | ||||
|     for idx, (key, experiment) in enumerate(experiments.items()): | ||||
|         if experiment.id == "0": | ||||
|             continue | ||||
|         if ( | ||||
|             name_filter is not None | ||||
|             and re.fullmatch(name_filter, experiment.name) is None | ||||
|         ): | ||||
|             continue | ||||
|         recorders = experiment.list_recorders() | ||||
|         recorders, not_finished = filter_finished(recorders) | ||||
|         if verbose: | ||||
|             print( | ||||
|                 "====>>>> {:02d}/{:02d}-th experiment {:9s} has {:02d}/{:02d} finished recorders.".format( | ||||
|                     idx + 1, | ||||
|                     len(experiments), | ||||
|                     experiment.name, | ||||
|                     len(recorders), | ||||
|                     len(recorders) + not_finished, | ||||
|                 ) | ||||
|             ) | ||||
|         result = QResult(experiment.name) | ||||
|         for recorder_id, recorder in recorders.items(): | ||||
|             result.update(recorder.list_metrics(), key_map) | ||||
|             result.append_path( | ||||
|                 os.path.join(recorder.uri, recorder.experiment_id, recorder.id) | ||||
|             ) | ||||
|         if not len(result): | ||||
|             print("There are no valid recorders for {:}".format(experiment)) | ||||
|             continue | ||||
|         else: | ||||
|             print( | ||||
|                 "There are {:} valid recorders for {:}".format( | ||||
|                     len(recorders), experiment.name | ||||
|                 ) | ||||
|             ) | ||||
|         qresults.append(result) | ||||
|     return qresults | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|  | ||||
|     parser = argparse.ArgumentParser("Show Results") | ||||
|  | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         nargs="+", | ||||
|         default=[], | ||||
|         help="The checkpoint directory.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--verbose", | ||||
|         type=arg_str2bool, | ||||
|         default=False, | ||||
|         help="Print detailed log information or not.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--name_filter", type=str, default=".*", help="Filter experiment names." | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     print("Show results of {:}".format(args.save_dir)) | ||||
|     if not args.save_dir: | ||||
|         raise ValueError("Receive no input directory for [args.save_dir]") | ||||
|  | ||||
|     provider_uri = "~/.qlib/qlib_data/cn_data" | ||||
|     qlib.init(provider_uri=provider_uri, region=REG_CN) | ||||
|  | ||||
|     """ | ||||
|     key_map = { | ||||
|         # "RMSE": "RMSE", | ||||
|         "IC": "IC", | ||||
|         "ICIR": "ICIR", | ||||
|         "Rank IC": "Rank_IC", | ||||
|         "Rank ICIR": "Rank_ICIR", | ||||
|         # "excess_return_with_cost.annualized_return": "Annualized_Return", | ||||
|         # "excess_return_with_cost.information_ratio": "Information_Ratio", | ||||
|         # "excess_return_with_cost.max_drawdown": "Max_Drawdown", | ||||
|     } | ||||
|     """ | ||||
|     key_map = dict() | ||||
|     for xset in ("train", "valid", "test"): | ||||
|         key_map["{:}-mean-IC".format(xset)] = "IC ({:})".format(xset) | ||||
|         # key_map["{:}-mean-ICIR".format(xset)] = "ICIR ({:})".format(xset) | ||||
|         key_map["{:}-mean-Rank-IC".format(xset)] = "Rank IC ({:})".format(xset) | ||||
|         # key_map["{:}-mean-Rank-ICIR".format(xset)] = "Rank ICIR ({:})".format(xset) | ||||
|  | ||||
|     all_qresults = [] | ||||
|     for save_dir in args.save_dir: | ||||
|         qresults = query_info(save_dir, args.verbose, args.name_filter, key_map) | ||||
|         all_qresults.extend(qresults) | ||||
|     names, head_strs, value_strs = [], [], [] | ||||
|     for result in all_qresults: | ||||
|         head_str, value_str = result.info(list(key_map.values()), verbose=args.verbose) | ||||
|         head_strs.append(head_str) | ||||
|         value_strs.append(value_str) | ||||
|         names.append(result.name) | ||||
|     compare_results( | ||||
|         head_strs, | ||||
|         value_strs, | ||||
|         names, | ||||
|         space=18, | ||||
|         verbose=True, | ||||
|         sort_key=True, | ||||
|     ) | ||||
							
								
								
									
										206
									
								
								AutoDL-Projects/exps/trading/workflow_tt.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										206
									
								
								AutoDL-Projects/exps/trading/workflow_tt.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,206 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 # | ||||
| ##################################################### | ||||
| # Refer to: | ||||
| # - https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.ipynb | ||||
| # - https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.py | ||||
| # python exps/trading/workflow_tt.py --gpu 1 --market csi300 | ||||
| ##################################################### | ||||
| import yaml | ||||
| import argparse | ||||
|  | ||||
| from xautodl.procedures.q_exps import update_gpu | ||||
| from xautodl.procedures.q_exps import update_market | ||||
| from xautodl.procedures.q_exps import run_exp | ||||
|  | ||||
| import qlib | ||||
| from qlib.config import C | ||||
| from qlib.config import REG_CN | ||||
| from qlib.utils import init_instance_by_config | ||||
| from qlib.workflow import R | ||||
|  | ||||
|  | ||||
| def main(xargs): | ||||
|     dataset_config = { | ||||
|         "class": "DatasetH", | ||||
|         "module_path": "qlib.data.dataset", | ||||
|         "kwargs": { | ||||
|             "handler": { | ||||
|                 "class": "Alpha360", | ||||
|                 "module_path": "qlib.contrib.data.handler", | ||||
|                 "kwargs": { | ||||
|                     "start_time": "2008-01-01", | ||||
|                     "end_time": "2020-08-01", | ||||
|                     "fit_start_time": "2008-01-01", | ||||
|                     "fit_end_time": "2014-12-31", | ||||
|                     "instruments": xargs.market, | ||||
|                     "infer_processors": [ | ||||
|                         { | ||||
|                             "class": "RobustZScoreNorm", | ||||
|                             "kwargs": {"fields_group": "feature", "clip_outlier": True}, | ||||
|                         }, | ||||
|                         {"class": "Fillna", "kwargs": {"fields_group": "feature"}}, | ||||
|                     ], | ||||
|                     "learn_processors": [ | ||||
|                         {"class": "DropnaLabel"}, | ||||
|                         {"class": "CSRankNorm", "kwargs": {"fields_group": "label"}}, | ||||
|                     ], | ||||
|                     "label": ["Ref($close, -2) / Ref($close, -1) - 1"], | ||||
|                 }, | ||||
|             }, | ||||
|             "segments": { | ||||
|                 "train": ("2008-01-01", "2014-12-31"), | ||||
|                 "valid": ("2015-01-01", "2016-12-31"), | ||||
|                 "test": ("2017-01-01", "2020-08-01"), | ||||
|             }, | ||||
|         }, | ||||
|     } | ||||
|  | ||||
|     model_config = { | ||||
|         "class": "QuantTransformer", | ||||
|         "module_path": "xautodl.trade_models.quant_transformer", | ||||
|         "kwargs": { | ||||
|             "net_config": None, | ||||
|             "opt_config": None, | ||||
|             "GPU": "0", | ||||
|             "metric": "loss", | ||||
|         }, | ||||
|     } | ||||
|  | ||||
|     port_analysis_config = { | ||||
|         "strategy": { | ||||
|             "class": "TopkDropoutStrategy", | ||||
|             "module_path": "qlib.contrib.strategy.strategy", | ||||
|             "kwargs": { | ||||
|                 "topk": 50, | ||||
|                 "n_drop": 5, | ||||
|             }, | ||||
|         }, | ||||
|         "backtest": { | ||||
|             "verbose": False, | ||||
|             "limit_threshold": 0.095, | ||||
|             "account": 100000000, | ||||
|             "benchmark": "SH000300", | ||||
|             "deal_price": "close", | ||||
|             "open_cost": 0.0005, | ||||
|             "close_cost": 0.0015, | ||||
|             "min_cost": 5, | ||||
|         }, | ||||
|     } | ||||
|  | ||||
|     record_config = [ | ||||
|         { | ||||
|             "class": "SignalRecord", | ||||
|             "module_path": "qlib.workflow.record_temp", | ||||
|             "kwargs": dict(), | ||||
|         }, | ||||
|         { | ||||
|             "class": "SigAnaRecord", | ||||
|             "module_path": "qlib.workflow.record_temp", | ||||
|             "kwargs": dict(ana_long_short=False, ann_scaler=252), | ||||
|         }, | ||||
|         { | ||||
|             "class": "PortAnaRecord", | ||||
|             "module_path": "qlib.workflow.record_temp", | ||||
|             "kwargs": dict(config=port_analysis_config), | ||||
|         }, | ||||
|     ] | ||||
|  | ||||
|     provider_uri = "~/.qlib/qlib_data/cn_data" | ||||
|     qlib.init(provider_uri=provider_uri, region=REG_CN) | ||||
|  | ||||
|     from qlib.utils import init_instance_by_config | ||||
|  | ||||
|     xconfig = """ | ||||
| model: | ||||
|         class: SFM | ||||
|         module_path: qlib.contrib.model.pytorch_sfm | ||||
|         kwargs: | ||||
|             d_feat: 6 | ||||
|             hidden_size: 64 | ||||
|             output_dim: 32 | ||||
|             freq_dim: 25 | ||||
|             dropout_W: 0.5 | ||||
|             dropout_U: 0.5 | ||||
|             n_epochs: 20 | ||||
|             lr: 1e-3 | ||||
|             batch_size: 1600 | ||||
|             early_stop: 20 | ||||
|             eval_steps: 5 | ||||
|             loss: mse | ||||
|             optimizer: adam | ||||
|             GPU: 0 | ||||
| """ | ||||
|     xconfig = """ | ||||
| model: | ||||
|         class: TabnetModel | ||||
|         module_path: qlib.contrib.model.pytorch_tabnet | ||||
|         kwargs: | ||||
|             d_feat: 360 | ||||
|             pretrain: True | ||||
| """ | ||||
|     xconfig = """ | ||||
| model: | ||||
|         class: GRU | ||||
|         module_path: qlib.contrib.model.pytorch_gru | ||||
|         kwargs: | ||||
|             d_feat: 6 | ||||
|             hidden_size: 64 | ||||
|             num_layers: 4 | ||||
|             dropout: 0.0 | ||||
|             n_epochs: 200 | ||||
|             lr: 0.001 | ||||
|             early_stop: 20 | ||||
|             batch_size: 800 | ||||
|             metric: loss | ||||
|             loss: mse | ||||
|             GPU: 0 | ||||
| """ | ||||
|     xconfig = yaml.safe_load(xconfig) | ||||
|     model = init_instance_by_config(xconfig["model"]) | ||||
|     from xautodl.utils.flop_benchmark import count_parameters_in_MB | ||||
|  | ||||
|     # print(count_parameters_in_MB(model.tabnet_model)) | ||||
|     import pdb | ||||
|  | ||||
|     pdb.set_trace() | ||||
|  | ||||
|     save_dir = "{:}-{:}".format(xargs.save_dir, xargs.market) | ||||
|     dataset = init_instance_by_config(dataset_config) | ||||
|     for irun in range(xargs.times): | ||||
|         xmodel_config = model_config.copy() | ||||
|         xmodel_config = update_gpu(xmodel_config, xargs.gpu) | ||||
|         task_config = dict( | ||||
|             model=xmodel_config, dataset=dataset_config, record=record_config | ||||
|         ) | ||||
|  | ||||
|         run_exp( | ||||
|             task_config, | ||||
|             dataset, | ||||
|             xargs.name, | ||||
|             "recorder-{:02d}-{:02d}".format(irun, xargs.times), | ||||
|             save_dir, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser("Vanilla Transformable Transformer") | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="./outputs/vtt-runs", | ||||
|         help="The checkpoint directory.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--name", type=str, default="Transformer", help="The experiment name." | ||||
|     ) | ||||
|     parser.add_argument("--times", type=int, default=10, help="The repeated run times.") | ||||
|     parser.add_argument( | ||||
|         "--gpu", type=int, default=0, help="The GPU ID used for train / test." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--market", type=str, default="all", help="The market indicator." | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     main(args) | ||||
		Reference in New Issue
	
	Block a user