Simplify Q-T

This commit is contained in:
D-X-Y 2021-03-06 21:44:59 -08:00
parent c0481a2357
commit d17f73394f

View File

@ -37,7 +37,9 @@ from qlib.data.dataset.handler import DataHandlerLP
default_net_config = dict(d_feat=6, hidden_size=48, depth=5, pos_drop=0.1)
default_opt_config = dict(epochs=200, lr=0.001, batch_size=2000, early_stop=20, loss="mse", optimizer="adam")
default_opt_config = dict(
epochs=200, lr=0.001, batch_size=2000, early_stop=20, loss="mse", optimizer="adam", num_workers=4
)
class QuantTransformer(Model):
@ -159,6 +161,16 @@ class QuantTransformer(Model):
torch.from_numpy(df_data["label"].values).squeeze().float(),
)
def _prepare_loader(dataset, shuffle):
return th_data.DataLoader(
dataset,
batch_size=self.opt_config["batch_size"],
drop_last=False,
pin_memory=True,
num_workers=self.opt_config["num_workers"],
shuffle=shuffle,
)
df_train, df_valid, df_test = dataset.prepare(
["train", "valid", "test"],
col_set=["feature", "label"],
@ -169,15 +181,10 @@ class QuantTransformer(Model):
_prepare_dataset(df_valid),
_prepare_dataset(df_test),
)
train_loader = th_data.DataLoader(
train_dataset, batch_size=self.opt_config["batch_size"], shuffle=True, drop_last=False, pin_memory=True
)
valid_loader = th_data.DataLoader(
valid_dataset, batch_size=self.opt_config["batch_size"], shuffle=False, drop_last=False, pin_memory=True
)
test_loader = th_data.DataLoader(
test_dataset, batch_size=self.opt_config["batch_size"], shuffle=False, drop_last=False, pin_memory=True
train_loader, valid_loader, test_loader = (
_prepare_loader(train_dataset, True),
_prepare_loader(valid_dataset, False),
_prepare_loader(test_dataset, False),
)
if save_path == None: