138 lines
		
	
	
		
			5.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			138 lines
		
	
	
		
			5.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | ### packages for visualization | ||
|  | from analysis.rdkit_functions import compute_molecular_metrics | ||
|  | from mini_moses.metrics.metrics import compute_intermediate_statistics | ||
|  | from metrics.property_metric import TaskModel | ||
|  | 
 | ||
|  | import torch | ||
|  | import torch.nn as nn | ||
|  | 
 | ||
|  | import os | ||
|  | import csv | ||
|  | import time | ||
|  | 
 | ||
|  | def result_to_csv(path, dict_data): | ||
|  |     file_exists = os.path.exists(path) | ||
|  |     log_name = dict_data.pop("log_name", None) | ||
|  |     if log_name is None: | ||
|  |         raise ValueError("The provided dictionary must contain a 'log_name' key.") | ||
|  |     field_names = ["log_name"] + list(dict_data.keys()) | ||
|  |     dict_data["log_name"] = log_name | ||
|  |     with open(path, "a", newline="") as file: | ||
|  |         writer = csv.DictWriter(file, fieldnames=field_names) | ||
|  |         if not file_exists: | ||
|  |             writer.writeheader() | ||
|  |         writer.writerow(dict_data) | ||
|  | 
 | ||
|  | 
 | ||
|  | class SamplingMolecularMetrics(nn.Module): | ||
|  |     def __init__( | ||
|  |         self, | ||
|  |         dataset_infos, | ||
|  |         train_smiles, | ||
|  |         reference_smiles, | ||
|  |         n_jobs=1, | ||
|  |         device="cpu", | ||
|  |         batch_size=512, | ||
|  |     ): | ||
|  |         super().__init__() | ||
|  |         self.task_name = dataset_infos.task | ||
|  |         self.dataset_infos = dataset_infos | ||
|  |         self.active_atoms = dataset_infos.active_atoms | ||
|  |         self.train_smiles = train_smiles | ||
|  | 
 | ||
|  |         if reference_smiles is not None: | ||
|  |             print( | ||
|  |                 f"--- Computing intermediate statistics for training for #{len(reference_smiles)} smiles ---" | ||
|  |             ) | ||
|  |             start_time = time.time() | ||
|  |             self.stat_ref = compute_intermediate_statistics( | ||
|  |                 reference_smiles, n_jobs=n_jobs, device=device, batch_size=batch_size | ||
|  |             ) | ||
|  |             end_time = time.time() | ||
|  |             elapsed_time = end_time - start_time | ||
|  |             print( | ||
|  |                 f"--- End computing intermediate statistics: using {elapsed_time:.2f}s ---" | ||
|  |             ) | ||
|  |         else: | ||
|  |             self.stat_ref = None | ||
|  |      | ||
|  |         self.comput_config = { | ||
|  |             "n_jobs": n_jobs, | ||
|  |             "device": device, | ||
|  |             "batch_size": batch_size, | ||
|  |         } | ||
|  | 
 | ||
|  |         self.task_evaluator = {'meta_taskname': dataset_infos.task, 'sas': None, 'scs': None} | ||
|  |         for cur_task in dataset_infos.task.split("-")[:]: | ||
|  |             # print('loading evaluator for task', cur_task) | ||
|  |             model_path = os.path.join( | ||
|  |                 dataset_infos.base_path, "data/evaluator", f"{cur_task}.joblib" | ||
|  |             ) | ||
|  |             os.makedirs(os.path.dirname(model_path), exist_ok=True) | ||
|  |             evaluator = TaskModel(model_path, cur_task) | ||
|  |             self.task_evaluator[cur_task] = evaluator | ||
|  | 
 | ||
|  |     def forward(self, molecules, targets, name, current_epoch, val_counter, test=False): | ||
|  |         if isinstance(targets, list): | ||
|  |             targets_cat = torch.cat(targets, dim=0) | ||
|  |             targets_np = targets_cat.detach().cpu().numpy() | ||
|  |         else: | ||
|  |             targets_np = targets.detach().cpu().numpy() | ||
|  | 
 | ||
|  |         unique_smiles, all_smiles, all_metrics, targets_log = compute_molecular_metrics( | ||
|  |             molecules, | ||
|  |             targets_np, | ||
|  |             self.train_smiles, | ||
|  |             self.stat_ref, | ||
|  |             self.dataset_infos, | ||
|  |             self.task_evaluator, | ||
|  |             self.comput_config, | ||
|  |         ) | ||
|  | 
 | ||
|  |         if test: | ||
|  |             file_name = "final_smiles.txt" | ||
|  |             with open(file_name, "w") as fp: | ||
|  |                 all_tasks_name = list(self.task_evaluator.keys()) | ||
|  |                 all_tasks_name = all_tasks_name.copy() | ||
|  |                 if 'meta_taskname' in all_tasks_name: | ||
|  |                     all_tasks_name.remove('meta_taskname') | ||
|  |                 if 'scs' in all_tasks_name: | ||
|  |                     all_tasks_name.remove('scs') | ||
|  | 
 | ||
|  |                 all_tasks_str = "smiles, " + ", ".join([f"input_{task}" for task in all_tasks_name] + [f"output_{task}" for task in all_tasks_name]) | ||
|  |                 fp.write(all_tasks_str + "\n") | ||
|  |                 for i, smiles in enumerate(all_smiles): | ||
|  |                     if targets_log is not None: | ||
|  |                         all_result_str = f"{smiles}, " + ", ".join([f"{targets_log['input_'+task][i]}" for task in all_tasks_name] + [f"{targets_log['output_'+task][i]}" for task in all_tasks_name]) | ||
|  |                         fp.write(all_result_str + "\n") | ||
|  |                     else: | ||
|  |                         fp.write("%s\n" % smiles) | ||
|  |                 print("All smiles saved") | ||
|  |         else: | ||
|  |             result_path = os.path.join(os.getcwd(), f"graphs/{name}") | ||
|  |             os.makedirs(result_path, exist_ok=True) | ||
|  |             text_path = os.path.join( | ||
|  |                 result_path, | ||
|  |                 f"valid_unique_molecules_e{current_epoch}_b{val_counter}.txt", | ||
|  |             ) | ||
|  |             textfile = open(text_path, "w") | ||
|  |             for smiles in unique_smiles: | ||
|  |                 textfile.write(smiles + "\n") | ||
|  |             textfile.close() | ||
|  | 
 | ||
|  |         all_logs = all_metrics | ||
|  |         if test: | ||
|  |             all_logs["log_name"] = "test" | ||
|  |         else: | ||
|  |             all_logs["log_name"] = ( | ||
|  |                 "epoch" + str(current_epoch) + "_batch" + str(val_counter) | ||
|  |             ) | ||
|  |          | ||
|  |         result_to_csv("output.csv", all_logs) | ||
|  |         return all_smiles | ||
|  | 
 | ||
|  |     def reset(self): | ||
|  |         pass | ||
|  | 
 | ||
|  | if __name__ == "__main__": | ||
|  |     pass |