Fix bugs in update_gpu in procedures

This commit is contained in:
D-X-Y 2021-03-06 19:27:05 -08:00
parent 55c9734c31
commit 53e1441c8d
2 changed files with 21 additions and 14 deletions

View File

@ -1,15 +1,16 @@
##################################################### #####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 #
##################################################### #####################################################
# python exps/trading/baselines.py --alg GRU # python exps/trading/baselines.py --alg GRU #
# python exps/trading/baselines.py --alg LSTM # python exps/trading/baselines.py --alg LSTM #
# python exps/trading/baselines.py --alg ALSTM # python exps/trading/baselines.py --alg ALSTM #
# python exps/trading/baselines.py --alg MLP # python exps/trading/baselines.py --alg MLP #
# python exps/trading/baselines.py --alg SFM # python exps/trading/baselines.py --alg SFM #
# python exps/trading/baselines.py --alg XGBoost # python exps/trading/baselines.py --alg XGBoost #
# python exps/trading/baselines.py --alg LightGBM # python exps/trading/baselines.py --alg LightGBM #
##################################################### #####################################################
import sys, argparse import sys
import argparse
from collections import OrderedDict from collections import OrderedDict
from pathlib import Path from pathlib import Path
from pprint import pprint from pprint import pprint
@ -20,7 +21,6 @@ if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir)) sys.path.insert(0, str(lib_dir))
from procedures.q_exps import update_gpu from procedures.q_exps import update_gpu
from procedures.q_exps import update_market
from procedures.q_exps import run_exp from procedures.q_exps import run_exp
import qlib import qlib
@ -64,7 +64,6 @@ def main(xargs, exp_yaml):
with open(exp_yaml) as fp: with open(exp_yaml) as fp:
config = yaml.safe_load(fp) config = yaml.safe_load(fp)
config = update_gpu(config, xargs.gpu) config = update_gpu(config, xargs.gpu)
# config = update_market(config, 'csi300')
qlib.init(**config.get("qlib_init")) qlib.init(**config.get("qlib_init"))
dataset_config = config.get("task").get("dataset") dataset_config = config.get("task").get("dataset")

View File

@ -12,10 +12,18 @@ from qlib.log import get_module_logger
def update_gpu(config, gpu): def update_gpu(config, gpu):
config = config.copy() config = config.copy()
if "task" in config and "moodel" in config["task"] and "GPU" in config["task"]["model"]: if "task" in config and "model" in config["task"]:
if "GPU" in config["task"]["model"]:
config["task"]["model"]["GPU"] = gpu config["task"]["model"]["GPU"] = gpu
elif "model" in config and "GPU" in config["model"]: elif "kwargs" in config["task"]["model"] and "GPU" in config["task"]["model"]["kwargs"]:
config["task"]["model"]["kwargs"]["GPU"] = gpu
elif "model" in config:
if "GPU" in config["model"]:
config["model"]["GPU"] = gpu config["model"]["GPU"] = gpu
elif "kwargs" in config["model"] and "GPU" in config["model"]["kwargs"]:
config["model"]["kwargs"]["GPU"] = gpu
elif "kwargs" in config and "GPU" in config["kwargs"]:
config["kwargs"]["GPU"] = gpu
elif "GPU" in config: elif "GPU" in config:
config["GPU"] = gpu config["GPU"] = gpu
return config return config