update print and output json statements
This commit is contained in:
		| @@ -1,5 +1,6 @@ | |||||||
| ### packages for visualization | ### packages for visualization | ||||||
| from analysis.rdkit_functions import compute_molecular_metrics | from analysis.rdkit_functions import compute_molecular_metrics | ||||||
|  | from analysis.rdkit_functions import compute_graph_metrics | ||||||
| from mini_moses.metrics.metrics import compute_intermediate_statistics | from mini_moses.metrics.metrics import compute_intermediate_statistics | ||||||
| from metrics.property_metric import TaskModel | from metrics.property_metric import TaskModel | ||||||
|  |  | ||||||
| @@ -49,8 +50,8 @@ class SamplingGraphMetrics(nn.Module): | |||||||
|  |  | ||||||
|         self.task_evaluator = { |         self.task_evaluator = { | ||||||
|             'meta_taskname': dataset_infos.task, |             'meta_taskname': dataset_infos.task, | ||||||
|             'sas': None, |             # 'sas': None, | ||||||
|             'scs': None |             # 'scs': None | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         for cur_task in dataset_infos.task.split("-")[:]: |         for cur_task in dataset_infos.task.split("-")[:]: | ||||||
| @@ -62,13 +63,14 @@ class SamplingGraphMetrics(nn.Module): | |||||||
|             self.task_evaluator[cur_task] = evaluator |             self.task_evaluator[cur_task] = evaluator | ||||||
|  |  | ||||||
|     def forward(self, graphs, targets, name, current_epoch, val_counter, test=False): |     def forward(self, graphs, targets, name, current_epoch, val_counter, test=False): | ||||||
|  |         test = True | ||||||
|         if isinstance(targets, list): |         if isinstance(targets, list): | ||||||
|             targets_cat = torch.cat(targets, dim=0) |             targets_cat = torch.cat(targets, dim=0) | ||||||
|             targets_np = targets_cat.detach().cpu().numpy() |             targets_np = targets_cat.detach().cpu().numpy() | ||||||
|         else: |         else: | ||||||
|             targets_np = targets.detach().cpu().numpy() |             targets_np = targets.detach().cpu().numpy() | ||||||
|  |  | ||||||
|         unique_graphs, all_graphs, all_graphs, targets_log = compute_molecular_metrics( |         unique_graphs, all_graphs, all_metrics, targets_log = compute_graph_metrics( | ||||||
|             graphs, |             graphs, | ||||||
|             targets_np, |             targets_np, | ||||||
|             self.train_graphs, |             self.train_graphs, | ||||||
| @@ -77,6 +79,22 @@ class SamplingGraphMetrics(nn.Module): | |||||||
|             self.task_evaluator, |             self.task_evaluator, | ||||||
|             self.compute_config, |             self.compute_config, | ||||||
|         ) |         ) | ||||||
|  |         print(f"all graphs: {all_graphs}") | ||||||
|  |         print(f"all graphs[0]: {all_graphs[0]}") | ||||||
|  |         tmp_graphs = all_graphs.copy() | ||||||
|  |         str_graphs = [] | ||||||
|  |         for graph in tmp_graphs: | ||||||
|  |             node_types = graph[0] | ||||||
|  |             edge_types = graph[1] | ||||||
|  |             node_str = " ".join([str(node) for node in node_types]) | ||||||
|  |             edge_str_list = [] | ||||||
|  |             for i in range(len(node_types)): | ||||||
|  |                 for j in range(len(node_types)): | ||||||
|  |                     edge_str_list.append(str(edge_types[i][j])) | ||||||
|  |                 edge_str_list.append("/n") | ||||||
|  |             edge_str = " ".join(edge_str_list) | ||||||
|  |             str_graphs.append(f"nodes: {node_str} /n edges: /n{edge_str}") | ||||||
|  |  | ||||||
|  |  | ||||||
|         if test: |         if test: | ||||||
|             file_name = "final_graphs.txt" |             file_name = "final_graphs.txt" | ||||||
| @@ -88,7 +106,7 @@ class SamplingGraphMetrics(nn.Module): | |||||||
|  |  | ||||||
|                 all_tasks_str = "graph, " + ", ".join([f"input_{task}" for task in all_tasks_name] + [f"output_{task}" for task in all_tasks_name]) |                 all_tasks_str = "graph, " + ", ".join([f"input_{task}" for task in all_tasks_name] + [f"output_{task}" for task in all_tasks_name]) | ||||||
|                 fp.write(all_tasks_str + "\n") |                 fp.write(all_tasks_str + "\n") | ||||||
|                 for i, graph in enumerate(all_graphs): |                 for i, graph in enumerate(str_graphs): | ||||||
|                     if targets_log is not None: |                     if targets_log is not None: | ||||||
|                         all_result_str = f"{graph}, " + ", ".join([f"{targets_log['input_'+task][i]}" for task in all_tasks_name] + [f"{targets_log['output_'+task][i]}" for task in all_tasks_name]) |                         all_result_str = f"{graph}, " + ", ".join([f"{targets_log['input_'+task][i]}" for task in all_tasks_name] + [f"{targets_log['output_'+task][i]}" for task in all_tasks_name]) | ||||||
|                         fp.write(all_result_str + "\n") |                         fp.write(all_result_str + "\n") | ||||||
| @@ -107,7 +125,7 @@ class SamplingGraphMetrics(nn.Module): | |||||||
|                 textfile.write(graph + "\n") |                 textfile.write(graph + "\n") | ||||||
|             textfile.close() |             textfile.close() | ||||||
|          |          | ||||||
|         all_logs = all_graphs |         all_logs = all_metrics | ||||||
|         if test: |         if test: | ||||||
|             all_logs["log_name"] = "test" |             all_logs["log_name"] = "test" | ||||||
|         else: |         else: | ||||||
| @@ -116,7 +134,7 @@ class SamplingGraphMetrics(nn.Module): | |||||||
|             ) |             ) | ||||||
|          |          | ||||||
|         result_to_csv("output.csv", all_logs) |         result_to_csv("output.csv", all_logs) | ||||||
|         return all_graphs |         return str_graphs | ||||||
|      |      | ||||||
|     def reset(self): |     def reset(self): | ||||||
|         pass |         pass | ||||||
|   | |||||||
| @@ -102,6 +102,7 @@ class TaskModel(): | |||||||
|         mask = ~np.isnan(labels) |         mask = ~np.isnan(labels) | ||||||
|         labels = labels[mask] |         labels = labels[mask] | ||||||
|         features = features[mask] |         features = features[mask] | ||||||
|  |         # features = str(features) | ||||||
|         self.model.fit(features, labels) |         self.model.fit(features, labels) | ||||||
|         y_pred = self.model.predict(features) |         y_pred = self.model.predict(features) | ||||||
|         perf = self.metric_func(labels, y_pred) |         perf = self.metric_func(labels, y_pred) | ||||||
| @@ -136,7 +137,7 @@ class TaskModel(): | |||||||
|         print(f'{self.task_name} performance: {perf}') |         print(f'{self.task_name} performance: {perf}') | ||||||
|         return perf |         return perf | ||||||
|  |  | ||||||
|     def __call__(self, smiles_list): |     def __call(self, smiles_list): | ||||||
|         fps = [] |         fps = [] | ||||||
|         mask = [] |         mask = [] | ||||||
|         for i,smiles in enumerate(smiles_list): |         for i,smiles in enumerate(smiles_list): | ||||||
| @@ -153,6 +154,54 @@ class TaskModel(): | |||||||
|         scores = scores * np.array(mask) |         scores = scores * np.array(mask) | ||||||
|         return np.float32(scores) |         return np.float32(scores) | ||||||
|  |  | ||||||
|  |     def __call__(self, graph_list): | ||||||
|  |         # def read_adj_ops_from_json(filename): | ||||||
|  |         #     with open(filename, 'r') as json_file: | ||||||
|  |         #         data = json.load(json_file) | ||||||
|  |  | ||||||
|  |         #     adj_ops_pairs = [] | ||||||
|  |         #     for item in data: | ||||||
|  |         #         adj_matrix = np.array(item['adj_matrix']) | ||||||
|  |         #         ops = item['ops'] | ||||||
|  |         #         acc = item['train'][0]['accuracy'] | ||||||
|  |         #         adj_ops_pairs.append((adj_matrix, ops, acc)) | ||||||
|  |              | ||||||
|  |         #     return adj_ops_pairs | ||||||
|  |         def feature_from_adj_and_ops(ops, adj): | ||||||
|  |             return np.concatenate([adj.flatten(), ops]) | ||||||
|  |         # filename = '/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json' | ||||||
|  |         # graphs = read_adj_ops_from_json(filename) | ||||||
|  |         # adjs = [] | ||||||
|  |         # opss = [] | ||||||
|  |         # accs = [] | ||||||
|  |         # features = [] | ||||||
|  |         # for graph in graphs: | ||||||
|  |         #     adj, ops, acc=graph | ||||||
|  |         #     op_code = [op_type[op] for op in ops] | ||||||
|  |         #     adjs.append(adj) | ||||||
|  |         #     opss.append(op_code) | ||||||
|  |         #     accs.append(acc) | ||||||
|  |         features = [] | ||||||
|  |         print(f"graphlist: {graph_list[0]}") | ||||||
|  |         print(f"len graphlist: {len(graph_list)}")  | ||||||
|  |         for op_code, adj in graph_list: | ||||||
|  |             features.append(feature_from_adj_and_ops(op_code, adj)) | ||||||
|  |         print(f"len features: {len(features)}") | ||||||
|  |         # print(f"features: {features[0].shape}") | ||||||
|  |         features = np.stack(features) | ||||||
|  |         features = features.astype(np.float32) | ||||||
|  |         print(f"features shape: {features.shape}") | ||||||
|  |  | ||||||
|  |  | ||||||
|  |         fps = features | ||||||
|  |         if 'classification' in self.task_type: | ||||||
|  |             scores = self.model.predict_proba(fps)[:, 1] | ||||||
|  |         else: | ||||||
|  |             scores = self.model.predict(fps) | ||||||
|  |         # scores = scores * np.array(mask) | ||||||
|  |         return np.float32(scores) | ||||||
|  |  | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def fingerprints_from_mol(cls, mol):  # use ECFP4 |     def fingerprints_from_mol(cls, mol):  # use ECFP4 | ||||||
|         features_vec = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048) |         features_vec = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user