rewrite to graph metrics
This commit is contained in:
parent
a7f7010da7
commit
222470a43c
@ -35,7 +35,13 @@ class CEPerClass(Metric):
|
||||
|
||||
def compute(self):
|
||||
return self.total_ce / self.total_samples
|
||||
class NodeCE(CEPerClass):
|
||||
def __init__(self, i):
|
||||
super().__init__(i)
|
||||
|
||||
class EdgeCE(CEPerClass):
|
||||
def __init__(self, i):
|
||||
super().__init__(i)
|
||||
|
||||
class AtomCE(CEPerClass):
|
||||
def __init__(self, i):
|
||||
@ -65,6 +71,12 @@ class AromaticCE(CEPerClass):
|
||||
def __init__(self, i):
|
||||
super().__init__(i)
|
||||
|
||||
class NodeMetricsCE(MetricCollection):
|
||||
def __init__(self, active_nodes):
|
||||
metrics_list = []
|
||||
|
||||
for i, node_type in enumerate(active_nodes) :
|
||||
metrics_list.append(type(f'{node_type}_CE', (NodeCE,), {})(i))
|
||||
|
||||
class AtomMetricsCE(MetricCollection):
|
||||
def __init__(self, active_atoms):
|
||||
@ -84,7 +96,12 @@ class BondMetricsCE(MetricCollection):
|
||||
ce_TR = TripleCE(3)
|
||||
super().__init__([ce_no_bond, ce_SI, ce_DO, ce_TR])
|
||||
|
||||
#
|
||||
#
|
||||
|
||||
class TrainGraphMetricsDiscrete(nn.Module):
|
||||
def __init__(self, dataset_infos):
|
||||
super().__init__()
|
||||
|
||||
class TrainMolecularMetricsDiscrete(nn.Module):
|
||||
def __init__(self, dataset_infos):
|
||||
super().__init__()
|
||||
|
Loading…
Reference in New Issue
Block a user