diff --git a/lib/trade_models/quant_transformer.py b/lib/trade_models/quant_transformer.py index 7e94028..c8fbc86 100755 --- a/lib/trade_models/quant_transformer.py +++ b/lib/trade_models/quant_transformer.py @@ -4,8 +4,7 @@ from __future__ import division from __future__ import print_function -import os -import math +import os, math, random from collections import OrderedDict import numpy as np import pandas as pd @@ -37,7 +36,7 @@ from qlib.data.dataset import DatasetH from qlib.data.dataset.handler import DataHandlerLP -default_net_config = dict(d_feat=6, hidden_size=48, depth=5, pos_drop=0.1) +default_net_config = dict(d_feat=6, embed_dim=48, depth=5, num_heads=4, mlp_ratio=4.0, qkv_bias=True, pos_drop=0.1) default_opt_config = dict( epochs=200, lr=0.001, batch_size=2000, early_stop=20, loss="mse", optimizer="adam", num_workers=4 @@ -50,7 +49,7 @@ class QuantTransformer(Model): def __init__(self, net_config=None, opt_config=None, metric="", GPU=0, seed=None, **kwargs): # Set logger. self.logger = get_module_logger("QuantTransformer") - self.logger.info("QuantTransformer pytorch version...") + self.logger.info("QuantTransformer PyTorch version...") # set hyper-parameters. self.net_config = net_config or default_net_config @@ -75,12 +74,16 @@ class QuantTransformer(Model): ) if self.seed is not None: + random.seed(self.seed) np.random.seed(self.seed) torch.manual_seed(self.seed) + if self.use_gpu: + torch.cuda.manual_seed(self.seed) + torch.cuda.manual_seed_all(self.seed) self.model = TransformerModel( d_feat=self.net_config["d_feat"], - embed_dim=self.net_config["hidden_size"], + embed_dim=self.net_config["embed_dim"], depth=self.net_config["depth"], pos_drop=self.net_config["pos_drop"], ) @@ -99,7 +102,7 @@ class QuantTransformer(Model): @property def use_gpu(self): - self.device == torch.device("cpu") + return self.device != torch.device("cpu") def loss_fn(self, pred, label): mask = ~torch.isnan(label) @@ -176,7 +179,7 @@ class QuantTransformer(Model): _prepare_loader(test_dataset, False), ) - save_path = get_or_create_path(save_path) + save_path = get_or_create_path(save_path, return_dir=True) self.logger.info("Fit procedure for [{:}] with save path={:}".format(self.__class__.__name__, save_path)) def _internal_test(ckp_epoch=None, results_dict=None): @@ -286,11 +289,11 @@ class QuantTransformer(Model): class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0): super(Attention, self).__init__() self.num_heads = num_heads head_dim = dim // num_heads - # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights self.scale = qk_scale or math.sqrt(head_dim) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) @@ -314,6 +317,7 @@ class Attention(nn.Module): class Block(nn.Module): + def __init__( self, dim, @@ -345,6 +349,7 @@ class Block(nn.Module): class SimpleEmbed(nn.Module): + def __init__(self, d_feat, embed_dim): super(SimpleEmbed, self).__init__() self.d_feat = d_feat @@ -361,18 +366,19 @@ class SimpleEmbed(nn.Module): class TransformerModel(nn.Module): def __init__( self, - d_feat: int, + d_feat: int = 6, embed_dim: int = 64, depth: int = 4, num_heads: int = 4, mlp_ratio: float = 4.0, qkv_bias: bool = True, qk_scale: Optional[float] = None, - pos_drop=0.0, - mlp_drop_rate=0.0, - attn_drop_rate=0.0, - drop_path_rate=0.0, - norm_layer=None, + pos_drop: float = 0.0, + mlp_drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + norm_layer: Optional[nn.Module] = None, + max_seq_len: int = 65, ): """ Args: @@ -397,7 +403,7 @@ class TransformerModel(nn.Module): self.input_embed = SimpleEmbed(d_feat, embed_dim=embed_dim) self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) - self.pos_embed = xlayers.PositionalEncoder(d_model=embed_dim, max_seq_len=65, dropout=pos_drop) + self.pos_embed = xlayers.PositionalEncoder(d_model=embed_dim, max_seq_len=max_seq_len, dropout=pos_drop) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.blocks = nn.ModuleList(