Update to accomendate last updates of qlib
This commit is contained in:
		 Submodule .latent-data/qlib updated: d13c9ae018...0ef7c8e0e6
									
								
							
							
								
								
									
										74
									
								
								configs/qlib/workflow_config_TabNet_Alpha360.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										74
									
								
								configs/qlib/workflow_config_TabNet_Alpha360.yaml
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,74 @@ | ||||
| qlib_init: | ||||
|     provider_uri: "~/.qlib/qlib_data/cn_data" | ||||
|     region: cn | ||||
| market: &market all | ||||
| benchmark: &benchmark SH000300 | ||||
| data_handler_config: &data_handler_config | ||||
|     start_time: 2008-01-01 | ||||
|     end_time: 2020-08-01 | ||||
|     fit_start_time: 2008-01-01 | ||||
|     fit_end_time: 2014-12-31 | ||||
|     instruments: *market | ||||
|     infer_processors: | ||||
|         - class: RobustZScoreNorm | ||||
|           kwargs: | ||||
|               fields_group: feature | ||||
|               clip_outlier: true | ||||
|         - class: Fillna | ||||
|           kwargs: | ||||
|               fields_group: feature | ||||
|     learn_processors: | ||||
|         - class: DropnaLabel | ||||
|         - class: CSRankNorm | ||||
|           kwargs: | ||||
|               fields_group: label | ||||
|     label: ["Ref($close, -2) / Ref($close, -1) - 1"] | ||||
| port_analysis_config: &port_analysis_config | ||||
|     strategy: | ||||
|         class: TopkDropoutStrategy | ||||
|         module_path: qlib.contrib.strategy.strategy | ||||
|         kwargs: | ||||
|             topk: 50 | ||||
|             n_drop: 5 | ||||
|     backtest: | ||||
|         verbose: False | ||||
|         limit_threshold: 0.095 | ||||
|         account: 100000000 | ||||
|         benchmark: *benchmark | ||||
|         deal_price: close | ||||
|         open_cost: 0.0005 | ||||
|         close_cost: 0.0015 | ||||
|         min_cost: 5 | ||||
| task: | ||||
|     model: | ||||
|         class: TabnetModel | ||||
|         module_path: qlib.contrib.model.pytorch_tabnet | ||||
|         kwargs: | ||||
|             pretrain: True | ||||
|     dataset: | ||||
|         class: DatasetH | ||||
|         module_path: qlib.data.dataset | ||||
|         kwargs: | ||||
|             handler: | ||||
|                 class: Alpha360 | ||||
|                 module_path: qlib.contrib.data.handler | ||||
|                 kwargs: *data_handler_config | ||||
|             segments: | ||||
|                 pretrain: [2008-01-01, 2014-12-31] | ||||
|                 pretrain_validation: [2015-01-01, 2020-08-01] | ||||
|                 train: [2008-01-01, 2014-12-31] | ||||
|                 valid: [2015-01-01, 2016-12-31] | ||||
|                 test: [2017-01-01, 2020-08-01] | ||||
|     record:  | ||||
|         - class: SignalRecord | ||||
|           module_path: qlib.workflow.record_temp | ||||
|           kwargs: {} | ||||
|         - class: SigAnaRecord | ||||
|           module_path: qlib.workflow.record_temp | ||||
|           kwargs:  | ||||
|             ana_long_short: False | ||||
|             ann_scaler: 252 | ||||
|         - class: PortAnaRecord | ||||
|           module_path: qlib.workflow.record_temp | ||||
|           kwargs:  | ||||
|             config: *port_analysis_config | ||||
| @@ -22,13 +22,13 @@ if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| from procedures.q_exps import update_gpu | ||||
| from procedures.q_exps import update_market | ||||
| from procedures.q_exps import run_exp | ||||
|  | ||||
| import qlib | ||||
| from qlib.utils import init_instance_by_config | ||||
| from qlib.workflow import R | ||||
| from qlib.utils import flatten_dict | ||||
| from qlib.log import set_log_basic_config | ||||
|  | ||||
|  | ||||
| def retrieve_configs(): | ||||
| @@ -49,6 +49,7 @@ def retrieve_configs(): | ||||
|     alg2names["SFM"] = "workflow_config_sfm_Alpha360.yaml" | ||||
|     # DoubleEnsemble: A New Ensemble Method Based on Sample Reweighting and Feature Selection for Financial Data Analysis, https://arxiv.org/pdf/2010.01265.pdf | ||||
|     alg2names["DoubleE"] = "workflow_config_doubleensemble_Alpha360.yaml" | ||||
|     alg2names["TabNet"] = "workflow_config_TabNet_Alpha360.yaml" | ||||
|  | ||||
|     # find the yaml paths | ||||
|     alg2paths = OrderedDict() | ||||
| @@ -66,6 +67,7 @@ def main(xargs, exp_yaml): | ||||
|  | ||||
|     with open(exp_yaml) as fp: | ||||
|         config = yaml.safe_load(fp) | ||||
|     config = update_market(config, xargs.market) | ||||
|     config = update_gpu(config, xargs.gpu) | ||||
|  | ||||
|     qlib.init(**config.get("qlib_init")) | ||||
| @@ -77,7 +79,7 @@ def main(xargs, exp_yaml): | ||||
|  | ||||
|     for irun in range(xargs.times): | ||||
|         run_exp( | ||||
|             config.get("task"), dataset, xargs.alg, "recorder-{:02d}-{:02d}".format(irun, xargs.times), xargs.save_dir | ||||
|             config.get("task"), dataset, xargs.alg, "recorder-{:02d}-{:02d}".format(irun, xargs.times), '{:}-{:}'.format(xargs.save_dir, xargs.market) | ||||
|         ) | ||||
|  | ||||
|  | ||||
| @@ -87,6 +89,7 @@ if __name__ == "__main__": | ||||
|  | ||||
|     parser = argparse.ArgumentParser("Baselines") | ||||
|     parser.add_argument("--save_dir", type=str, default="./outputs/qlib-baselines", help="The checkpoint directory.") | ||||
|     parser.add_argument("--market", type=str, default="all", choices=["csi100", "csi300", "all"], help="The market indicator.") | ||||
|     parser.add_argument("--times", type=int, default=10, help="The repeated run times.") | ||||
|     parser.add_argument("--gpu", type=int, default=0, help="The GPU ID used for train / test.") | ||||
|     parser.add_argument("--alg", type=str, choices=list(alg2paths.keys()), required=True, help="The algorithm name.") | ||||
|   | ||||
| @@ -105,7 +105,7 @@ def filter_finished(recorders): | ||||
|  | ||||
|  | ||||
| def query_info(save_dir, verbose): | ||||
|     R.reset_default_uri(save_dir) | ||||
|     R.set_uri(save_dir) | ||||
|     experiments = R.list_experiments() | ||||
|  | ||||
|     key_map = { | ||||
|   | ||||
| @@ -3,15 +3,38 @@ | ||||
| ##################################################### | ||||
|  | ||||
| import inspect | ||||
| import os | ||||
| import logging | ||||
|  | ||||
| import qlib | ||||
| from qlib.utils import init_instance_by_config | ||||
| from qlib.workflow import R | ||||
| from qlib.utils import flatten_dict | ||||
| from qlib.log import set_log_basic_config | ||||
| from qlib.log import get_module_logger | ||||
|  | ||||
|  | ||||
| def set_log_basic_config(filename=None, format=None, level=None): | ||||
|     """ | ||||
|     Set the basic configuration for the logging system. | ||||
|     See details at https://docs.python.org/3/library/logging.html#logging.basicConfig | ||||
|     :param filename: str or None | ||||
|         The path to save the logs. | ||||
|     :param format: the logging format | ||||
|     :param level: int | ||||
|     :return: Logger | ||||
|         Logger object. | ||||
|     """ | ||||
|     from qlib.config import C | ||||
|  | ||||
|     if level is None: | ||||
|         level = C.logging_level | ||||
|  | ||||
|     if format is None: | ||||
|         format = C.logging_config["formatters"]["logger_format"]["format"] | ||||
|  | ||||
|     logging.basicConfig(filename=filename, format=format, level=level) | ||||
|  | ||||
|  | ||||
| def update_gpu(config, gpu): | ||||
|     config = config.copy() | ||||
|     if "task" in config and "model" in config["task"]: | ||||
| @@ -46,8 +69,8 @@ def run_exp(task_config, dataset, experiment_name, recorder_name, uri): | ||||
|     # Let's start the experiment. | ||||
|     with R.start(experiment_name=experiment_name, recorder_name=recorder_name, uri=uri): | ||||
|         # Setup log | ||||
|         recorder_root_dir = R.get_recorder().root_uri | ||||
|         log_file = recorder_root_dir / "{:}.log".format(experiment_name) | ||||
|         recorder_root_dir = R.get_recorder().get_local_dir() | ||||
|         log_file = os.path.join(recorder_root_dir, "{:}.log".format(experiment_name)) | ||||
|         set_log_basic_config(log_file) | ||||
|         logger = get_module_logger("q.run_exp") | ||||
|         logger.info("task_config={:}".format(task_config)) | ||||
| @@ -56,8 +79,8 @@ def run_exp(task_config, dataset, experiment_name, recorder_name, uri): | ||||
|  | ||||
|         # Train model | ||||
|         R.log_params(**flatten_dict(task_config)) | ||||
|         if 'save_path' in inspect.getfullargspec(model.fit).args: | ||||
|           model_fit_kwargs['save_path'] = str(recorder_root_dir / 'model-ckps') | ||||
|         if "save_path" in inspect.getfullargspec(model.fit).args: | ||||
|             model_fit_kwargs["save_path"] = os.path.join(recorder_root_dir, "model-ckps") | ||||
|         model.fit(**model_fit_kwargs) | ||||
|         # Get the recorder | ||||
|         recorder = R.get_recorder() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user