77 lines
2.1 KiB
Python
77 lines
2.1 KiB
Python
import os
|
|
import sys
|
|
import qlib
|
|
import pprint
|
|
import numpy as np
|
|
import pandas as pd
|
|
|
|
from pathlib import Path
|
|
import torch
|
|
|
|
__file__ = os.path.dirname(os.path.realpath("__file__"))
|
|
|
|
lib_dir = (Path(__file__).parent / ".." / "lib").resolve()
|
|
print("library path: {:}".format(lib_dir))
|
|
assert lib_dir.exists(), "{:} does not exist".format(lib_dir)
|
|
if str(lib_dir) not in sys.path:
|
|
sys.path.insert(0, str(lib_dir))
|
|
|
|
from trade_models import get_transformer
|
|
|
|
from qlib import config as qconfig
|
|
from qlib.utils import init_instance_by_config
|
|
from qlib.model.base import Model
|
|
from qlib.data.dataset import DatasetH
|
|
from qlib.data.dataset.handler import DataHandlerLP
|
|
|
|
qlib.init(provider_uri="~/.qlib/qlib_data/cn_data", region=qconfig.REG_CN)
|
|
|
|
dataset_config = {
|
|
"class": "DatasetH",
|
|
"module_path": "qlib.data.dataset",
|
|
"kwargs": {
|
|
"handler": {
|
|
"class": "Alpha360",
|
|
"module_path": "qlib.contrib.data.handler",
|
|
"kwargs": {
|
|
"start_time": "2008-01-01",
|
|
"end_time": "2020-08-01",
|
|
"fit_start_time": "2008-01-01",
|
|
"fit_end_time": "2014-12-31",
|
|
"instruments": "csi100",
|
|
},
|
|
},
|
|
"segments": {
|
|
"train": ("2008-01-01", "2014-12-31"),
|
|
"valid": ("2015-01-01", "2016-12-31"),
|
|
"test": ("2017-01-01", "2020-08-01"),
|
|
},
|
|
},
|
|
}
|
|
pprint.pprint(dataset_config)
|
|
dataset = init_instance_by_config(dataset_config)
|
|
|
|
df_train, df_valid, df_test = dataset.prepare(
|
|
["train", "valid", "test"],
|
|
col_set=["feature", "label"],
|
|
data_key=DataHandlerLP.DK_L,
|
|
)
|
|
model = get_transformer(None)
|
|
print(model)
|
|
|
|
features = torch.from_numpy(df_train["feature"].values).float()
|
|
labels = torch.from_numpy(df_train["label"].values).squeeze().float()
|
|
|
|
batch = list(range(2000))
|
|
predicts = model(features[batch])
|
|
mask = ~torch.isnan(labels[batch])
|
|
|
|
pred = predicts[mask]
|
|
label = labels[batch][mask]
|
|
|
|
loss = torch.nn.functional.mse_loss(pred, label)
|
|
|
|
from sklearn.metrics import mean_squared_error
|
|
|
|
mse_loss = mean_squared_error(pred.numpy(), label.numpy())
|