Update baselines
This commit is contained in:
		| @@ -8,16 +8,12 @@ import os, math, random | ||||
| from collections import OrderedDict | ||||
| import numpy as np | ||||
| import pandas as pd | ||||
| from typing import Text, Union | ||||
| import copy | ||||
| from functools import partial | ||||
| from typing import Optional, Text | ||||
|  | ||||
| from qlib.utils import ( | ||||
|     unpack_archive_with_buffer, | ||||
|     save_multiple_parts_file, | ||||
|     get_or_create_path, | ||||
|     drop_nan_by_y_index, | ||||
| ) | ||||
| from qlib.utils import get_or_create_path | ||||
| from qlib.log import get_module_logger | ||||
|  | ||||
| import torch | ||||
| @@ -308,10 +304,10 @@ class QuantTransformer(Model): | ||||
|             torch.cuda.empty_cache() | ||||
|         self.fitted = True | ||||
|  | ||||
|     def predict(self, dataset, segment="test"): | ||||
|     def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): | ||||
|         if not self.fitted: | ||||
|             raise ValueError("The model is not fitted yet!") | ||||
|         x_test = dataset.prepare(segment, col_set="feature") | ||||
|         x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) | ||||
|         index = x_test.index | ||||
|  | ||||
|         self.model.eval() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user