Fix bugs in update_gpu in procedures

This commit is contained in:
D-X-Y 2021-03-06 19:27:05 -08:00
parent 55c9734c31
commit 53e1441c8d
2 changed files with 21 additions and 14 deletions

View File

@ -1,15 +1,16 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 #
#####################################################
# python exps/trading/baselines.py --alg GRU
# python exps/trading/baselines.py --alg LSTM
# python exps/trading/baselines.py --alg ALSTM
# python exps/trading/baselines.py --alg MLP
# python exps/trading/baselines.py --alg SFM
# python exps/trading/baselines.py --alg XGBoost
# python exps/trading/baselines.py --alg LightGBM
# python exps/trading/baselines.py --alg GRU #
# python exps/trading/baselines.py --alg LSTM #
# python exps/trading/baselines.py --alg ALSTM #
# python exps/trading/baselines.py --alg MLP #
# python exps/trading/baselines.py --alg SFM #
# python exps/trading/baselines.py --alg XGBoost #
# python exps/trading/baselines.py --alg LightGBM #
#####################################################
import sys, argparse
import sys
import argparse
from collections import OrderedDict
from pathlib import Path
from pprint import pprint
@ -20,7 +21,6 @@ if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from procedures.q_exps import update_gpu
from procedures.q_exps import update_market
from procedures.q_exps import run_exp
import qlib
@ -64,7 +64,6 @@ def main(xargs, exp_yaml):
with open(exp_yaml) as fp:
config = yaml.safe_load(fp)
config = update_gpu(config, xargs.gpu)
# config = update_market(config, 'csi300')
qlib.init(**config.get("qlib_init"))
dataset_config = config.get("task").get("dataset")

View File

@ -12,10 +12,18 @@ from qlib.log import get_module_logger
def update_gpu(config, gpu):
config = config.copy()
if "task" in config and "moodel" in config["task"] and "GPU" in config["task"]["model"]:
config["task"]["model"]["GPU"] = gpu
elif "model" in config and "GPU" in config["model"]:
config["model"]["GPU"] = gpu
if "task" in config and "model" in config["task"]:
if "GPU" in config["task"]["model"]:
config["task"]["model"]["GPU"] = gpu
elif "kwargs" in config["task"]["model"] and "GPU" in config["task"]["model"]["kwargs"]:
config["task"]["model"]["kwargs"]["GPU"] = gpu
elif "model" in config:
if "GPU" in config["model"]:
config["model"]["GPU"] = gpu
elif "kwargs" in config["model"] and "GPU" in config["model"]["kwargs"]:
config["model"]["kwargs"]["GPU"] = gpu
elif "kwargs" in config and "GPU" in config["kwargs"]:
config["kwargs"]["GPU"] = gpu
elif "GPU" in config:
config["GPU"] = gpu
return config