Simplify Q-T
This commit is contained in:
parent
c0481a2357
commit
d17f73394f
@ -37,7 +37,9 @@ from qlib.data.dataset.handler import DataHandlerLP
|
|||||||
|
|
||||||
default_net_config = dict(d_feat=6, hidden_size=48, depth=5, pos_drop=0.1)
|
default_net_config = dict(d_feat=6, hidden_size=48, depth=5, pos_drop=0.1)
|
||||||
|
|
||||||
default_opt_config = dict(epochs=200, lr=0.001, batch_size=2000, early_stop=20, loss="mse", optimizer="adam")
|
default_opt_config = dict(
|
||||||
|
epochs=200, lr=0.001, batch_size=2000, early_stop=20, loss="mse", optimizer="adam", num_workers=4
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class QuantTransformer(Model):
|
class QuantTransformer(Model):
|
||||||
@ -159,6 +161,16 @@ class QuantTransformer(Model):
|
|||||||
torch.from_numpy(df_data["label"].values).squeeze().float(),
|
torch.from_numpy(df_data["label"].values).squeeze().float(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _prepare_loader(dataset, shuffle):
|
||||||
|
return th_data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=self.opt_config["batch_size"],
|
||||||
|
drop_last=False,
|
||||||
|
pin_memory=True,
|
||||||
|
num_workers=self.opt_config["num_workers"],
|
||||||
|
shuffle=shuffle,
|
||||||
|
)
|
||||||
|
|
||||||
df_train, df_valid, df_test = dataset.prepare(
|
df_train, df_valid, df_test = dataset.prepare(
|
||||||
["train", "valid", "test"],
|
["train", "valid", "test"],
|
||||||
col_set=["feature", "label"],
|
col_set=["feature", "label"],
|
||||||
@ -169,15 +181,10 @@ class QuantTransformer(Model):
|
|||||||
_prepare_dataset(df_valid),
|
_prepare_dataset(df_valid),
|
||||||
_prepare_dataset(df_test),
|
_prepare_dataset(df_test),
|
||||||
)
|
)
|
||||||
|
train_loader, valid_loader, test_loader = (
|
||||||
train_loader = th_data.DataLoader(
|
_prepare_loader(train_dataset, True),
|
||||||
train_dataset, batch_size=self.opt_config["batch_size"], shuffle=True, drop_last=False, pin_memory=True
|
_prepare_loader(valid_dataset, False),
|
||||||
)
|
_prepare_loader(test_dataset, False),
|
||||||
valid_loader = th_data.DataLoader(
|
|
||||||
valid_dataset, batch_size=self.opt_config["batch_size"], shuffle=False, drop_last=False, pin_memory=True
|
|
||||||
)
|
|
||||||
test_loader = th_data.DataLoader(
|
|
||||||
test_dataset, batch_size=self.opt_config["batch_size"], shuffle=False, drop_last=False, pin_memory=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if save_path == None:
|
if save_path == None:
|
||||||
|
Loading…
Reference in New Issue
Block a user