Fix bugs in update_gpu in procedures
This commit is contained in:
parent
55c9734c31
commit
53e1441c8d
@ -1,15 +1,16 @@
|
|||||||
#####################################################
|
#####################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 #
|
||||||
#####################################################
|
#####################################################
|
||||||
# python exps/trading/baselines.py --alg GRU
|
# python exps/trading/baselines.py --alg GRU #
|
||||||
# python exps/trading/baselines.py --alg LSTM
|
# python exps/trading/baselines.py --alg LSTM #
|
||||||
# python exps/trading/baselines.py --alg ALSTM
|
# python exps/trading/baselines.py --alg ALSTM #
|
||||||
# python exps/trading/baselines.py --alg MLP
|
# python exps/trading/baselines.py --alg MLP #
|
||||||
# python exps/trading/baselines.py --alg SFM
|
# python exps/trading/baselines.py --alg SFM #
|
||||||
# python exps/trading/baselines.py --alg XGBoost
|
# python exps/trading/baselines.py --alg XGBoost #
|
||||||
# python exps/trading/baselines.py --alg LightGBM
|
# python exps/trading/baselines.py --alg LightGBM #
|
||||||
#####################################################
|
#####################################################
|
||||||
import sys, argparse
|
import sys
|
||||||
|
import argparse
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
@ -20,7 +21,6 @@ if str(lib_dir) not in sys.path:
|
|||||||
sys.path.insert(0, str(lib_dir))
|
sys.path.insert(0, str(lib_dir))
|
||||||
|
|
||||||
from procedures.q_exps import update_gpu
|
from procedures.q_exps import update_gpu
|
||||||
from procedures.q_exps import update_market
|
|
||||||
from procedures.q_exps import run_exp
|
from procedures.q_exps import run_exp
|
||||||
|
|
||||||
import qlib
|
import qlib
|
||||||
@ -64,7 +64,6 @@ def main(xargs, exp_yaml):
|
|||||||
with open(exp_yaml) as fp:
|
with open(exp_yaml) as fp:
|
||||||
config = yaml.safe_load(fp)
|
config = yaml.safe_load(fp)
|
||||||
config = update_gpu(config, xargs.gpu)
|
config = update_gpu(config, xargs.gpu)
|
||||||
# config = update_market(config, 'csi300')
|
|
||||||
|
|
||||||
qlib.init(**config.get("qlib_init"))
|
qlib.init(**config.get("qlib_init"))
|
||||||
dataset_config = config.get("task").get("dataset")
|
dataset_config = config.get("task").get("dataset")
|
||||||
|
@ -12,10 +12,18 @@ from qlib.log import get_module_logger
|
|||||||
|
|
||||||
def update_gpu(config, gpu):
|
def update_gpu(config, gpu):
|
||||||
config = config.copy()
|
config = config.copy()
|
||||||
if "task" in config and "moodel" in config["task"] and "GPU" in config["task"]["model"]:
|
if "task" in config and "model" in config["task"]:
|
||||||
|
if "GPU" in config["task"]["model"]:
|
||||||
config["task"]["model"]["GPU"] = gpu
|
config["task"]["model"]["GPU"] = gpu
|
||||||
elif "model" in config and "GPU" in config["model"]:
|
elif "kwargs" in config["task"]["model"] and "GPU" in config["task"]["model"]["kwargs"]:
|
||||||
|
config["task"]["model"]["kwargs"]["GPU"] = gpu
|
||||||
|
elif "model" in config:
|
||||||
|
if "GPU" in config["model"]:
|
||||||
config["model"]["GPU"] = gpu
|
config["model"]["GPU"] = gpu
|
||||||
|
elif "kwargs" in config["model"] and "GPU" in config["model"]["kwargs"]:
|
||||||
|
config["model"]["kwargs"]["GPU"] = gpu
|
||||||
|
elif "kwargs" in config and "GPU" in config["kwargs"]:
|
||||||
|
config["kwargs"]["GPU"] = gpu
|
||||||
elif "GPU" in config:
|
elif "GPU" in config:
|
||||||
config["GPU"] = gpu
|
config["GPU"] = gpu
|
||||||
return config
|
return config
|
||||||
|
Loading…
Reference in New Issue
Block a user