rewrite to graph metrics

This commit is contained in:
mhz 2024-06-27 20:44:04 +02:00
parent a7f7010da7
commit 222470a43c

View File

@ -35,7 +35,13 @@ class CEPerClass(Metric):
def compute(self):
return self.total_ce / self.total_samples
class NodeCE(CEPerClass):
def __init__(self, i):
super().__init__(i)
class EdgeCE(CEPerClass):
def __init__(self, i):
super().__init__(i)
class AtomCE(CEPerClass):
def __init__(self, i):
@ -65,6 +71,12 @@ class AromaticCE(CEPerClass):
def __init__(self, i):
super().__init__(i)
class NodeMetricsCE(MetricCollection):
def __init__(self, active_nodes):
metrics_list = []
for i, node_type in enumerate(active_nodes) :
metrics_list.append(type(f'{node_type}_CE', (NodeCE,), {})(i))
class AtomMetricsCE(MetricCollection):
def __init__(self, active_atoms):
@ -84,7 +96,12 @@ class BondMetricsCE(MetricCollection):
ce_TR = TripleCE(3)
super().__init__([ce_no_bond, ce_SI, ce_DO, ce_TR])
#
#
class TrainGraphMetricsDiscrete(nn.Module):
def __init__(self, dataset_infos):
super().__init__()
class TrainMolecularMetricsDiscrete(nn.Module):
def __init__(self, dataset_infos):
super().__init__()