Simplify Q-T

This commit is contained in:
D-X-Y 2021-03-06 21:44:59 -08:00
parent c0481a2357
commit d17f73394f

View File

@ -37,7 +37,9 @@ from qlib.data.dataset.handler import DataHandlerLP
default_net_config = dict(d_feat=6, hidden_size=48, depth=5, pos_drop=0.1) default_net_config = dict(d_feat=6, hidden_size=48, depth=5, pos_drop=0.1)
default_opt_config = dict(epochs=200, lr=0.001, batch_size=2000, early_stop=20, loss="mse", optimizer="adam") default_opt_config = dict(
epochs=200, lr=0.001, batch_size=2000, early_stop=20, loss="mse", optimizer="adam", num_workers=4
)
class QuantTransformer(Model): class QuantTransformer(Model):
@ -159,6 +161,16 @@ class QuantTransformer(Model):
torch.from_numpy(df_data["label"].values).squeeze().float(), torch.from_numpy(df_data["label"].values).squeeze().float(),
) )
def _prepare_loader(dataset, shuffle):
return th_data.DataLoader(
dataset,
batch_size=self.opt_config["batch_size"],
drop_last=False,
pin_memory=True,
num_workers=self.opt_config["num_workers"],
shuffle=shuffle,
)
df_train, df_valid, df_test = dataset.prepare( df_train, df_valid, df_test = dataset.prepare(
["train", "valid", "test"], ["train", "valid", "test"],
col_set=["feature", "label"], col_set=["feature", "label"],
@ -169,15 +181,10 @@ class QuantTransformer(Model):
_prepare_dataset(df_valid), _prepare_dataset(df_valid),
_prepare_dataset(df_test), _prepare_dataset(df_test),
) )
train_loader, valid_loader, test_loader = (
train_loader = th_data.DataLoader( _prepare_loader(train_dataset, True),
train_dataset, batch_size=self.opt_config["batch_size"], shuffle=True, drop_last=False, pin_memory=True _prepare_loader(valid_dataset, False),
) _prepare_loader(test_dataset, False),
valid_loader = th_data.DataLoader(
valid_dataset, batch_size=self.opt_config["batch_size"], shuffle=False, drop_last=False, pin_memory=True
)
test_loader = th_data.DataLoader(
test_dataset, batch_size=self.opt_config["batch_size"], shuffle=False, drop_last=False, pin_memory=True
) )
if save_path == None: if save_path == None: