| 
									
										
										
										
											2021-03-06 06:38:34 -08:00
										 |  |  | ##################################################### | 
					
						
							|  |  |  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 # | 
					
						
							|  |  |  | ##################################################### | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-07 01:44:26 -08:00
										 |  |  | import inspect | 
					
						
							| 
									
										
										
										
											2021-03-11 03:09:55 +00:00
										 |  |  | import os | 
					
						
							| 
									
										
										
										
											2021-03-15 04:56:39 +00:00
										 |  |  | import pprint | 
					
						
							| 
									
										
										
										
											2021-03-11 03:09:55 +00:00
										 |  |  | import logging | 
					
						
							| 
									
										
										
										
											2021-03-23 11:13:51 +00:00
										 |  |  | from copy import deepcopy | 
					
						
							| 
									
										
										
										
											2021-03-06 06:38:34 -08:00
										 |  |  | import qlib | 
					
						
							|  |  |  | from qlib.utils import init_instance_by_config | 
					
						
							|  |  |  | from qlib.workflow import R | 
					
						
							|  |  |  | from qlib.utils import flatten_dict | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  | from qlib.log import get_module_logger | 
					
						
							| 
									
										
										
										
											2021-03-06 06:38:34 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-11 03:09:55 +00:00
										 |  |  | def set_log_basic_config(filename=None, format=None, level=None): | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     Set the basic configuration for the logging system. | 
					
						
							|  |  |  |     See details at https://docs.python.org/3/library/logging.html#logging.basicConfig | 
					
						
							|  |  |  |     :param filename: str or None | 
					
						
							|  |  |  |         The path to save the logs. | 
					
						
							|  |  |  |     :param format: the logging format | 
					
						
							|  |  |  |     :param level: int | 
					
						
							|  |  |  |     :return: Logger | 
					
						
							|  |  |  |         Logger object. | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     from qlib.config import C | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if level is None: | 
					
						
							|  |  |  |         level = C.logging_level | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if format is None: | 
					
						
							|  |  |  |         format = C.logging_config["formatters"]["logger_format"]["format"] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-23 11:13:51 +00:00
										 |  |  |     # Remove all handlers associated with the root logger object. | 
					
						
							|  |  |  |     for handler in logging.root.handlers[:]: | 
					
						
							|  |  |  |         logging.root.removeHandler(handler) | 
					
						
							| 
									
										
										
										
											2021-03-11 03:09:55 +00:00
										 |  |  |     logging.basicConfig(filename=filename, format=format, level=level) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-06 06:38:34 -08:00
										 |  |  | def update_gpu(config, gpu): | 
					
						
							| 
									
										
										
										
											2021-03-23 11:13:51 +00:00
										 |  |  |     config = deepcopy(config) | 
					
						
							| 
									
										
										
										
											2021-03-06 19:27:05 -08:00
										 |  |  |     if "task" in config and "model" in config["task"]: | 
					
						
							|  |  |  |         if "GPU" in config["task"]["model"]: | 
					
						
							|  |  |  |             config["task"]["model"]["GPU"] = gpu | 
					
						
							| 
									
										
										
										
											2021-03-19 23:57:23 +08:00
										 |  |  |         elif ( | 
					
						
							|  |  |  |             "kwargs" in config["task"]["model"] | 
					
						
							|  |  |  |             and "GPU" in config["task"]["model"]["kwargs"] | 
					
						
							|  |  |  |         ): | 
					
						
							| 
									
										
										
										
											2021-03-06 19:27:05 -08:00
										 |  |  |             config["task"]["model"]["kwargs"]["GPU"] = gpu | 
					
						
							|  |  |  |     elif "model" in config: | 
					
						
							|  |  |  |         if "GPU" in config["model"]: | 
					
						
							|  |  |  |             config["model"]["GPU"] = gpu | 
					
						
							|  |  |  |         elif "kwargs" in config["model"] and "GPU" in config["model"]["kwargs"]: | 
					
						
							|  |  |  |             config["model"]["kwargs"]["GPU"] = gpu | 
					
						
							|  |  |  |     elif "kwargs" in config and "GPU" in config["kwargs"]: | 
					
						
							|  |  |  |         config["kwargs"]["GPU"] = gpu | 
					
						
							| 
									
										
										
										
											2021-03-06 06:38:34 -08:00
										 |  |  |     elif "GPU" in config: | 
					
						
							|  |  |  |         config["GPU"] = gpu | 
					
						
							|  |  |  |     return config | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def update_market(config, market): | 
					
						
							| 
									
										
										
										
											2021-03-23 11:13:51 +00:00
										 |  |  |     config = deepcopy(config.copy()) | 
					
						
							| 
									
										
										
										
											2021-03-06 06:38:34 -08:00
										 |  |  |     config["market"] = market | 
					
						
							|  |  |  |     config["data_handler_config"]["instruments"] = market | 
					
						
							|  |  |  |     return config | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-23 11:13:51 +00:00
										 |  |  | def run_exp( | 
					
						
							|  |  |  |     task_config, | 
					
						
							|  |  |  |     dataset, | 
					
						
							|  |  |  |     experiment_name, | 
					
						
							|  |  |  |     recorder_name, | 
					
						
							|  |  |  |     uri, | 
					
						
							|  |  |  |     model_obj_name="model.pkl", | 
					
						
							|  |  |  | ): | 
					
						
							| 
									
										
										
										
											2021-03-06 06:38:34 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     model = init_instance_by_config(task_config["model"]) | 
					
						
							| 
									
										
										
										
											2021-03-07 01:44:26 -08:00
										 |  |  |     model_fit_kwargs = dict(dataset=dataset) | 
					
						
							| 
									
										
										
										
											2021-03-06 06:38:34 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-07 01:44:26 -08:00
										 |  |  |     # Let's start the experiment. | 
					
						
							| 
									
										
										
										
											2021-03-19 23:57:23 +08:00
										 |  |  |     with R.start( | 
					
						
							|  |  |  |         experiment_name=experiment_name, | 
					
						
							|  |  |  |         recorder_name=recorder_name, | 
					
						
							|  |  |  |         uri=uri, | 
					
						
							|  |  |  |         resume=True, | 
					
						
							|  |  |  |     ): | 
					
						
							| 
									
										
										
										
											2021-03-07 01:44:26 -08:00
										 |  |  |         # Setup log | 
					
						
							| 
									
										
										
										
											2021-03-11 03:09:55 +00:00
										 |  |  |         recorder_root_dir = R.get_recorder().get_local_dir() | 
					
						
							|  |  |  |         log_file = os.path.join(recorder_root_dir, "{:}.log".format(experiment_name)) | 
					
						
							| 
									
										
										
										
											2021-03-23 11:13:51 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-06 06:38:34 -08:00
										 |  |  |         set_log_basic_config(log_file) | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |         logger = get_module_logger("q.run_exp") | 
					
						
							| 
									
										
										
										
											2021-03-15 04:56:39 +00:00
										 |  |  |         logger.info("task_config::\n{:}".format(pprint.pformat(task_config, indent=2))) | 
					
						
							| 
									
										
										
										
											2021-03-07 03:09:47 +00:00
										 |  |  |         logger.info("[{:}] - [{:}]: {:}".format(experiment_name, recorder_name, uri)) | 
					
						
							|  |  |  |         logger.info("dataset={:}".format(dataset)) | 
					
						
							| 
									
										
										
										
											2021-03-06 06:38:34 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-07 01:44:26 -08:00
										 |  |  |         # Train model | 
					
						
							| 
									
										
										
										
											2021-03-23 11:13:51 +00:00
										 |  |  |         try: | 
					
						
							|  |  |  |             model = R.load_object(model_obj_name) | 
					
						
							|  |  |  |             logger.info("[Find existing object from {:}]".format(model_obj_name)) | 
					
						
							|  |  |  |         except OSError: | 
					
						
							|  |  |  |             R.log_params(**flatten_dict(task_config)) | 
					
						
							|  |  |  |             if "save_path" in inspect.getfullargspec(model.fit).args: | 
					
						
							|  |  |  |                 model_fit_kwargs["save_path"] = os.path.join( | 
					
						
							|  |  |  |                     recorder_root_dir, "model.ckp" | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |             elif "save_dir" in inspect.getfullargspec(model.fit).args: | 
					
						
							|  |  |  |                 model_fit_kwargs["save_dir"] = os.path.join( | 
					
						
							|  |  |  |                     recorder_root_dir, "model-ckps" | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |             model.fit(**model_fit_kwargs) | 
					
						
							|  |  |  |             R.save_objects(**{model_obj_name: model}) | 
					
						
							|  |  |  |         except: | 
					
						
							|  |  |  |             raise ValueError("Something wrong.") | 
					
						
							| 
									
										
										
										
											2021-03-07 01:44:26 -08:00
										 |  |  |         # Get the recorder | 
					
						
							| 
									
										
										
										
											2021-03-06 06:38:34 -08:00
										 |  |  |         recorder = R.get_recorder() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-07 01:44:26 -08:00
										 |  |  |         # Generate records: prediction, backtest, and analysis | 
					
						
							| 
									
										
										
										
											2021-03-06 06:38:34 -08:00
										 |  |  |         for record in task_config["record"]: | 
					
						
							| 
									
										
										
										
											2021-03-23 11:13:51 +00:00
										 |  |  |             record = deepcopy(record) | 
					
						
							| 
									
										
										
										
											2021-03-06 06:38:34 -08:00
										 |  |  |             if record["class"] == "SignalRecord": | 
					
						
							|  |  |  |                 srconf = {"model": model, "dataset": dataset, "recorder": recorder} | 
					
						
							|  |  |  |                 record["kwargs"].update(srconf) | 
					
						
							|  |  |  |                 sr = init_instance_by_config(record) | 
					
						
							|  |  |  |                 sr.generate() | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 rconf = {"recorder": recorder} | 
					
						
							|  |  |  |                 record["kwargs"].update(rconf) | 
					
						
							|  |  |  |                 ar = init_instance_by_config(record) | 
					
						
							|  |  |  |                 ar.generate() |