Update to accomendate last updates of qlib
This commit is contained in:
		 Submodule .latent-data/qlib updated: d13c9ae018...0ef7c8e0e6
									
								
							
							
								
								
									
										74
									
								
								configs/qlib/workflow_config_TabNet_Alpha360.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										74
									
								
								configs/qlib/workflow_config_TabNet_Alpha360.yaml
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,74 @@ | |||||||
|  | qlib_init: | ||||||
|  |     provider_uri: "~/.qlib/qlib_data/cn_data" | ||||||
|  |     region: cn | ||||||
|  | market: &market all | ||||||
|  | benchmark: &benchmark SH000300 | ||||||
|  | data_handler_config: &data_handler_config | ||||||
|  |     start_time: 2008-01-01 | ||||||
|  |     end_time: 2020-08-01 | ||||||
|  |     fit_start_time: 2008-01-01 | ||||||
|  |     fit_end_time: 2014-12-31 | ||||||
|  |     instruments: *market | ||||||
|  |     infer_processors: | ||||||
|  |         - class: RobustZScoreNorm | ||||||
|  |           kwargs: | ||||||
|  |               fields_group: feature | ||||||
|  |               clip_outlier: true | ||||||
|  |         - class: Fillna | ||||||
|  |           kwargs: | ||||||
|  |               fields_group: feature | ||||||
|  |     learn_processors: | ||||||
|  |         - class: DropnaLabel | ||||||
|  |         - class: CSRankNorm | ||||||
|  |           kwargs: | ||||||
|  |               fields_group: label | ||||||
|  |     label: ["Ref($close, -2) / Ref($close, -1) - 1"] | ||||||
|  | port_analysis_config: &port_analysis_config | ||||||
|  |     strategy: | ||||||
|  |         class: TopkDropoutStrategy | ||||||
|  |         module_path: qlib.contrib.strategy.strategy | ||||||
|  |         kwargs: | ||||||
|  |             topk: 50 | ||||||
|  |             n_drop: 5 | ||||||
|  |     backtest: | ||||||
|  |         verbose: False | ||||||
|  |         limit_threshold: 0.095 | ||||||
|  |         account: 100000000 | ||||||
|  |         benchmark: *benchmark | ||||||
|  |         deal_price: close | ||||||
|  |         open_cost: 0.0005 | ||||||
|  |         close_cost: 0.0015 | ||||||
|  |         min_cost: 5 | ||||||
|  | task: | ||||||
|  |     model: | ||||||
|  |         class: TabnetModel | ||||||
|  |         module_path: qlib.contrib.model.pytorch_tabnet | ||||||
|  |         kwargs: | ||||||
|  |             pretrain: True | ||||||
|  |     dataset: | ||||||
|  |         class: DatasetH | ||||||
|  |         module_path: qlib.data.dataset | ||||||
|  |         kwargs: | ||||||
|  |             handler: | ||||||
|  |                 class: Alpha360 | ||||||
|  |                 module_path: qlib.contrib.data.handler | ||||||
|  |                 kwargs: *data_handler_config | ||||||
|  |             segments: | ||||||
|  |                 pretrain: [2008-01-01, 2014-12-31] | ||||||
|  |                 pretrain_validation: [2015-01-01, 2020-08-01] | ||||||
|  |                 train: [2008-01-01, 2014-12-31] | ||||||
|  |                 valid: [2015-01-01, 2016-12-31] | ||||||
|  |                 test: [2017-01-01, 2020-08-01] | ||||||
|  |     record:  | ||||||
|  |         - class: SignalRecord | ||||||
|  |           module_path: qlib.workflow.record_temp | ||||||
|  |           kwargs: {} | ||||||
|  |         - class: SigAnaRecord | ||||||
|  |           module_path: qlib.workflow.record_temp | ||||||
|  |           kwargs:  | ||||||
|  |             ana_long_short: False | ||||||
|  |             ann_scaler: 252 | ||||||
|  |         - class: PortAnaRecord | ||||||
|  |           module_path: qlib.workflow.record_temp | ||||||
|  |           kwargs:  | ||||||
|  |             config: *port_analysis_config | ||||||
| @@ -22,13 +22,13 @@ if str(lib_dir) not in sys.path: | |||||||
|     sys.path.insert(0, str(lib_dir)) |     sys.path.insert(0, str(lib_dir)) | ||||||
|  |  | ||||||
| from procedures.q_exps import update_gpu | from procedures.q_exps import update_gpu | ||||||
|  | from procedures.q_exps import update_market | ||||||
| from procedures.q_exps import run_exp | from procedures.q_exps import run_exp | ||||||
|  |  | ||||||
| import qlib | import qlib | ||||||
| from qlib.utils import init_instance_by_config | from qlib.utils import init_instance_by_config | ||||||
| from qlib.workflow import R | from qlib.workflow import R | ||||||
| from qlib.utils import flatten_dict | from qlib.utils import flatten_dict | ||||||
| from qlib.log import set_log_basic_config |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def retrieve_configs(): | def retrieve_configs(): | ||||||
| @@ -49,6 +49,7 @@ def retrieve_configs(): | |||||||
|     alg2names["SFM"] = "workflow_config_sfm_Alpha360.yaml" |     alg2names["SFM"] = "workflow_config_sfm_Alpha360.yaml" | ||||||
|     # DoubleEnsemble: A New Ensemble Method Based on Sample Reweighting and Feature Selection for Financial Data Analysis, https://arxiv.org/pdf/2010.01265.pdf |     # DoubleEnsemble: A New Ensemble Method Based on Sample Reweighting and Feature Selection for Financial Data Analysis, https://arxiv.org/pdf/2010.01265.pdf | ||||||
|     alg2names["DoubleE"] = "workflow_config_doubleensemble_Alpha360.yaml" |     alg2names["DoubleE"] = "workflow_config_doubleensemble_Alpha360.yaml" | ||||||
|  |     alg2names["TabNet"] = "workflow_config_TabNet_Alpha360.yaml" | ||||||
|  |  | ||||||
|     # find the yaml paths |     # find the yaml paths | ||||||
|     alg2paths = OrderedDict() |     alg2paths = OrderedDict() | ||||||
| @@ -66,6 +67,7 @@ def main(xargs, exp_yaml): | |||||||
|  |  | ||||||
|     with open(exp_yaml) as fp: |     with open(exp_yaml) as fp: | ||||||
|         config = yaml.safe_load(fp) |         config = yaml.safe_load(fp) | ||||||
|  |     config = update_market(config, xargs.market) | ||||||
|     config = update_gpu(config, xargs.gpu) |     config = update_gpu(config, xargs.gpu) | ||||||
|  |  | ||||||
|     qlib.init(**config.get("qlib_init")) |     qlib.init(**config.get("qlib_init")) | ||||||
| @@ -77,7 +79,7 @@ def main(xargs, exp_yaml): | |||||||
|  |  | ||||||
|     for irun in range(xargs.times): |     for irun in range(xargs.times): | ||||||
|         run_exp( |         run_exp( | ||||||
|             config.get("task"), dataset, xargs.alg, "recorder-{:02d}-{:02d}".format(irun, xargs.times), xargs.save_dir |             config.get("task"), dataset, xargs.alg, "recorder-{:02d}-{:02d}".format(irun, xargs.times), '{:}-{:}'.format(xargs.save_dir, xargs.market) | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -87,6 +89,7 @@ if __name__ == "__main__": | |||||||
|  |  | ||||||
|     parser = argparse.ArgumentParser("Baselines") |     parser = argparse.ArgumentParser("Baselines") | ||||||
|     parser.add_argument("--save_dir", type=str, default="./outputs/qlib-baselines", help="The checkpoint directory.") |     parser.add_argument("--save_dir", type=str, default="./outputs/qlib-baselines", help="The checkpoint directory.") | ||||||
|  |     parser.add_argument("--market", type=str, default="all", choices=["csi100", "csi300", "all"], help="The market indicator.") | ||||||
|     parser.add_argument("--times", type=int, default=10, help="The repeated run times.") |     parser.add_argument("--times", type=int, default=10, help="The repeated run times.") | ||||||
|     parser.add_argument("--gpu", type=int, default=0, help="The GPU ID used for train / test.") |     parser.add_argument("--gpu", type=int, default=0, help="The GPU ID used for train / test.") | ||||||
|     parser.add_argument("--alg", type=str, choices=list(alg2paths.keys()), required=True, help="The algorithm name.") |     parser.add_argument("--alg", type=str, choices=list(alg2paths.keys()), required=True, help="The algorithm name.") | ||||||
|   | |||||||
| @@ -105,7 +105,7 @@ def filter_finished(recorders): | |||||||
|  |  | ||||||
|  |  | ||||||
| def query_info(save_dir, verbose): | def query_info(save_dir, verbose): | ||||||
|     R.reset_default_uri(save_dir) |     R.set_uri(save_dir) | ||||||
|     experiments = R.list_experiments() |     experiments = R.list_experiments() | ||||||
|  |  | ||||||
|     key_map = { |     key_map = { | ||||||
|   | |||||||
| @@ -3,15 +3,38 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
|  |  | ||||||
| import inspect | import inspect | ||||||
|  | import os | ||||||
|  | import logging | ||||||
|  |  | ||||||
| import qlib | import qlib | ||||||
| from qlib.utils import init_instance_by_config | from qlib.utils import init_instance_by_config | ||||||
| from qlib.workflow import R | from qlib.workflow import R | ||||||
| from qlib.utils import flatten_dict | from qlib.utils import flatten_dict | ||||||
| from qlib.log import set_log_basic_config |  | ||||||
| from qlib.log import get_module_logger | from qlib.log import get_module_logger | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def set_log_basic_config(filename=None, format=None, level=None): | ||||||
|  |     """ | ||||||
|  |     Set the basic configuration for the logging system. | ||||||
|  |     See details at https://docs.python.org/3/library/logging.html#logging.basicConfig | ||||||
|  |     :param filename: str or None | ||||||
|  |         The path to save the logs. | ||||||
|  |     :param format: the logging format | ||||||
|  |     :param level: int | ||||||
|  |     :return: Logger | ||||||
|  |         Logger object. | ||||||
|  |     """ | ||||||
|  |     from qlib.config import C | ||||||
|  |  | ||||||
|  |     if level is None: | ||||||
|  |         level = C.logging_level | ||||||
|  |  | ||||||
|  |     if format is None: | ||||||
|  |         format = C.logging_config["formatters"]["logger_format"]["format"] | ||||||
|  |  | ||||||
|  |     logging.basicConfig(filename=filename, format=format, level=level) | ||||||
|  |  | ||||||
|  |  | ||||||
| def update_gpu(config, gpu): | def update_gpu(config, gpu): | ||||||
|     config = config.copy() |     config = config.copy() | ||||||
|     if "task" in config and "model" in config["task"]: |     if "task" in config and "model" in config["task"]: | ||||||
| @@ -46,8 +69,8 @@ def run_exp(task_config, dataset, experiment_name, recorder_name, uri): | |||||||
|     # Let's start the experiment. |     # Let's start the experiment. | ||||||
|     with R.start(experiment_name=experiment_name, recorder_name=recorder_name, uri=uri): |     with R.start(experiment_name=experiment_name, recorder_name=recorder_name, uri=uri): | ||||||
|         # Setup log |         # Setup log | ||||||
|         recorder_root_dir = R.get_recorder().root_uri |         recorder_root_dir = R.get_recorder().get_local_dir() | ||||||
|         log_file = recorder_root_dir / "{:}.log".format(experiment_name) |         log_file = os.path.join(recorder_root_dir, "{:}.log".format(experiment_name)) | ||||||
|         set_log_basic_config(log_file) |         set_log_basic_config(log_file) | ||||||
|         logger = get_module_logger("q.run_exp") |         logger = get_module_logger("q.run_exp") | ||||||
|         logger.info("task_config={:}".format(task_config)) |         logger.info("task_config={:}".format(task_config)) | ||||||
| @@ -56,8 +79,8 @@ def run_exp(task_config, dataset, experiment_name, recorder_name, uri): | |||||||
|  |  | ||||||
|         # Train model |         # Train model | ||||||
|         R.log_params(**flatten_dict(task_config)) |         R.log_params(**flatten_dict(task_config)) | ||||||
|         if 'save_path' in inspect.getfullargspec(model.fit).args: |         if "save_path" in inspect.getfullargspec(model.fit).args: | ||||||
|           model_fit_kwargs['save_path'] = str(recorder_root_dir / 'model-ckps') |             model_fit_kwargs["save_path"] = os.path.join(recorder_root_dir, "model-ckps") | ||||||
|         model.fit(**model_fit_kwargs) |         model.fit(**model_fit_kwargs) | ||||||
|         # Get the recorder |         # Get the recorder | ||||||
|         recorder = R.get_recorder() |         recorder = R.get_recorder() | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user