2021-03-05 14:11:26 +01:00
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 #
#####################################################
2021-03-11 14:07:08 +01:00
# python exps/trading/baselines.py --alg MLP #
2021-03-07 04:27:05 +01:00
# python exps/trading/baselines.py --alg GRU #
# python exps/trading/baselines.py --alg LSTM #
# python exps/trading/baselines.py --alg ALSTM #
2021-03-17 04:32:47 +01:00
# python exps/trading/baselines.py --alg NAIVE-V1 #
# python exps/trading/baselines.py --alg NAIVE-V2 #
2021-03-11 14:07:08 +01:00
# #
2021-03-07 04:27:05 +01:00
# python exps/trading/baselines.py --alg SFM #
# python exps/trading/baselines.py --alg XGBoost #
# python exps/trading/baselines.py --alg LightGBM #
2021-03-07 10:52:30 +01:00
# python exps/trading/baselines.py --alg DoubleE #
2021-03-12 03:46:39 +01:00
# python exps/trading/baselines.py --alg TabNet #
2021-03-21 13:59:56 +01:00
# #
# python exps/trading/baselines.py --alg Transformer#
2021-03-30 11:17:05 +02:00
# python exps/trading/baselines.py --alg TSF
2021-03-30 11:02:41 +02:00
# python exps/trading/baselines.py --alg TSF-4x64-drop0_0
2021-03-05 14:11:26 +01:00
#####################################################
2021-03-07 04:27:05 +01:00
import sys
2021-03-25 14:51:45 +01:00
import copy
2021-03-07 04:27:05 +01:00
import argparse
2021-03-05 14:11:26 +01:00
from collections import OrderedDict
from pathlib import Path
from pprint import pprint
import ruamel . yaml as yaml
lib_dir = ( Path ( __file__ ) . parent / " .. " / " .. " / " lib " ) . resolve ( )
if str ( lib_dir ) not in sys . path :
sys . path . insert ( 0 , str ( lib_dir ) )
2021-03-30 11:17:05 +02:00
from config_utils import arg_str2bool
2021-03-07 04:09:47 +01:00
from procedures . q_exps import update_gpu
2021-03-11 04:09:55 +01:00
from procedures . q_exps import update_market
2021-03-07 04:09:47 +01:00
from procedures . q_exps import run_exp
2021-03-05 14:11:26 +01:00
import qlib
from qlib . utils import init_instance_by_config
from qlib . workflow import R
from qlib . utils import flatten_dict
2021-03-30 11:02:41 +02:00
def to_drop ( config , pos_drop , other_drop ) :
2021-03-25 14:51:45 +01:00
config = copy . deepcopy ( config )
net = config [ " task " ] [ " model " ] [ " kwargs " ] [ " net_config " ]
2021-03-30 11:02:41 +02:00
net [ " pos_drop " ] = pos_drop
net [ " other_drop " ] = other_drop
2021-03-25 14:51:45 +01:00
return config
def to_layer ( config , embed_dim , depth ) :
config = copy . deepcopy ( config )
net = config [ " task " ] [ " model " ] [ " kwargs " ] [ " net_config " ]
net [ " embed_dim " ] = embed_dim
net [ " num_heads " ] = [ 4 ] * depth
net [ " mlp_hidden_multipliers " ] = [ 4 ] * depth
return config
def extend_transformer_settings ( alg2configs , name ) :
config = copy . deepcopy ( alg2configs [ name ] )
2021-03-28 09:34:21 +02:00
for i in range ( 1 , 7 ) :
2021-03-30 11:02:41 +02:00
for j in ( 6 , 12 , 24 , 32 , 48 , 64 ) :
for k1 in ( 0 , 0.1 , 0.2 ) :
for k2 in ( 0 , 0.1 ) :
alg2configs [
name + " - {:} x {:} -drop {:} _ {:} " . format ( i , j , k1 , k2 )
] = to_layer ( to_drop ( config , k1 , k2 ) , j , i )
2021-03-25 14:51:45 +01:00
return alg2configs
2021-03-29 07:04:24 +02:00
def refresh_record ( alg2configs ) :
2021-03-28 12:57:20 +02:00
alg2configs = copy . deepcopy ( alg2configs )
for key , config in alg2configs . items ( ) :
xlist = config [ " task " ] [ " record " ]
new_list = [ ]
for x in xlist :
2021-03-29 07:04:24 +02:00
# remove PortAnaRecord and SignalMseRecord
if x [ " class " ] != " PortAnaRecord " and x [ " class " ] != " SignalMseRecord " :
2021-03-28 12:57:20 +02:00
new_list . append ( x )
2021-03-29 07:04:24 +02:00
## add MultiSegRecord
new_list . append (
{
" class " : " MultiSegRecord " ,
" module_path " : " qlib.contrib.workflow " ,
" generate_kwargs " : {
" segments " : { " train " : " train " , " valid " : " valid " , " test " : " test " } ,
" save " : True ,
} ,
}
)
2021-03-28 12:57:20 +02:00
config [ " task " ] [ " record " ] = new_list
return alg2configs
2021-03-05 14:11:26 +01:00
def retrieve_configs ( ) :
# https://github.com/microsoft/qlib/blob/main/examples/benchmarks/
config_dir = ( lib_dir / " .. " / " configs " / " qlib " ) . resolve ( )
# algorithm to file names
alg2names = OrderedDict ( )
alg2names [ " GRU " ] = " workflow_config_gru_Alpha360.yaml "
alg2names [ " LSTM " ] = " workflow_config_lstm_Alpha360.yaml "
2021-03-07 04:09:47 +01:00
alg2names [ " MLP " ] = " workflow_config_mlp_Alpha360.yaml "
2021-03-05 14:11:26 +01:00
# A dual-stage attention-based recurrent neural network for time series prediction, IJCAI-2017
alg2names [ " ALSTM " ] = " workflow_config_alstm_Alpha360.yaml "
# XGBoost: A Scalable Tree Boosting System, KDD-2016
alg2names [ " XGBoost " ] = " workflow_config_xgboost_Alpha360.yaml "
# LightGBM: A Highly Efficient Gradient Boosting Decision Tree, NeurIPS-2017
alg2names [ " LightGBM " ] = " workflow_config_lightgbm_Alpha360.yaml "
2021-03-07 04:09:47 +01:00
# State Frequency Memory (SFM): Stock Price Prediction via Discovering Multi-Frequency Trading Patterns, KDD-2017
alg2names [ " SFM " ] = " workflow_config_sfm_Alpha360.yaml "
2021-03-07 10:52:30 +01:00
# DoubleEnsemble: A New Ensemble Method Based on Sample Reweighting and Feature Selection for Financial Data Analysis, https://arxiv.org/pdf/2010.01265.pdf
alg2names [ " DoubleE " ] = " workflow_config_doubleensemble_Alpha360.yaml "
2021-03-11 04:09:55 +01:00
alg2names [ " TabNet " ] = " workflow_config_TabNet_Alpha360.yaml "
2021-03-17 04:32:47 +01:00
alg2names [ " NAIVE-V1 " ] = " workflow_config_naive_v1_Alpha360.yaml "
alg2names [ " NAIVE-V2 " ] = " workflow_config_naive_v2_Alpha360.yaml "
2021-03-17 04:51:48 +01:00
alg2names [ " Transformer " ] = " workflow_config_transformer_Alpha360.yaml "
2021-03-25 14:51:45 +01:00
alg2names [ " TSF " ] = " workflow_config_transformer_basic_Alpha360.yaml "
2021-03-05 14:11:26 +01:00
# find the yaml paths
2021-03-25 14:51:45 +01:00
alg2configs = OrderedDict ( )
2021-03-07 04:09:47 +01:00
print ( " Start retrieving the algorithm configurations " )
2021-03-05 14:11:26 +01:00
for idx , ( alg , name ) in enumerate ( alg2names . items ( ) ) :
path = config_dir / name
assert path . exists ( ) , " {:} does not exist. " . format ( path )
2021-03-25 14:51:45 +01:00
with open ( path ) as fp :
alg2configs [ alg ] = yaml . safe_load ( fp )
2021-03-18 09:02:55 +01:00
print (
" The {:02d} / {:02d} -th baseline algorithm is {:9s} ( {:} ). " . format (
2021-03-25 14:51:45 +01:00
idx , len ( alg2configs ) , alg , path
2021-03-18 09:02:55 +01:00
)
)
2021-03-26 04:41:22 +01:00
alg2configs = extend_transformer_settings ( alg2configs , " TSF " )
2021-03-29 07:04:24 +02:00
alg2configs = refresh_record ( alg2configs )
2021-03-28 12:57:20 +02:00
print (
" There are {:} algorithms : {:} " . format (
len ( alg2configs ) , list ( alg2configs . keys ( ) )
)
)
2021-03-25 14:51:45 +01:00
return alg2configs
2021-03-05 14:11:26 +01:00
2021-03-25 14:51:45 +01:00
def main ( xargs , config ) :
2021-03-05 14:11:26 +01:00
2021-03-17 10:25:58 +01:00
pprint ( " Run {:} " . format ( xargs . alg ) )
2021-03-11 04:09:55 +01:00
config = update_market ( config , xargs . market )
2021-03-05 14:11:26 +01:00
config = update_gpu ( config , xargs . gpu )
qlib . init ( * * config . get ( " qlib_init " ) )
dataset_config = config . get ( " task " ) . get ( " dataset " )
dataset = init_instance_by_config ( dataset_config )
2021-03-06 15:13:22 +01:00
pprint ( " args: {:} " . format ( xargs ) )
2021-03-05 14:11:26 +01:00
pprint ( dataset_config )
pprint ( dataset )
for irun in range ( xargs . times ) :
2021-03-06 15:13:22 +01:00
run_exp (
2021-03-17 10:25:58 +01:00
config . get ( " task " ) ,
dataset ,
xargs . alg ,
" recorder- {:02d} - {:02d} " . format ( irun , xargs . times ) ,
" {:} - {:} " . format ( xargs . save_dir , xargs . market ) ,
2021-03-06 15:13:22 +01:00
)
2021-03-05 14:11:26 +01:00
if __name__ == " __main__ " :
2021-03-25 14:51:45 +01:00
alg2configs = retrieve_configs ( )
2021-03-05 14:11:26 +01:00
parser = argparse . ArgumentParser ( " Baselines " )
2021-03-17 10:25:58 +01:00
parser . add_argument (
2021-03-18 09:02:55 +01:00
" --save_dir " ,
type = str ,
default = " ./outputs/qlib-baselines " ,
help = " The checkpoint directory. " ,
)
parser . add_argument (
" --market " ,
type = str ,
default = " all " ,
choices = [ " csi100 " , " csi300 " , " all " ] ,
help = " The market indicator. " ,
2021-03-17 10:25:58 +01:00
)
2021-03-25 14:51:45 +01:00
parser . add_argument ( " --times " , type = int , default = 5 , help = " The repeated run times. " )
2021-03-30 11:17:05 +02:00
parser . add_argument (
" --shared_dataset " ,
type = arg_str2bool ,
default = False ,
help = " Whether to share the dataset for all algorithms? " ,
)
2021-03-18 09:02:55 +01:00
parser . add_argument (
" --gpu " , type = int , default = 0 , help = " The GPU ID used for train / test. "
)
parser . add_argument (
" --alg " ,
type = str ,
2021-03-26 04:41:22 +01:00
choices = list ( alg2configs . keys ( ) ) ,
2021-03-30 11:17:05 +02:00
nargs = " + " ,
2021-03-18 09:02:55 +01:00
required = True ,
2021-03-30 11:17:05 +02:00
help = " The algorithm name(s). " ,
2021-03-18 09:02:55 +01:00
)
2021-03-05 14:11:26 +01:00
args = parser . parse_args ( )
2021-03-30 11:17:05 +02:00
if len ( args . alg ) == 1 :
main ( args , alg2configs [ args . alg [ 0 ] ] )
else :
print ( " - " )