Compare commits
	
		
			19 Commits
		
	
	
		
			6d9db64a48
			...
			nasbench
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 91d4e3c7ad | |||
| c867aef5a6 | |||
| 1ad520d248 | |||
| 74a629fdcc | |||
| 94fe13756f | |||
| 2ac17caa3c | |||
| 0c60171c71 | |||
| 97fbdf91c7 | |||
| 297261d666 | |||
| 5dccf590e7 | |||
| 0c4b597dd2 | |||
| 11d9697e06 | |||
| 244b159c26 | |||
| 63ca6c716e | |||
| d36e1d1077 | |||
| 82183d3df7 | |||
| c86db9b6ba | |||
| a0473008a1 | |||
| 05ee34e355 | 
							
								
								
									
										24
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										24
									
								
								README.md
									
									
									
									
									
								
							| @@ -1,14 +1,34 @@ | ||||
| Graph Diffusion Transformer for Multi-Conditional Molecular Generation | ||||
| ================================================================ | ||||
|  | ||||
| ## Initial Setup | ||||
|  | ||||
| Please download NASBench201 dataset(NAS-Bench-201-v1_1-096897.pth) from | ||||
| https://drive.google.com/file/d/16Y0UwGisiouVRxW-W5hEtbxmcHw_0hF_/view | ||||
|  | ||||
| and put it in the `/path/to/repo/graph_dit` folder. | ||||
|  | ||||
| ## Running the code | ||||
|  | ||||
| start command: | ||||
| ``` bash | ||||
| python main.py --config-name=config.yaml \ | ||||
| model.ensure_connected=True \ | ||||
| dataset.task_name='nasbench201' \ | ||||
| dataset.guidance_target='regression' | ||||
| ``` | ||||
|  | ||||
| This repository contains the code for the paper "Inverse Molecular Design with Multi-Conditional Diffusion Guidance" by Gang Liu, Jiaxin Xu, Tengfei Luo, and Meng Jiang. | ||||
|  | ||||
|  | ||||
| Paper: https://arxiv.org/abs/2401.13858 | ||||
|  | ||||
| This is the code for Graph DiT. The denoising model architecture in `graph_dit/models` looks like: | ||||
| <!-- This is the code for Graph DiT. The denoising model architecture in `graph_dit/models` looks like: | ||||
|  | ||||
| <div style="display: flex;" markdown="1"> | ||||
|       <img src="asset/reverse.png" style="width: 45%;" alt="Description of the first image"> | ||||
|       <img src="asset/arch.png" style="width: 45%;" alt="Description of the second image"> | ||||
| </div> | ||||
| </div> --> | ||||
|  | ||||
|  | ||||
| ## Requirements | ||||
|   | ||||
| @@ -771,9 +771,10 @@ class Dataset(InMemoryDataset): | ||||
|             edge_type = torch.tensor(edge_type, dtype=torch.long) | ||||
|             edge_attr = edge_type | ||||
|             # y = torch.tensor([0, 0], dtype=torch.float).view(1, -1) | ||||
|             y = get_nasbench201_idx_score(idx, train_loader, searchspace, args, device) | ||||
|             # y = get_nasbench201_idx_score(idx, train_loader, searchspace, args, device) | ||||
|             y = self.swap_scores[idx] | ||||
|             print(y, idx) | ||||
|             if y > 1600: | ||||
|             if y > 60000: | ||||
|                 print(f'idx={idx}, y={y}') | ||||
|                 y = torch.tensor([1, 1], dtype=torch.float).view(1, -1) | ||||
|                 data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i) | ||||
| @@ -812,6 +813,14 @@ class Dataset(InMemoryDataset): | ||||
|         args.num_labels = 1 | ||||
|         searchspace = nasspace.get_search_space(args) | ||||
|         train_loader = dt.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args) | ||||
|         self.swap_scores = [] | ||||
|         import csv | ||||
|         # with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results.csv', 'r') as f: | ||||
|         with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results_cifar100.csv', 'r') as f: | ||||
|             reader = csv.reader(f) | ||||
|             header = next(reader) | ||||
|             data = [row for row in reader] | ||||
|             self.swap_scores = [float(row[0]) for row in data] | ||||
|         device = torch.device('cuda:2') | ||||
|         with tqdm(total = len_data) as pbar: | ||||
|             active_nodes = set() | ||||
| @@ -823,14 +832,8 @@ class Dataset(InMemoryDataset): | ||||
|             flex_graph_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json' | ||||
|             for graph in graph_list: | ||||
|                 print(f'iterate every graph in graph_list, here is {i}') | ||||
|                 # arch_info = self.api.query_meta_info_by_index(i) | ||||
|                 # results = self.api.query_by_index(i, 'cifar100') | ||||
|                 arch_info = graph['arch_str'] | ||||
|                 # results =  | ||||
|                 # nodes, edges = parse_architecture_string(arch_info.arch_str) | ||||
|                 # ops, adj_matrix = parse_architecture_string(arch_info.arch_str, padding=4) | ||||
|                 ops, adj_matrix, ori_nodes, ori_adj = parse_architecture_string(arch_info, padding=4) | ||||
|                 # adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges) | ||||
|                 for op in ops: | ||||
|                     if op not in active_nodes: | ||||
|                         active_nodes.add(op) | ||||
| @@ -839,12 +842,6 @@ class Dataset(InMemoryDataset): | ||||
|                 if data is None: | ||||
|                     pbar.update(1) | ||||
|                     continue | ||||
|                 # with open(flex_graph_path, 'a') as f: | ||||
|                 #     flex_graph = { | ||||
|                 #         'adj_matrix': adj_matrix, | ||||
|                 #         'ops': ops, | ||||
|                 #     } | ||||
|                 #     json.dump(flex_graph, f) | ||||
|                 flex_graph_list.append({ | ||||
|                     'adj_matrix':adj_matrix, | ||||
|                     'ops': ops, | ||||
|   | ||||
| @@ -195,15 +195,18 @@ class Graph_DiT(pl.LightningModule): | ||||
|         # print("Size of the input features Xdim {}, Edim {}, ydim {}".format(self.Xdim, self.Edim, self.ydim)) | ||||
|  | ||||
|     def on_train_epoch_start(self) -> None: | ||||
|         if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: | ||||
|             print("Starting train epoch {}/{}...".format(self.current_epoch, self.trainer.max_epochs)) | ||||
|         # if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: | ||||
|         if self.current_epoch / self.cfg.train.n_epochs in [0.25, 0.5, 0.75, 1.0]: | ||||
|             # print("Starting train epoch {}/{}...".format(self.current_epoch, self.trainer.max_epochs)) | ||||
|             print("Starting train epoch {}/{}...".format(self.current_epoch, self.cfg.train.n_epochs)) | ||||
|         self.start_epoch_time = time.time() | ||||
|         self.train_loss.reset() | ||||
|         self.train_metrics.reset() | ||||
|  | ||||
|     def on_train_epoch_end(self) -> None: | ||||
|  | ||||
|         if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: | ||||
|         # if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: | ||||
|         if self.current_epoch / self.cfg.train.n_epochs in [0.25, 0.5, 0.75, 1.0]: | ||||
|             log = True | ||||
|         else: | ||||
|             log = False | ||||
| @@ -239,8 +242,8 @@ class Graph_DiT(pl.LightningModule): | ||||
|          | ||||
|                    self.val_X_logp.compute(), self.val_E_logp.compute()] | ||||
|          | ||||
|         if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: | ||||
|             print(f"Epoch {self.current_epoch}: Val NLL {metrics[0] :.2f} -- Val Atom type KL {metrics[1] :.2f} -- ", | ||||
|         # if self.current_epoch / self.trainer.max_epochs in [0.25, 0.5, 0.75, 1.0]: | ||||
|         print(f"Epoch {self.current_epoch}: Val NLL {metrics[0] :.2f} -- Val Atom type KL {metrics[1] :.2f} -- ", | ||||
|                 f"Val Edge type KL: {metrics[2] :.2f}", 'Val loss: %.2f \t Best :  %.2f\n' % (metrics[0], self.best_val_nll)) | ||||
|         with open("validation-metrics.csv", "a") as f: | ||||
|             # save the metrics as csv file | ||||
| @@ -286,7 +289,7 @@ class Graph_DiT(pl.LightningModule): | ||||
|                 samples.extend(self.sample_batch(batch_id=ident, batch_size=to_generate, y=batch_y, | ||||
|                                                 save_final=to_save, | ||||
|                                                 keep_chain=chains_save, | ||||
|                                                 number_chain_steps=self.number_chain_steps)) | ||||
|                                                 number_chain_steps=self.number_chain_steps)[0]) | ||||
|                 ident += to_generate | ||||
|                 start_index += to_generate | ||||
|  | ||||
| @@ -360,7 +363,7 @@ class Graph_DiT(pl.LightningModule): | ||||
|             batch_y = torch.ones(to_generate, self.ydim_output, device=self.device) | ||||
|  | ||||
|             cur_sample = self.sample_batch(batch_id, to_generate, batch_y, save_final=to_save, | ||||
|                                             keep_chain=chains_save, number_chain_steps=self.number_chain_steps) | ||||
|                                             keep_chain=chains_save, number_chain_steps=self.number_chain_steps)[0] | ||||
|             samples = samples + cur_sample | ||||
|              | ||||
|             all_ys.append(batch_y) | ||||
| @@ -601,6 +604,9 @@ class Graph_DiT(pl.LightningModule): | ||||
|  | ||||
|         assert (E == torch.transpose(E, 1, 2)).all() | ||||
|  | ||||
|         total_log_probs = torch.zeros([self.cfg.general.final_model_samples_to_generate,10], device=self.device) | ||||
|         # total_log_probs = torch.zeros([self.cfg.general.samples_to_generate,10], device=self.device) | ||||
|  | ||||
|         # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1. | ||||
|         for s_int in reversed(range(0, self.T)): | ||||
|             s_array = s_int * torch.ones((batch_size, 1)).type_as(y) | ||||
| @@ -609,21 +615,24 @@ class Graph_DiT(pl.LightningModule): | ||||
|             t_norm = t_array / self.T | ||||
|  | ||||
|             # Sample z_s | ||||
|             sampled_s, discrete_sampled_s = self.sample_p_zs_given_zt(s_norm, t_norm, X, E, y, node_mask) | ||||
|             sampled_s, discrete_sampled_s, log_probs= self.sample_p_zs_given_zt(s_norm, t_norm, X, E, y, node_mask) | ||||
|             X, E, y = sampled_s.X, sampled_s.E, sampled_s.y | ||||
|             print(f'sampled_s.X shape: {sampled_s.X.shape}, sampled_s.E shape: {sampled_s.E.shape}') | ||||
|             print(f'log_probs shape: {log_probs.shape}') | ||||
|             total_log_probs += log_probs | ||||
|  | ||||
|         # Sample | ||||
|         sampled_s = sampled_s.mask(node_mask, collapse=True) | ||||
|         X, E, y = sampled_s.X, sampled_s.E, sampled_s.y | ||||
|          | ||||
|         molecule_list = [] | ||||
|         graph_list = [] | ||||
|         for i in range(batch_size): | ||||
|             n = n_nodes[i] | ||||
|             atom_types = X[i, :n].cpu() | ||||
|             node_types = X[i, :n].cpu() | ||||
|             edge_types = E[i, :n, :n].cpu() | ||||
|             molecule_list.append([atom_types, edge_types]) | ||||
|             graph_list.append([node_types, edge_types]) | ||||
|          | ||||
|         return molecule_list | ||||
|         return graph_list, total_log_probs | ||||
|  | ||||
|     def sample_p_zs_given_zt(self, s, t, X_t, E_t, y_t, node_mask): | ||||
|         """Samples from zs ~ p(zs | zt). Only used during sampling. | ||||
| @@ -635,6 +644,7 @@ class Graph_DiT(pl.LightningModule): | ||||
|  | ||||
|         # Neural net predictions | ||||
|         noisy_data = {'X_t': X_t, 'E_t': E_t, 'y_t': y_t, 't': t, 'node_mask': node_mask} | ||||
|         print(f"sample p zs given zt X_t shape: {X_t.shape}, E_t shape: {E_t.shape}, y_t shape: {y_t.shape}, node_mask shape: {node_mask.shape}") | ||||
|          | ||||
|         def get_prob(noisy_data, unconditioned=False): | ||||
|             pred = self.forward(noisy_data, unconditioned=unconditioned) | ||||
| @@ -674,7 +684,19 @@ class Graph_DiT(pl.LightningModule): | ||||
|         # with condition = P_t(G_{t-1} |G_t, C) | ||||
|         # with condition = P_t(A_{t-1} |A_t, y) | ||||
|         prob_X, prob_E, pred = get_prob(noisy_data) | ||||
|         print(f'prob_X shape: {prob_X.shape}, prob_E shape: {prob_E.shape}') | ||||
|         print(f'X_t shape: {X_t.shape}, E_t shape: {E_t.shape}, y_t shape: {y_t.shape}') | ||||
|         print(f'X_t: {X_t}') | ||||
|         log_prob_X = torch.log(torch.gather(prob_X, -1, X_t.long()).squeeze(-1))  # bs, n | ||||
|         log_prob_E = torch.log(torch.gather(prob_E, -1, E_t.long()).squeeze(-1))  # bs, n, n | ||||
|  | ||||
|         # Sum the log_prob across dimensions for total log_prob | ||||
|         log_prob_X = log_prob_X.sum(dim=-1) | ||||
|         log_prob_E = log_prob_E.sum(dim=(1, 2)) | ||||
|         print(f'log_prob_X shape: {log_prob_X.shape}, log_prob_E shape: {log_prob_E.shape}') | ||||
|         # log_probs = log_prob_E + log_prob_X | ||||
|         log_probs = torch.cat([log_prob_X, log_prob_E], dim=-1)  # (batch_size, 2) | ||||
|         print(f'log_probs shape: {log_probs.shape}') | ||||
|         ### Guidance | ||||
|         if self.guidance_target is not None and self.guide_scale is not None and self.guide_scale != 1: | ||||
|             uncon_prob_X, uncon_prob_E, pred = get_prob(noisy_data, unconditioned=True) | ||||
| @@ -810,4 +832,4 @@ class Graph_DiT(pl.LightningModule): | ||||
|         out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=y_t) | ||||
|         out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=y_t) | ||||
|  | ||||
|         return out_one_hot.mask(node_mask).type_as(y_t), out_discrete.mask(node_mask, collapse=True).type_as(y_t) | ||||
|         return out_one_hot.mask(node_mask).type_as(y_t), out_discrete.mask(node_mask, collapse=True).type_as(y_t), log_probs | ||||
|   | ||||
							
								
								
									
										
											BIN
										
									
								
								graph_dit/exp_201/barplog.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								graph_dit/exp_201/barplog.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 30 KiB | 
| @@ -2,44 +2,45 @@ | ||||
| import matplotlib.pyplot as plt | ||||
| import pandas as pd | ||||
| from nas_201_api import NASBench201API as API | ||||
| from naswot.score_networks import get_nasbench201_idx_score | ||||
| from naswot import datasets as dt | ||||
| from naswot import nasspace | ||||
| # from naswot.score_networks import get_nasbench201_idx_score | ||||
| # from naswot import datasets as dt | ||||
| # from naswot import nasspace | ||||
|  | ||||
| class Args(): | ||||
|     pass | ||||
| args = Args() | ||||
| args.trainval = True | ||||
| args.augtype = 'none' | ||||
| args.repeat = 1 | ||||
| args.score = 'hook_logdet' | ||||
| args.sigma = 0.05 | ||||
| args.nasspace = 'nasbench201' | ||||
| args.batch_size = 128 | ||||
| args.GPU = '0' | ||||
| args.dataset = 'cifar10' | ||||
| args.api_loc = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' | ||||
| args.data_loc = '../cifardata/' | ||||
| args.seed = 777 | ||||
| args.init = '' | ||||
| args.save_loc = 'results' | ||||
| args.save_string = 'naswot' | ||||
| args.dropout = False | ||||
| args.maxofn = 1 | ||||
| args.n_samples = 100 | ||||
| args.n_runs = 500 | ||||
| args.stem_out_channels = 16 | ||||
| args.num_stacks = 3 | ||||
| args.num_modules_per_stack = 3 | ||||
| args.num_labels = 1 | ||||
| searchspace = nasspace.get_search_space(args) | ||||
| train_loader = dt.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args) | ||||
| device = torch.device('cuda:2') | ||||
| # class Args(): | ||||
| #     pass | ||||
| # args = Args() | ||||
| # args.trainval = True | ||||
| # args.augtype = 'none' | ||||
| # args.repeat = 1 | ||||
| # args.score = 'hook_logdet' | ||||
| # args.sigma = 0.05 | ||||
| # args.nasspace = 'nasbench201' | ||||
| # args.batch_size = 128 | ||||
| # args.GPU = '0' | ||||
| # args.dataset = 'cifar10' | ||||
| # args.api_loc = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' | ||||
| # args.data_loc = '../cifardata/' | ||||
| # args.seed = 777 | ||||
| # args.init = '' | ||||
| # args.save_loc = 'results' | ||||
| # args.save_string = 'naswot' | ||||
| # args.dropout = False | ||||
| # args.maxofn = 1 | ||||
| # args.n_samples = 100 | ||||
| # args.n_runs = 500 | ||||
| # args.stem_out_channels = 16 | ||||
| # args.num_stacks = 3 | ||||
| # args.num_modules_per_stack = 3 | ||||
| # args.num_labels = 1 | ||||
| # searchspace = nasspace.get_search_space(args) | ||||
| # train_loader = dt.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args) | ||||
| # device = torch.device('cuda:2') | ||||
|  | ||||
|  | ||||
| source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' | ||||
| api = API(source) | ||||
|  | ||||
|  | ||||
|           | ||||
| # source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' | ||||
| # api = API(source) | ||||
|  | ||||
|  | ||||
|  | ||||
| @@ -50,8 +51,10 @@ percentages = [] | ||||
| len_201 = 15625 | ||||
|  | ||||
| for i in range(len_201): | ||||
|     percentage = get_nasbench201_idx_score(i, train_loader, searchspace, args, device) | ||||
|     percentages.append(percentage) | ||||
|     # percentage = get_nasbench201_idx_score(i, train_loader, searchspace, args, device) | ||||
|     results = api.query_by_index(i, 'cifar10') | ||||
|     result = results[111].get_eval('ori-test') | ||||
|     percentages.append(result) | ||||
|  | ||||
| # 定义10%区间 | ||||
| bins = [i for i in range(0, 101, 10)] | ||||
|   | ||||
| @@ -1,4 +1,5 @@ | ||||
| # These imports are tricky because they use c++, do not move them | ||||
| from tqdm import tqdm | ||||
| import os, shutil | ||||
| import warnings | ||||
|  | ||||
| @@ -144,10 +145,32 @@ def main(cfg: DictConfig): | ||||
|     else: | ||||
|         trainer.test(model, datamodule=datamodule, ckpt_path=cfg.general.test_only) | ||||
|  | ||||
| from accelerate import Accelerator | ||||
| from accelerate.utils import set_seed, ProjectConfiguration | ||||
|  | ||||
| @hydra.main( | ||||
|     version_base="1.1", config_path="../configs", config_name="config" | ||||
| ) | ||||
| def test(cfg: DictConfig): | ||||
|     os.environ["CUDA_VISIBLE_DEVICES"] = cfg.general.gpu_number | ||||
|     accelerator_config = ProjectConfiguration( | ||||
|         project_dir=os.path.join(cfg.general.log_dir, cfg.general.name), | ||||
|         automatic_checkpoint_naming=True, | ||||
|         total_limit=cfg.general.number_checkpoint_limit, | ||||
|     ) | ||||
|     accelerator = Accelerator( | ||||
|         mixed_precision='no', | ||||
|         project_config=accelerator_config, | ||||
|         # gradient_accumulation_steps=cfg.train.gradient_accumulation_steps * cfg.train.n_epochs, | ||||
|         gradient_accumulation_steps=cfg.train.gradient_accumulation_steps,  | ||||
|     ) | ||||
|  | ||||
|     # Debug: 确认可用设备 | ||||
|     print(f"Available GPUs: {torch.cuda.device_count()}") | ||||
|     print(f"Using device: {accelerator.device}") | ||||
|  | ||||
|     set_seed(cfg.train.seed, device_specific=True) | ||||
|  | ||||
|     datamodule = dataset.DataModule(cfg) | ||||
|     datamodule.prepare_data() | ||||
|     dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg, dataset=datamodule.dataset) | ||||
| @@ -169,40 +192,216 @@ def test(cfg: DictConfig): | ||||
|         "visualization_tools": visulization_tools, | ||||
|     } | ||||
|  | ||||
|     # Debug: 确认可用设备 | ||||
|     print(f"Available GPUs: {torch.cuda.device_count()}") | ||||
|     print(f"Using device: {accelerator.device}") | ||||
|  | ||||
|     if cfg.general.test_only: | ||||
|         cfg, _ = get_resume(cfg, model_kwargs) | ||||
|         os.chdir(cfg.general.test_only.split("checkpoints")[0]) | ||||
|     elif cfg.general.resume is not None: | ||||
|         cfg, _ = get_resume_adaptive(cfg, model_kwargs) | ||||
|         os.chdir(cfg.general.resume.split("checkpoints")[0]) | ||||
|     # os.environ["CUDA_VISIBLE_DEVICES"] = cfg.general.gpu_number | ||||
|     model = Graph_DiT(cfg=cfg, **model_kwargs) | ||||
|     trainer = Trainer( | ||||
|         gradient_clip_val=cfg.train.clip_grad, | ||||
|         # accelerator="cpu", | ||||
|         accelerator="gpu" | ||||
|         if torch.cuda.is_available() and cfg.general.gpus > 0 | ||||
|         else "cpu", | ||||
|         devices=[cfg.general.gpu_number] | ||||
|         if torch.cuda.is_available() and cfg.general.gpus > 0 | ||||
|         else None, | ||||
|         max_epochs=cfg.train.n_epochs, | ||||
|         enable_checkpointing=False, | ||||
|         check_val_every_n_epoch=cfg.train.check_val_every_n_epoch, | ||||
|         val_check_interval=cfg.train.val_check_interval, | ||||
|         strategy="ddp" if cfg.general.gpus > 1 else "auto", | ||||
|         enable_progress_bar=cfg.general.enable_progress_bar, | ||||
|         callbacks=[], | ||||
|         reload_dataloaders_every_n_epochs=0, | ||||
|         logger=[], | ||||
|     ) | ||||
|     graph_dit_model = model | ||||
|  | ||||
|     if not cfg.general.test_only: | ||||
|         print("start testing fit method") | ||||
|         trainer.fit(model, datamodule=datamodule, ckpt_path=cfg.general.resume) | ||||
|         if cfg.general.save_model: | ||||
|             trainer.save_checkpoint(f"checkpoints/{cfg.general.name}/last.ckpt") | ||||
|         trainer.test(model, datamodule=datamodule) | ||||
|     inference_dtype = torch.float32 | ||||
|     graph_dit_model.to(accelerator.device, dtype=inference_dtype) | ||||
|  | ||||
|  | ||||
|     # optional: freeze the model | ||||
|     # graph_dit_model.model.requires_grad_(True) | ||||
|  | ||||
|     import torch.nn.functional as F | ||||
|     optimizer = graph_dit_model.configure_optimizers() | ||||
|     train_dataloader = accelerator.prepare(datamodule.train_dataloader()) | ||||
|     optimizer, graph_dit_model = accelerator.prepare(optimizer, graph_dit_model) | ||||
|     # start training | ||||
|     for epoch in range(cfg.train.n_epochs): | ||||
|         graph_dit_model.train()  # 设置模型为训练模式 | ||||
|         print(f"Epoch {epoch}", end="\n") | ||||
|         graph_dit_model.on_train_epoch_start() | ||||
|         for data in train_dataloader:  # 从数据加载器中获取一个批次的数据 | ||||
|             # data.to(accelerator.device) | ||||
|             # data_x = F.one_hot(data.x, num_classes=12).float()[:, graph_dit_model.active_index] | ||||
|             # data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float() | ||||
|             # dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, graph_dit_model.max_n_nodes) | ||||
|             # dense_data = dense_data.mask(node_mask) | ||||
|             # X, E = dense_data.X, dense_data.E | ||||
|             # noisy_data = graph_dit_model.apply_noise(X, E, data.y, node_mask) | ||||
|             # pred = graph_dit_model.forward(noisy_data) | ||||
|             # loss = graph_dit_model.train_loss(masked_pred_X=pred.X, masked_pred_E=pred.E, pred_y=pred.y, | ||||
|             #                     true_X=X, true_E=E, true_y=data.y, node_mask=node_mask, | ||||
|             #                     log=epoch % graph_dit_model.log_every_steps == 0) | ||||
|             # # print(f'training loss: {loss}, epoch: {self.current_epoch}, batch: {i}\n, pred type: {type(pred)}, pred.X shape: {type(pred.X)}, {pred.X.shape}, pred.E shape: {type(pred.E)}, {pred.E.shape}') | ||||
|             # graph_dit_model.train_metrics(masked_pred_X=pred.X, masked_pred_E=pred.E, true_X=X, true_E=E, | ||||
|             #                 log=epoch % graph_dit_model.log_every_steps == 0) | ||||
|             # graph_dit_model.log(f'loss', loss, batch_size=X.size(0), sync_dist=True) | ||||
|             # print(f"training loss: {loss}") | ||||
|             # with open("training-loss.csv", "a") as f: | ||||
|             #     f.write(f"{loss}, {epoch}\n") | ||||
|             loss = graph_dit_model.training_step(data, epoch) | ||||
|             loss = loss['loss'] | ||||
|  | ||||
|             accelerator.backward(loss) | ||||
|             optimizer.step() | ||||
|             optimizer.zero_grad() | ||||
|             # return {'loss': loss} | ||||
|         graph_dit_model.on_train_epoch_end() | ||||
|         if epoch % cfg.train.check_val_every_n_epoch == 0: | ||||
|             print(f'print validation loss') | ||||
|             graph_dit_model.eval() | ||||
|             graph_dit_model.on_validation_epoch_start() | ||||
|             graph_dit_model.validation_step(data, epoch) | ||||
|             graph_dit_model.on_validation_epoch_end() | ||||
|      | ||||
|     # start testing | ||||
|     print("start testing") | ||||
|     graph_dit_model.eval() | ||||
|     test_dataloader = accelerator.prepare(datamodule.test_dataloader()) | ||||
|     graph_dit_model.on_test_epoch_start() | ||||
|     for data in test_dataloader: | ||||
|         nll = graph_dit_model.test_step(data, epoch) | ||||
|         # data_x = F.one_hot(data.x, num_classes=12).float()[:, graph_dit_model.active_index] | ||||
|         # data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float() | ||||
|  | ||||
|         # dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, graph_dit_model.max_n_nodes) | ||||
|         # dense_data = dense_data.mask(node_mask) | ||||
|         # noisy_data = graph_dit_model.apply_noise(dense_data.X, dense_data.E, data.y, node_mask) | ||||
|         # pred = graph_dit_model.forward(noisy_data) | ||||
|         # nll = graph_dit_model.compute_val_loss(pred, noisy_data, dense_data.X, dense_data.E, data.y, node_mask, test=True) | ||||
|         # graph_dit_model.test_y_collection.append(data.y) | ||||
|         print(f'test loss: {nll}') | ||||
|      | ||||
|     graph_dit_model.on_test_epoch_end() | ||||
|  | ||||
|     # start sampling | ||||
|  | ||||
|     # samples_left_to_generate = cfg.general.final_model_samples_to_generate | ||||
|     # samples_left_to_save = cfg.general.final_model_samples_to_save | ||||
|     # chains_left_to_save = cfg.general.final_model_chains_to_save | ||||
|  | ||||
|     # samples, all_ys, batch_id = [], [], 0 | ||||
|     # samples_with_log_probs = [] | ||||
|     # test_y_collection = torch.cat(graph_dit_model.test_y_collection, dim=0) | ||||
|     # num_examples = test_y_collection.size(0) | ||||
|     # if cfg.general.final_model_samples_to_generate > num_examples: | ||||
|     #     ratio = cfg.general.final_model_samples_to_generate // num_examples | ||||
|     #     test_y_collection = test_y_collection.repeat(ratio+1, 1) | ||||
|     #     num_examples = test_y_collection.size(0) | ||||
|      | ||||
|     # Normal reward function | ||||
|     # from nas_201_api import NASBench201API as API | ||||
|     # api = API('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth') | ||||
|     # def graph_reward_fn(graphs, true_graphs=None, device=None, reward_model='swap'): | ||||
|     #     rewards = [] | ||||
|     #     if reward_model == 'swap': | ||||
|     #         import csv | ||||
|     #         with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results.csv', 'r') as f: | ||||
|     #             reader = csv.reader(f) | ||||
|     #             header = next(reader) | ||||
|     #             data = [row for row in reader] | ||||
|     #             swap_scores = [float(row[0]) for row in data] | ||||
|     #             for graph in graphs: | ||||
|     #                 node_tensor = graph[0] | ||||
|     #                 node = node_tensor.cpu().numpy().tolist() | ||||
|  | ||||
|     #                 def nodes_to_arch_str(nodes): | ||||
|     #                     num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output'] | ||||
|     #                     nodes_str = [num_to_op[node] for node in nodes] | ||||
|     #                     arch_str = '|' + nodes_str[1] + '~0|+' + \ | ||||
|     #                             '|' + nodes_str[2] + '~0|' + nodes_str[3] + '~1|+' +\ | ||||
|     #                             '|' + nodes_str[4] + '~0|' + nodes_str[5] + '~1|' + nodes_str[6] + '~2|'  | ||||
|     #                     return arch_str | ||||
|                      | ||||
|     #                 arch_str = nodes_to_arch_str(node) | ||||
|     #                 reward = swap_scores[api.query_index_by_arch(arch_str)] | ||||
|     #                 rewards.append(reward) | ||||
|                  | ||||
|     #     # for graph in graphs: | ||||
|     #     #     reward = 1.0 | ||||
|     #     #     rewards.append(reward) | ||||
|     #     return torch.tensor(rewards, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device) | ||||
|     # old_log_probs = None | ||||
|     # while samples_left_to_generate > 0: | ||||
|     #     print(f'samples left to generate: {samples_left_to_generate}/' | ||||
|     #         f'{cfg.general.final_model_samples_to_generate}', end='', flush=True) | ||||
|     #     bs = 1 * cfg.train.batch_size | ||||
|     #     to_generate = min(samples_left_to_generate, bs) | ||||
|     #     to_save = min(samples_left_to_save, bs) | ||||
|     #     chains_save = min(chains_left_to_save, bs) | ||||
|     #     # batch_y = test_y_collection[batch_id : batch_id + to_generate] | ||||
|     #     batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device) | ||||
|  | ||||
|     #     cur_sample, log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save, | ||||
|     #                                     keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps) | ||||
|     #     log_probs = torch.sum(log_probs, dim=-1).unsqueeze(1) | ||||
|     #     samples = samples + cur_sample | ||||
|     #     reward = graph_reward_fn(cur_sample, device=graph_dit_model.device) | ||||
|     #     advantages = (reward - torch.mean(reward)) / (torch.std(reward) + 1e-6) | ||||
|     #     print(f'reward: {reward.shape}, advantages: {advantages.shape}, log_probs: {log_probs.shape}, cur_sample: {len(cur_sample)}') | ||||
|     #     if old_log_probs is None: | ||||
|     #         old_log_probs = log_probs.clone() | ||||
|     #     ratio = torch.exp(log_probs - old_log_probs) | ||||
|     #     unclipped_loss = -advantages * ratio | ||||
|     #     clipped_loss = -advantages * torch.clamp(ratio, 1.0 - cfg.ppo.clip_param, 1.0 + cfg.ppo.clip_param) | ||||
|     #     loss = torch.mean(torch.max(unclipped_loss, clipped_loss)) | ||||
|     #     accelerator.backward(loss) | ||||
|     #     optimizer.step() | ||||
|     #     optimizer.zero_grad() | ||||
|  | ||||
|  | ||||
|     #     samples_with_log_probs.append((cur_sample, log_probs, reward)) | ||||
|          | ||||
|     #     all_ys.append(batch_y) | ||||
|     #     batch_id += to_generate | ||||
|  | ||||
|     #     samples_left_to_save -= to_save | ||||
|     #     samples_left_to_generate -= to_generate | ||||
|     #     chains_left_to_save -= chains_save | ||||
|          | ||||
|     # print(f"final Computing sampling metrics...") | ||||
|     # graph_dit_model.sampling_metrics.reset() | ||||
|     # graph_dit_model.sampling_metrics(samples, all_ys, graph_dit_model.name, graph_dit_model.current_epoch, graph_dit_model.val_counter, test=True) | ||||
|     # graph_dit_model.sampling_metrics.reset() | ||||
|     # print(f"Done.") | ||||
|  | ||||
|     # # save samples | ||||
|     # print("Samples:") | ||||
|     # print(samples) | ||||
|  | ||||
|     # ======================== | ||||
|      | ||||
|  | ||||
|      | ||||
|  | ||||
|  | ||||
|     # trainer = Trainer( | ||||
|     #     gradient_clip_val=cfg.train.clip_grad, | ||||
|     #     # accelerator="cpu", | ||||
|     #     accelerator="gpu" | ||||
|     #     if torch.cuda.is_available() and cfg.general.gpus > 0 | ||||
|     #     else "cpu", | ||||
|     #     devices=[cfg.general.gpu_number] | ||||
|     #     if torch.cuda.is_available() and cfg.general.gpus > 0 | ||||
|     #     else None, | ||||
|     #     max_epochs=cfg.train.n_epochs, | ||||
|     #     enable_checkpointing=False, | ||||
|     #     check_val_every_n_epoch=cfg.train.check_val_every_n_epoch, | ||||
|     #     val_check_interval=cfg.train.val_check_interval, | ||||
|     #     strategy="ddp" if cfg.general.gpus > 1 else "auto", | ||||
|     #     enable_progress_bar=cfg.general.enable_progress_bar, | ||||
|     #     callbacks=[], | ||||
|     #     reload_dataloaders_every_n_epochs=0, | ||||
|     #     logger=[], | ||||
|     # ) | ||||
|  | ||||
|     # if not cfg.general.test_only: | ||||
|     #     print("start testing fit method") | ||||
|     #     trainer.fit(model, datamodule=datamodule, ckpt_path=cfg.general.resume) | ||||
|     #     if cfg.general.save_model: | ||||
|     #         trainer.save_checkpoint(f"checkpoints/{cfg.general.name}/last.ckpt") | ||||
|     #     trainer.test(model, datamodule=datamodule) | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     test() | ||||
|   | ||||
| @@ -76,6 +76,8 @@ class CategoricalEmbedder(nn.Module): | ||||
|             embeddings = embeddings + noise | ||||
|         return embeddings | ||||
|      | ||||
| # 相似的condition cluster起来 | ||||
| # size  | ||||
| class ClusterContinuousEmbedder(nn.Module): | ||||
|     def __init__(self, input_size, hidden_size, dropout_prob): | ||||
|         super().__init__() | ||||
| @@ -108,6 +110,8 @@ class ClusterContinuousEmbedder(nn.Module): | ||||
|          | ||||
|         if drop_ids is not None: | ||||
|             embeddings = torch.zeros((labels.shape[0], self.hidden_size), device=labels.device) | ||||
|             # print(labels[~drop_ids].shape) | ||||
|             # torch.Size([1200]) | ||||
|             embeddings[~drop_ids] = self.mlp(labels[~drop_ids]) | ||||
|             embeddings[drop_ids] += self.embedding_drop.weight[0] | ||||
|         else: | ||||
|   | ||||
| @@ -17,20 +17,22 @@ class Denoiser(nn.Module): | ||||
|         num_heads=16, | ||||
|         mlp_ratio=4.0, | ||||
|         drop_condition=0.1, | ||||
|         Xdim=118, | ||||
|         Edim=5, | ||||
|         ydim=3, | ||||
|         Xdim=7, | ||||
|         Edim=2, | ||||
|         ydim=1, | ||||
|         task_type='regression', | ||||
|     ): | ||||
|         super().__init__() | ||||
|         print(f"Denoiser, xdim: {Xdim}, edim: {Edim}, ydim: {ydim}, hidden_size: {hidden_size}, depth: {depth}, num_heads: {num_heads}, mlp_ratio: {mlp_ratio}, drop_condition: {drop_condition}") | ||||
|         self.num_heads = num_heads | ||||
|         self.ydim = ydim | ||||
|         self.x_embedder = nn.Linear(Xdim + max_n_nodes * Edim, hidden_size, bias=False) | ||||
|  | ||||
|         self.t_embedder = TimestepEmbedder(hidden_size) | ||||
|         #  | ||||
|         self.y_embedding_list = torch.nn.ModuleList() | ||||
|  | ||||
|         self.y_embedding_list.append(ClusterContinuousEmbedder(2, hidden_size, drop_condition)) | ||||
|         self.y_embedding_list.append(ClusterContinuousEmbedder(1, hidden_size, drop_condition)) | ||||
|         for i in range(ydim - 2): | ||||
|             if task_type == 'regression': | ||||
|                 self.y_embedding_list.append(ClusterContinuousEmbedder(1, hidden_size, drop_condition)) | ||||
| @@ -88,6 +90,8 @@ class Denoiser(nn.Module): | ||||
|          | ||||
|         # print("Denoiser Forward") | ||||
|         # print(x.shape, e.shape, y.shape, t.shape, unconditioned) | ||||
|         # torch.Size([1200, 8, 7]) torch.Size([1200, 8, 8, 2]) torch.Size([1200, 2]) torch.Size([1200, 1]) False | ||||
|         # print(y) | ||||
|         force_drop_id = torch.zeros_like(y.sum(-1)) | ||||
|         # drop the nan values | ||||
|         force_drop_id[torch.isnan(y.sum(-1))] = 1 | ||||
| @@ -109,11 +113,12 @@ class Denoiser(nn.Module): | ||||
|         c1 = self.t_embedder(t) | ||||
|         # print("C1 after t_embedder") | ||||
|         # print(c1.shape) | ||||
|         for i in range(1, self.ydim): | ||||
|             if i == 1: | ||||
|                 c2 = self.y_embedding_list[i-1](y[:, :2], self.training, force_drop_id, t) | ||||
|             else: | ||||
|                 c2 = c2 + self.y_embedding_list[i-1](y[:, i:i+1], self.training, force_drop_id, t) | ||||
|         c2 = self.y_embedding_list[0](y[:,0].unsqueeze(-1), self.training, force_drop_id, t) | ||||
|         # for i in range(1, self.ydim): | ||||
|         #     if i == 1: | ||||
|         #         c2 = self.y_embedding_list[i-1](y[:, :2], self.training, force_drop_id, t) | ||||
|         #     else: | ||||
|                 # c2 = c2 + self.y_embedding_list[i-1](y[:, i:i+1], self.training, force_drop_id, t) | ||||
|         # print("C2 after y_embedding_list") | ||||
|         # print(c2.shape) | ||||
|         # print("C1 + C2") | ||||
|   | ||||
							
								
								
									
										15626
									
								
								graph_dit/swap_results_aircraft.csv
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15626
									
								
								graph_dit/swap_results_aircraft.csv
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										144
									
								
								graph_dit/test_perf.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										144
									
								
								graph_dit/test_perf.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,144 @@ | ||||
| from nas_201_api import NASBench201API as API | ||||
| import re | ||||
| import pandas as pd | ||||
| import json | ||||
| import numpy as np | ||||
| import argparse | ||||
|  | ||||
| api = API('./NAS-Bench-201-v1_1-096897.pth') | ||||
|  | ||||
| parser = argparse.ArgumentParser(description='Process some integers.') | ||||
|  | ||||
| parser.add_argument('--file_path', type=str, default='211035.txt',) | ||||
| parser.add_argument('--datasets', type=str, default='cifar10',) | ||||
| args = parser.parse_args() | ||||
|  | ||||
| def process_graph_data(text): | ||||
|     # Split the input text into sections for each graph | ||||
|     graph_sections = text.strip().split('nodes:') | ||||
|      | ||||
|     # Prepare lists to store data | ||||
|     nodes_list = [] | ||||
|     edges_list = [] | ||||
|     results_list = [] | ||||
|      | ||||
|     for section in graph_sections[1:]: | ||||
|         # Extract nodes | ||||
|         nodes_section = section.split('edges:')[0] | ||||
|         nodes_match = re.search(r'(tensor\(\d+\) ?)+', section) | ||||
|         if nodes_match: | ||||
|             nodes = re.findall(r'tensor\((\d+)\)', nodes_match.group(0)) | ||||
|             nodes_list.append(nodes) | ||||
|          | ||||
|         # Extract edges | ||||
|         edge_section = section.split('edges:')[1] | ||||
|         edges_match = re.search(r'edges:', section) | ||||
|         if edges_match: | ||||
|             edges = re.findall(r'tensor\((\d+)\)', edge_section) | ||||
|             edges_list.append(edges) | ||||
|          | ||||
|         # Extract the last floating point number as a result | ||||
|      | ||||
|     # Create a DataFrame to store the extracted data | ||||
|     data = { | ||||
|         'nodes': nodes_list, | ||||
|         'edges': edges_list, | ||||
|     } | ||||
|     data['nodes'] = [[int(x) for x in node] for node in data['nodes']] | ||||
|     data['edges'] = [[int(x) for x in edge] for edge in data['edges']] | ||||
|     def split_list(input_list, chunk_size): | ||||
|         return [input_list[i:i + chunk_size] for i in range(0, len(input_list), chunk_size)] | ||||
|     data['edges'] = [split_list(edge, 8) for edge in data['edges']] | ||||
|  | ||||
|     print(data) | ||||
|     df = pd.DataFrame(data) | ||||
|     print('df') | ||||
|     print(df['nodes'][0], df['edges'][0]) | ||||
|     return df | ||||
|  | ||||
| def is_valid_nasbench201(adj, ops): | ||||
|     print(ops) | ||||
|     if ops[0] != 0 or ops[-1] != 6: | ||||
|         return False | ||||
|     for i in range(2, len(ops) - 1): | ||||
|         if ops[i] not in [1, 2, 3, 4, 5]: | ||||
|             return False | ||||
|     adj_mat = [ [0, 1, 1, 0, 1, 0, 0, 0], | ||||
|                 [0, 0, 0, 1, 0, 1 ,0 ,0], | ||||
|                 [0, 0, 0, 0, 0, 0, 1, 0], | ||||
|                 [0, 0, 0, 0, 0, 0, 1, 0], | ||||
|                 [0, 0, 0, 0, 0, 0, 0, 1], | ||||
|                 [0, 0, 0, 0, 0, 0, 0, 1], | ||||
|                 [0, 0, 0, 0, 0, 0, 0, 1], | ||||
|                 [0, 0, 0, 0, 0, 0, 0, 0]] | ||||
|   | ||||
|     for i in range(len(adj)): | ||||
|         for j in range(len(adj[i])): | ||||
|             if adj[i][j] not in [0, 1]: | ||||
|                 return False | ||||
|             if j > i: | ||||
|                 if adj[i][j] != adj_mat[i][j]: | ||||
|                     return False | ||||
|     return True | ||||
|  | ||||
| num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output'] | ||||
| def nodes_to_arch_str(nodes): | ||||
|     nodes_str = [num_to_op[node] for node in nodes] | ||||
|     arch_str = '|' + nodes_str[1] + '~0|+' + \ | ||||
|                '|' + nodes_str[2] + '~0|' + nodes_str[3] + '~1|+' +\ | ||||
|                '|' + nodes_str[4] + '~0|' + nodes_str[5] + '~1|' + nodes_str[6] + '~2|'  | ||||
|     return arch_str | ||||
|  | ||||
| filename = args.file_path | ||||
| datasets_name = args.datasets | ||||
|  | ||||
| with open('./output_graphs/' + filename, 'r') as f: | ||||
|     texts = f.read() | ||||
|     df = process_graph_data(texts) | ||||
|     valid = 0 | ||||
|     not_valid = 0 | ||||
|     scores = [] | ||||
|  | ||||
|     # 定义分类标准和分布字典的映射 | ||||
|     thresholds = { | ||||
|         'cifar10': [90, 91, 92, 93, 94], | ||||
|         'cifar100': [68,69,70, 71, 72, 73] | ||||
|     } | ||||
|     dist = {f'<{threshold}': 0 for threshold in thresholds[datasets_name]} | ||||
|     dist[f'>{thresholds[datasets_name][-1]}'] = 0 | ||||
|  | ||||
|     for i in range(len(df)): | ||||
|         nodes = df['nodes'][i] | ||||
|         edges = df['edges'][i] | ||||
|         result = is_valid_nasbench201(edges, nodes) | ||||
|         if result: | ||||
|             valid += 1 | ||||
|             arch_str = nodes_to_arch_str(nodes) | ||||
|             index = api.query_index_by_arch(arch_str) | ||||
|             res = api.get_more_info(index, datasets_name, None, hp=200, is_random=False) | ||||
|             acc = res['test-accuracy'] | ||||
|             scores.append((index, acc)) | ||||
|  | ||||
|             # 根据阈值更新分布 | ||||
|             updated = False | ||||
|             for threshold in thresholds[datasets_name]: | ||||
|                 if acc < threshold: | ||||
|                     dist[f'<{threshold}'] += 1 | ||||
|                     updated = True | ||||
|                     break | ||||
|             if not updated: | ||||
|                 dist[f'>{thresholds[datasets_name][-1]}'] += 1 | ||||
|         else: | ||||
|             not_valid += 1 | ||||
|  | ||||
|     with open('./output_graphs/' + filename + '_' + datasets_name +'.json', 'w') as f: | ||||
|         json.dump(scores, f) | ||||
|  | ||||
|     print(scores) | ||||
|     print(valid, not_valid) | ||||
|     print(dist) | ||||
|     print("mean: ", np.mean([x[1] for x in scores])) | ||||
|     print("max: ", np.max([x[1] for x in scores])) | ||||
|     print("min: ", np.min([x[1] for x in scores])) | ||||
|  | ||||
|          | ||||
		Reference in New Issue
	
	Block a user