Compare commits
	
		
			19 Commits
		
	
	
		
			6d9db64a48
			...
			nasbench
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 91d4e3c7ad | |||
| c867aef5a6 | |||
| 1ad520d248 | |||
| 74a629fdcc | |||
| 94fe13756f | |||
| 2ac17caa3c | |||
| 0c60171c71 | |||
| 97fbdf91c7 | |||
| 297261d666 | |||
| 5dccf590e7 | |||
| 0c4b597dd2 | |||
| 11d9697e06 | |||
| 244b159c26 | |||
| 63ca6c716e | |||
| d36e1d1077 | |||
| 82183d3df7 | |||
| c86db9b6ba | |||
| a0473008a1 | |||
| 05ee34e355 | 
							
								
								
									
										24
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										24
									
								
								README.md
									
									
									
									
									
								
							| @@ -1,14 +1,34 @@ | |||||||
| Graph Diffusion Transformer for Multi-Conditional Molecular Generation | Graph Diffusion Transformer for Multi-Conditional Molecular Generation | ||||||
| ================================================================ | ================================================================ | ||||||
|  |  | ||||||
|  | ## Initial Setup | ||||||
|  |  | ||||||
|  | Please download NASBench201 dataset(NAS-Bench-201-v1_1-096897.pth) from | ||||||
|  | https://drive.google.com/file/d/16Y0UwGisiouVRxW-W5hEtbxmcHw_0hF_/view | ||||||
|  |  | ||||||
|  | and put it in the `/path/to/repo/graph_dit` folder. | ||||||
|  |  | ||||||
|  | ## Running the code | ||||||
|  |  | ||||||
|  | start command: | ||||||
|  | ``` bash | ||||||
|  | python main.py --config-name=config.yaml \ | ||||||
|  | model.ensure_connected=True \ | ||||||
|  | dataset.task_name='nasbench201' \ | ||||||
|  | dataset.guidance_target='regression' | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | This repository contains the code for the paper "Inverse Molecular Design with Multi-Conditional Diffusion Guidance" by Gang Liu, Jiaxin Xu, Tengfei Luo, and Meng Jiang. | ||||||
|  |  | ||||||
|  |  | ||||||
| Paper: https://arxiv.org/abs/2401.13858 | Paper: https://arxiv.org/abs/2401.13858 | ||||||
|  |  | ||||||
| This is the code for Graph DiT. The denoising model architecture in `graph_dit/models` looks like: | <!-- This is the code for Graph DiT. The denoising model architecture in `graph_dit/models` looks like: | ||||||
|  |  | ||||||
| <div style="display: flex;" markdown="1"> | <div style="display: flex;" markdown="1"> | ||||||
|       <img src="asset/reverse.png" style="width: 45%;" alt="Description of the first image"> |       <img src="asset/reverse.png" style="width: 45%;" alt="Description of the first image"> | ||||||
|       <img src="asset/arch.png" style="width: 45%;" alt="Description of the second image"> |       <img src="asset/arch.png" style="width: 45%;" alt="Description of the second image"> | ||||||
| </div> | </div> --> | ||||||
|  |  | ||||||
|  |  | ||||||
| ## Requirements | ## Requirements | ||||||
|   | |||||||
| @@ -771,9 +771,10 @@ class Dataset(InMemoryDataset): | |||||||
|             edge_type = torch.tensor(edge_type, dtype=torch.long) |             edge_type = torch.tensor(edge_type, dtype=torch.long) | ||||||
|             edge_attr = edge_type |             edge_attr = edge_type | ||||||
|             # y = torch.tensor([0, 0], dtype=torch.float).view(1, -1) |             # y = torch.tensor([0, 0], dtype=torch.float).view(1, -1) | ||||||
|             y = get_nasbench201_idx_score(idx, train_loader, searchspace, args, device) |             # y = get_nasbench201_idx_score(idx, train_loader, searchspace, args, device) | ||||||
|  |             y = self.swap_scores[idx] | ||||||
|             print(y, idx) |             print(y, idx) | ||||||
|             if y > 1600: |             if y > 60000: | ||||||
|                 print(f'idx={idx}, y={y}') |                 print(f'idx={idx}, y={y}') | ||||||
|                 y = torch.tensor([1, 1], dtype=torch.float).view(1, -1) |                 y = torch.tensor([1, 1], dtype=torch.float).view(1, -1) | ||||||
|                 data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i) |                 data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i) | ||||||
| @@ -812,6 +813,14 @@ class Dataset(InMemoryDataset): | |||||||
|         args.num_labels = 1 |         args.num_labels = 1 | ||||||
|         searchspace = nasspace.get_search_space(args) |         searchspace = nasspace.get_search_space(args) | ||||||
|         train_loader = dt.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args) |         train_loader = dt.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args) | ||||||
|  |         self.swap_scores = [] | ||||||
|  |         import csv | ||||||
|  |         # with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results.csv', 'r') as f: | ||||||
|  |         with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results_cifar100.csv', 'r') as f: | ||||||
|  |             reader = csv.reader(f) | ||||||
|  |             header = next(reader) | ||||||
|  |             data = [row for row in reader] | ||||||
|  |             self.swap_scores = [float(row[0]) for row in data] | ||||||
|         device = torch.device('cuda:2') |         device = torch.device('cuda:2') | ||||||
|         with tqdm(total = len_data) as pbar: |         with tqdm(total = len_data) as pbar: | ||||||
|             active_nodes = set() |             active_nodes = set() | ||||||
| @@ -823,14 +832,8 @@ class Dataset(InMemoryDataset): | |||||||
|             flex_graph_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json' |             flex_graph_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json' | ||||||
|             for graph in graph_list: |             for graph in graph_list: | ||||||
|                 print(f'iterate every graph in graph_list, here is {i}') |                 print(f'iterate every graph in graph_list, here is {i}') | ||||||
|                 # arch_info = self.api.query_meta_info_by_index(i) |  | ||||||
|                 # results = self.api.query_by_index(i, 'cifar100') |  | ||||||
|                 arch_info = graph['arch_str'] |                 arch_info = graph['arch_str'] | ||||||
|                 # results =  |  | ||||||
|                 # nodes, edges = parse_architecture_string(arch_info.arch_str) |  | ||||||
|                 # ops, adj_matrix = parse_architecture_string(arch_info.arch_str, padding=4) |  | ||||||
|                 ops, adj_matrix, ori_nodes, ori_adj = parse_architecture_string(arch_info, padding=4) |                 ops, adj_matrix, ori_nodes, ori_adj = parse_architecture_string(arch_info, padding=4) | ||||||
|                 # adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges) |  | ||||||
|                 for op in ops: |                 for op in ops: | ||||||
|                     if op not in active_nodes: |                     if op not in active_nodes: | ||||||
|                         active_nodes.add(op) |                         active_nodes.add(op) | ||||||
| @@ -839,12 +842,6 @@ class Dataset(InMemoryDataset): | |||||||
|                 if data is None: |                 if data is None: | ||||||
|                     pbar.update(1) |                     pbar.update(1) | ||||||
|                     continue |                     continue | ||||||
|                 # with open(flex_graph_path, 'a') as f: |  | ||||||
|                 #     flex_graph = { |  | ||||||
|                 #         'adj_matrix': adj_matrix, |  | ||||||
|                 #         'ops': ops, |  | ||||||
|                 #     } |  | ||||||
|                 #     json.dump(flex_graph, f) |  | ||||||
|                 flex_graph_list.append({ |                 flex_graph_list.append({ | ||||||
|                     'adj_matrix':adj_matrix, |                     'adj_matrix':adj_matrix, | ||||||
|                     'ops': ops, |                     'ops': ops, | ||||||
|   | |||||||
| @@ -195,15 +195,18 @@ class Graph_DiT(pl.LightningModule): | |||||||
|         # print("Size of the input features Xdim {}, Edim {}, ydim {}".format(self.Xdim, self.Edim, self.ydim)) |         # print("Size of the input features Xdim {}, Edim {}, ydim {}".format(self.Xdim, self.Edim, self.ydim)) | ||||||
|  |  | ||||||
|     def on_train_epoch_start(self) -> None: |     def on_train_epoch_start(self) -> None: | ||||||
|         if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: |         # if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: | ||||||
|             print("Starting train epoch {}/{}...".format(self.current_epoch, self.trainer.max_epochs)) |         if self.current_epoch / self.cfg.train.n_epochs in [0.25, 0.5, 0.75, 1.0]: | ||||||
|  |             # print("Starting train epoch {}/{}...".format(self.current_epoch, self.trainer.max_epochs)) | ||||||
|  |             print("Starting train epoch {}/{}...".format(self.current_epoch, self.cfg.train.n_epochs)) | ||||||
|         self.start_epoch_time = time.time() |         self.start_epoch_time = time.time() | ||||||
|         self.train_loss.reset() |         self.train_loss.reset() | ||||||
|         self.train_metrics.reset() |         self.train_metrics.reset() | ||||||
|  |  | ||||||
|     def on_train_epoch_end(self) -> None: |     def on_train_epoch_end(self) -> None: | ||||||
|  |  | ||||||
|         if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: |         # if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: | ||||||
|  |         if self.current_epoch / self.cfg.train.n_epochs in [0.25, 0.5, 0.75, 1.0]: | ||||||
|             log = True |             log = True | ||||||
|         else: |         else: | ||||||
|             log = False |             log = False | ||||||
| @@ -239,8 +242,8 @@ class Graph_DiT(pl.LightningModule): | |||||||
|          |          | ||||||
|                    self.val_X_logp.compute(), self.val_E_logp.compute()] |                    self.val_X_logp.compute(), self.val_E_logp.compute()] | ||||||
|          |          | ||||||
|         if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: |         # if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: | ||||||
|             print(f"Epoch {self.current_epoch}: Val NLL {metrics[0] :.2f} -- Val Atom type KL {metrics[1] :.2f} -- ", |         print(f"Epoch {self.current_epoch}: Val NLL {metrics[0] :.2f} -- Val Atom type KL {metrics[1] :.2f} -- ", | ||||||
|                 f"Val Edge type KL: {metrics[2] :.2f}", 'Val loss: %.2f \t Best :  %.2f\n' % (metrics[0], self.best_val_nll)) |                 f"Val Edge type KL: {metrics[2] :.2f}", 'Val loss: %.2f \t Best :  %.2f\n' % (metrics[0], self.best_val_nll)) | ||||||
|         with open("validation-metrics.csv", "a") as f: |         with open("validation-metrics.csv", "a") as f: | ||||||
|             # save the metrics as csv file |             # save the metrics as csv file | ||||||
| @@ -286,7 +289,7 @@ class Graph_DiT(pl.LightningModule): | |||||||
|                 samples.extend(self.sample_batch(batch_id=ident, batch_size=to_generate, y=batch_y, |                 samples.extend(self.sample_batch(batch_id=ident, batch_size=to_generate, y=batch_y, | ||||||
|                                                 save_final=to_save, |                                                 save_final=to_save, | ||||||
|                                                 keep_chain=chains_save, |                                                 keep_chain=chains_save, | ||||||
|                                                 number_chain_steps=self.number_chain_steps)) |                                                 number_chain_steps=self.number_chain_steps)[0]) | ||||||
|                 ident += to_generate |                 ident += to_generate | ||||||
|                 start_index += to_generate |                 start_index += to_generate | ||||||
|  |  | ||||||
| @@ -360,7 +363,7 @@ class Graph_DiT(pl.LightningModule): | |||||||
|             batch_y = torch.ones(to_generate, self.ydim_output, device=self.device) |             batch_y = torch.ones(to_generate, self.ydim_output, device=self.device) | ||||||
|  |  | ||||||
|             cur_sample = self.sample_batch(batch_id, to_generate, batch_y, save_final=to_save, |             cur_sample = self.sample_batch(batch_id, to_generate, batch_y, save_final=to_save, | ||||||
|                                             keep_chain=chains_save, number_chain_steps=self.number_chain_steps) |                                             keep_chain=chains_save, number_chain_steps=self.number_chain_steps)[0] | ||||||
|             samples = samples + cur_sample |             samples = samples + cur_sample | ||||||
|              |              | ||||||
|             all_ys.append(batch_y) |             all_ys.append(batch_y) | ||||||
| @@ -601,6 +604,9 @@ class Graph_DiT(pl.LightningModule): | |||||||
|  |  | ||||||
|         assert (E == torch.transpose(E, 1, 2)).all() |         assert (E == torch.transpose(E, 1, 2)).all() | ||||||
|  |  | ||||||
|  |         total_log_probs = torch.zeros([self.cfg.general.final_model_samples_to_generate,10], device=self.device) | ||||||
|  |         # total_log_probs = torch.zeros([self.cfg.general.samples_to_generate,10], device=self.device) | ||||||
|  |  | ||||||
|         # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1. |         # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1. | ||||||
|         for s_int in reversed(range(0, self.T)): |         for s_int in reversed(range(0, self.T)): | ||||||
|             s_array = s_int * torch.ones((batch_size, 1)).type_as(y) |             s_array = s_int * torch.ones((batch_size, 1)).type_as(y) | ||||||
| @@ -609,21 +615,24 @@ class Graph_DiT(pl.LightningModule): | |||||||
|             t_norm = t_array / self.T |             t_norm = t_array / self.T | ||||||
|  |  | ||||||
|             # Sample z_s |             # Sample z_s | ||||||
|             sampled_s, discrete_sampled_s = self.sample_p_zs_given_zt(s_norm, t_norm, X, E, y, node_mask) |             sampled_s, discrete_sampled_s, log_probs= self.sample_p_zs_given_zt(s_norm, t_norm, X, E, y, node_mask) | ||||||
|             X, E, y = sampled_s.X, sampled_s.E, sampled_s.y |             X, E, y = sampled_s.X, sampled_s.E, sampled_s.y | ||||||
|  |             print(f'sampled_s.X shape: {sampled_s.X.shape}, sampled_s.E shape: {sampled_s.E.shape}') | ||||||
|  |             print(f'log_probs shape: {log_probs.shape}') | ||||||
|  |             total_log_probs += log_probs | ||||||
|  |  | ||||||
|         # Sample |         # Sample | ||||||
|         sampled_s = sampled_s.mask(node_mask, collapse=True) |         sampled_s = sampled_s.mask(node_mask, collapse=True) | ||||||
|         X, E, y = sampled_s.X, sampled_s.E, sampled_s.y |         X, E, y = sampled_s.X, sampled_s.E, sampled_s.y | ||||||
|          |          | ||||||
|         molecule_list = [] |         graph_list = [] | ||||||
|         for i in range(batch_size): |         for i in range(batch_size): | ||||||
|             n = n_nodes[i] |             n = n_nodes[i] | ||||||
|             atom_types = X[i, :n].cpu() |             node_types = X[i, :n].cpu() | ||||||
|             edge_types = E[i, :n, :n].cpu() |             edge_types = E[i, :n, :n].cpu() | ||||||
|             molecule_list.append([atom_types, edge_types]) |             graph_list.append([node_types, edge_types]) | ||||||
|          |          | ||||||
|         return molecule_list |         return graph_list, total_log_probs | ||||||
|  |  | ||||||
|     def sample_p_zs_given_zt(self, s, t, X_t, E_t, y_t, node_mask): |     def sample_p_zs_given_zt(self, s, t, X_t, E_t, y_t, node_mask): | ||||||
|         """Samples from zs ~ p(zs | zt). Only used during sampling. |         """Samples from zs ~ p(zs | zt). Only used during sampling. | ||||||
| @@ -635,6 +644,7 @@ class Graph_DiT(pl.LightningModule): | |||||||
|  |  | ||||||
|         # Neural net predictions |         # Neural net predictions | ||||||
|         noisy_data = {'X_t': X_t, 'E_t': E_t, 'y_t': y_t, 't': t, 'node_mask': node_mask} |         noisy_data = {'X_t': X_t, 'E_t': E_t, 'y_t': y_t, 't': t, 'node_mask': node_mask} | ||||||
|  |         print(f"sample p zs given zt X_t shape: {X_t.shape}, E_t shape: {E_t.shape}, y_t shape: {y_t.shape}, node_mask shape: {node_mask.shape}") | ||||||
|          |          | ||||||
|         def get_prob(noisy_data, unconditioned=False): |         def get_prob(noisy_data, unconditioned=False): | ||||||
|             pred = self.forward(noisy_data, unconditioned=unconditioned) |             pred = self.forward(noisy_data, unconditioned=unconditioned) | ||||||
| @@ -674,7 +684,19 @@ class Graph_DiT(pl.LightningModule): | |||||||
|         # with condition = P_t(G_{t-1} |G_t, C) |         # with condition = P_t(G_{t-1} |G_t, C) | ||||||
|         # with condition = P_t(A_{t-1} |A_t, y) |         # with condition = P_t(A_{t-1} |A_t, y) | ||||||
|         prob_X, prob_E, pred = get_prob(noisy_data) |         prob_X, prob_E, pred = get_prob(noisy_data) | ||||||
|  |         print(f'prob_X shape: {prob_X.shape}, prob_E shape: {prob_E.shape}') | ||||||
|  |         print(f'X_t shape: {X_t.shape}, E_t shape: {E_t.shape}, y_t shape: {y_t.shape}') | ||||||
|  |         print(f'X_t: {X_t}') | ||||||
|  |         log_prob_X = torch.log(torch.gather(prob_X, -1, X_t.long()).squeeze(-1))  # bs, n | ||||||
|  |         log_prob_E = torch.log(torch.gather(prob_E, -1, E_t.long()).squeeze(-1))  # bs, n, n | ||||||
|  |  | ||||||
|  |         # Sum the log_prob across dimensions for total log_prob | ||||||
|  |         log_prob_X = log_prob_X.sum(dim=-1) | ||||||
|  |         log_prob_E = log_prob_E.sum(dim=(1, 2)) | ||||||
|  |         print(f'log_prob_X shape: {log_prob_X.shape}, log_prob_E shape: {log_prob_E.shape}') | ||||||
|  |         # log_probs = log_prob_E + log_prob_X | ||||||
|  |         log_probs = torch.cat([log_prob_X, log_prob_E], dim=-1)  # (batch_size, 2) | ||||||
|  |         print(f'log_probs shape: {log_probs.shape}') | ||||||
|         ### Guidance |         ### Guidance | ||||||
|         if self.guidance_target is not None and self.guide_scale is not None and self.guide_scale != 1: |         if self.guidance_target is not None and self.guide_scale is not None and self.guide_scale != 1: | ||||||
|             uncon_prob_X, uncon_prob_E, pred = get_prob(noisy_data, unconditioned=True) |             uncon_prob_X, uncon_prob_E, pred = get_prob(noisy_data, unconditioned=True) | ||||||
| @@ -810,4 +832,4 @@ class Graph_DiT(pl.LightningModule): | |||||||
|         out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=y_t) |         out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=y_t) | ||||||
|         out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=y_t) |         out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=y_t) | ||||||
|  |  | ||||||
|         return out_one_hot.mask(node_mask).type_as(y_t), out_discrete.mask(node_mask, collapse=True).type_as(y_t) |         return out_one_hot.mask(node_mask).type_as(y_t), out_discrete.mask(node_mask, collapse=True).type_as(y_t), log_probs | ||||||
|   | |||||||
							
								
								
									
										
											BIN
										
									
								
								graph_dit/exp_201/barplog.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								graph_dit/exp_201/barplog.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 30 KiB | 
| @@ -2,44 +2,45 @@ | |||||||
| import matplotlib.pyplot as plt | import matplotlib.pyplot as plt | ||||||
| import pandas as pd | import pandas as pd | ||||||
| from nas_201_api import NASBench201API as API | from nas_201_api import NASBench201API as API | ||||||
| from naswot.score_networks import get_nasbench201_idx_score | # from naswot.score_networks import get_nasbench201_idx_score | ||||||
| from naswot import datasets as dt | # from naswot import datasets as dt | ||||||
| from naswot import nasspace | # from naswot import nasspace | ||||||
|  |  | ||||||
| class Args(): | # class Args(): | ||||||
|     pass | #     pass | ||||||
| args = Args() | # args = Args() | ||||||
| args.trainval = True | # args.trainval = True | ||||||
| args.augtype = 'none' | # args.augtype = 'none' | ||||||
| args.repeat = 1 | # args.repeat = 1 | ||||||
| args.score = 'hook_logdet' | # args.score = 'hook_logdet' | ||||||
| args.sigma = 0.05 | # args.sigma = 0.05 | ||||||
| args.nasspace = 'nasbench201' | # args.nasspace = 'nasbench201' | ||||||
| args.batch_size = 128 | # args.batch_size = 128 | ||||||
| args.GPU = '0' | # args.GPU = '0' | ||||||
| args.dataset = 'cifar10' | # args.dataset = 'cifar10' | ||||||
| args.api_loc = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' | # args.api_loc = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' | ||||||
| args.data_loc = '../cifardata/' | # args.data_loc = '../cifardata/' | ||||||
| args.seed = 777 | # args.seed = 777 | ||||||
| args.init = '' | # args.init = '' | ||||||
| args.save_loc = 'results' | # args.save_loc = 'results' | ||||||
| args.save_string = 'naswot' | # args.save_string = 'naswot' | ||||||
| args.dropout = False | # args.dropout = False | ||||||
| args.maxofn = 1 | # args.maxofn = 1 | ||||||
| args.n_samples = 100 | # args.n_samples = 100 | ||||||
| args.n_runs = 500 | # args.n_runs = 500 | ||||||
| args.stem_out_channels = 16 | # args.stem_out_channels = 16 | ||||||
| args.num_stacks = 3 | # args.num_stacks = 3 | ||||||
| args.num_modules_per_stack = 3 | # args.num_modules_per_stack = 3 | ||||||
| args.num_labels = 1 | # args.num_labels = 1 | ||||||
| searchspace = nasspace.get_search_space(args) | # searchspace = nasspace.get_search_space(args) | ||||||
| train_loader = dt.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args) | # train_loader = dt.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args) | ||||||
| device = torch.device('cuda:2') | # device = torch.device('cuda:2') | ||||||
|  |  | ||||||
|  |  | ||||||
|  | source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' | ||||||
|  | api = API(source) | ||||||
|  |  | ||||||
|  |  | ||||||
|           |  | ||||||
| # source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' |  | ||||||
| # api = API(source) |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -50,8 +51,10 @@ percentages = [] | |||||||
| len_201 = 15625 | len_201 = 15625 | ||||||
|  |  | ||||||
| for i in range(len_201): | for i in range(len_201): | ||||||
|     percentage = get_nasbench201_idx_score(i, train_loader, searchspace, args, device) |     # percentage = get_nasbench201_idx_score(i, train_loader, searchspace, args, device) | ||||||
|     percentages.append(percentage) |     results = api.query_by_index(i, 'cifar10') | ||||||
|  |     result = results[111].get_eval('ori-test') | ||||||
|  |     percentages.append(result) | ||||||
|  |  | ||||||
| # 定义10%区间 | # 定义10%区间 | ||||||
| bins = [i for i in range(0, 101, 10)] | bins = [i for i in range(0, 101, 10)] | ||||||
|   | |||||||
| @@ -1,4 +1,5 @@ | |||||||
| # These imports are tricky because they use c++, do not move them | # These imports are tricky because they use c++, do not move them | ||||||
|  | from tqdm import tqdm | ||||||
| import os, shutil | import os, shutil | ||||||
| import warnings | import warnings | ||||||
|  |  | ||||||
| @@ -144,10 +145,32 @@ def main(cfg: DictConfig): | |||||||
|     else: |     else: | ||||||
|         trainer.test(model, datamodule=datamodule, ckpt_path=cfg.general.test_only) |         trainer.test(model, datamodule=datamodule, ckpt_path=cfg.general.test_only) | ||||||
|  |  | ||||||
|  | from accelerate import Accelerator | ||||||
|  | from accelerate.utils import set_seed, ProjectConfiguration | ||||||
|  |  | ||||||
| @hydra.main( | @hydra.main( | ||||||
|     version_base="1.1", config_path="../configs", config_name="config" |     version_base="1.1", config_path="../configs", config_name="config" | ||||||
| ) | ) | ||||||
| def test(cfg: DictConfig): | def test(cfg: DictConfig): | ||||||
|  |     os.environ["CUDA_VISIBLE_DEVICES"] = cfg.general.gpu_number | ||||||
|  |     accelerator_config = ProjectConfiguration( | ||||||
|  |         project_dir=os.path.join(cfg.general.log_dir, cfg.general.name), | ||||||
|  |         automatic_checkpoint_naming=True, | ||||||
|  |         total_limit=cfg.general.number_checkpoint_limit, | ||||||
|  |     ) | ||||||
|  |     accelerator = Accelerator( | ||||||
|  |         mixed_precision='no', | ||||||
|  |         project_config=accelerator_config, | ||||||
|  |         # gradient_accumulation_steps=cfg.train.gradient_accumulation_steps * cfg.train.n_epochs, | ||||||
|  |         gradient_accumulation_steps=cfg.train.gradient_accumulation_steps,  | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     # Debug: 确认可用设备 | ||||||
|  |     print(f"Available GPUs: {torch.cuda.device_count()}") | ||||||
|  |     print(f"Using device: {accelerator.device}") | ||||||
|  |  | ||||||
|  |     set_seed(cfg.train.seed, device_specific=True) | ||||||
|  |  | ||||||
|     datamodule = dataset.DataModule(cfg) |     datamodule = dataset.DataModule(cfg) | ||||||
|     datamodule.prepare_data() |     datamodule.prepare_data() | ||||||
|     dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg, dataset=datamodule.dataset) |     dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg, dataset=datamodule.dataset) | ||||||
| @@ -169,40 +192,216 @@ def test(cfg: DictConfig): | |||||||
|         "visualization_tools": visulization_tools, |         "visualization_tools": visulization_tools, | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     # Debug: 确认可用设备 | ||||||
|  |     print(f"Available GPUs: {torch.cuda.device_count()}") | ||||||
|  |     print(f"Using device: {accelerator.device}") | ||||||
|  |  | ||||||
|     if cfg.general.test_only: |     if cfg.general.test_only: | ||||||
|         cfg, _ = get_resume(cfg, model_kwargs) |         cfg, _ = get_resume(cfg, model_kwargs) | ||||||
|         os.chdir(cfg.general.test_only.split("checkpoints")[0]) |         os.chdir(cfg.general.test_only.split("checkpoints")[0]) | ||||||
|     elif cfg.general.resume is not None: |     elif cfg.general.resume is not None: | ||||||
|         cfg, _ = get_resume_adaptive(cfg, model_kwargs) |         cfg, _ = get_resume_adaptive(cfg, model_kwargs) | ||||||
|         os.chdir(cfg.general.resume.split("checkpoints")[0]) |         os.chdir(cfg.general.resume.split("checkpoints")[0]) | ||||||
|     # os.environ["CUDA_VISIBLE_DEVICES"] = cfg.general.gpu_number |  | ||||||
|     model = Graph_DiT(cfg=cfg, **model_kwargs) |     model = Graph_DiT(cfg=cfg, **model_kwargs) | ||||||
|     trainer = Trainer( |     graph_dit_model = model | ||||||
|         gradient_clip_val=cfg.train.clip_grad, |  | ||||||
|         # accelerator="cpu", |  | ||||||
|         accelerator="gpu" |  | ||||||
|         if torch.cuda.is_available() and cfg.general.gpus > 0 |  | ||||||
|         else "cpu", |  | ||||||
|         devices=[cfg.general.gpu_number] |  | ||||||
|         if torch.cuda.is_available() and cfg.general.gpus > 0 |  | ||||||
|         else None, |  | ||||||
|         max_epochs=cfg.train.n_epochs, |  | ||||||
|         enable_checkpointing=False, |  | ||||||
|         check_val_every_n_epoch=cfg.train.check_val_every_n_epoch, |  | ||||||
|         val_check_interval=cfg.train.val_check_interval, |  | ||||||
|         strategy="ddp" if cfg.general.gpus > 1 else "auto", |  | ||||||
|         enable_progress_bar=cfg.general.enable_progress_bar, |  | ||||||
|         callbacks=[], |  | ||||||
|         reload_dataloaders_every_n_epochs=0, |  | ||||||
|         logger=[], |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     if not cfg.general.test_only: |     inference_dtype = torch.float32 | ||||||
|         print("start testing fit method") |     graph_dit_model.to(accelerator.device, dtype=inference_dtype) | ||||||
|         trainer.fit(model, datamodule=datamodule, ckpt_path=cfg.general.resume) |  | ||||||
|         if cfg.general.save_model: |  | ||||||
|             trainer.save_checkpoint(f"checkpoints/{cfg.general.name}/last.ckpt") |     # optional: freeze the model | ||||||
|         trainer.test(model, datamodule=datamodule) |     # graph_dit_model.model.requires_grad_(True) | ||||||
|  |  | ||||||
|  |     import torch.nn.functional as F | ||||||
|  |     optimizer = graph_dit_model.configure_optimizers() | ||||||
|  |     train_dataloader = accelerator.prepare(datamodule.train_dataloader()) | ||||||
|  |     optimizer, graph_dit_model = accelerator.prepare(optimizer, graph_dit_model) | ||||||
|  |     # start training | ||||||
|  |     for epoch in range(cfg.train.n_epochs): | ||||||
|  |         graph_dit_model.train()  # 设置模型为训练模式 | ||||||
|  |         print(f"Epoch {epoch}", end="\n") | ||||||
|  |         graph_dit_model.on_train_epoch_start() | ||||||
|  |         for data in train_dataloader:  # 从数据加载器中获取一个批次的数据 | ||||||
|  |             # data.to(accelerator.device) | ||||||
|  |             # data_x = F.one_hot(data.x, num_classes=12).float()[:, graph_dit_model.active_index] | ||||||
|  |             # data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float() | ||||||
|  |             # dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, graph_dit_model.max_n_nodes) | ||||||
|  |             # dense_data = dense_data.mask(node_mask) | ||||||
|  |             # X, E = dense_data.X, dense_data.E | ||||||
|  |             # noisy_data = graph_dit_model.apply_noise(X, E, data.y, node_mask) | ||||||
|  |             # pred = graph_dit_model.forward(noisy_data) | ||||||
|  |             # loss = graph_dit_model.train_loss(masked_pred_X=pred.X, masked_pred_E=pred.E, pred_y=pred.y, | ||||||
|  |             #                     true_X=X, true_E=E, true_y=data.y, node_mask=node_mask, | ||||||
|  |             #                     log=epoch % graph_dit_model.log_every_steps == 0) | ||||||
|  |             # # print(f'training loss: {loss}, epoch: {self.current_epoch}, batch: {i}\n, pred type: {type(pred)}, pred.X shape: {type(pred.X)}, {pred.X.shape}, pred.E shape: {type(pred.E)}, {pred.E.shape}') | ||||||
|  |             # graph_dit_model.train_metrics(masked_pred_X=pred.X, masked_pred_E=pred.E, true_X=X, true_E=E, | ||||||
|  |             #                 log=epoch % graph_dit_model.log_every_steps == 0) | ||||||
|  |             # graph_dit_model.log(f'loss', loss, batch_size=X.size(0), sync_dist=True) | ||||||
|  |             # print(f"training loss: {loss}") | ||||||
|  |             # with open("training-loss.csv", "a") as f: | ||||||
|  |             #     f.write(f"{loss}, {epoch}\n") | ||||||
|  |             loss = graph_dit_model.training_step(data, epoch) | ||||||
|  |             loss = loss['loss'] | ||||||
|  |  | ||||||
|  |             accelerator.backward(loss) | ||||||
|  |             optimizer.step() | ||||||
|  |             optimizer.zero_grad() | ||||||
|  |             # return {'loss': loss} | ||||||
|  |         graph_dit_model.on_train_epoch_end() | ||||||
|  |         if epoch % cfg.train.check_val_every_n_epoch == 0: | ||||||
|  |             print(f'print validation loss') | ||||||
|  |             graph_dit_model.eval() | ||||||
|  |             graph_dit_model.on_validation_epoch_start() | ||||||
|  |             graph_dit_model.validation_step(data, epoch) | ||||||
|  |             graph_dit_model.on_validation_epoch_end() | ||||||
|  |      | ||||||
|  |     # start testing | ||||||
|  |     print("start testing") | ||||||
|  |     graph_dit_model.eval() | ||||||
|  |     test_dataloader = accelerator.prepare(datamodule.test_dataloader()) | ||||||
|  |     graph_dit_model.on_test_epoch_start() | ||||||
|  |     for data in test_dataloader: | ||||||
|  |         nll = graph_dit_model.test_step(data, epoch) | ||||||
|  |         # data_x = F.one_hot(data.x, num_classes=12).float()[:, graph_dit_model.active_index] | ||||||
|  |         # data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float() | ||||||
|  |  | ||||||
|  |         # dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, graph_dit_model.max_n_nodes) | ||||||
|  |         # dense_data = dense_data.mask(node_mask) | ||||||
|  |         # noisy_data = graph_dit_model.apply_noise(dense_data.X, dense_data.E, data.y, node_mask) | ||||||
|  |         # pred = graph_dit_model.forward(noisy_data) | ||||||
|  |         # nll = graph_dit_model.compute_val_loss(pred, noisy_data, dense_data.X, dense_data.E, data.y, node_mask, test=True) | ||||||
|  |         # graph_dit_model.test_y_collection.append(data.y) | ||||||
|  |         print(f'test loss: {nll}') | ||||||
|  |      | ||||||
|  |     graph_dit_model.on_test_epoch_end() | ||||||
|  |  | ||||||
|  |     # start sampling | ||||||
|  |  | ||||||
|  |     # samples_left_to_generate = cfg.general.final_model_samples_to_generate | ||||||
|  |     # samples_left_to_save = cfg.general.final_model_samples_to_save | ||||||
|  |     # chains_left_to_save = cfg.general.final_model_chains_to_save | ||||||
|  |  | ||||||
|  |     # samples, all_ys, batch_id = [], [], 0 | ||||||
|  |     # samples_with_log_probs = [] | ||||||
|  |     # test_y_collection = torch.cat(graph_dit_model.test_y_collection, dim=0) | ||||||
|  |     # num_examples = test_y_collection.size(0) | ||||||
|  |     # if cfg.general.final_model_samples_to_generate > num_examples: | ||||||
|  |     #     ratio = cfg.general.final_model_samples_to_generate // num_examples | ||||||
|  |     #     test_y_collection = test_y_collection.repeat(ratio+1, 1) | ||||||
|  |     #     num_examples = test_y_collection.size(0) | ||||||
|  |      | ||||||
|  |     # Normal reward function | ||||||
|  |     # from nas_201_api import NASBench201API as API | ||||||
|  |     # api = API('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth') | ||||||
|  |     # def graph_reward_fn(graphs, true_graphs=None, device=None, reward_model='swap'): | ||||||
|  |     #     rewards = [] | ||||||
|  |     #     if reward_model == 'swap': | ||||||
|  |     #         import csv | ||||||
|  |     #         with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results.csv', 'r') as f: | ||||||
|  |     #             reader = csv.reader(f) | ||||||
|  |     #             header = next(reader) | ||||||
|  |     #             data = [row for row in reader] | ||||||
|  |     #             swap_scores = [float(row[0]) for row in data] | ||||||
|  |     #             for graph in graphs: | ||||||
|  |     #                 node_tensor = graph[0] | ||||||
|  |     #                 node = node_tensor.cpu().numpy().tolist() | ||||||
|  |  | ||||||
|  |     #                 def nodes_to_arch_str(nodes): | ||||||
|  |     #                     num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output'] | ||||||
|  |     #                     nodes_str = [num_to_op[node] for node in nodes] | ||||||
|  |     #                     arch_str = '|' + nodes_str[1] + '~0|+' + \ | ||||||
|  |     #                             '|' + nodes_str[2] + '~0|' + nodes_str[3] + '~1|+' +\ | ||||||
|  |     #                             '|' + nodes_str[4] + '~0|' + nodes_str[5] + '~1|' + nodes_str[6] + '~2|'  | ||||||
|  |     #                     return arch_str | ||||||
|  |                      | ||||||
|  |     #                 arch_str = nodes_to_arch_str(node) | ||||||
|  |     #                 reward = swap_scores[api.query_index_by_arch(arch_str)] | ||||||
|  |     #                 rewards.append(reward) | ||||||
|  |                  | ||||||
|  |     #     # for graph in graphs: | ||||||
|  |     #     #     reward = 1.0 | ||||||
|  |     #     #     rewards.append(reward) | ||||||
|  |     #     return torch.tensor(rewards, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device) | ||||||
|  |     # old_log_probs = None | ||||||
|  |     # while samples_left_to_generate > 0: | ||||||
|  |     #     print(f'samples left to generate: {samples_left_to_generate}/' | ||||||
|  |     #         f'{cfg.general.final_model_samples_to_generate}', end='', flush=True) | ||||||
|  |     #     bs = 1 * cfg.train.batch_size | ||||||
|  |     #     to_generate = min(samples_left_to_generate, bs) | ||||||
|  |     #     to_save = min(samples_left_to_save, bs) | ||||||
|  |     #     chains_save = min(chains_left_to_save, bs) | ||||||
|  |     #     # batch_y = test_y_collection[batch_id : batch_id + to_generate] | ||||||
|  |     #     batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device) | ||||||
|  |  | ||||||
|  |     #     cur_sample, log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save, | ||||||
|  |     #                                     keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps) | ||||||
|  |     #     log_probs = torch.sum(log_probs, dim=-1).unsqueeze(1) | ||||||
|  |     #     samples = samples + cur_sample | ||||||
|  |     #     reward = graph_reward_fn(cur_sample, device=graph_dit_model.device) | ||||||
|  |     #     advantages = (reward - torch.mean(reward)) / (torch.std(reward) + 1e-6) | ||||||
|  |     #     print(f'reward: {reward.shape}, advantages: {advantages.shape}, log_probs: {log_probs.shape}, cur_sample: {len(cur_sample)}') | ||||||
|  |     #     if old_log_probs is None: | ||||||
|  |     #         old_log_probs = log_probs.clone() | ||||||
|  |     #     ratio = torch.exp(log_probs - old_log_probs) | ||||||
|  |     #     unclipped_loss = -advantages * ratio | ||||||
|  |     #     clipped_loss = -advantages * torch.clamp(ratio, 1.0 - cfg.ppo.clip_param, 1.0 + cfg.ppo.clip_param) | ||||||
|  |     #     loss = torch.mean(torch.max(unclipped_loss, clipped_loss)) | ||||||
|  |     #     accelerator.backward(loss) | ||||||
|  |     #     optimizer.step() | ||||||
|  |     #     optimizer.zero_grad() | ||||||
|  |  | ||||||
|  |  | ||||||
|  |     #     samples_with_log_probs.append((cur_sample, log_probs, reward)) | ||||||
|  |          | ||||||
|  |     #     all_ys.append(batch_y) | ||||||
|  |     #     batch_id += to_generate | ||||||
|  |  | ||||||
|  |     #     samples_left_to_save -= to_save | ||||||
|  |     #     samples_left_to_generate -= to_generate | ||||||
|  |     #     chains_left_to_save -= chains_save | ||||||
|  |          | ||||||
|  |     # print(f"final Computing sampling metrics...") | ||||||
|  |     # graph_dit_model.sampling_metrics.reset() | ||||||
|  |     # graph_dit_model.sampling_metrics(samples, all_ys, graph_dit_model.name, graph_dit_model.current_epoch, graph_dit_model.val_counter, test=True) | ||||||
|  |     # graph_dit_model.sampling_metrics.reset() | ||||||
|  |     # print(f"Done.") | ||||||
|  |  | ||||||
|  |     # # save samples | ||||||
|  |     # print("Samples:") | ||||||
|  |     # print(samples) | ||||||
|  |  | ||||||
|  |     # ======================== | ||||||
|  |      | ||||||
|  |  | ||||||
|  |      | ||||||
|  |  | ||||||
|  |  | ||||||
|  |     # trainer = Trainer( | ||||||
|  |     #     gradient_clip_val=cfg.train.clip_grad, | ||||||
|  |     #     # accelerator="cpu", | ||||||
|  |     #     accelerator="gpu" | ||||||
|  |     #     if torch.cuda.is_available() and cfg.general.gpus > 0 | ||||||
|  |     #     else "cpu", | ||||||
|  |     #     devices=[cfg.general.gpu_number] | ||||||
|  |     #     if torch.cuda.is_available() and cfg.general.gpus > 0 | ||||||
|  |     #     else None, | ||||||
|  |     #     max_epochs=cfg.train.n_epochs, | ||||||
|  |     #     enable_checkpointing=False, | ||||||
|  |     #     check_val_every_n_epoch=cfg.train.check_val_every_n_epoch, | ||||||
|  |     #     val_check_interval=cfg.train.val_check_interval, | ||||||
|  |     #     strategy="ddp" if cfg.general.gpus > 1 else "auto", | ||||||
|  |     #     enable_progress_bar=cfg.general.enable_progress_bar, | ||||||
|  |     #     callbacks=[], | ||||||
|  |     #     reload_dataloaders_every_n_epochs=0, | ||||||
|  |     #     logger=[], | ||||||
|  |     # ) | ||||||
|  |  | ||||||
|  |     # if not cfg.general.test_only: | ||||||
|  |     #     print("start testing fit method") | ||||||
|  |     #     trainer.fit(model, datamodule=datamodule, ckpt_path=cfg.general.resume) | ||||||
|  |     #     if cfg.general.save_model: | ||||||
|  |     #         trainer.save_checkpoint(f"checkpoints/{cfg.general.name}/last.ckpt") | ||||||
|  |     #     trainer.test(model, datamodule=datamodule) | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     test() |     test() | ||||||
|   | |||||||
| @@ -76,6 +76,8 @@ class CategoricalEmbedder(nn.Module): | |||||||
|             embeddings = embeddings + noise |             embeddings = embeddings + noise | ||||||
|         return embeddings |         return embeddings | ||||||
|      |      | ||||||
|  | # 相似的condition cluster起来 | ||||||
|  | # size  | ||||||
| class ClusterContinuousEmbedder(nn.Module): | class ClusterContinuousEmbedder(nn.Module): | ||||||
|     def __init__(self, input_size, hidden_size, dropout_prob): |     def __init__(self, input_size, hidden_size, dropout_prob): | ||||||
|         super().__init__() |         super().__init__() | ||||||
| @@ -108,6 +110,8 @@ class ClusterContinuousEmbedder(nn.Module): | |||||||
|          |          | ||||||
|         if drop_ids is not None: |         if drop_ids is not None: | ||||||
|             embeddings = torch.zeros((labels.shape[0], self.hidden_size), device=labels.device) |             embeddings = torch.zeros((labels.shape[0], self.hidden_size), device=labels.device) | ||||||
|  |             # print(labels[~drop_ids].shape) | ||||||
|  |             # torch.Size([1200]) | ||||||
|             embeddings[~drop_ids] = self.mlp(labels[~drop_ids]) |             embeddings[~drop_ids] = self.mlp(labels[~drop_ids]) | ||||||
|             embeddings[drop_ids] += self.embedding_drop.weight[0] |             embeddings[drop_ids] += self.embedding_drop.weight[0] | ||||||
|         else: |         else: | ||||||
|   | |||||||
| @@ -17,20 +17,22 @@ class Denoiser(nn.Module): | |||||||
|         num_heads=16, |         num_heads=16, | ||||||
|         mlp_ratio=4.0, |         mlp_ratio=4.0, | ||||||
|         drop_condition=0.1, |         drop_condition=0.1, | ||||||
|         Xdim=118, |         Xdim=7, | ||||||
|         Edim=5, |         Edim=2, | ||||||
|         ydim=3, |         ydim=1, | ||||||
|         task_type='regression', |         task_type='regression', | ||||||
|     ): |     ): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|  |         print(f"Denoiser, xdim: {Xdim}, edim: {Edim}, ydim: {ydim}, hidden_size: {hidden_size}, depth: {depth}, num_heads: {num_heads}, mlp_ratio: {mlp_ratio}, drop_condition: {drop_condition}") | ||||||
|         self.num_heads = num_heads |         self.num_heads = num_heads | ||||||
|         self.ydim = ydim |         self.ydim = ydim | ||||||
|         self.x_embedder = nn.Linear(Xdim + max_n_nodes * Edim, hidden_size, bias=False) |         self.x_embedder = nn.Linear(Xdim + max_n_nodes * Edim, hidden_size, bias=False) | ||||||
|  |  | ||||||
|         self.t_embedder = TimestepEmbedder(hidden_size) |         self.t_embedder = TimestepEmbedder(hidden_size) | ||||||
|  |         #  | ||||||
|         self.y_embedding_list = torch.nn.ModuleList() |         self.y_embedding_list = torch.nn.ModuleList() | ||||||
|  |  | ||||||
|         self.y_embedding_list.append(ClusterContinuousEmbedder(2, hidden_size, drop_condition)) |         self.y_embedding_list.append(ClusterContinuousEmbedder(1, hidden_size, drop_condition)) | ||||||
|         for i in range(ydim - 2): |         for i in range(ydim - 2): | ||||||
|             if task_type == 'regression': |             if task_type == 'regression': | ||||||
|                 self.y_embedding_list.append(ClusterContinuousEmbedder(1, hidden_size, drop_condition)) |                 self.y_embedding_list.append(ClusterContinuousEmbedder(1, hidden_size, drop_condition)) | ||||||
| @@ -88,6 +90,8 @@ class Denoiser(nn.Module): | |||||||
|          |          | ||||||
|         # print("Denoiser Forward") |         # print("Denoiser Forward") | ||||||
|         # print(x.shape, e.shape, y.shape, t.shape, unconditioned) |         # print(x.shape, e.shape, y.shape, t.shape, unconditioned) | ||||||
|  |         # torch.Size([1200, 8, 7]) torch.Size([1200, 8, 8, 2]) torch.Size([1200, 2]) torch.Size([1200, 1]) False | ||||||
|  |         # print(y) | ||||||
|         force_drop_id = torch.zeros_like(y.sum(-1)) |         force_drop_id = torch.zeros_like(y.sum(-1)) | ||||||
|         # drop the nan values |         # drop the nan values | ||||||
|         force_drop_id[torch.isnan(y.sum(-1))] = 1 |         force_drop_id[torch.isnan(y.sum(-1))] = 1 | ||||||
| @@ -109,11 +113,12 @@ class Denoiser(nn.Module): | |||||||
|         c1 = self.t_embedder(t) |         c1 = self.t_embedder(t) | ||||||
|         # print("C1 after t_embedder") |         # print("C1 after t_embedder") | ||||||
|         # print(c1.shape) |         # print(c1.shape) | ||||||
|         for i in range(1, self.ydim): |         c2 = self.y_embedding_list[0](y[:,0].unsqueeze(-1), self.training, force_drop_id, t) | ||||||
|             if i == 1: |         # for i in range(1, self.ydim): | ||||||
|                 c2 = self.y_embedding_list[i-1](y[:, :2], self.training, force_drop_id, t) |         #     if i == 1: | ||||||
|             else: |         #         c2 = self.y_embedding_list[i-1](y[:, :2], self.training, force_drop_id, t) | ||||||
|                 c2 = c2 + self.y_embedding_list[i-1](y[:, i:i+1], self.training, force_drop_id, t) |         #     else: | ||||||
|  |                 # c2 = c2 + self.y_embedding_list[i-1](y[:, i:i+1], self.training, force_drop_id, t) | ||||||
|         # print("C2 after y_embedding_list") |         # print("C2 after y_embedding_list") | ||||||
|         # print(c2.shape) |         # print(c2.shape) | ||||||
|         # print("C1 + C2") |         # print("C1 + C2") | ||||||
|   | |||||||
							
								
								
									
										15626
									
								
								graph_dit/swap_results_aircraft.csv
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15626
									
								
								graph_dit/swap_results_aircraft.csv
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										144
									
								
								graph_dit/test_perf.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										144
									
								
								graph_dit/test_perf.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,144 @@ | |||||||
|  | from nas_201_api import NASBench201API as API | ||||||
|  | import re | ||||||
|  | import pandas as pd | ||||||
|  | import json | ||||||
|  | import numpy as np | ||||||
|  | import argparse | ||||||
|  |  | ||||||
|  | api = API('./NAS-Bench-201-v1_1-096897.pth') | ||||||
|  |  | ||||||
|  | parser = argparse.ArgumentParser(description='Process some integers.') | ||||||
|  |  | ||||||
|  | parser.add_argument('--file_path', type=str, default='211035.txt',) | ||||||
|  | parser.add_argument('--datasets', type=str, default='cifar10',) | ||||||
|  | args = parser.parse_args() | ||||||
|  |  | ||||||
|  | def process_graph_data(text): | ||||||
|  |     # Split the input text into sections for each graph | ||||||
|  |     graph_sections = text.strip().split('nodes:') | ||||||
|  |      | ||||||
|  |     # Prepare lists to store data | ||||||
|  |     nodes_list = [] | ||||||
|  |     edges_list = [] | ||||||
|  |     results_list = [] | ||||||
|  |      | ||||||
|  |     for section in graph_sections[1:]: | ||||||
|  |         # Extract nodes | ||||||
|  |         nodes_section = section.split('edges:')[0] | ||||||
|  |         nodes_match = re.search(r'(tensor\(\d+\) ?)+', section) | ||||||
|  |         if nodes_match: | ||||||
|  |             nodes = re.findall(r'tensor\((\d+)\)', nodes_match.group(0)) | ||||||
|  |             nodes_list.append(nodes) | ||||||
|  |          | ||||||
|  |         # Extract edges | ||||||
|  |         edge_section = section.split('edges:')[1] | ||||||
|  |         edges_match = re.search(r'edges:', section) | ||||||
|  |         if edges_match: | ||||||
|  |             edges = re.findall(r'tensor\((\d+)\)', edge_section) | ||||||
|  |             edges_list.append(edges) | ||||||
|  |          | ||||||
|  |         # Extract the last floating point number as a result | ||||||
|  |      | ||||||
|  |     # Create a DataFrame to store the extracted data | ||||||
|  |     data = { | ||||||
|  |         'nodes': nodes_list, | ||||||
|  |         'edges': edges_list, | ||||||
|  |     } | ||||||
|  |     data['nodes'] = [[int(x) for x in node] for node in data['nodes']] | ||||||
|  |     data['edges'] = [[int(x) for x in edge] for edge in data['edges']] | ||||||
|  |     def split_list(input_list, chunk_size): | ||||||
|  |         return [input_list[i:i + chunk_size] for i in range(0, len(input_list), chunk_size)] | ||||||
|  |     data['edges'] = [split_list(edge, 8) for edge in data['edges']] | ||||||
|  |  | ||||||
|  |     print(data) | ||||||
|  |     df = pd.DataFrame(data) | ||||||
|  |     print('df') | ||||||
|  |     print(df['nodes'][0], df['edges'][0]) | ||||||
|  |     return df | ||||||
|  |  | ||||||
|  | def is_valid_nasbench201(adj, ops): | ||||||
|  |     print(ops) | ||||||
|  |     if ops[0] != 0 or ops[-1] != 6: | ||||||
|  |         return False | ||||||
|  |     for i in range(2, len(ops) - 1): | ||||||
|  |         if ops[i] not in [1, 2, 3, 4, 5]: | ||||||
|  |             return False | ||||||
|  |     adj_mat = [ [0, 1, 1, 0, 1, 0, 0, 0], | ||||||
|  |                 [0, 0, 0, 1, 0, 1 ,0 ,0], | ||||||
|  |                 [0, 0, 0, 0, 0, 0, 1, 0], | ||||||
|  |                 [0, 0, 0, 0, 0, 0, 1, 0], | ||||||
|  |                 [0, 0, 0, 0, 0, 0, 0, 1], | ||||||
|  |                 [0, 0, 0, 0, 0, 0, 0, 1], | ||||||
|  |                 [0, 0, 0, 0, 0, 0, 0, 1], | ||||||
|  |                 [0, 0, 0, 0, 0, 0, 0, 0]] | ||||||
|  |   | ||||||
|  |     for i in range(len(adj)): | ||||||
|  |         for j in range(len(adj[i])): | ||||||
|  |             if adj[i][j] not in [0, 1]: | ||||||
|  |                 return False | ||||||
|  |             if j > i: | ||||||
|  |                 if adj[i][j] != adj_mat[i][j]: | ||||||
|  |                     return False | ||||||
|  |     return True | ||||||
|  |  | ||||||
|  | num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output'] | ||||||
|  | def nodes_to_arch_str(nodes): | ||||||
|  |     nodes_str = [num_to_op[node] for node in nodes] | ||||||
|  |     arch_str = '|' + nodes_str[1] + '~0|+' + \ | ||||||
|  |                '|' + nodes_str[2] + '~0|' + nodes_str[3] + '~1|+' +\ | ||||||
|  |                '|' + nodes_str[4] + '~0|' + nodes_str[5] + '~1|' + nodes_str[6] + '~2|'  | ||||||
|  |     return arch_str | ||||||
|  |  | ||||||
|  | filename = args.file_path | ||||||
|  | datasets_name = args.datasets | ||||||
|  |  | ||||||
|  | with open('./output_graphs/' + filename, 'r') as f: | ||||||
|  |     texts = f.read() | ||||||
|  |     df = process_graph_data(texts) | ||||||
|  |     valid = 0 | ||||||
|  |     not_valid = 0 | ||||||
|  |     scores = [] | ||||||
|  |  | ||||||
|  |     # 定义分类标准和分布字典的映射 | ||||||
|  |     thresholds = { | ||||||
|  |         'cifar10': [90, 91, 92, 93, 94], | ||||||
|  |         'cifar100': [68,69,70, 71, 72, 73] | ||||||
|  |     } | ||||||
|  |     dist = {f'<{threshold}': 0 for threshold in thresholds[datasets_name]} | ||||||
|  |     dist[f'>{thresholds[datasets_name][-1]}'] = 0 | ||||||
|  |  | ||||||
|  |     for i in range(len(df)): | ||||||
|  |         nodes = df['nodes'][i] | ||||||
|  |         edges = df['edges'][i] | ||||||
|  |         result = is_valid_nasbench201(edges, nodes) | ||||||
|  |         if result: | ||||||
|  |             valid += 1 | ||||||
|  |             arch_str = nodes_to_arch_str(nodes) | ||||||
|  |             index = api.query_index_by_arch(arch_str) | ||||||
|  |             res = api.get_more_info(index, datasets_name, None, hp=200, is_random=False) | ||||||
|  |             acc = res['test-accuracy'] | ||||||
|  |             scores.append((index, acc)) | ||||||
|  |  | ||||||
|  |             # 根据阈值更新分布 | ||||||
|  |             updated = False | ||||||
|  |             for threshold in thresholds[datasets_name]: | ||||||
|  |                 if acc < threshold: | ||||||
|  |                     dist[f'<{threshold}'] += 1 | ||||||
|  |                     updated = True | ||||||
|  |                     break | ||||||
|  |             if not updated: | ||||||
|  |                 dist[f'>{thresholds[datasets_name][-1]}'] += 1 | ||||||
|  |         else: | ||||||
|  |             not_valid += 1 | ||||||
|  |  | ||||||
|  |     with open('./output_graphs/' + filename + '_' + datasets_name +'.json', 'w') as f: | ||||||
|  |         json.dump(scores, f) | ||||||
|  |  | ||||||
|  |     print(scores) | ||||||
|  |     print(valid, not_valid) | ||||||
|  |     print(dist) | ||||||
|  |     print("mean: ", np.mean([x[1] for x in scores])) | ||||||
|  |     print("max: ", np.max([x[1] for x in scores])) | ||||||
|  |     print("min: ", np.min([x[1] for x in scores])) | ||||||
|  |  | ||||||
|  |          | ||||||
		Reference in New Issue
	
	Block a user