Compare commits
	
		
			2 Commits
		
	
	
		
			0c3cfb195a
			...
			d44900c8ba
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| d44900c8ba | |||
| 73324083ce | 
| @@ -2,6 +2,7 @@ general: | ||||
|     name: 'graph_dit' | ||||
|     wandb: 'disabled'  | ||||
|     gpus: 1 | ||||
|     gpu_number: 3 | ||||
|     resume: null | ||||
|     test_only: null | ||||
|     sample_every_val: 2500 | ||||
| @@ -10,7 +11,7 @@ general: | ||||
|     chains_to_save: 1 | ||||
|     log_every_steps: 50 | ||||
|     number_chain_steps: 8 | ||||
|     final_model_samples_to_generate: 10000 | ||||
|     final_model_samples_to_generate: 100 | ||||
|     final_model_samples_to_save: 20 | ||||
|     final_model_chains_to_save: 1 | ||||
|     enable_progress_bar: False | ||||
| @@ -30,7 +31,7 @@ model: | ||||
|     lambda_train: [1, 10]  # node and edge training weight  | ||||
|     ensure_connected: True | ||||
| train: | ||||
|     n_epochs: 10000 | ||||
|     n_epochs: 5000 | ||||
|     batch_size: 1200 | ||||
|     lr: 0.0002 | ||||
|     clip_grad: null | ||||
|   | ||||
| @@ -175,6 +175,7 @@ def test(cfg: DictConfig): | ||||
|     elif cfg.general.resume is not None: | ||||
|         cfg, _ = get_resume_adaptive(cfg, model_kwargs) | ||||
|         os.chdir(cfg.general.resume.split("checkpoints")[0]) | ||||
|     # os.environ["CUDA_VISIBLE_DEVICES"] = cfg.general.gpu_number | ||||
|     model = Graph_DiT(cfg=cfg, **model_kwargs) | ||||
|     trainer = Trainer( | ||||
|         gradient_clip_val=cfg.train.clip_grad, | ||||
| @@ -182,7 +183,7 @@ def test(cfg: DictConfig): | ||||
|         accelerator="gpu" | ||||
|         if torch.cuda.is_available() and cfg.general.gpus > 0 | ||||
|         else "cpu", | ||||
|         devices=cfg.general.gpus | ||||
|         devices=[cfg.general.gpu_number] | ||||
|         if torch.cuda.is_available() and cfg.general.gpus > 0 | ||||
|         else None, | ||||
|         max_epochs=cfg.train.n_epochs, | ||||
|   | ||||
| @@ -1,5 +1,6 @@ | ||||
| ### packages for visualization | ||||
| from analysis.rdkit_functions import compute_molecular_metrics | ||||
| from analysis.rdkit_functions import compute_graph_metrics | ||||
| from mini_moses.metrics.metrics import compute_intermediate_statistics | ||||
| from metrics.property_metric import TaskModel | ||||
|  | ||||
| @@ -49,8 +50,8 @@ class SamplingGraphMetrics(nn.Module): | ||||
|  | ||||
|         self.task_evaluator = { | ||||
|             'meta_taskname': dataset_infos.task, | ||||
|             'sas': None, | ||||
|             'scs': None | ||||
|             # 'sas': None, | ||||
|             # 'scs': None | ||||
|         } | ||||
|  | ||||
|         for cur_task in dataset_infos.task.split("-")[:]: | ||||
| @@ -62,13 +63,14 @@ class SamplingGraphMetrics(nn.Module): | ||||
|             self.task_evaluator[cur_task] = evaluator | ||||
|  | ||||
|     def forward(self, graphs, targets, name, current_epoch, val_counter, test=False): | ||||
|         test = True | ||||
|         if isinstance(targets, list): | ||||
|             targets_cat = torch.cat(targets, dim=0) | ||||
|             targets_np = targets_cat.detach().cpu().numpy() | ||||
|         else: | ||||
|             targets_np = targets.detach().cpu().numpy() | ||||
|  | ||||
|         unique_graphs, all_graphs, all_graphs, targets_log = compute_molecular_metrics( | ||||
|         unique_graphs, all_graphs, all_metrics, targets_log = compute_graph_metrics( | ||||
|             graphs, | ||||
|             targets_np, | ||||
|             self.train_graphs, | ||||
| @@ -77,6 +79,22 @@ class SamplingGraphMetrics(nn.Module): | ||||
|             self.task_evaluator, | ||||
|             self.compute_config, | ||||
|         ) | ||||
|         print(f"all graphs: {all_graphs}") | ||||
|         print(f"all graphs[0]: {all_graphs[0]}") | ||||
|         tmp_graphs = all_graphs.copy() | ||||
|         str_graphs = [] | ||||
|         for graph in tmp_graphs: | ||||
|             node_types = graph[0] | ||||
|             edge_types = graph[1] | ||||
|             node_str = " ".join([str(node) for node in node_types]) | ||||
|             edge_str_list = [] | ||||
|             for i in range(len(node_types)): | ||||
|                 for j in range(len(node_types)): | ||||
|                     edge_str_list.append(str(edge_types[i][j])) | ||||
|                 edge_str_list.append("/n") | ||||
|             edge_str = " ".join(edge_str_list) | ||||
|             str_graphs.append(f"nodes: {node_str} /n edges: /n{edge_str}") | ||||
|  | ||||
|  | ||||
|         if test: | ||||
|             file_name = "final_graphs.txt" | ||||
| @@ -88,7 +106,7 @@ class SamplingGraphMetrics(nn.Module): | ||||
|  | ||||
|                 all_tasks_str = "graph, " + ", ".join([f"input_{task}" for task in all_tasks_name] + [f"output_{task}" for task in all_tasks_name]) | ||||
|                 fp.write(all_tasks_str + "\n") | ||||
|                 for i, graph in enumerate(all_graphs): | ||||
|                 for i, graph in enumerate(str_graphs): | ||||
|                     if targets_log is not None: | ||||
|                         all_result_str = f"{graph}, " + ", ".join([f"{targets_log['input_'+task][i]}" for task in all_tasks_name] + [f"{targets_log['output_'+task][i]}" for task in all_tasks_name]) | ||||
|                         fp.write(all_result_str + "\n") | ||||
| @@ -107,7 +125,7 @@ class SamplingGraphMetrics(nn.Module): | ||||
|                 textfile.write(graph + "\n") | ||||
|             textfile.close() | ||||
|          | ||||
|         all_logs = all_graphs | ||||
|         all_logs = all_metrics | ||||
|         if test: | ||||
|             all_logs["log_name"] = "test" | ||||
|         else: | ||||
| @@ -116,7 +134,7 @@ class SamplingGraphMetrics(nn.Module): | ||||
|             ) | ||||
|          | ||||
|         result_to_csv("output.csv", all_logs) | ||||
|         return all_graphs | ||||
|         return str_graphs | ||||
|      | ||||
|     def reset(self): | ||||
|         pass | ||||
|   | ||||
| @@ -102,6 +102,7 @@ class TaskModel(): | ||||
|         mask = ~np.isnan(labels) | ||||
|         labels = labels[mask] | ||||
|         features = features[mask] | ||||
|         # features = str(features) | ||||
|         self.model.fit(features, labels) | ||||
|         y_pred = self.model.predict(features) | ||||
|         perf = self.metric_func(labels, y_pred) | ||||
| @@ -136,7 +137,7 @@ class TaskModel(): | ||||
|         print(f'{self.task_name} performance: {perf}') | ||||
|         return perf | ||||
|  | ||||
|     def __call__(self, smiles_list): | ||||
|     def __call(self, smiles_list): | ||||
|         fps = [] | ||||
|         mask = [] | ||||
|         for i,smiles in enumerate(smiles_list): | ||||
| @@ -153,6 +154,54 @@ class TaskModel(): | ||||
|         scores = scores * np.array(mask) | ||||
|         return np.float32(scores) | ||||
|  | ||||
|     def __call__(self, graph_list): | ||||
|         # def read_adj_ops_from_json(filename): | ||||
|         #     with open(filename, 'r') as json_file: | ||||
|         #         data = json.load(json_file) | ||||
|  | ||||
|         #     adj_ops_pairs = [] | ||||
|         #     for item in data: | ||||
|         #         adj_matrix = np.array(item['adj_matrix']) | ||||
|         #         ops = item['ops'] | ||||
|         #         acc = item['train'][0]['accuracy'] | ||||
|         #         adj_ops_pairs.append((adj_matrix, ops, acc)) | ||||
|              | ||||
|         #     return adj_ops_pairs | ||||
|         def feature_from_adj_and_ops(ops, adj): | ||||
|             return np.concatenate([adj.flatten(), ops]) | ||||
|         # filename = '/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json' | ||||
|         # graphs = read_adj_ops_from_json(filename) | ||||
|         # adjs = [] | ||||
|         # opss = [] | ||||
|         # accs = [] | ||||
|         # features = [] | ||||
|         # for graph in graphs: | ||||
|         #     adj, ops, acc=graph | ||||
|         #     op_code = [op_type[op] for op in ops] | ||||
|         #     adjs.append(adj) | ||||
|         #     opss.append(op_code) | ||||
|         #     accs.append(acc) | ||||
|         features = [] | ||||
|         print(f"graphlist: {graph_list[0]}") | ||||
|         print(f"len graphlist: {len(graph_list)}")  | ||||
|         for op_code, adj in graph_list: | ||||
|             features.append(feature_from_adj_and_ops(op_code, adj)) | ||||
|         print(f"len features: {len(features)}") | ||||
|         # print(f"features: {features[0].shape}") | ||||
|         features = np.stack(features) | ||||
|         features = features.astype(np.float32) | ||||
|         print(f"features shape: {features.shape}") | ||||
|  | ||||
|  | ||||
|         fps = features | ||||
|         if 'classification' in self.task_type: | ||||
|             scores = self.model.predict_proba(fps)[:, 1] | ||||
|         else: | ||||
|             scores = self.model.predict(fps) | ||||
|         # scores = scores * np.array(mask) | ||||
|         return np.float32(scores) | ||||
|  | ||||
|  | ||||
|     @classmethod | ||||
|     def fingerprints_from_mol(cls, mol):  # use ECFP4 | ||||
|         features_vec = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user