update EdgeMetricsCE class

This commit is contained in:
mhz 2024-06-30 17:37:18 +02:00
parent d57575586d
commit 0fc6f6e686
2 changed files with 156 additions and 15 deletions

View File

@ -23,7 +23,104 @@ def result_to_csv(path, dict_data):
writer.writeheader()
writer.writerow(dict_data)
class SamplingGraphMetrics(nn.Module):
def __init__(
self,
dataset_infos,
train_graphs,
reference_graphs,
n_jobs=1,
device="cpu",
batch_size=512,
):
super().__init__()
self.task_name = dataset_infos.task
self.dataset_infos = dataset_infos
self.active_nodes = dataset_infos.active_nodes
self.train_graphs = train_graphs
self.stat_ref = None
self.compute_config = {
"n_jobs": n_jobs,
"device": device,
"batch_size": batch_size,
}
self.task_evaluator = {
'meta_taskname': dataset_infos.task,
'sas': None,
'scs': None
}
for cur_task in dataset_infos.task.split("-")[:]:
model_path = os.path.join(
dataset_infos.base_path, "data/evaluator", f"{cur_task}.joblib"
)
os.makedirs(os.path.dirname(model_path), exist_ok=True)
evaluator = TaskModel(model_path, cur_task)
self.task_evaluator[cur_task] = evaluator
def forward(self, graphs, targets, name, current_epoch, val_counter, test=False):
if isinstance(targets, list):
targets_cat = torch.cat(targets, dim=0)
targets_np = targets_cat.detach().cpu().numpy()
else:
targets_np = targets.detach().cpu().numpy()
unique_graphs, all_graphs, all_graphs, targets_log = compute_molecular_metrics(
graphs,
targets_np,
self.train_graphs,
self.stat_ref,
self.dataset_infos,
self.task_evaluator,
self.compute_config,
)
if test:
file_name = "final_graphs.txt"
with open(file_name, "w") as fp:
all_tasks_name = list(self.task_evaluator.keys())
all_tasks_name = all_tasks_name.copy()
if 'meta_taskname' in all_tasks_name:
all_tasks_name.remove('meta_taskname')
all_tasks_str = "graph, " + ", ".join([f"input_{task}" for task in all_tasks_name] + [f"output_{task}" for task in all_tasks_name])
fp.write(all_tasks_str + "\n")
for i, graph in enumerate(all_graphs):
if targets_log is not None:
all_result_str = f"{graph}, " + ", ".join([f"{targets_log['input_'+task][i]}" for task in all_tasks_name] + [f"{targets_log['output_'+task][i]}" for task in all_tasks_name])
fp.write(all_result_str + "\n")
else:
fp.write("%s\n" % graph)
print("All graphs saved")
else:
result_path = os.path.join(os.getcwd(), f"graphs/{name}")
os.makedirs(result_path, exist_ok=True)
text_path = os.path.join(
result_path,
f"valid_unique_graphs_e{current_epoch}_b{val_counter}.txt",
)
textfile = open(text_path, "w")
for graph in unique_graphs:
textfile.write(graph + "\n")
textfile.close()
all_logs = all_graphs
if test:
all_logs["log_name"] = "test"
else:
all_logs["log_name"] = (
"epoch" + str(current_epoch) + "_batch" + str(val_counter)
)
result_to_csv("output.csv", all_logs)
return all_graphs
def reset(self):
pass
class SamplingMolecularMetrics(nn.Module):
def __init__(
self,
@ -40,21 +137,21 @@ class SamplingMolecularMetrics(nn.Module):
self.active_atoms = dataset_infos.active_atoms
self.train_smiles = train_smiles
if reference_smiles is not None:
print(
f"--- Computing intermediate statistics for training for #{len(reference_smiles)} smiles ---"
)
start_time = time.time()
self.stat_ref = compute_intermediate_statistics(
reference_smiles, n_jobs=n_jobs, device=device, batch_size=batch_size
)
end_time = time.time()
elapsed_time = end_time - start_time
print(
f"--- End computing intermediate statistics: using {elapsed_time:.2f}s ---"
)
else:
self.stat_ref = None
# if reference_smiles is not None:
# print(
# f"--- Computing intermediate statistics for training for #{len(reference_smiles)} smiles ---"
# )
# start_time = time.time()
# self.stat_ref = compute_intermediate_statistics(
# reference_smiles, n_jobs=n_jobs, device=device, batch_size=batch_size
# )
# end_time = time.time()
# elapsed_time = end_time - start_time
# print(
# f"--- End computing intermediate statistics: using {elapsed_time:.2f}s ---"
# )
# else:
self.stat_ref = None
self.comput_config = {
"n_jobs": n_jobs,

View File

@ -77,6 +77,15 @@ class NodeMetricsCE(MetricCollection):
for i, node_type in enumerate(active_nodes) :
metrics_list.append(type(f'{node_type}_CE', (NodeCE,), {})(i))
super().__init__(metrics_list)
class EdgeMetricsCE(MetricCollection):
def __init__(self):
ce_no_bond = NoBondCE(0)
ce_SI = SingleCE(1)
ce_DO = DoubleCE(2)
ce_TR = TripleCE(3)
super().__init__([ce_no_bond, ce_SI, ce_DO, ce_TR])
class AtomMetricsCE(MetricCollection):
def __init__(self, active_atoms):
@ -101,6 +110,41 @@ class BondMetricsCE(MetricCollection):
class TrainGraphMetricsDiscrete(nn.Module):
def __init__(self, dataset_infos):
super().__init__()
active_nodes = dataset_infos.active_nodes
self.train_node_metrics = NodeMetricsCE(active_nodes=active_nodes)
self.train_edge_metrics = EdgeMetricsCE()
def forward(self, masked_pred_X, masked_pred_E, true_X, true_E, log: bool):
self.train_node_metrics(masked_pred_X, true_X)
self.train_edge_metrics(masked_pred_E, true_E)
if log:
to_log = {}
for key, val in self.train_node_metrics.compute().items():
to_log['train/' + key] = val.item()
for key, val in self.train_edge_metrics.compute().items():
to_log['train/' + key] = val.item()
def reset(self):
for metric in [self.train_node_metrics, self.train_edge_metrics]:
metric.reset()
def log_epoch_metrics(self, current_epoch, log=True):
epoch_node_metrics = self.train_node_metrics.compute()
epoch_edge_metrics = self.train_edge_metrics.compute()
to_log = {}
for key, val in epoch_node_metrics.items():
to_log['train_epoch/' + key] = val.item()
for key, val in epoch_edge_metrics.items():
to_log['train_epoch/' + key] = val.item()
for key, val in epoch_node_metrics.items():
epoch_node_metrics[key] = round(val.item(),4)
for key, val in epoch_edge_metrics.items():
epoch_edge_metrics[key] = round(val.item(),4)
if log:
print(f"Epoch {current_epoch}: {epoch_node_metrics} -- {epoch_edge_metrics}")
class TrainMolecularMetricsDiscrete(nn.Module):
def __init__(self, dataset_infos):