update analysis code from diffsize branch
This commit is contained in:
		| @@ -37,6 +37,144 @@ def selectivity_evaluation(gas1, gas2, prop_name): | ||||
|     y = np.log10(np.array(gas1) / np.array(gas2)) | ||||
|     upper = (y - (a_dict[prop_name] * x + b_dict[prop_name])) > 0 | ||||
|     return upper | ||||
| class BasicGraphMetrics(object): | ||||
|     def __init__(self, graph_decoder, train_graphs=None, stat_ref=None, task_evaluator=None, n_jobs=8, device='cpu', batch_size=512): | ||||
|         self.dataset_graphs_list = train_graphs | ||||
|         self.graph_decoder = graph_decoder | ||||
|         self.n_jobs = n_jobs | ||||
|         self.device = device | ||||
|         self.batch_size = batch_size | ||||
|         self.stat_ref = stat_ref | ||||
|         self.task_evaluator = task_evaluator | ||||
|     def compute_relaxed_validity(self, generated, ensure_connected): | ||||
|         valid = [] | ||||
|         num_components = [] | ||||
|         all_graphs = [] | ||||
|         valid_graphs = [] | ||||
|         covered_nodes = set() | ||||
|         direct_valid_count = 0 | ||||
|         print(f"generated number: {len(generated)}") | ||||
|         for graph in generated: | ||||
|             node_types, edge_types = graph | ||||
|             direct_valid_flag = True | ||||
|             direct_valid_count += 1 | ||||
|             valid.append(graph) | ||||
|             num_components.append(1) | ||||
|             covered_nodes.update(set(node_types)) | ||||
|             all_graphs.append(graph) | ||||
|         return  valid, len(valid) / len(generated), direct_valid_count / len(generated), np.array(num_components), all_graphs, covered_nodes | ||||
|          | ||||
|     def evaluate(self, generated, targets, ensure_connected, active_atoms=None): | ||||
|         valid, validity, nc_validity, num_components, all_graphs, covered_nodes = self.compute_relaxed_validity(generated, ensure_connected=ensure_connected) | ||||
|         nc_mu = num_components.mean() if len(num_components) > 0 else 0 | ||||
|         nc_min = num_components.min() if len(num_components) > 0 else 0 | ||||
|         nc_max = num_components.max() if len(num_components) > 0 else 0 | ||||
|  | ||||
|         len_active = len(active_atoms) if active_atoms is not None else 1 | ||||
|          | ||||
|         cover_str = f"Cover {len(covered_nodes)} ({len(covered_nodes)/len_active * 100:.2f}%) atoms: {covered_nodes}" | ||||
|         print(f"Validity over {len(generated)} graphs: {validity * 100 :.2f}% (w/o correction: {nc_validity * 100 :.2f}%), cover {len(covered_nodes)} ({len(covered_nodes)/len_active * 100:.2f}%) nodes: {covered_nodes}") | ||||
|         print(f"Number of connected components of {len(generated)} graphs: min:{nc_min:.2f} mean:{nc_mu:.2f} max:{nc_max:.2f}") | ||||
|  | ||||
|         if validity > 0:  | ||||
|             dist_metrics = {'cover_str': cover_str ,'validity': validity, 'validity_nc': nc_validity} | ||||
|             unique = valid | ||||
|             close_pool = False | ||||
|             if self.n_jobs != 1: | ||||
|                 pool = Pool(self.n_jobs) | ||||
|                 close_pool = True | ||||
|             else: | ||||
|                 pool = 1 | ||||
|             # valid_graphs = mapper(pool)(get_mol, valid)  | ||||
|             valid_graphs = valid  | ||||
|             """ | ||||
|             Computes internal diversity as: | ||||
|             1/|A|^2 sum_{x, y in AxA} (1-tanimoto(x, y)) | ||||
|             """ | ||||
|             # dist_metrics['interval_diversity'] = internal_diversity(valid_graphs, pool, device=self.device) | ||||
|              | ||||
|             start_time = time.time() | ||||
|             if self.stat_ref is not None: | ||||
|                 kwargs = {'n_jobs': pool, 'device': self.device, 'batch_size': self.batch_size} | ||||
|                 kwargs_fcd = {'n_jobs': self.n_jobs, 'device': self.device, 'batch_size': self.batch_size} | ||||
|                 try: | ||||
|                     dist_metrics['sim/Frag'] = FragMetric(**kwargs)(gen=valid_graphs, pref=self.stat_ref['Frag']) | ||||
|                 except: | ||||
|                     print('error: ', 'pool', pool) | ||||
|                     print('valid_graphs: ', valid_graphs) | ||||
|                 dist_metrics['dist/FCD'] = FCDMetric(**kwargs_fcd)(gen=valid, pref=self.stat_ref['FCD']) | ||||
|  | ||||
|             if self.task_evaluator is not None: | ||||
|                 evaluation_list = list(self.task_evaluator.keys()) | ||||
|                 print('evaluation_list: ', evaluation_list) | ||||
|                 evaluation_list = evaluation_list.copy() | ||||
|  | ||||
|                 assert 'meta_taskname' in evaluation_list | ||||
|                 meta_taskname = self.task_evaluator['meta_taskname'] | ||||
|                 evaluation_list.remove('meta_taskname') | ||||
|                 # meta_split = meta_taskname.split('-') | ||||
|  | ||||
|                 valid_index = np.array([True if graphs else False for graphs in all_graphs]) | ||||
|                 targets_log = {} | ||||
|                 for i, name in enumerate(evaluation_list): | ||||
|                     targets_log[f'input_{name}'] = np.array([float('nan')] * len(valid_index)) | ||||
|                     targets_log[f'input_{name}'] = targets[:, i] | ||||
|                  | ||||
|                 targets = targets[valid_index] | ||||
|                 # if len(meta_split) == 2: | ||||
|                 #     cached_perm = {meta_split[0]: None, meta_split[1]: None} | ||||
|                  | ||||
|                 for i, name in enumerate(evaluation_list): | ||||
|                     # if name == 'scs': | ||||
|                     #     continue | ||||
|                     # elif name == 'sas': | ||||
|                     #     scores = calculateSAS(valid) | ||||
|                     # else: | ||||
|                     # scores = self.task_evaluator[name](valid) | ||||
|                     # fix the scores | ||||
|                     scores = np.random.rand(len(valid_index)) | ||||
|                     targets_log[f'output_{name}'] = np.array([float('nan')] * len(valid_index)) | ||||
|                     targets_log[f'output_{name}'][valid_index] = scores | ||||
|                     # if name in ['O2', 'N2', 'CO2']: | ||||
|                     #     if len(meta_split) == 2: | ||||
|                     #         cached_perm[name] = scores | ||||
|                     #     scores, cur_targets = np.log10(scores), np.log10(targets[:, i]) | ||||
|                     #     dist_metrics[f'{name}/mae'] = np.mean(np.abs(scores - cur_targets)) | ||||
|                     # elif name == 'sas': | ||||
|                     #     dist_metrics[f'{name}/mae'] = np.mean(np.abs(scores - targets[:, i])) | ||||
|                     # else: | ||||
|                     true_y = targets[:, i] | ||||
|                     predicted_labels = (scores >= 0.5).astype(int) | ||||
|                     acc = (predicted_labels == true_y).sum() / len(true_y) | ||||
|                     dist_metrics[f'{name}/acc'] = acc | ||||
|  | ||||
|                 # if len(meta_split) == 2: | ||||
|                 #     if cached_perm[meta_split[0]] is not None and cached_perm[meta_split[1]] is not None: | ||||
|                 #         task_name = self.task_evaluator['meta_taskname'] | ||||
|                 #         upper = selectivity_evaluation(cached_perm[meta_split[0]], cached_perm[meta_split[1]], task_name) | ||||
|                 #         dist_metrics[f'selectivity/{task_name}'] = np.sum(upper) | ||||
|  | ||||
|             end_time = time.time() | ||||
|             elapsed_time = end_time - start_time | ||||
|             max_key_length = max(len(key) for key in dist_metrics) | ||||
|             print(f'Details over {len(valid)} ({len(generated)}) valid (total) graphs, calculating metrics using {elapsed_time:.2f} s:') | ||||
|             strs = '' | ||||
|             for i, (key, value) in enumerate(dist_metrics.items()): | ||||
|                 if isinstance(value, (int, float, np.floating, np.integer)): | ||||
|                     strs = strs + f'{key:>{max_key_length}}:{value:<7.4f}\t' | ||||
|                 if i % 4 == 3: | ||||
|                     strs = strs + '\n' | ||||
|             print(strs) | ||||
|  | ||||
|             if close_pool: | ||||
|                 pool.close() | ||||
|                 pool.join() | ||||
|         else: | ||||
|             unique = [] | ||||
|             dist_metrics = {} | ||||
|             targets_log = None | ||||
|         return unique, dict(nc_min=nc_min, nc_max=nc_max, nc_mu=nc_mu), all_graphs, dist_metrics, targets_log | ||||
|  | ||||
|  | ||||
| class BasicMolecularMetrics(object): | ||||
|     def __init__(self, atom_decoder, train_smiles=None, stat_ref=None, task_evaluator=None, n_jobs=8, device='cpu', batch_size=512): | ||||
| @@ -388,6 +526,18 @@ def connect_fragments(mol): | ||||
|     return combined_mol | ||||
|  | ||||
| #### connect fragements | ||||
| def compute_graph_metrics(graph_list, targets, train_graphs, stat_ref, dataset_info, task_evaluator, comput_config): | ||||
|     """ graph_list: (dict) """ | ||||
|     node_decoder = dataset_info.node_decoder | ||||
|     active_nodes = dataset_info.active_nodes | ||||
|     ensure_connected = dataset_info.ensure_connected | ||||
|     metrics = BasicGraphMetrics(node_decoder, train_graphs, stat_ref, task_evaluator, **comput_config) | ||||
|     evaluated_res = metrics.evaluate(graph_list, targets, ensure_connected, active_nodes) | ||||
|     all_graphs = evaluated_res[-3] | ||||
|     all_metrics = evaluated_res[-2] | ||||
|     targets_log = evaluated_res[-1] | ||||
|     unique_graphs = evaluated_res[0] | ||||
|     return unique_graphs, all_graphs, all_metrics, targets_log | ||||
|  | ||||
| def compute_molecular_metrics(molecule_list, targets, train_smiles, stat_ref, dataset_info, task_evaluator, comput_config): | ||||
|     """ molecule_list: (dict) """ | ||||
|   | ||||
| @@ -10,7 +10,41 @@ import numpy as np | ||||
| import rdkit.Chem | ||||
| import matplotlib.pyplot as plt | ||||
|  | ||||
| class GraphVisualization: | ||||
|     def __init__(self, dataset_infos): | ||||
|         self.dataset_infos = dataset_infos | ||||
|     def graph_from_graphs(self, node_list, adjency_matrix): | ||||
|         """ | ||||
|         Convert graphs to networkx graphs | ||||
|         node_list: the nodes of a batch of nodes (bs x n) | ||||
|         adjacency_matrix: the adjacency_matrix of the molecule (bs x n x n) | ||||
|         """ | ||||
|         graph = nx.Graph() | ||||
|  | ||||
|         for i in range(len(node_list)): | ||||
|             if node_list[i] == -1: | ||||
|                 continue | ||||
|             graph.add_node(i, number=i, symbol=node_list[i], color_val=node_list[i]) | ||||
|  | ||||
|         rows, cols = np.where(adjency_matrix >= 1) | ||||
|         edges = zip(rows.tolist(), cols.tolist()) | ||||
|         for edge in edges: | ||||
|             edge_type = adjency_matrix[edge[0]][edge[1]] | ||||
|             graph.add_edge(edge[0], edge[1], color=float(edge_type), weight=3 * edge_type) | ||||
|  | ||||
|         return graph | ||||
|  | ||||
|     def visualize(self, path: str, graphs: list, num_graphs_to_visualize: int, log='graph'): | ||||
|         # define path to save figures | ||||
|         if not os.path.exists(path): | ||||
|             os.makedirs(path) | ||||
|  | ||||
|         # visualize the final molecules | ||||
|         for i in range(num_graphs_to_visualize): | ||||
|             file_path = os.path.join(path, 'graph_{}.png'.format(i)) | ||||
|             graph = self.graph_from_graphs(graphs[i][0].numpy(), graphs[i][1].numpy()) | ||||
|             self.visualize_graph(graph=graph, pos=None, path=file_path) | ||||
|             im = plt.imread(file_path) | ||||
| class MolecularVisualization: | ||||
|     def __init__(self, dataset_infos): | ||||
|         self.dataset_infos = dataset_infos | ||||
|   | ||||
		Reference in New Issue
	
	Block a user