Update q-config and black for procedures/utils
This commit is contained in:
		| @@ -4,6 +4,8 @@ | ||||
| # python exps/trading/baselines.py --alg GRU | ||||
| # python exps/trading/baselines.py --alg LSTM | ||||
| # python exps/trading/baselines.py --alg ALSTM | ||||
| # python exps/trading/baselines.py --alg MLP | ||||
| # python exps/trading/baselines.py --alg SFM | ||||
| # python exps/trading/baselines.py --alg XGBoost | ||||
| # python exps/trading/baselines.py --alg LightGBM | ||||
| ##################################################### | ||||
| @@ -17,6 +19,10 @@ lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| from procedures.q_exps import update_gpu | ||||
| from procedures.q_exps import update_market | ||||
| from procedures.q_exps import run_exp | ||||
|  | ||||
| import qlib | ||||
| from qlib.utils import init_instance_by_config | ||||
| from qlib.workflow import R | ||||
| @@ -31,15 +37,19 @@ def retrieve_configs(): | ||||
|     alg2names = OrderedDict() | ||||
|     alg2names["GRU"] = "workflow_config_gru_Alpha360.yaml" | ||||
|     alg2names["LSTM"] = "workflow_config_lstm_Alpha360.yaml" | ||||
|     alg2names["MLP"] = "workflow_config_mlp_Alpha360.yaml" | ||||
|     # A dual-stage attention-based recurrent neural network for time series prediction, IJCAI-2017 | ||||
|     alg2names["ALSTM"] = "workflow_config_alstm_Alpha360.yaml" | ||||
|     # XGBoost: A Scalable Tree Boosting System, KDD-2016 | ||||
|     alg2names["XGBoost"] = "workflow_config_xgboost_Alpha360.yaml" | ||||
|     # LightGBM: A Highly Efficient Gradient Boosting Decision Tree, NeurIPS-2017 | ||||
|     alg2names["LightGBM"] = "workflow_config_lightgbm_Alpha360.yaml" | ||||
|     # State Frequency Memory (SFM): Stock Price Prediction via Discovering Multi-Frequency Trading Patterns, KDD-2017 | ||||
|     alg2names["SFM"] = "workflow_config_sfm_Alpha360.yaml" | ||||
|  | ||||
|     # find the yaml paths | ||||
|     alg2paths = OrderedDict() | ||||
|     print("Start retrieving the algorithm configurations") | ||||
|     for idx, (alg, name) in enumerate(alg2names.items()): | ||||
|         path = config_dir / name | ||||
|         assert path.exists(), "{:} does not exist.".format(path) | ||||
| @@ -48,56 +58,6 @@ def retrieve_configs(): | ||||
|     return alg2paths | ||||
|  | ||||
|  | ||||
| def update_gpu(config, gpu): | ||||
|     config = config.copy() | ||||
|     if "GPU" in config["task"]["model"]: | ||||
|         config["task"]["model"]["GPU"] = gpu | ||||
|     return config | ||||
|  | ||||
|  | ||||
| def update_market(config, market): | ||||
|     config = config.copy() | ||||
|     config["market"] = market | ||||
|     config["data_handler_config"]["instruments"] = market | ||||
|     return config | ||||
|  | ||||
|  | ||||
| def run_exp(task_config, dataset, experiment_name, recorder_name, uri): | ||||
|  | ||||
|     # model initiaiton | ||||
|     print("") | ||||
|     print("[{:}] - [{:}]: {:}".format(experiment_name, recorder_name, uri)) | ||||
|     print("dataset={:}".format(dataset)) | ||||
|  | ||||
|     model = init_instance_by_config(task_config["model"]) | ||||
|  | ||||
|     # start exp | ||||
|     with R.start(experiment_name=experiment_name, recorder_name=recorder_name, uri=uri): | ||||
|  | ||||
|         log_file = R.get_recorder().root_uri / "{:}.log".format(experiment_name) | ||||
|         set_log_basic_config(log_file) | ||||
|  | ||||
|         # train model | ||||
|         R.log_params(**flatten_dict(task_config)) | ||||
|         model.fit(dataset) | ||||
|         recorder = R.get_recorder() | ||||
|         R.save_objects(**{"model.pkl": model}) | ||||
|  | ||||
|         # generate records: prediction, backtest, and analysis | ||||
|         for record in task_config["record"]: | ||||
|             record = record.copy() | ||||
|             if record["class"] == "SignalRecord": | ||||
|                 srconf = {"model": model, "dataset": dataset, "recorder": recorder} | ||||
|                 record["kwargs"].update(srconf) | ||||
|                 sr = init_instance_by_config(record) | ||||
|                 sr.generate() | ||||
|             else: | ||||
|                 rconf = {"recorder": recorder} | ||||
|                 record["kwargs"].update(rconf) | ||||
|                 ar = init_instance_by_config(record) | ||||
|                 ar.generate() | ||||
|  | ||||
|  | ||||
| def main(xargs, exp_yaml): | ||||
|     assert Path(exp_yaml).exists(), "{:} does not exist.".format(exp_yaml) | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user