diff --git a/lib/trade_models/quant_transformer.py b/lib/trade_models/quant_transformer.py index 9962896..55a92ec 100755 --- a/lib/trade_models/quant_transformer.py +++ b/lib/trade_models/quant_transformer.py @@ -37,7 +37,9 @@ from qlib.data.dataset.handler import DataHandlerLP default_net_config = dict(d_feat=6, hidden_size=48, depth=5, pos_drop=0.1) -default_opt_config = dict(epochs=200, lr=0.001, batch_size=2000, early_stop=20, loss="mse", optimizer="adam") +default_opt_config = dict( + epochs=200, lr=0.001, batch_size=2000, early_stop=20, loss="mse", optimizer="adam", num_workers=4 +) class QuantTransformer(Model): @@ -159,6 +161,16 @@ class QuantTransformer(Model): torch.from_numpy(df_data["label"].values).squeeze().float(), ) + def _prepare_loader(dataset, shuffle): + return th_data.DataLoader( + dataset, + batch_size=self.opt_config["batch_size"], + drop_last=False, + pin_memory=True, + num_workers=self.opt_config["num_workers"], + shuffle=shuffle, + ) + df_train, df_valid, df_test = dataset.prepare( ["train", "valid", "test"], col_set=["feature", "label"], @@ -169,15 +181,10 @@ class QuantTransformer(Model): _prepare_dataset(df_valid), _prepare_dataset(df_test), ) - - train_loader = th_data.DataLoader( - train_dataset, batch_size=self.opt_config["batch_size"], shuffle=True, drop_last=False, pin_memory=True - ) - valid_loader = th_data.DataLoader( - valid_dataset, batch_size=self.opt_config["batch_size"], shuffle=False, drop_last=False, pin_memory=True - ) - test_loader = th_data.DataLoader( - test_dataset, batch_size=self.opt_config["batch_size"], shuffle=False, drop_last=False, pin_memory=True + train_loader, valid_loader, test_loader = ( + _prepare_loader(train_dataset, True), + _prepare_loader(valid_dataset, False), + _prepare_loader(test_dataset, False), ) if save_path == None: