Compare commits
	
		
			23 Commits
		
	
	
		
			01c5c277be
			...
			nasbench
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 91d4e3c7ad | |||
| c867aef5a6 | |||
| 1ad520d248 | |||
| 74a629fdcc | |||
| 94fe13756f | |||
| 2ac17caa3c | |||
| 0c60171c71 | |||
| 97fbdf91c7 | |||
| 297261d666 | |||
| 5dccf590e7 | |||
| 0c4b597dd2 | |||
| 11d9697e06 | |||
| 244b159c26 | |||
| 63ca6c716e | |||
| d36e1d1077 | |||
| 82183d3df7 | |||
| c86db9b6ba | |||
| a0473008a1 | |||
| 05ee34e355 | |||
| 6d9db64a48 | |||
| 3950a8438d | |||
| 1fa2d49c11 | |||
| 3c92e754d3 | 
							
								
								
									
										24
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										24
									
								
								README.md
									
									
									
									
									
								
							| @@ -1,14 +1,34 @@ | ||||
| Graph Diffusion Transformer for Multi-Conditional Molecular Generation | ||||
| ================================================================ | ||||
|  | ||||
| ## Initial Setup | ||||
|  | ||||
| Please download NASBench201 dataset(NAS-Bench-201-v1_1-096897.pth) from | ||||
| https://drive.google.com/file/d/16Y0UwGisiouVRxW-W5hEtbxmcHw_0hF_/view | ||||
|  | ||||
| and put it in the `/path/to/repo/graph_dit` folder. | ||||
|  | ||||
| ## Running the code | ||||
|  | ||||
| start command: | ||||
| ``` bash | ||||
| python main.py --config-name=config.yaml \ | ||||
| model.ensure_connected=True \ | ||||
| dataset.task_name='nasbench201' \ | ||||
| dataset.guidance_target='regression' | ||||
| ``` | ||||
|  | ||||
| This repository contains the code for the paper "Inverse Molecular Design with Multi-Conditional Diffusion Guidance" by Gang Liu, Jiaxin Xu, Tengfei Luo, and Meng Jiang. | ||||
|  | ||||
|  | ||||
| Paper: https://arxiv.org/abs/2401.13858 | ||||
|  | ||||
| This is the code for Graph DiT. The denoising model architecture in `graph_dit/models` looks like: | ||||
| <!-- This is the code for Graph DiT. The denoising model architecture in `graph_dit/models` looks like: | ||||
|  | ||||
| <div style="display: flex;" markdown="1"> | ||||
|       <img src="asset/reverse.png" style="width: 45%;" alt="Description of the first image"> | ||||
|       <img src="asset/arch.png" style="width: 45%;" alt="Description of the second image"> | ||||
| </div> | ||||
| </div> --> | ||||
|  | ||||
|  | ||||
| ## Requirements | ||||
|   | ||||
| @@ -32,7 +32,7 @@ model: | ||||
|     ensure_connected: True | ||||
| train: | ||||
|     # n_epochs: 5000 | ||||
|     n_epochs: 10 | ||||
|     n_epochs: 500 | ||||
|     batch_size: 1200 | ||||
|     lr: 0.0002 | ||||
|     clip_grad: null | ||||
|   | ||||
| @@ -25,7 +25,9 @@ from sklearn.model_selection import train_test_split | ||||
| import utils as utils | ||||
| from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule | ||||
| from diffusion.distributions import DistributionNodes | ||||
| # from naswot.score_networks import get_nasbench201_idx_score | ||||
| from naswot.score_networks import get_nasbench201_idx_score | ||||
| from naswot import nasspace | ||||
| from naswot import datasets as dt | ||||
|  | ||||
| import networkx as nx | ||||
|  | ||||
| @@ -682,7 +684,7 @@ class Dataset(InMemoryDataset): | ||||
|  | ||||
|         data_list = [] | ||||
|         # len_data = len(self.api) | ||||
|         len_data = 1000 | ||||
|         len_data = 15625 | ||||
|         def check_valid_graph(nodes, edges): | ||||
|             if len(nodes) != edges.shape[0] or len(nodes) != edges.shape[1]: | ||||
|                 return False | ||||
| @@ -745,11 +747,9 @@ class Dataset(InMemoryDataset): | ||||
|             print(f'edges size: {edges.shape}, nodes size: {len(nodes)}') | ||||
|             return  edges,nodes | ||||
|          | ||||
|         def get_nasbench_201_val(idx): | ||||
|             pass | ||||
|  | ||||
|         # def graph_to_graph_data(graph, idx): | ||||
|         def graph_to_graph_data(graph): | ||||
|         def graph_to_graph_data(graph, idx, train_loader, searchspace, args, device): | ||||
|         # def graph_to_graph_data(graph): | ||||
|             ops = graph[1] | ||||
|             adj = graph[0] | ||||
|             nodes = [] | ||||
| @@ -770,12 +770,58 @@ class Dataset(InMemoryDataset): | ||||
|             edge_index = torch.tensor(edges_list, dtype=torch.long).t() | ||||
|             edge_type = torch.tensor(edge_type, dtype=torch.long) | ||||
|             edge_attr = edge_type | ||||
|             y = torch.tensor([0, 0], dtype=torch.float).view(1, -1) | ||||
|             # y = get_nasbench_201_val(idx) | ||||
|             data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i) | ||||
|             # y = torch.tensor([0, 0], dtype=torch.float).view(1, -1) | ||||
|             # y = get_nasbench201_idx_score(idx, train_loader, searchspace, args, device) | ||||
|             y = self.swap_scores[idx] | ||||
|             print(y, idx) | ||||
|             if y > 60000: | ||||
|                 print(f'idx={idx}, y={y}') | ||||
|                 y = torch.tensor([1, 1], dtype=torch.float).view(1, -1) | ||||
|                 data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i) | ||||
|             else: | ||||
|                 print(f'idx={idx}, y={y}') | ||||
|                 y = torch.tensor([0, 0], dtype=torch.float).view(1, -1) | ||||
|                 data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i) | ||||
|                 # return None | ||||
|             return data | ||||
|         graph_list = [] | ||||
|  | ||||
|         class Args: | ||||
|             pass | ||||
|         args = Args() | ||||
|         args.trainval = True | ||||
|         args.augtype = 'none' | ||||
|         args.repeat = 1 | ||||
|         args.score = 'hook_logdet' | ||||
|         args.sigma = 0.05 | ||||
|         args.nasspace = 'nasbench201' | ||||
|         args.batch_size = 128 | ||||
|         args.GPU = '0' | ||||
|         args.dataset = 'cifar10' | ||||
|         args.api_loc = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' | ||||
|         args.data_loc = '../cifardata/' | ||||
|         args.seed = 777 | ||||
|         args.init = '' | ||||
|         args.save_loc = 'results' | ||||
|         args.save_string = 'naswot' | ||||
|         args.dropout = False | ||||
|         args.maxofn = 1 | ||||
|         args.n_samples = 100 | ||||
|         args.n_runs = 500 | ||||
|         args.stem_out_channels = 16 | ||||
|         args.num_stacks = 3 | ||||
|         args.num_modules_per_stack = 3 | ||||
|         args.num_labels = 1 | ||||
|         searchspace = nasspace.get_search_space(args) | ||||
|         train_loader = dt.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args) | ||||
|         self.swap_scores = [] | ||||
|         import csv | ||||
|         # with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results.csv', 'r') as f: | ||||
|         with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results_cifar100.csv', 'r') as f: | ||||
|             reader = csv.reader(f) | ||||
|             header = next(reader) | ||||
|             data = [row for row in reader] | ||||
|             self.swap_scores = [float(row[0]) for row in data] | ||||
|         device = torch.device('cuda:2') | ||||
|         with tqdm(total = len_data) as pbar: | ||||
|             active_nodes = set() | ||||
|             file_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json' | ||||
| @@ -785,25 +831,17 @@ class Dataset(InMemoryDataset): | ||||
|             flex_graph_list = [] | ||||
|             flex_graph_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json' | ||||
|             for graph in graph_list: | ||||
|                 # arch_info = self.api.query_meta_info_by_index(i) | ||||
|                 # results = self.api.query_by_index(i, 'cifar100') | ||||
|                 print(f'iterate every graph in graph_list, here is {i}') | ||||
|                 arch_info = graph['arch_str'] | ||||
|                 # results =  | ||||
|                 # nodes, edges = parse_architecture_string(arch_info.arch_str) | ||||
|                 # ops, adj_matrix = parse_architecture_string(arch_info.arch_str, padding=4) | ||||
|                 ops, adj_matrix, ori_nodes, ori_adj = parse_architecture_string(arch_info, padding=4) | ||||
|                 # adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges) | ||||
|                 for op in ops: | ||||
|                     if op not in active_nodes: | ||||
|                         active_nodes.add(op) | ||||
|                  | ||||
|                 data = graph_to_graph_data((adj_matrix, ops))  | ||||
|                 # with open(flex_graph_path, 'a') as f: | ||||
|                 #     flex_graph = { | ||||
|                 #         'adj_matrix': adj_matrix, | ||||
|                 #         'ops': ops, | ||||
|                 #     } | ||||
|                 #     json.dump(flex_graph, f) | ||||
|                 data = graph_to_graph_data((adj_matrix, ops),idx=i, train_loader=train_loader, searchspace=searchspace, args=args, device=device)  | ||||
|                 i += 1 | ||||
|                 if data is None: | ||||
|                     pbar.update(1) | ||||
|                     continue | ||||
|                 flex_graph_list.append({ | ||||
|                     'adj_matrix':adj_matrix, | ||||
|                     'ops': ops, | ||||
| @@ -816,18 +854,12 @@ class Dataset(InMemoryDataset): | ||||
|                         f.write(str(data.edge_attr)) | ||||
|                 data_list.append(data) | ||||
|  | ||||
|                 new_adj, new_ops = generate_flex_adj_mat(ori_nodes=ori_nodes, ori_edges=ori_adj, max_nodes=12, min_nodes=9,  random_ratio=0.5) | ||||
|                 flex_graph_list.append({ | ||||
|                     'adj_matrix':new_adj.tolist(), | ||||
|                     'ops': new_ops, | ||||
|                 }) | ||||
|                 # with open(flex_graph_path, 'w') as f: | ||||
|                 #     flex_graph = { | ||||
|                 #         'adj_matrix': new_adj.tolist(), | ||||
|                 #         'ops': new_ops, | ||||
|                 #     } | ||||
|                 #     json.dump(flex_graph, f) | ||||
|                 data_list.append(graph_to_graph_data((new_adj, new_ops))) | ||||
|                 # new_adj, new_ops = generate_flex_adj_mat(ori_nodes=ori_nodes, ori_edges=ori_adj, max_nodes=12, min_nodes=9,  random_ratio=0.5) | ||||
|                 # flex_graph_list.append({ | ||||
|                 #     'adj_matrix':new_adj.tolist(), | ||||
|                 #     'ops': new_ops, | ||||
|                 # }) | ||||
|                 # data_list.append(graph_to_graph_data((new_adj, new_ops))) | ||||
|                 | ||||
|                 # graph_list.append({ | ||||
|                 #     "adj_matrix": adj_matrix, | ||||
| @@ -859,6 +891,7 @@ class Dataset(InMemoryDataset): | ||||
|                 #         "seed": seed, | ||||
|                 #     }for seed, result in results.items()] | ||||
|                 # }) | ||||
|                 # i += 1 | ||||
|                 pbar.update(1) | ||||
|          | ||||
|         for graph in graph_list: | ||||
| @@ -872,8 +905,8 @@ class Dataset(InMemoryDataset): | ||||
|                 graph['ops'] = ops | ||||
|         with open(f'nasbench-201-graph.json', 'w') as f: | ||||
|             json.dump(graph_list, f) | ||||
|         with open(flex_graph_path, 'w') as f: | ||||
|             json.dump(flex_graph_list, f) | ||||
|         # with open(flex_graph_path, 'w') as f: | ||||
|             # json.dump(flex_graph_list, f) | ||||
|              | ||||
|         torch.save(self.collate(data_list), self.processed_paths[0]) | ||||
|  | ||||
| @@ -1148,7 +1181,8 @@ class DataInfos(AbstractDatasetInfos): | ||||
|             #         ops_type[op] = len(ops_type) | ||||
|             # len_ops.add(len(ops)) | ||||
|             # graphs.append((adj_matrix, ops)) | ||||
|         graphs = read_adj_ops_from_json(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json') | ||||
|         # graphs = read_adj_ops_from_json(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json') | ||||
|         graphs = read_adj_ops_from_json(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json') | ||||
|  | ||||
|         # check first five graphs | ||||
|         for i in range(5): | ||||
|   | ||||
| @@ -195,15 +195,18 @@ class Graph_DiT(pl.LightningModule): | ||||
|         # print("Size of the input features Xdim {}, Edim {}, ydim {}".format(self.Xdim, self.Edim, self.ydim)) | ||||
|  | ||||
|     def on_train_epoch_start(self) -> None: | ||||
|         if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: | ||||
|             print("Starting train epoch {}/{}...".format(self.current_epoch, self.trainer.max_epochs)) | ||||
|         # if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: | ||||
|         if self.current_epoch / self.cfg.train.n_epochs in [0.25, 0.5, 0.75, 1.0]: | ||||
|             # print("Starting train epoch {}/{}...".format(self.current_epoch, self.trainer.max_epochs)) | ||||
|             print("Starting train epoch {}/{}...".format(self.current_epoch, self.cfg.train.n_epochs)) | ||||
|         self.start_epoch_time = time.time() | ||||
|         self.train_loss.reset() | ||||
|         self.train_metrics.reset() | ||||
|  | ||||
|     def on_train_epoch_end(self) -> None: | ||||
|  | ||||
|         if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: | ||||
|         # if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: | ||||
|         if self.current_epoch / self.cfg.train.n_epochs in [0.25, 0.5, 0.75, 1.0]: | ||||
|             log = True | ||||
|         else: | ||||
|             log = False | ||||
| @@ -239,8 +242,8 @@ class Graph_DiT(pl.LightningModule): | ||||
|          | ||||
|                    self.val_X_logp.compute(), self.val_E_logp.compute()] | ||||
|          | ||||
|         if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: | ||||
|             print(f"Epoch {self.current_epoch}: Val NLL {metrics[0] :.2f} -- Val Atom type KL {metrics[1] :.2f} -- ", | ||||
|         # if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: | ||||
|         print(f"Epoch {self.current_epoch}: Val NLL {metrics[0] :.2f} -- Val Atom type KL {metrics[1] :.2f} -- ", | ||||
|                 f"Val Edge type KL: {metrics[2] :.2f}", 'Val loss: %.2f \t Best :  %.2f\n' % (metrics[0], self.best_val_nll)) | ||||
|         with open("validation-metrics.csv", "a") as f: | ||||
|             # save the metrics as csv file | ||||
| @@ -286,7 +289,7 @@ class Graph_DiT(pl.LightningModule): | ||||
|                 samples.extend(self.sample_batch(batch_id=ident, batch_size=to_generate, y=batch_y, | ||||
|                                                 save_final=to_save, | ||||
|                                                 keep_chain=chains_save, | ||||
|                                                 number_chain_steps=self.number_chain_steps)) | ||||
|                                                 number_chain_steps=self.number_chain_steps)[0]) | ||||
|                 ident += to_generate | ||||
|                 start_index += to_generate | ||||
|  | ||||
| @@ -356,10 +359,11 @@ class Graph_DiT(pl.LightningModule): | ||||
|             to_generate = min(samples_left_to_generate, bs) | ||||
|             to_save = min(samples_left_to_save, bs) | ||||
|             chains_save = min(chains_left_to_save, bs) | ||||
|             batch_y = test_y_collection[batch_id : batch_id + to_generate] | ||||
|             # batch_y = test_y_collection[batch_id : batch_id + to_generate] | ||||
|             batch_y = torch.ones(to_generate, self.ydim_output, device=self.device) | ||||
|  | ||||
|             cur_sample = self.sample_batch(batch_id, to_generate, batch_y, save_final=to_save, | ||||
|                                             keep_chain=chains_save, number_chain_steps=self.number_chain_steps) | ||||
|                                             keep_chain=chains_save, number_chain_steps=self.number_chain_steps)[0] | ||||
|             samples = samples + cur_sample | ||||
|              | ||||
|             all_ys.append(batch_y) | ||||
| @@ -600,6 +604,9 @@ class Graph_DiT(pl.LightningModule): | ||||
|  | ||||
|         assert (E == torch.transpose(E, 1, 2)).all() | ||||
|  | ||||
|         total_log_probs = torch.zeros([self.cfg.general.final_model_samples_to_generate,10], device=self.device) | ||||
|         # total_log_probs = torch.zeros([self.cfg.general.samples_to_generate,10], device=self.device) | ||||
|  | ||||
|         # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1. | ||||
|         for s_int in reversed(range(0, self.T)): | ||||
|             s_array = s_int * torch.ones((batch_size, 1)).type_as(y) | ||||
| @@ -608,21 +615,24 @@ class Graph_DiT(pl.LightningModule): | ||||
|             t_norm = t_array / self.T | ||||
|  | ||||
|             # Sample z_s | ||||
|             sampled_s, discrete_sampled_s = self.sample_p_zs_given_zt(s_norm, t_norm, X, E, y, node_mask) | ||||
|             sampled_s, discrete_sampled_s, log_probs= self.sample_p_zs_given_zt(s_norm, t_norm, X, E, y, node_mask) | ||||
|             X, E, y = sampled_s.X, sampled_s.E, sampled_s.y | ||||
|             print(f'sampled_s.X shape: {sampled_s.X.shape}, sampled_s.E shape: {sampled_s.E.shape}') | ||||
|             print(f'log_probs shape: {log_probs.shape}') | ||||
|             total_log_probs += log_probs | ||||
|  | ||||
|         # Sample | ||||
|         sampled_s = sampled_s.mask(node_mask, collapse=True) | ||||
|         X, E, y = sampled_s.X, sampled_s.E, sampled_s.y | ||||
|          | ||||
|         molecule_list = [] | ||||
|         graph_list = [] | ||||
|         for i in range(batch_size): | ||||
|             n = n_nodes[i] | ||||
|             atom_types = X[i, :n].cpu() | ||||
|             node_types = X[i, :n].cpu() | ||||
|             edge_types = E[i, :n, :n].cpu() | ||||
|             molecule_list.append([atom_types, edge_types]) | ||||
|             graph_list.append([node_types, edge_types]) | ||||
|          | ||||
|         return molecule_list | ||||
|         return graph_list, total_log_probs | ||||
|  | ||||
|     def sample_p_zs_given_zt(self, s, t, X_t, E_t, y_t, node_mask): | ||||
|         """Samples from zs ~ p(zs | zt). Only used during sampling. | ||||
| @@ -634,6 +644,7 @@ class Graph_DiT(pl.LightningModule): | ||||
|  | ||||
|         # Neural net predictions | ||||
|         noisy_data = {'X_t': X_t, 'E_t': E_t, 'y_t': y_t, 't': t, 'node_mask': node_mask} | ||||
|         print(f"sample p zs given zt X_t shape: {X_t.shape}, E_t shape: {E_t.shape}, y_t shape: {y_t.shape}, node_mask shape: {node_mask.shape}") | ||||
|          | ||||
|         def get_prob(noisy_data, unconditioned=False): | ||||
|             pred = self.forward(noisy_data, unconditioned=unconditioned) | ||||
| @@ -673,7 +684,19 @@ class Graph_DiT(pl.LightningModule): | ||||
|         # with condition = P_t(G_{t-1} |G_t, C) | ||||
|         # with condition = P_t(A_{t-1} |A_t, y) | ||||
|         prob_X, prob_E, pred = get_prob(noisy_data) | ||||
|         print(f'prob_X shape: {prob_X.shape}, prob_E shape: {prob_E.shape}') | ||||
|         print(f'X_t shape: {X_t.shape}, E_t shape: {E_t.shape}, y_t shape: {y_t.shape}') | ||||
|         print(f'X_t: {X_t}') | ||||
|         log_prob_X = torch.log(torch.gather(prob_X, -1, X_t.long()).squeeze(-1))  # bs, n | ||||
|         log_prob_E = torch.log(torch.gather(prob_E, -1, E_t.long()).squeeze(-1))  # bs, n, n | ||||
|  | ||||
|         # Sum the log_prob across dimensions for total log_prob | ||||
|         log_prob_X = log_prob_X.sum(dim=-1) | ||||
|         log_prob_E = log_prob_E.sum(dim=(1, 2)) | ||||
|         print(f'log_prob_X shape: {log_prob_X.shape}, log_prob_E shape: {log_prob_E.shape}') | ||||
|         # log_probs = log_prob_E + log_prob_X | ||||
|         log_probs = torch.cat([log_prob_X, log_prob_E], dim=-1)  # (batch_size, 2) | ||||
|         print(f'log_probs shape: {log_probs.shape}') | ||||
|         ### Guidance | ||||
|         if self.guidance_target is not None and self.guide_scale is not None and self.guide_scale != 1: | ||||
|             uncon_prob_X, uncon_prob_E, pred = get_prob(noisy_data, unconditioned=True) | ||||
| @@ -809,4 +832,4 @@ class Graph_DiT(pl.LightningModule): | ||||
|         out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=y_t) | ||||
|         out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=y_t) | ||||
|  | ||||
|         return out_one_hot.mask(node_mask).type_as(y_t), out_discrete.mask(node_mask, collapse=True).type_as(y_t) | ||||
|         return out_one_hot.mask(node_mask).type_as(y_t), out_discrete.mask(node_mask, collapse=True).type_as(y_t), log_probs | ||||
|   | ||||
							
								
								
									
										
											BIN
										
									
								
								graph_dit/exp_201/barplog.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								graph_dit/exp_201/barplog.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 30 KiB | 
							
								
								
									
										85
									
								
								graph_dit/exp_201/main.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										85
									
								
								graph_dit/exp_201/main.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,85 @@ | ||||
|  | ||||
| import matplotlib.pyplot as plt | ||||
| import pandas as pd | ||||
| from nas_201_api import NASBench201API as API | ||||
| # from naswot.score_networks import get_nasbench201_idx_score | ||||
| # from naswot import datasets as dt | ||||
| # from naswot import nasspace | ||||
|  | ||||
| # class Args(): | ||||
| #     pass | ||||
| # args = Args() | ||||
| # args.trainval = True | ||||
| # args.augtype = 'none' | ||||
| # args.repeat = 1 | ||||
| # args.score = 'hook_logdet' | ||||
| # args.sigma = 0.05 | ||||
| # args.nasspace = 'nasbench201' | ||||
| # args.batch_size = 128 | ||||
| # args.GPU = '0' | ||||
| # args.dataset = 'cifar10' | ||||
| # args.api_loc = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' | ||||
| # args.data_loc = '../cifardata/' | ||||
| # args.seed = 777 | ||||
| # args.init = '' | ||||
| # args.save_loc = 'results' | ||||
| # args.save_string = 'naswot' | ||||
| # args.dropout = False | ||||
| # args.maxofn = 1 | ||||
| # args.n_samples = 100 | ||||
| # args.n_runs = 500 | ||||
| # args.stem_out_channels = 16 | ||||
| # args.num_stacks = 3 | ||||
| # args.num_modules_per_stack = 3 | ||||
| # args.num_labels = 1 | ||||
| # searchspace = nasspace.get_search_space(args) | ||||
| # train_loader = dt.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args) | ||||
| # device = torch.device('cuda:2') | ||||
|  | ||||
|  | ||||
| source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' | ||||
| api = API(source) | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| # 示例百分数列表,精确到小数点后两位 | ||||
| # percentages = [5.12, 15.78, 25.43, 35.22, 45.99, 55.34, 65.12, 75.68, 85.99, 95.25, 23.45, 12.34, 37.89, 58.67, 64.23, 72.15, 81.76, 99.99, 42.11, 61.58, 77.34, 14.56] | ||||
| percentages = [] | ||||
|  | ||||
| len_201 = 15625 | ||||
|  | ||||
| for i in range(len_201): | ||||
|     # percentage = get_nasbench201_idx_score(i, train_loader, searchspace, args, device) | ||||
|     results = api.query_by_index(i, 'cifar10') | ||||
|     result = results[111].get_eval('ori-test') | ||||
|     percentages.append(result) | ||||
|  | ||||
| # 定义10%区间 | ||||
| bins = [i for i in range(0, 101, 10)] | ||||
|  | ||||
| # 对数据进行分箱,计算每个区间的数据量 | ||||
| hist, bin_edges = pd.cut(percentages, bins=bins, right=False, retbins=True, include_lowest=True) | ||||
| bin_counts = hist.value_counts().sort_index() | ||||
|  | ||||
| total_counts = len(percentages) | ||||
| percentages_in_bins = (bin_counts / total_counts) * 100 | ||||
|  | ||||
| # 绘制条形图 | ||||
| plt.figure(figsize=(10, 6)) | ||||
| bars = plt.bar(bin_counts.index.astype(str), bin_counts.values, width=0.9, color='skyblue') | ||||
|  | ||||
| for bar, percentage in zip(bars, percentages_in_bins): | ||||
|     plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height(), | ||||
|             f'{percentage:.2f}%', ha='center', va='bottom') | ||||
|  | ||||
| # 添加标题和标签 | ||||
| plt.title('Distribution of Percentages in 10% Intervals') | ||||
| plt.xlabel('Percentage Interval') | ||||
| plt.ylabel('Count') | ||||
|  | ||||
| # 显示图表 | ||||
| plt.xticks(rotation=45) | ||||
| plt.savefig('barplog.png') | ||||
|  | ||||
| @@ -1,4 +1,5 @@ | ||||
| # These imports are tricky because they use c++, do not move them | ||||
| from tqdm import tqdm | ||||
| import os, shutil | ||||
| import warnings | ||||
|  | ||||
| @@ -144,10 +145,32 @@ def main(cfg: DictConfig): | ||||
|     else: | ||||
|         trainer.test(model, datamodule=datamodule, ckpt_path=cfg.general.test_only) | ||||
|  | ||||
| from accelerate import Accelerator | ||||
| from accelerate.utils import set_seed, ProjectConfiguration | ||||
|  | ||||
| @hydra.main( | ||||
|     version_base="1.1", config_path="../configs", config_name="config" | ||||
| ) | ||||
| def test(cfg: DictConfig): | ||||
|     os.environ["CUDA_VISIBLE_DEVICES"] = cfg.general.gpu_number | ||||
|     accelerator_config = ProjectConfiguration( | ||||
|         project_dir=os.path.join(cfg.general.log_dir, cfg.general.name), | ||||
|         automatic_checkpoint_naming=True, | ||||
|         total_limit=cfg.general.number_checkpoint_limit, | ||||
|     ) | ||||
|     accelerator = Accelerator( | ||||
|         mixed_precision='no', | ||||
|         project_config=accelerator_config, | ||||
|         # gradient_accumulation_steps=cfg.train.gradient_accumulation_steps * cfg.train.n_epochs, | ||||
|         gradient_accumulation_steps=cfg.train.gradient_accumulation_steps,  | ||||
|     ) | ||||
|  | ||||
|     # Debug: 确认可用设备 | ||||
|     print(f"Available GPUs: {torch.cuda.device_count()}") | ||||
|     print(f"Using device: {accelerator.device}") | ||||
|  | ||||
|     set_seed(cfg.train.seed, device_specific=True) | ||||
|  | ||||
|     datamodule = dataset.DataModule(cfg) | ||||
|     datamodule.prepare_data() | ||||
|     dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg, dataset=datamodule.dataset) | ||||
| @@ -169,40 +192,216 @@ def test(cfg: DictConfig): | ||||
|         "visualization_tools": visulization_tools, | ||||
|     } | ||||
|  | ||||
|     # Debug: 确认可用设备 | ||||
|     print(f"Available GPUs: {torch.cuda.device_count()}") | ||||
|     print(f"Using device: {accelerator.device}") | ||||
|  | ||||
|     if cfg.general.test_only: | ||||
|         cfg, _ = get_resume(cfg, model_kwargs) | ||||
|         os.chdir(cfg.general.test_only.split("checkpoints")[0]) | ||||
|     elif cfg.general.resume is not None: | ||||
|         cfg, _ = get_resume_adaptive(cfg, model_kwargs) | ||||
|         os.chdir(cfg.general.resume.split("checkpoints")[0]) | ||||
|     # os.environ["CUDA_VISIBLE_DEVICES"] = cfg.general.gpu_number | ||||
|     model = Graph_DiT(cfg=cfg, **model_kwargs) | ||||
|     trainer = Trainer( | ||||
|         gradient_clip_val=cfg.train.clip_grad, | ||||
|         # accelerator="cpu", | ||||
|         accelerator="gpu" | ||||
|         if torch.cuda.is_available() and cfg.general.gpus > 0 | ||||
|         else "cpu", | ||||
|         devices=[cfg.general.gpu_number] | ||||
|         if torch.cuda.is_available() and cfg.general.gpus > 0 | ||||
|         else None, | ||||
|         max_epochs=cfg.train.n_epochs, | ||||
|         enable_checkpointing=False, | ||||
|         check_val_every_n_epoch=cfg.train.check_val_every_n_epoch, | ||||
|         val_check_interval=cfg.train.val_check_interval, | ||||
|         strategy="ddp" if cfg.general.gpus > 1 else "auto", | ||||
|         enable_progress_bar=cfg.general.enable_progress_bar, | ||||
|         callbacks=[], | ||||
|         reload_dataloaders_every_n_epochs=0, | ||||
|         logger=[], | ||||
|     ) | ||||
|     graph_dit_model = model | ||||
|  | ||||
|     if not cfg.general.test_only: | ||||
|         print("start testing fit method") | ||||
|         trainer.fit(model, datamodule=datamodule, ckpt_path=cfg.general.resume) | ||||
|         if cfg.general.save_model: | ||||
|             trainer.save_checkpoint(f"checkpoints/{cfg.general.name}/last.ckpt") | ||||
|         trainer.test(model, datamodule=datamodule) | ||||
|     inference_dtype = torch.float32 | ||||
|     graph_dit_model.to(accelerator.device, dtype=inference_dtype) | ||||
|  | ||||
|  | ||||
|     # optional: freeze the model | ||||
|     # graph_dit_model.model.requires_grad_(True) | ||||
|  | ||||
|     import torch.nn.functional as F | ||||
|     optimizer = graph_dit_model.configure_optimizers() | ||||
|     train_dataloader = accelerator.prepare(datamodule.train_dataloader()) | ||||
|     optimizer, graph_dit_model = accelerator.prepare(optimizer, graph_dit_model) | ||||
|     # start training | ||||
|     for epoch in range(cfg.train.n_epochs): | ||||
|         graph_dit_model.train()  # 设置模型为训练模式 | ||||
|         print(f"Epoch {epoch}", end="\n") | ||||
|         graph_dit_model.on_train_epoch_start() | ||||
|         for data in train_dataloader:  # 从数据加载器中获取一个批次的数据 | ||||
|             # data.to(accelerator.device) | ||||
|             # data_x = F.one_hot(data.x, num_classes=12).float()[:, graph_dit_model.active_index] | ||||
|             # data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float() | ||||
|             # dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, graph_dit_model.max_n_nodes) | ||||
|             # dense_data = dense_data.mask(node_mask) | ||||
|             # X, E = dense_data.X, dense_data.E | ||||
|             # noisy_data = graph_dit_model.apply_noise(X, E, data.y, node_mask) | ||||
|             # pred = graph_dit_model.forward(noisy_data) | ||||
|             # loss = graph_dit_model.train_loss(masked_pred_X=pred.X, masked_pred_E=pred.E, pred_y=pred.y, | ||||
|             #                     true_X=X, true_E=E, true_y=data.y, node_mask=node_mask, | ||||
|             #                     log=epoch % graph_dit_model.log_every_steps == 0) | ||||
|             # # print(f'training loss: {loss}, epoch: {self.current_epoch}, batch: {i}\n, pred type: {type(pred)}, pred.X shape: {type(pred.X)}, {pred.X.shape}, pred.E shape: {type(pred.E)}, {pred.E.shape}') | ||||
|             # graph_dit_model.train_metrics(masked_pred_X=pred.X, masked_pred_E=pred.E, true_X=X, true_E=E, | ||||
|             #                 log=epoch % graph_dit_model.log_every_steps == 0) | ||||
|             # graph_dit_model.log(f'loss', loss, batch_size=X.size(0), sync_dist=True) | ||||
|             # print(f"training loss: {loss}") | ||||
|             # with open("training-loss.csv", "a") as f: | ||||
|             #     f.write(f"{loss}, {epoch}\n") | ||||
|             loss = graph_dit_model.training_step(data, epoch) | ||||
|             loss = loss['loss'] | ||||
|  | ||||
|             accelerator.backward(loss) | ||||
|             optimizer.step() | ||||
|             optimizer.zero_grad() | ||||
|             # return {'loss': loss} | ||||
|         graph_dit_model.on_train_epoch_end() | ||||
|         if epoch % cfg.train.check_val_every_n_epoch == 0: | ||||
|             print(f'print validation loss') | ||||
|             graph_dit_model.eval() | ||||
|             graph_dit_model.on_validation_epoch_start() | ||||
|             graph_dit_model.validation_step(data, epoch) | ||||
|             graph_dit_model.on_validation_epoch_end() | ||||
|      | ||||
|     # start testing | ||||
|     print("start testing") | ||||
|     graph_dit_model.eval() | ||||
|     test_dataloader = accelerator.prepare(datamodule.test_dataloader()) | ||||
|     graph_dit_model.on_test_epoch_start() | ||||
|     for data in test_dataloader: | ||||
|         nll = graph_dit_model.test_step(data, epoch) | ||||
|         # data_x = F.one_hot(data.x, num_classes=12).float()[:, graph_dit_model.active_index] | ||||
|         # data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float() | ||||
|  | ||||
|         # dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, graph_dit_model.max_n_nodes) | ||||
|         # dense_data = dense_data.mask(node_mask) | ||||
|         # noisy_data = graph_dit_model.apply_noise(dense_data.X, dense_data.E, data.y, node_mask) | ||||
|         # pred = graph_dit_model.forward(noisy_data) | ||||
|         # nll = graph_dit_model.compute_val_loss(pred, noisy_data, dense_data.X, dense_data.E, data.y, node_mask, test=True) | ||||
|         # graph_dit_model.test_y_collection.append(data.y) | ||||
|         print(f'test loss: {nll}') | ||||
|      | ||||
|     graph_dit_model.on_test_epoch_end() | ||||
|  | ||||
|     # start sampling | ||||
|  | ||||
|     # samples_left_to_generate = cfg.general.final_model_samples_to_generate | ||||
|     # samples_left_to_save = cfg.general.final_model_samples_to_save | ||||
|     # chains_left_to_save = cfg.general.final_model_chains_to_save | ||||
|  | ||||
|     # samples, all_ys, batch_id = [], [], 0 | ||||
|     # samples_with_log_probs = [] | ||||
|     # test_y_collection = torch.cat(graph_dit_model.test_y_collection, dim=0) | ||||
|     # num_examples = test_y_collection.size(0) | ||||
|     # if cfg.general.final_model_samples_to_generate > num_examples: | ||||
|     #     ratio = cfg.general.final_model_samples_to_generate // num_examples | ||||
|     #     test_y_collection = test_y_collection.repeat(ratio+1, 1) | ||||
|     #     num_examples = test_y_collection.size(0) | ||||
|      | ||||
|     # Normal reward function | ||||
|     # from nas_201_api import NASBench201API as API | ||||
|     # api = API('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth') | ||||
|     # def graph_reward_fn(graphs, true_graphs=None, device=None, reward_model='swap'): | ||||
|     #     rewards = [] | ||||
|     #     if reward_model == 'swap': | ||||
|     #         import csv | ||||
|     #         with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results.csv', 'r') as f: | ||||
|     #             reader = csv.reader(f) | ||||
|     #             header = next(reader) | ||||
|     #             data = [row for row in reader] | ||||
|     #             swap_scores = [float(row[0]) for row in data] | ||||
|     #             for graph in graphs: | ||||
|     #                 node_tensor = graph[0] | ||||
|     #                 node = node_tensor.cpu().numpy().tolist() | ||||
|  | ||||
|     #                 def nodes_to_arch_str(nodes): | ||||
|     #                     num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output'] | ||||
|     #                     nodes_str = [num_to_op[node] for node in nodes] | ||||
|     #                     arch_str = '|' + nodes_str[1] + '~0|+' + \ | ||||
|     #                             '|' + nodes_str[2] + '~0|' + nodes_str[3] + '~1|+' +\ | ||||
|     #                             '|' + nodes_str[4] + '~0|' + nodes_str[5] + '~1|' + nodes_str[6] + '~2|'  | ||||
|     #                     return arch_str | ||||
|                      | ||||
|     #                 arch_str = nodes_to_arch_str(node) | ||||
|     #                 reward = swap_scores[api.query_index_by_arch(arch_str)] | ||||
|     #                 rewards.append(reward) | ||||
|                  | ||||
|     #     # for graph in graphs: | ||||
|     #     #     reward = 1.0 | ||||
|     #     #     rewards.append(reward) | ||||
|     #     return torch.tensor(rewards, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device) | ||||
|     # old_log_probs = None | ||||
|     # while samples_left_to_generate > 0: | ||||
|     #     print(f'samples left to generate: {samples_left_to_generate}/' | ||||
|     #         f'{cfg.general.final_model_samples_to_generate}', end='', flush=True) | ||||
|     #     bs = 1 * cfg.train.batch_size | ||||
|     #     to_generate = min(samples_left_to_generate, bs) | ||||
|     #     to_save = min(samples_left_to_save, bs) | ||||
|     #     chains_save = min(chains_left_to_save, bs) | ||||
|     #     # batch_y = test_y_collection[batch_id : batch_id + to_generate] | ||||
|     #     batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device) | ||||
|  | ||||
|     #     cur_sample, log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save, | ||||
|     #                                     keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps) | ||||
|     #     log_probs = torch.sum(log_probs, dim=-1).unsqueeze(1) | ||||
|     #     samples = samples + cur_sample | ||||
|     #     reward = graph_reward_fn(cur_sample, device=graph_dit_model.device) | ||||
|     #     advantages = (reward - torch.mean(reward)) / (torch.std(reward) + 1e-6) | ||||
|     #     print(f'reward: {reward.shape}, advantages: {advantages.shape}, log_probs: {log_probs.shape}, cur_sample: {len(cur_sample)}') | ||||
|     #     if old_log_probs is None: | ||||
|     #         old_log_probs = log_probs.clone() | ||||
|     #     ratio = torch.exp(log_probs - old_log_probs) | ||||
|     #     unclipped_loss = -advantages * ratio | ||||
|     #     clipped_loss = -advantages * torch.clamp(ratio, 1.0 - cfg.ppo.clip_param, 1.0 + cfg.ppo.clip_param) | ||||
|     #     loss = torch.mean(torch.max(unclipped_loss, clipped_loss)) | ||||
|     #     accelerator.backward(loss) | ||||
|     #     optimizer.step() | ||||
|     #     optimizer.zero_grad() | ||||
|  | ||||
|  | ||||
|     #     samples_with_log_probs.append((cur_sample, log_probs, reward)) | ||||
|          | ||||
|     #     all_ys.append(batch_y) | ||||
|     #     batch_id += to_generate | ||||
|  | ||||
|     #     samples_left_to_save -= to_save | ||||
|     #     samples_left_to_generate -= to_generate | ||||
|     #     chains_left_to_save -= chains_save | ||||
|          | ||||
|     # print(f"final Computing sampling metrics...") | ||||
|     # graph_dit_model.sampling_metrics.reset() | ||||
|     # graph_dit_model.sampling_metrics(samples, all_ys, graph_dit_model.name, graph_dit_model.current_epoch, graph_dit_model.val_counter, test=True) | ||||
|     # graph_dit_model.sampling_metrics.reset() | ||||
|     # print(f"Done.") | ||||
|  | ||||
|     # # save samples | ||||
|     # print("Samples:") | ||||
|     # print(samples) | ||||
|  | ||||
|     # ======================== | ||||
|      | ||||
|  | ||||
|      | ||||
|  | ||||
|  | ||||
|     # trainer = Trainer( | ||||
|     #     gradient_clip_val=cfg.train.clip_grad, | ||||
|     #     # accelerator="cpu", | ||||
|     #     accelerator="gpu" | ||||
|     #     if torch.cuda.is_available() and cfg.general.gpus > 0 | ||||
|     #     else "cpu", | ||||
|     #     devices=[cfg.general.gpu_number] | ||||
|     #     if torch.cuda.is_available() and cfg.general.gpus > 0 | ||||
|     #     else None, | ||||
|     #     max_epochs=cfg.train.n_epochs, | ||||
|     #     enable_checkpointing=False, | ||||
|     #     check_val_every_n_epoch=cfg.train.check_val_every_n_epoch, | ||||
|     #     val_check_interval=cfg.train.val_check_interval, | ||||
|     #     strategy="ddp" if cfg.general.gpus > 1 else "auto", | ||||
|     #     enable_progress_bar=cfg.general.enable_progress_bar, | ||||
|     #     callbacks=[], | ||||
|     #     reload_dataloaders_every_n_epochs=0, | ||||
|     #     logger=[], | ||||
|     # ) | ||||
|  | ||||
|     # if not cfg.general.test_only: | ||||
|     #     print("start testing fit method") | ||||
|     #     trainer.fit(model, datamodule=datamodule, ckpt_path=cfg.general.resume) | ||||
|     #     if cfg.general.save_model: | ||||
|     #         trainer.save_checkpoint(f"checkpoints/{cfg.general.name}/last.ckpt") | ||||
|     #     trainer.test(model, datamodule=datamodule) | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     test() | ||||
|   | ||||
| @@ -76,6 +76,8 @@ class CategoricalEmbedder(nn.Module): | ||||
|             embeddings = embeddings + noise | ||||
|         return embeddings | ||||
|      | ||||
| # 相似的condition cluster起来 | ||||
| # size  | ||||
| class ClusterContinuousEmbedder(nn.Module): | ||||
|     def __init__(self, input_size, hidden_size, dropout_prob): | ||||
|         super().__init__() | ||||
| @@ -108,6 +110,8 @@ class ClusterContinuousEmbedder(nn.Module): | ||||
|          | ||||
|         if drop_ids is not None: | ||||
|             embeddings = torch.zeros((labels.shape[0], self.hidden_size), device=labels.device) | ||||
|             # print(labels[~drop_ids].shape) | ||||
|             # torch.Size([1200]) | ||||
|             embeddings[~drop_ids] = self.mlp(labels[~drop_ids]) | ||||
|             embeddings[drop_ids] += self.embedding_drop.weight[0] | ||||
|         else: | ||||
|   | ||||
| @@ -17,20 +17,22 @@ class Denoiser(nn.Module): | ||||
|         num_heads=16, | ||||
|         mlp_ratio=4.0, | ||||
|         drop_condition=0.1, | ||||
|         Xdim=118, | ||||
|         Edim=5, | ||||
|         ydim=3, | ||||
|         Xdim=7, | ||||
|         Edim=2, | ||||
|         ydim=1, | ||||
|         task_type='regression', | ||||
|     ): | ||||
|         super().__init__() | ||||
|         print(f"Denoiser, xdim: {Xdim}, edim: {Edim}, ydim: {ydim}, hidden_size: {hidden_size}, depth: {depth}, num_heads: {num_heads}, mlp_ratio: {mlp_ratio}, drop_condition: {drop_condition}") | ||||
|         self.num_heads = num_heads | ||||
|         self.ydim = ydim | ||||
|         self.x_embedder = nn.Linear(Xdim + max_n_nodes * Edim, hidden_size, bias=False) | ||||
|  | ||||
|         self.t_embedder = TimestepEmbedder(hidden_size) | ||||
|         #  | ||||
|         self.y_embedding_list = torch.nn.ModuleList() | ||||
|  | ||||
|         self.y_embedding_list.append(ClusterContinuousEmbedder(2, hidden_size, drop_condition)) | ||||
|         self.y_embedding_list.append(ClusterContinuousEmbedder(1, hidden_size, drop_condition)) | ||||
|         for i in range(ydim - 2): | ||||
|             if task_type == 'regression': | ||||
|                 self.y_embedding_list.append(ClusterContinuousEmbedder(1, hidden_size, drop_condition)) | ||||
| @@ -88,6 +90,8 @@ class Denoiser(nn.Module): | ||||
|          | ||||
|         # print("Denoiser Forward") | ||||
|         # print(x.shape, e.shape, y.shape, t.shape, unconditioned) | ||||
|         # torch.Size([1200, 8, 7]) torch.Size([1200, 8, 8, 2]) torch.Size([1200, 2]) torch.Size([1200, 1]) False | ||||
|         # print(y) | ||||
|         force_drop_id = torch.zeros_like(y.sum(-1)) | ||||
|         # drop the nan values | ||||
|         force_drop_id[torch.isnan(y.sum(-1))] = 1 | ||||
| @@ -109,11 +113,12 @@ class Denoiser(nn.Module): | ||||
|         c1 = self.t_embedder(t) | ||||
|         # print("C1 after t_embedder") | ||||
|         # print(c1.shape) | ||||
|         for i in range(1, self.ydim): | ||||
|             if i == 1: | ||||
|                 c2 = self.y_embedding_list[i-1](y[:, :2], self.training, force_drop_id, t) | ||||
|             else: | ||||
|                 c2 = c2 + self.y_embedding_list[i-1](y[:, i:i+1], self.training, force_drop_id, t) | ||||
|         c2 = self.y_embedding_list[0](y[:,0].unsqueeze(-1), self.training, force_drop_id, t) | ||||
|         # for i in range(1, self.ydim): | ||||
|         #     if i == 1: | ||||
|         #         c2 = self.y_embedding_list[i-1](y[:, :2], self.training, force_drop_id, t) | ||||
|         #     else: | ||||
|                 # c2 = c2 + self.y_embedding_list[i-1](y[:, i:i+1], self.training, force_drop_id, t) | ||||
|         # print("C2 after y_embedding_list") | ||||
|         # print(c2.shape) | ||||
|         # print("C1 + C2") | ||||
|   | ||||
							
								
								
									
										1
									
								
								graph_dit/nasbench-201-meta.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								graph_dit/nasbench-201-meta.json
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | ||||
| {"source": "nasbench-201", "num_graph": 15625, "n_nodes_per_graph": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], "max_n_nodes": 8, "max_n_edges": 8, "node_type_list": [0.125, 0.15, 0.15, 0.15, 0.15, 0.15, 0.125, 0.0], "edge_type_list": [0.6666666666666666, 0.3333333333333333], "valencies": [0.125, 0.15, 0.15, 0.15, 0.15, 0.15, 0.125, 0.0], "active_nodes": ["*", "input", "nor_conv_1x1", "nor_conv_3x3", "avg_pool_3x3", "skip_connect", "none"], "num_active_nodes": 7, "transition_E": [[[1.0, 0.0], [0.5, 0.5], [0.5, 0.5], [0.5, 0.5], [0.5, 0.5], [0.5, 0.5], [1.0, 0.0], [1.0, 0.0]], [[0.5, 0.5], [0.7333333333333333, 0.26666666666666666], [0.7333333333333333, 0.26666666666666666], [0.7333333333333333, 0.26666666666666666], [0.7333333333333333, 0.26666666666666666], [0.7333333333333333, 0.26666666666666666], [0.5, 0.5], [1.0, 0.0]], [[0.5, 0.5], [0.7333333333333333, 0.26666666666666666], [0.7333333333333333, 0.26666666666666666], [0.7333333333333333, 0.26666666666666666], [0.7333333333333333, 0.26666666666666666], [0.7333333333333333, 0.26666666666666666], [0.5, 0.5], [1.0, 0.0]], [[0.5, 0.5], [0.7333333333333333, 0.26666666666666666], [0.7333333333333333, 0.26666666666666666], [0.7333333333333333, 0.26666666666666666], [0.7333333333333333, 0.26666666666666666], [0.7333333333333333, 0.26666666666666666], [0.5, 0.5], [1.0, 0.0]], [[0.5, 0.5], [0.7333333333333333, 0.26666666666666666], [0.7333333333333333, 0.26666666666666666], [0.7333333333333333, 0.26666666666666666], [0.7333333333333333, 0.26666666666666666], [0.7333333333333333, 0.26666666666666666], [0.5, 0.5], [1.0, 0.0]], [[0.5, 0.5], [0.7333333333333333, 0.26666666666666666], [0.7333333333333333, 0.26666666666666666], [0.7333333333333333, 0.26666666666666666], [0.7333333333333333, 0.26666666666666666], [0.7333333333333333, 0.26666666666666666], [0.5, 0.5], [1.0, 0.0]], [[1.0, 0.0], [0.5, 0.5], [0.5, 0.5], [0.5, 0.5], [0.5, 0.5], [0.5, 0.5], [1.0, 0.0], [1.0, 0.0]], [[1.0, 0.0], [1.0, 0.0], [1.0, 0.0], [1.0, 0.0], [1.0, 0.0], [1.0, 0.0], [1.0, 0.0], [1.0, 0.0]]]} | ||||
							
								
								
									
										15626
									
								
								graph_dit/swap_results_aircraft.csv
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15626
									
								
								graph_dit/swap_results_aircraft.csv
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										144
									
								
								graph_dit/test_perf.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										144
									
								
								graph_dit/test_perf.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,144 @@ | ||||
| from nas_201_api import NASBench201API as API | ||||
| import re | ||||
| import pandas as pd | ||||
| import json | ||||
| import numpy as np | ||||
| import argparse | ||||
|  | ||||
| api = API('./NAS-Bench-201-v1_1-096897.pth') | ||||
|  | ||||
| parser = argparse.ArgumentParser(description='Process some integers.') | ||||
|  | ||||
| parser.add_argument('--file_path', type=str, default='211035.txt',) | ||||
| parser.add_argument('--datasets', type=str, default='cifar10',) | ||||
| args = parser.parse_args() | ||||
|  | ||||
| def process_graph_data(text): | ||||
|     # Split the input text into sections for each graph | ||||
|     graph_sections = text.strip().split('nodes:') | ||||
|      | ||||
|     # Prepare lists to store data | ||||
|     nodes_list = [] | ||||
|     edges_list = [] | ||||
|     results_list = [] | ||||
|      | ||||
|     for section in graph_sections[1:]: | ||||
|         # Extract nodes | ||||
|         nodes_section = section.split('edges:')[0] | ||||
|         nodes_match = re.search(r'(tensor\(\d+\) ?)+', section) | ||||
|         if nodes_match: | ||||
|             nodes = re.findall(r'tensor\((\d+)\)', nodes_match.group(0)) | ||||
|             nodes_list.append(nodes) | ||||
|          | ||||
|         # Extract edges | ||||
|         edge_section = section.split('edges:')[1] | ||||
|         edges_match = re.search(r'edges:', section) | ||||
|         if edges_match: | ||||
|             edges = re.findall(r'tensor\((\d+)\)', edge_section) | ||||
|             edges_list.append(edges) | ||||
|          | ||||
|         # Extract the last floating point number as a result | ||||
|      | ||||
|     # Create a DataFrame to store the extracted data | ||||
|     data = { | ||||
|         'nodes': nodes_list, | ||||
|         'edges': edges_list, | ||||
|     } | ||||
|     data['nodes'] = [[int(x) for x in node] for node in data['nodes']] | ||||
|     data['edges'] = [[int(x) for x in edge] for edge in data['edges']] | ||||
|     def split_list(input_list, chunk_size): | ||||
|         return [input_list[i:i + chunk_size] for i in range(0, len(input_list), chunk_size)] | ||||
|     data['edges'] = [split_list(edge, 8) for edge in data['edges']] | ||||
|  | ||||
|     print(data) | ||||
|     df = pd.DataFrame(data) | ||||
|     print('df') | ||||
|     print(df['nodes'][0], df['edges'][0]) | ||||
|     return df | ||||
|  | ||||
| def is_valid_nasbench201(adj, ops): | ||||
|     print(ops) | ||||
|     if ops[0] != 0 or ops[-1] != 6: | ||||
|         return False | ||||
|     for i in range(2, len(ops) - 1): | ||||
|         if ops[i] not in [1, 2, 3, 4, 5]: | ||||
|             return False | ||||
|     adj_mat = [ [0, 1, 1, 0, 1, 0, 0, 0], | ||||
|                 [0, 0, 0, 1, 0, 1 ,0 ,0], | ||||
|                 [0, 0, 0, 0, 0, 0, 1, 0], | ||||
|                 [0, 0, 0, 0, 0, 0, 1, 0], | ||||
|                 [0, 0, 0, 0, 0, 0, 0, 1], | ||||
|                 [0, 0, 0, 0, 0, 0, 0, 1], | ||||
|                 [0, 0, 0, 0, 0, 0, 0, 1], | ||||
|                 [0, 0, 0, 0, 0, 0, 0, 0]] | ||||
|   | ||||
|     for i in range(len(adj)): | ||||
|         for j in range(len(adj[i])): | ||||
|             if adj[i][j] not in [0, 1]: | ||||
|                 return False | ||||
|             if j > i: | ||||
|                 if adj[i][j] != adj_mat[i][j]: | ||||
|                     return False | ||||
|     return True | ||||
|  | ||||
| num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output'] | ||||
| def nodes_to_arch_str(nodes): | ||||
|     nodes_str = [num_to_op[node] for node in nodes] | ||||
|     arch_str = '|' + nodes_str[1] + '~0|+' + \ | ||||
|                '|' + nodes_str[2] + '~0|' + nodes_str[3] + '~1|+' +\ | ||||
|                '|' + nodes_str[4] + '~0|' + nodes_str[5] + '~1|' + nodes_str[6] + '~2|'  | ||||
|     return arch_str | ||||
|  | ||||
| filename = args.file_path | ||||
| datasets_name = args.datasets | ||||
|  | ||||
| with open('./output_graphs/' + filename, 'r') as f: | ||||
|     texts = f.read() | ||||
|     df = process_graph_data(texts) | ||||
|     valid = 0 | ||||
|     not_valid = 0 | ||||
|     scores = [] | ||||
|  | ||||
|     # 定义分类标准和分布字典的映射 | ||||
|     thresholds = { | ||||
|         'cifar10': [90, 91, 92, 93, 94], | ||||
|         'cifar100': [68,69,70, 71, 72, 73] | ||||
|     } | ||||
|     dist = {f'<{threshold}': 0 for threshold in thresholds[datasets_name]} | ||||
|     dist[f'>{thresholds[datasets_name][-1]}'] = 0 | ||||
|  | ||||
|     for i in range(len(df)): | ||||
|         nodes = df['nodes'][i] | ||||
|         edges = df['edges'][i] | ||||
|         result = is_valid_nasbench201(edges, nodes) | ||||
|         if result: | ||||
|             valid += 1 | ||||
|             arch_str = nodes_to_arch_str(nodes) | ||||
|             index = api.query_index_by_arch(arch_str) | ||||
|             res = api.get_more_info(index, datasets_name, None, hp=200, is_random=False) | ||||
|             acc = res['test-accuracy'] | ||||
|             scores.append((index, acc)) | ||||
|  | ||||
|             # 根据阈值更新分布 | ||||
|             updated = False | ||||
|             for threshold in thresholds[datasets_name]: | ||||
|                 if acc < threshold: | ||||
|                     dist[f'<{threshold}'] += 1 | ||||
|                     updated = True | ||||
|                     break | ||||
|             if not updated: | ||||
|                 dist[f'>{thresholds[datasets_name][-1]}'] += 1 | ||||
|         else: | ||||
|             not_valid += 1 | ||||
|  | ||||
|     with open('./output_graphs/' + filename + '_' + datasets_name +'.json', 'w') as f: | ||||
|         json.dump(scores, f) | ||||
|  | ||||
|     print(scores) | ||||
|     print(valid, not_valid) | ||||
|     print(dist) | ||||
|     print("mean: ", np.mean([x[1] for x in scores])) | ||||
|     print("max: ", np.max([x[1] for x in scores])) | ||||
|     print("min: ", np.min([x[1] for x in scores])) | ||||
|  | ||||
|          | ||||
		Reference in New Issue
	
	Block a user