Fix bugs
This commit is contained in:
		| @@ -24,38 +24,38 @@ from qlib.model.base import Model | ||||
| from qlib.data.dataset import DatasetH | ||||
| from qlib.data.dataset.handler import DataHandlerLP | ||||
|  | ||||
| qlib.init(provider_uri='~/.qlib/qlib_data/cn_data', region=qconfig.REG_CN) | ||||
| qlib.init(provider_uri="~/.qlib/qlib_data/cn_data", region=qconfig.REG_CN) | ||||
|  | ||||
| dataset_config = { | ||||
|             "class": "DatasetH", | ||||
|             "module_path": "qlib.data.dataset", | ||||
|     "class": "DatasetH", | ||||
|     "module_path": "qlib.data.dataset", | ||||
|     "kwargs": { | ||||
|         "handler": { | ||||
|             "class": "Alpha360", | ||||
|             "module_path": "qlib.contrib.data.handler", | ||||
|             "kwargs": { | ||||
|                 "handler": { | ||||
|                     "class": "Alpha360", | ||||
|                     "module_path": "qlib.contrib.data.handler", | ||||
|                     "kwargs": { | ||||
|                         "start_time": "2008-01-01", | ||||
|                         "end_time": "2020-08-01", | ||||
|                         "fit_start_time": "2008-01-01", | ||||
|                         "fit_end_time": "2014-12-31", | ||||
|                         "instruments": "csi100", | ||||
|                     }, | ||||
|                 }, | ||||
|                 "segments": { | ||||
|                     "train": ("2008-01-01", "2014-12-31"), | ||||
|                     "valid": ("2015-01-01", "2016-12-31"), | ||||
|                     "test": ("2017-01-01", "2020-08-01"), | ||||
|                 }, | ||||
|                 "start_time": "2008-01-01", | ||||
|                 "end_time": "2020-08-01", | ||||
|                 "fit_start_time": "2008-01-01", | ||||
|                 "fit_end_time": "2014-12-31", | ||||
|                 "instruments": "csi100", | ||||
|             }, | ||||
|         } | ||||
|         }, | ||||
|         "segments": { | ||||
|             "train": ("2008-01-01", "2014-12-31"), | ||||
|             "valid": ("2015-01-01", "2016-12-31"), | ||||
|             "test": ("2017-01-01", "2020-08-01"), | ||||
|         }, | ||||
|     }, | ||||
| } | ||||
| pprint.pprint(dataset_config) | ||||
| dataset = init_instance_by_config(dataset_config) | ||||
|  | ||||
| df_train, df_valid, df_test = dataset.prepare( | ||||
|             ["train", "valid", "test"], | ||||
|             col_set=["feature", "label"], | ||||
|             data_key=DataHandlerLP.DK_L, | ||||
|         ) | ||||
|     ["train", "valid", "test"], | ||||
|     col_set=["feature", "label"], | ||||
|     data_key=DataHandlerLP.DK_L, | ||||
| ) | ||||
| model = get_transformer(None) | ||||
| print(model) | ||||
|  | ||||
| @@ -72,4 +72,5 @@ label = labels[batch][mask] | ||||
| loss = torch.nn.functional.mse_loss(pred, label) | ||||
|  | ||||
| from sklearn.metrics import mean_squared_error | ||||
|  | ||||
| mse_loss = mean_squared_error(pred.numpy(), label.numpy()) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user