102 lines
2.9 KiB
Python
102 lines
2.9 KiB
Python
|
import numpy as np
|
||
|
from typing import List, Text
|
||
|
from collections import defaultdict, OrderedDict
|
||
|
|
||
|
|
||
|
class QResult:
|
||
|
"""A class to maintain the results of a qlib experiment."""
|
||
|
|
||
|
def __init__(self, name):
|
||
|
self._result = defaultdict(list)
|
||
|
self._name = name
|
||
|
self._recorder_paths = []
|
||
|
|
||
|
def append(self, key, value):
|
||
|
self._result[key].append(value)
|
||
|
|
||
|
def append_path(self, xpath):
|
||
|
self._recorder_paths.append(xpath)
|
||
|
|
||
|
@property
|
||
|
def name(self):
|
||
|
return self._name
|
||
|
|
||
|
@property
|
||
|
def paths(self):
|
||
|
return self._recorder_paths
|
||
|
|
||
|
@property
|
||
|
def result(self):
|
||
|
return self._result
|
||
|
|
||
|
@property
|
||
|
def keys(self):
|
||
|
return list(self._result.keys())
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self._result)
|
||
|
|
||
|
def __repr__(self):
|
||
|
return "{name}({xname}, {num} metrics)".format(
|
||
|
name=self.__class__.__name__, xname=self.name, num=len(self.result)
|
||
|
)
|
||
|
|
||
|
def __getitem__(self, key):
|
||
|
if key not in self._result:
|
||
|
raise ValueError(
|
||
|
"Invalid key {:}, please use one of {:}".format(key, self.keys)
|
||
|
)
|
||
|
values = self._result[key]
|
||
|
return float(np.mean(values))
|
||
|
|
||
|
def update(self, metrics, filter_keys=None):
|
||
|
for key, value in metrics.items():
|
||
|
if filter_keys is not None and key in filter_keys:
|
||
|
key = filter_keys[key]
|
||
|
elif filter_keys is not None:
|
||
|
continue
|
||
|
self.append(key, value)
|
||
|
|
||
|
@staticmethod
|
||
|
def full_str(xstr, space):
|
||
|
xformat = "{:" + str(space) + "s}"
|
||
|
return xformat.format(str(xstr))
|
||
|
|
||
|
@staticmethod
|
||
|
def merge_dict(dict_list):
|
||
|
new_dict = dict()
|
||
|
for xkey in dict_list[0].keys():
|
||
|
values = [x for xdict in dict_list for x in xdict[xkey]]
|
||
|
new_dict[xkey] = values
|
||
|
return new_dict
|
||
|
|
||
|
def info(
|
||
|
self,
|
||
|
keys: List[Text],
|
||
|
separate: Text = "& ",
|
||
|
space: int = 20,
|
||
|
verbose: bool = True,
|
||
|
):
|
||
|
avaliable_keys = []
|
||
|
for key in keys:
|
||
|
if key not in self.result:
|
||
|
print("There are invalid key [{:}].".format(key))
|
||
|
else:
|
||
|
avaliable_keys.append(key)
|
||
|
head_str = separate.join([self.full_str(x, space) for x in avaliable_keys])
|
||
|
values = []
|
||
|
for key in avaliable_keys:
|
||
|
if "IR" in key:
|
||
|
current_values = [x * 100 for x in self._result[key]]
|
||
|
else:
|
||
|
current_values = self._result[key]
|
||
|
mean = np.mean(current_values)
|
||
|
std = np.std(current_values)
|
||
|
# values.append("{:.4f} $\pm$ {:.4f}".format(mean, std))
|
||
|
values.append("{:.2f} $\pm$ {:.2f}".format(mean, std))
|
||
|
value_str = separate.join([self.full_str(x, space) for x in values])
|
||
|
if verbose:
|
||
|
print(head_str)
|
||
|
print(value_str)
|
||
|
return head_str, value_str
|