add best score part
This commit is contained in:
		| @@ -3,7 +3,9 @@ import torch.nn.functional as F | |||||||
| import pytorch_lightning as pl | import pytorch_lightning as pl | ||||||
| import time | import time | ||||||
| import os | import os | ||||||
|  | from naswot.score_networks import get_nasbench201_nodes_score | ||||||
|  | from naswot import nasspace | ||||||
|  | from naswot import datasets | ||||||
| from models.transformer import Denoiser | from models.transformer import Denoiser | ||||||
| from diffusion.noise_schedule import PredefinedNoiseScheduleDiscrete, MarginalTransition | from diffusion.noise_schedule import PredefinedNoiseScheduleDiscrete, MarginalTransition | ||||||
|  |  | ||||||
| @@ -26,6 +28,43 @@ class Graph_DiT(pl.LightningModule): | |||||||
|         nodes_dist = dataset_infos.nodes_dist |         nodes_dist = dataset_infos.nodes_dist | ||||||
|         active_index = dataset_infos.active_index |         active_index = dataset_infos.active_index | ||||||
|  |  | ||||||
|  |         class Args: | ||||||
|  |             pass | ||||||
|  |  | ||||||
|  |         self.args = Args() | ||||||
|  |         self.args.trainval = True | ||||||
|  |         self.args.augtype = 'none' | ||||||
|  |         self.args.repeat = 1 | ||||||
|  |         self.args.score = 'hook_logdet' | ||||||
|  |         self.args.sigma = 0.05 | ||||||
|  |         self.args.nasspace = 'nasbench201' | ||||||
|  |         self.args.batch_size = 128 | ||||||
|  |         self.args.GPU = '0' | ||||||
|  |         self.args.dataset = 'cifar10-valid' | ||||||
|  |         self.args.api_loc = '/home/stud/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' | ||||||
|  |         self.args.data_loc = '../cifardata/' | ||||||
|  |         self.args.seed = 777 | ||||||
|  |         self.args.init = '' | ||||||
|  |         self.args.save_loc = 'results' | ||||||
|  |         self.args.save_string = 'naswot' | ||||||
|  |         self.args.dropout = False | ||||||
|  |         self.args.maxofn = 1 | ||||||
|  |         self.args.n_samples = 100 | ||||||
|  |         self.args.n_runs = 500 | ||||||
|  |         self.args.stem_out_channels = 16 | ||||||
|  |         self.args.num_stacks = 3 | ||||||
|  |         self.args.num_modules_per_stack = 3 | ||||||
|  |         self.args.num_labels = 1 | ||||||
|  |  | ||||||
|  |         if 'valid' in self.args.dataset: | ||||||
|  |             self.args.dataset = self.args.dataset.replace('-valid', '') | ||||||
|  |         print('graph_dit starts to get searchspace of nasbench201') | ||||||
|  |         self.searchspace = nasspace.get_search_space(self.args) | ||||||
|  |         print('searchspace of nasbench201 is obtained') | ||||||
|  |         print('graphdit starts to get train_loader') | ||||||
|  |         self.train_loader = datasets.get_data(self.args.dataset, self.args.data_loc, self.args.trainval, self.args.batch_size, self.args.augtype, self.args.repeat, self.args) | ||||||
|  |         print('train_loader is obtained') | ||||||
|  |  | ||||||
|         self.cfg = cfg |         self.cfg = cfg | ||||||
|         self.name = cfg.general.name |         self.name = cfg.general.name | ||||||
|         self.T = cfg.model.diffusion_steps |         self.T = cfg.model.diffusion_steps | ||||||
| @@ -629,15 +668,15 @@ class Graph_DiT(pl.LightningModule): | |||||||
|             prob_E = unnormalized_prob_E / torch.sum(unnormalized_prob_E, dim=-1, keepdim=True)  # bs, n, d_t-1 |             prob_E = unnormalized_prob_E / torch.sum(unnormalized_prob_E, dim=-1, keepdim=True)  # bs, n, d_t-1 | ||||||
|             prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1]) |             prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1]) | ||||||
|  |  | ||||||
|             return prob_X, prob_E |             return prob_X, prob_E, pred | ||||||
|         # diffusion nag: P_t(G_{t-1} |G_t, C) = P_t(G_{t-1} |G_t) + P_t(C | G_{t-1}, G_t) |         # diffusion nag: P_t(G_{t-1} |G_t, C) = P_t(G_{t-1} |G_t) + P_t(C | G_{t-1}, G_t) | ||||||
|         # with condition = P_t(G_{t-1} |G_t, C) |         # with condition = P_t(G_{t-1} |G_t, C) | ||||||
|         # with condition = P_t(A_{t-1} |A_t, y) |         # with condition = P_t(A_{t-1} |A_t, y) | ||||||
|         prob_X, prob_E = get_prob(noisy_data) |         prob_X, prob_E, pred = get_prob(noisy_data) | ||||||
|  |  | ||||||
|         ### Guidance |         ### Guidance | ||||||
|         if self.guidance_target is not None and self.guide_scale is not None and self.guide_scale != 1: |         if self.guidance_target is not None and self.guide_scale is not None and self.guide_scale != 1: | ||||||
|             uncon_prob_X, uncon_prob_E = get_prob(noisy_data, unconditioned=True) |             uncon_prob_X, uncon_prob_E, pred = get_prob(noisy_data, unconditioned=True) | ||||||
|             prob_X = uncon_prob_X * (prob_X / uncon_prob_X.clamp_min(1e-10)) ** self.guide_scale   |             prob_X = uncon_prob_X * (prob_X / uncon_prob_X.clamp_min(1e-10)) ** self.guide_scale   | ||||||
|             prob_E = uncon_prob_E * (prob_E / uncon_prob_E.clamp_min(1e-10)) ** self.guide_scale   |             prob_E = uncon_prob_E * (prob_E / uncon_prob_E.clamp_min(1e-10)) ** self.guide_scale   | ||||||
|             prob_X = prob_X / prob_X.sum(dim=-1, keepdim=True).clamp_min(1e-10) |             prob_X = prob_X / prob_X.sum(dim=-1, keepdim=True).clamp_min(1e-10) | ||||||
| @@ -647,32 +686,120 @@ class Graph_DiT(pl.LightningModule): | |||||||
|         assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-4).all() |         assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-4).all() | ||||||
|         assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-4).all() |         assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-4).all() | ||||||
|  |  | ||||||
|         sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask, step=s[0,0].item()) |         # sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask, step=s[0,0].item()) | ||||||
|  |  | ||||||
|         # sample multiple times and get the best score arch... |         # sample multiple times and get the best score arch... | ||||||
|  |  | ||||||
|         sample_num = 100 |         num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output'] | ||||||
|  |         op_type = { | ||||||
|  |             'input': 0, | ||||||
|  |             'nor_conv_1x1': 1, | ||||||
|  |             'nor_conv_3x3': 2, | ||||||
|  |             'avg_pool_3x3': 3, | ||||||
|  |             'skip_connect': 4, | ||||||
|  |             'none': 5, | ||||||
|  |             'output': 6, | ||||||
|  |         } | ||||||
|  |         def check_valid_graph(nodes, edges): | ||||||
|  |             nodes = [num_to_op[i] for i in nodes] | ||||||
|  |             if len(nodes) != edges.shape[0] or len(nodes) != edges.shape[1]: | ||||||
|  |                 return False | ||||||
|  |             if nodes[0] != 'input' or nodes[-1] != 'output': | ||||||
|  |                 return False | ||||||
|  |             for i in range(0, len(nodes)): | ||||||
|  |                 if edges[i][i] == 1: | ||||||
|  |                     return False | ||||||
|  |             for i in range(1, len(nodes) - 1): | ||||||
|  |                 if nodes[i] not in op_type or nodes[i] == 'input' or nodes[i] == 'output': | ||||||
|  |                     return False | ||||||
|  |             for i in range(0, len(nodes)): | ||||||
|  |                 for j in range(i, len(nodes)): | ||||||
|  |                     if edges[i, j] == 1 and nodes[j] == 'input': | ||||||
|  |                         return False | ||||||
|  |             for i in range(0, len(nodes)): | ||||||
|  |                 for j in range(i, len(nodes)): | ||||||
|  |                     if edges[i, j] == 1 and nodes[i] == 'output': | ||||||
|  |                         return False | ||||||
|  |             flag = 0 | ||||||
|  |             for i in range(0,len(nodes)): | ||||||
|  |                 if edges[i,-1] == 1: | ||||||
|  |                     flag = 1 | ||||||
|  |                     break | ||||||
|  |             if flag == 0: return False | ||||||
|  |             return True | ||||||
|  |  | ||||||
|  |         class Args: | ||||||
|  |             pass | ||||||
|  |  | ||||||
|  |         def get_score(sampled_s): | ||||||
|  |             x_list = sampled_s.X.unbind(dim=0) | ||||||
|  |             e_list = sampled_s.E.unbind(dim=0) | ||||||
|  |             valid_rlt = [check_valid_graph(x_list[i].cpu().numpy(), e_list[i].cpu().numpy()) for i in range(len(x_list))] | ||||||
|  |             from graph_dit.naswot.naswot.score_networks import get_nasbench201_nodes_score | ||||||
|  |             score = [] | ||||||
|  |              | ||||||
|  |             for i in range(len(x_list)): | ||||||
|  |                 if valid_rlt[i]: | ||||||
|  |                     nodes = [num_to_op[j] for j in x_list[i].cpu().numpy()] | ||||||
|  |                     # edges = e_list[i].cpu().numpy() | ||||||
|  |                     score.append(get_nasbench201_nodes_score(nodes,train_loader=self.train_loader,searchspace=self.searchspace,device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu") , args=self.args)) | ||||||
|  |                 else: | ||||||
|  |                     score.append(-1) | ||||||
|  |             return torch.tensor(score, dtype=torch.float32, requires_grad=True).to(x_list[0].device) | ||||||
|  |  | ||||||
|  |         sample_num = 10 | ||||||
|         best_arch = None |         best_arch = None | ||||||
|         best_score = -1e8 |         best_score_int = -1e8 | ||||||
|  |         score = torch.ones(100, dtype=torch.float32, requires_grad=True) * -1e8 | ||||||
|  |  | ||||||
|         for i in range(sample_num): |         for i in range(sample_num): | ||||||
|             sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask, step=s[0,0].item()) |             sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask, step=s[0,0].item()) | ||||||
|             score = get_score(sampled_s) |             score = get_score(sampled_s) | ||||||
|             if score > best_score: |             print(f'score: {score}') | ||||||
|  |             print(f'score.shape: {score.shape}') | ||||||
|  |             print(f'torch.sum(score): {torch.sum(score)}') | ||||||
|  |             sum_score = torch.sum(score) | ||||||
|  |             print(f'sum_score: {sum_score}') | ||||||
|  |             if sum_score > best_score_int: | ||||||
|  |                 best_score_int = sum_score | ||||||
|                 best_score = score |                 best_score = score | ||||||
|                 best_arch = sampled_s |                 best_arch = sampled_s | ||||||
|  |  | ||||||
|         X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).float() |         # print(f'prob_X: {prob_X.shape}, prob_E: {prob_E.shape}') | ||||||
|         E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).float() |          | ||||||
|  |         # best_arch = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask, step=s[0,0].item()) | ||||||
|  |  | ||||||
|  |         # X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).float() | ||||||
|  |         # E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).float() | ||||||
|  |         print(f'best_arch.X: {best_arch.X.shape}, best_arch.E: {best_arch.E.shape}') # 100 8 8, bs n n, 100 8 8 2, bs n n 2 | ||||||
|  |  | ||||||
|  |         print(f'best_arch.X: {best_arch.X}, best_arch.E: {best_arch.E}') | ||||||
|  |         X_s = F.one_hot(best_arch.X, num_classes=self.Xdim_output).float() | ||||||
|  |         E_s = F.one_hot(best_arch.E, num_classes=self.Edim_output).float() | ||||||
|  |         print(f'X_s: {X_s}, E_s: {E_s}') | ||||||
|  |  | ||||||
|         # NASWOT score |         # NASWOT score | ||||||
|         target_score = torch.tensor([3000.0]) |         target_score = torch.ones(100, requires_grad=True) * 2000.0 | ||||||
|  |         target_score = target_score.to(X_s.device) | ||||||
|  |  | ||||||
|         # compute loss mse(cur_score - target_score) |         # compute loss mse(cur_score - target_score) | ||||||
|  |         mse_loss = torch.nn.MSELoss() | ||||||
|  |         print(f'best_score: {best_score.shape}, target_score: {target_score.shape}') | ||||||
|  |         print(f'best_score.requires_grad: {best_score.requires_grad}, target_score.requires_grad: {target_score.requires_grad}') | ||||||
|  |         loss = mse_loss(best_score, target_score) | ||||||
|  |         loss.backward(retain_graph=True) | ||||||
|  |  | ||||||
|         # loss backward = gradient |         # loss backward = gradient | ||||||
|  |  | ||||||
|         # get prob.X, prob_E gradient |         # get prob.X, prob_E gradient | ||||||
|  |         x_grad = pred.X.grad | ||||||
|  |         e_grad = pred.E.grad | ||||||
|  |  | ||||||
|  |         beta_ratio = 0.5 | ||||||
|  |         # x_current = pred.X - beta_ratio * x_grad | ||||||
|  |         # e_current = pred.E - beta_ratio * e_grad | ||||||
|  |         E_s = pred.X - beta_ratio * x_grad | ||||||
|  |         X_s = pred.E - beta_ratio * e_grad | ||||||
|  |  | ||||||
|         # update prob.X prob_E with using gradient |         # update prob.X prob_E with using gradient | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user