From 53e1441c8da45e29db97952d90092a7c8ac037c9 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Sat, 6 Mar 2021 19:27:05 -0800 Subject: [PATCH] Fix bugs in update_gpu in procedures --- exps/trading/baselines.py | 19 +++++++++---------- lib/procedures/q_exps.py | 16 ++++++++++++---- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/exps/trading/baselines.py b/exps/trading/baselines.py index daccf43..b91d085 100644 --- a/exps/trading/baselines.py +++ b/exps/trading/baselines.py @@ -1,15 +1,16 @@ ##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 # ##################################################### -# python exps/trading/baselines.py --alg GRU -# python exps/trading/baselines.py --alg LSTM -# python exps/trading/baselines.py --alg ALSTM -# python exps/trading/baselines.py --alg MLP -# python exps/trading/baselines.py --alg SFM -# python exps/trading/baselines.py --alg XGBoost -# python exps/trading/baselines.py --alg LightGBM +# python exps/trading/baselines.py --alg GRU # +# python exps/trading/baselines.py --alg LSTM # +# python exps/trading/baselines.py --alg ALSTM # +# python exps/trading/baselines.py --alg MLP # +# python exps/trading/baselines.py --alg SFM # +# python exps/trading/baselines.py --alg XGBoost # +# python exps/trading/baselines.py --alg LightGBM # ##################################################### -import sys, argparse +import sys +import argparse from collections import OrderedDict from pathlib import Path from pprint import pprint @@ -20,7 +21,6 @@ if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) from procedures.q_exps import update_gpu -from procedures.q_exps import update_market from procedures.q_exps import run_exp import qlib @@ -64,7 +64,6 @@ def main(xargs, exp_yaml): with open(exp_yaml) as fp: config = yaml.safe_load(fp) config = update_gpu(config, xargs.gpu) - # config = update_market(config, 'csi300') qlib.init(**config.get("qlib_init")) dataset_config = config.get("task").get("dataset") diff --git a/lib/procedures/q_exps.py b/lib/procedures/q_exps.py index 3b62167..7a9f74b 100644 --- a/lib/procedures/q_exps.py +++ b/lib/procedures/q_exps.py @@ -12,10 +12,18 @@ from qlib.log import get_module_logger def update_gpu(config, gpu): config = config.copy() - if "task" in config and "moodel" in config["task"] and "GPU" in config["task"]["model"]: - config["task"]["model"]["GPU"] = gpu - elif "model" in config and "GPU" in config["model"]: - config["model"]["GPU"] = gpu + if "task" in config and "model" in config["task"]: + if "GPU" in config["task"]["model"]: + config["task"]["model"]["GPU"] = gpu + elif "kwargs" in config["task"]["model"] and "GPU" in config["task"]["model"]["kwargs"]: + config["task"]["model"]["kwargs"]["GPU"] = gpu + elif "model" in config: + if "GPU" in config["model"]: + config["model"]["GPU"] = gpu + elif "kwargs" in config["model"] and "GPU" in config["model"]["kwargs"]: + config["model"]["kwargs"]["GPU"] = gpu + elif "kwargs" in config and "GPU" in config["kwargs"]: + config["kwargs"]["GPU"] = gpu elif "GPU" in config: config["GPU"] = gpu return config