Simplify Q-T
This commit is contained in:
		| @@ -37,7 +37,9 @@ from qlib.data.dataset.handler import DataHandlerLP | ||||
|  | ||||
| default_net_config = dict(d_feat=6, hidden_size=48, depth=5, pos_drop=0.1) | ||||
|  | ||||
| default_opt_config = dict(epochs=200, lr=0.001, batch_size=2000, early_stop=20, loss="mse", optimizer="adam") | ||||
| default_opt_config = dict( | ||||
|     epochs=200, lr=0.001, batch_size=2000, early_stop=20, loss="mse", optimizer="adam", num_workers=4 | ||||
| ) | ||||
|  | ||||
|  | ||||
| class QuantTransformer(Model): | ||||
| @@ -159,6 +161,16 @@ class QuantTransformer(Model): | ||||
|                 torch.from_numpy(df_data["label"].values).squeeze().float(), | ||||
|             ) | ||||
|  | ||||
|         def _prepare_loader(dataset, shuffle): | ||||
|             return th_data.DataLoader( | ||||
|                 dataset, | ||||
|                 batch_size=self.opt_config["batch_size"], | ||||
|                 drop_last=False, | ||||
|                 pin_memory=True, | ||||
|                 num_workers=self.opt_config["num_workers"], | ||||
|                 shuffle=shuffle, | ||||
|             ) | ||||
|  | ||||
|         df_train, df_valid, df_test = dataset.prepare( | ||||
|             ["train", "valid", "test"], | ||||
|             col_set=["feature", "label"], | ||||
| @@ -169,15 +181,10 @@ class QuantTransformer(Model): | ||||
|             _prepare_dataset(df_valid), | ||||
|             _prepare_dataset(df_test), | ||||
|         ) | ||||
|  | ||||
|         train_loader = th_data.DataLoader( | ||||
|             train_dataset, batch_size=self.opt_config["batch_size"], shuffle=True, drop_last=False, pin_memory=True | ||||
|         ) | ||||
|         valid_loader = th_data.DataLoader( | ||||
|             valid_dataset, batch_size=self.opt_config["batch_size"], shuffle=False, drop_last=False, pin_memory=True | ||||
|         ) | ||||
|         test_loader = th_data.DataLoader( | ||||
|             test_dataset, batch_size=self.opt_config["batch_size"], shuffle=False, drop_last=False, pin_memory=True | ||||
|         train_loader, valid_loader, test_loader = ( | ||||
|             _prepare_loader(train_dataset, True), | ||||
|             _prepare_loader(valid_dataset, False), | ||||
|             _prepare_loader(test_dataset, False), | ||||
|         ) | ||||
|  | ||||
|         if save_path == None: | ||||
|   | ||||
		Reference in New Issue
	
	Block a user