update EdgeMetricsCE class
This commit is contained in:
parent
d57575586d
commit
0fc6f6e686
@ -23,7 +23,104 @@ def result_to_csv(path, dict_data):
|
||||
writer.writeheader()
|
||||
writer.writerow(dict_data)
|
||||
|
||||
class SamplingGraphMetrics(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dataset_infos,
|
||||
train_graphs,
|
||||
reference_graphs,
|
||||
n_jobs=1,
|
||||
device="cpu",
|
||||
batch_size=512,
|
||||
):
|
||||
super().__init__()
|
||||
self.task_name = dataset_infos.task
|
||||
self.dataset_infos = dataset_infos
|
||||
self.active_nodes = dataset_infos.active_nodes
|
||||
self.train_graphs = train_graphs
|
||||
|
||||
self.stat_ref = None
|
||||
|
||||
self.compute_config = {
|
||||
"n_jobs": n_jobs,
|
||||
"device": device,
|
||||
"batch_size": batch_size,
|
||||
}
|
||||
|
||||
self.task_evaluator = {
|
||||
'meta_taskname': dataset_infos.task,
|
||||
'sas': None,
|
||||
'scs': None
|
||||
}
|
||||
|
||||
for cur_task in dataset_infos.task.split("-")[:]:
|
||||
model_path = os.path.join(
|
||||
dataset_infos.base_path, "data/evaluator", f"{cur_task}.joblib"
|
||||
)
|
||||
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
||||
evaluator = TaskModel(model_path, cur_task)
|
||||
self.task_evaluator[cur_task] = evaluator
|
||||
|
||||
def forward(self, graphs, targets, name, current_epoch, val_counter, test=False):
|
||||
if isinstance(targets, list):
|
||||
targets_cat = torch.cat(targets, dim=0)
|
||||
targets_np = targets_cat.detach().cpu().numpy()
|
||||
else:
|
||||
targets_np = targets.detach().cpu().numpy()
|
||||
|
||||
unique_graphs, all_graphs, all_graphs, targets_log = compute_molecular_metrics(
|
||||
graphs,
|
||||
targets_np,
|
||||
self.train_graphs,
|
||||
self.stat_ref,
|
||||
self.dataset_infos,
|
||||
self.task_evaluator,
|
||||
self.compute_config,
|
||||
)
|
||||
|
||||
if test:
|
||||
file_name = "final_graphs.txt"
|
||||
with open(file_name, "w") as fp:
|
||||
all_tasks_name = list(self.task_evaluator.keys())
|
||||
all_tasks_name = all_tasks_name.copy()
|
||||
if 'meta_taskname' in all_tasks_name:
|
||||
all_tasks_name.remove('meta_taskname')
|
||||
|
||||
all_tasks_str = "graph, " + ", ".join([f"input_{task}" for task in all_tasks_name] + [f"output_{task}" for task in all_tasks_name])
|
||||
fp.write(all_tasks_str + "\n")
|
||||
for i, graph in enumerate(all_graphs):
|
||||
if targets_log is not None:
|
||||
all_result_str = f"{graph}, " + ", ".join([f"{targets_log['input_'+task][i]}" for task in all_tasks_name] + [f"{targets_log['output_'+task][i]}" for task in all_tasks_name])
|
||||
fp.write(all_result_str + "\n")
|
||||
else:
|
||||
fp.write("%s\n" % graph)
|
||||
print("All graphs saved")
|
||||
else:
|
||||
result_path = os.path.join(os.getcwd(), f"graphs/{name}")
|
||||
os.makedirs(result_path, exist_ok=True)
|
||||
text_path = os.path.join(
|
||||
result_path,
|
||||
f"valid_unique_graphs_e{current_epoch}_b{val_counter}.txt",
|
||||
)
|
||||
textfile = open(text_path, "w")
|
||||
for graph in unique_graphs:
|
||||
textfile.write(graph + "\n")
|
||||
textfile.close()
|
||||
|
||||
all_logs = all_graphs
|
||||
if test:
|
||||
all_logs["log_name"] = "test"
|
||||
else:
|
||||
all_logs["log_name"] = (
|
||||
"epoch" + str(current_epoch) + "_batch" + str(val_counter)
|
||||
)
|
||||
|
||||
result_to_csv("output.csv", all_logs)
|
||||
return all_graphs
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
class SamplingMolecularMetrics(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -40,21 +137,21 @@ class SamplingMolecularMetrics(nn.Module):
|
||||
self.active_atoms = dataset_infos.active_atoms
|
||||
self.train_smiles = train_smiles
|
||||
|
||||
if reference_smiles is not None:
|
||||
print(
|
||||
f"--- Computing intermediate statistics for training for #{len(reference_smiles)} smiles ---"
|
||||
)
|
||||
start_time = time.time()
|
||||
self.stat_ref = compute_intermediate_statistics(
|
||||
reference_smiles, n_jobs=n_jobs, device=device, batch_size=batch_size
|
||||
)
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
print(
|
||||
f"--- End computing intermediate statistics: using {elapsed_time:.2f}s ---"
|
||||
)
|
||||
else:
|
||||
self.stat_ref = None
|
||||
# if reference_smiles is not None:
|
||||
# print(
|
||||
# f"--- Computing intermediate statistics for training for #{len(reference_smiles)} smiles ---"
|
||||
# )
|
||||
# start_time = time.time()
|
||||
# self.stat_ref = compute_intermediate_statistics(
|
||||
# reference_smiles, n_jobs=n_jobs, device=device, batch_size=batch_size
|
||||
# )
|
||||
# end_time = time.time()
|
||||
# elapsed_time = end_time - start_time
|
||||
# print(
|
||||
# f"--- End computing intermediate statistics: using {elapsed_time:.2f}s ---"
|
||||
# )
|
||||
# else:
|
||||
self.stat_ref = None
|
||||
|
||||
self.comput_config = {
|
||||
"n_jobs": n_jobs,
|
||||
|
@ -77,6 +77,15 @@ class NodeMetricsCE(MetricCollection):
|
||||
|
||||
for i, node_type in enumerate(active_nodes) :
|
||||
metrics_list.append(type(f'{node_type}_CE', (NodeCE,), {})(i))
|
||||
super().__init__(metrics_list)
|
||||
|
||||
class EdgeMetricsCE(MetricCollection):
|
||||
def __init__(self):
|
||||
ce_no_bond = NoBondCE(0)
|
||||
ce_SI = SingleCE(1)
|
||||
ce_DO = DoubleCE(2)
|
||||
ce_TR = TripleCE(3)
|
||||
super().__init__([ce_no_bond, ce_SI, ce_DO, ce_TR])
|
||||
|
||||
class AtomMetricsCE(MetricCollection):
|
||||
def __init__(self, active_atoms):
|
||||
@ -101,6 +110,41 @@ class BondMetricsCE(MetricCollection):
|
||||
class TrainGraphMetricsDiscrete(nn.Module):
|
||||
def __init__(self, dataset_infos):
|
||||
super().__init__()
|
||||
active_nodes = dataset_infos.active_nodes
|
||||
self.train_node_metrics = NodeMetricsCE(active_nodes=active_nodes)
|
||||
self.train_edge_metrics = EdgeMetricsCE()
|
||||
|
||||
def forward(self, masked_pred_X, masked_pred_E, true_X, true_E, log: bool):
|
||||
self.train_node_metrics(masked_pred_X, true_X)
|
||||
self.train_edge_metrics(masked_pred_E, true_E)
|
||||
if log:
|
||||
to_log = {}
|
||||
for key, val in self.train_node_metrics.compute().items():
|
||||
to_log['train/' + key] = val.item()
|
||||
for key, val in self.train_edge_metrics.compute().items():
|
||||
to_log['train/' + key] = val.item()
|
||||
|
||||
def reset(self):
|
||||
for metric in [self.train_node_metrics, self.train_edge_metrics]:
|
||||
metric.reset()
|
||||
|
||||
def log_epoch_metrics(self, current_epoch, log=True):
|
||||
epoch_node_metrics = self.train_node_metrics.compute()
|
||||
epoch_edge_metrics = self.train_edge_metrics.compute()
|
||||
|
||||
to_log = {}
|
||||
for key, val in epoch_node_metrics.items():
|
||||
to_log['train_epoch/' + key] = val.item()
|
||||
for key, val in epoch_edge_metrics.items():
|
||||
to_log['train_epoch/' + key] = val.item()
|
||||
|
||||
for key, val in epoch_node_metrics.items():
|
||||
epoch_node_metrics[key] = round(val.item(),4)
|
||||
for key, val in epoch_edge_metrics.items():
|
||||
epoch_edge_metrics[key] = round(val.item(),4)
|
||||
|
||||
if log:
|
||||
print(f"Epoch {current_epoch}: {epoch_node_metrics} -- {epoch_edge_metrics}")
|
||||
|
||||
class TrainMolecularMetricsDiscrete(nn.Module):
|
||||
def __init__(self, dataset_infos):
|
||||
|
Loading…
Reference in New Issue
Block a user