Simplify Q-T
This commit is contained in:
parent
c0481a2357
commit
d17f73394f
@ -37,7 +37,9 @@ from qlib.data.dataset.handler import DataHandlerLP
|
||||
|
||||
default_net_config = dict(d_feat=6, hidden_size=48, depth=5, pos_drop=0.1)
|
||||
|
||||
default_opt_config = dict(epochs=200, lr=0.001, batch_size=2000, early_stop=20, loss="mse", optimizer="adam")
|
||||
default_opt_config = dict(
|
||||
epochs=200, lr=0.001, batch_size=2000, early_stop=20, loss="mse", optimizer="adam", num_workers=4
|
||||
)
|
||||
|
||||
|
||||
class QuantTransformer(Model):
|
||||
@ -159,6 +161,16 @@ class QuantTransformer(Model):
|
||||
torch.from_numpy(df_data["label"].values).squeeze().float(),
|
||||
)
|
||||
|
||||
def _prepare_loader(dataset, shuffle):
|
||||
return th_data.DataLoader(
|
||||
dataset,
|
||||
batch_size=self.opt_config["batch_size"],
|
||||
drop_last=False,
|
||||
pin_memory=True,
|
||||
num_workers=self.opt_config["num_workers"],
|
||||
shuffle=shuffle,
|
||||
)
|
||||
|
||||
df_train, df_valid, df_test = dataset.prepare(
|
||||
["train", "valid", "test"],
|
||||
col_set=["feature", "label"],
|
||||
@ -169,15 +181,10 @@ class QuantTransformer(Model):
|
||||
_prepare_dataset(df_valid),
|
||||
_prepare_dataset(df_test),
|
||||
)
|
||||
|
||||
train_loader = th_data.DataLoader(
|
||||
train_dataset, batch_size=self.opt_config["batch_size"], shuffle=True, drop_last=False, pin_memory=True
|
||||
)
|
||||
valid_loader = th_data.DataLoader(
|
||||
valid_dataset, batch_size=self.opt_config["batch_size"], shuffle=False, drop_last=False, pin_memory=True
|
||||
)
|
||||
test_loader = th_data.DataLoader(
|
||||
test_dataset, batch_size=self.opt_config["batch_size"], shuffle=False, drop_last=False, pin_memory=True
|
||||
train_loader, valid_loader, test_loader = (
|
||||
_prepare_loader(train_dataset, True),
|
||||
_prepare_loader(valid_dataset, False),
|
||||
_prepare_loader(test_dataset, False),
|
||||
)
|
||||
|
||||
if save_path == None:
|
||||
|
Loading…
Reference in New Issue
Block a user