add best score part
This commit is contained in:
		@@ -3,7 +3,9 @@ import torch.nn.functional as F
 | 
			
		||||
import pytorch_lightning as pl
 | 
			
		||||
import time
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
from naswot.score_networks import get_nasbench201_nodes_score
 | 
			
		||||
from naswot import nasspace
 | 
			
		||||
from naswot import datasets
 | 
			
		||||
from models.transformer import Denoiser
 | 
			
		||||
from diffusion.noise_schedule import PredefinedNoiseScheduleDiscrete, MarginalTransition
 | 
			
		||||
 | 
			
		||||
@@ -26,6 +28,43 @@ class Graph_DiT(pl.LightningModule):
 | 
			
		||||
        nodes_dist = dataset_infos.nodes_dist
 | 
			
		||||
        active_index = dataset_infos.active_index
 | 
			
		||||
 | 
			
		||||
        class Args:
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
        self.args = Args()
 | 
			
		||||
        self.args.trainval = True
 | 
			
		||||
        self.args.augtype = 'none'
 | 
			
		||||
        self.args.repeat = 1
 | 
			
		||||
        self.args.score = 'hook_logdet'
 | 
			
		||||
        self.args.sigma = 0.05
 | 
			
		||||
        self.args.nasspace = 'nasbench201'
 | 
			
		||||
        self.args.batch_size = 128
 | 
			
		||||
        self.args.GPU = '0'
 | 
			
		||||
        self.args.dataset = 'cifar10-valid'
 | 
			
		||||
        self.args.api_loc = '/home/stud/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
 | 
			
		||||
        self.args.data_loc = '../cifardata/'
 | 
			
		||||
        self.args.seed = 777
 | 
			
		||||
        self.args.init = ''
 | 
			
		||||
        self.args.save_loc = 'results'
 | 
			
		||||
        self.args.save_string = 'naswot'
 | 
			
		||||
        self.args.dropout = False
 | 
			
		||||
        self.args.maxofn = 1
 | 
			
		||||
        self.args.n_samples = 100
 | 
			
		||||
        self.args.n_runs = 500
 | 
			
		||||
        self.args.stem_out_channels = 16
 | 
			
		||||
        self.args.num_stacks = 3
 | 
			
		||||
        self.args.num_modules_per_stack = 3
 | 
			
		||||
        self.args.num_labels = 1
 | 
			
		||||
 | 
			
		||||
        if 'valid' in self.args.dataset:
 | 
			
		||||
            self.args.dataset = self.args.dataset.replace('-valid', '')
 | 
			
		||||
        print('graph_dit starts to get searchspace of nasbench201')
 | 
			
		||||
        self.searchspace = nasspace.get_search_space(self.args)
 | 
			
		||||
        print('searchspace of nasbench201 is obtained')
 | 
			
		||||
        print('graphdit starts to get train_loader')
 | 
			
		||||
        self.train_loader = datasets.get_data(self.args.dataset, self.args.data_loc, self.args.trainval, self.args.batch_size, self.args.augtype, self.args.repeat, self.args)
 | 
			
		||||
        print('train_loader is obtained')
 | 
			
		||||
 | 
			
		||||
        self.cfg = cfg
 | 
			
		||||
        self.name = cfg.general.name
 | 
			
		||||
        self.T = cfg.model.diffusion_steps
 | 
			
		||||
@@ -629,15 +668,15 @@ class Graph_DiT(pl.LightningModule):
 | 
			
		||||
            prob_E = unnormalized_prob_E / torch.sum(unnormalized_prob_E, dim=-1, keepdim=True)  # bs, n, d_t-1
 | 
			
		||||
            prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1])
 | 
			
		||||
 | 
			
		||||
            return prob_X, prob_E
 | 
			
		||||
            return prob_X, prob_E, pred
 | 
			
		||||
        # diffusion nag: P_t(G_{t-1} |G_t, C) = P_t(G_{t-1} |G_t) + P_t(C | G_{t-1}, G_t)
 | 
			
		||||
        # with condition = P_t(G_{t-1} |G_t, C)
 | 
			
		||||
        # with condition = P_t(A_{t-1} |A_t, y)
 | 
			
		||||
        prob_X, prob_E = get_prob(noisy_data)
 | 
			
		||||
        prob_X, prob_E, pred = get_prob(noisy_data)
 | 
			
		||||
 | 
			
		||||
        ### Guidance
 | 
			
		||||
        if self.guidance_target is not None and self.guide_scale is not None and self.guide_scale != 1:
 | 
			
		||||
            uncon_prob_X, uncon_prob_E = get_prob(noisy_data, unconditioned=True)
 | 
			
		||||
            uncon_prob_X, uncon_prob_E, pred = get_prob(noisy_data, unconditioned=True)
 | 
			
		||||
            prob_X = uncon_prob_X * (prob_X / uncon_prob_X.clamp_min(1e-10)) ** self.guide_scale  
 | 
			
		||||
            prob_E = uncon_prob_E * (prob_E / uncon_prob_E.clamp_min(1e-10)) ** self.guide_scale  
 | 
			
		||||
            prob_X = prob_X / prob_X.sum(dim=-1, keepdim=True).clamp_min(1e-10)
 | 
			
		||||
@@ -647,32 +686,120 @@ class Graph_DiT(pl.LightningModule):
 | 
			
		||||
        assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-4).all()
 | 
			
		||||
        assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-4).all()
 | 
			
		||||
 | 
			
		||||
        sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask, step=s[0,0].item())
 | 
			
		||||
        # sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask, step=s[0,0].item())
 | 
			
		||||
 | 
			
		||||
        # sample multiple times and get the best score arch...
 | 
			
		||||
 | 
			
		||||
        sample_num = 100
 | 
			
		||||
        num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output']
 | 
			
		||||
        op_type = {
 | 
			
		||||
            'input': 0,
 | 
			
		||||
            'nor_conv_1x1': 1,
 | 
			
		||||
            'nor_conv_3x3': 2,
 | 
			
		||||
            'avg_pool_3x3': 3,
 | 
			
		||||
            'skip_connect': 4,
 | 
			
		||||
            'none': 5,
 | 
			
		||||
            'output': 6,
 | 
			
		||||
        }
 | 
			
		||||
        def check_valid_graph(nodes, edges):
 | 
			
		||||
            nodes = [num_to_op[i] for i in nodes]
 | 
			
		||||
            if len(nodes) != edges.shape[0] or len(nodes) != edges.shape[1]:
 | 
			
		||||
                return False
 | 
			
		||||
            if nodes[0] != 'input' or nodes[-1] != 'output':
 | 
			
		||||
                return False
 | 
			
		||||
            for i in range(0, len(nodes)):
 | 
			
		||||
                if edges[i][i] == 1:
 | 
			
		||||
                    return False
 | 
			
		||||
            for i in range(1, len(nodes) - 1):
 | 
			
		||||
                if nodes[i] not in op_type or nodes[i] == 'input' or nodes[i] == 'output':
 | 
			
		||||
                    return False
 | 
			
		||||
            for i in range(0, len(nodes)):
 | 
			
		||||
                for j in range(i, len(nodes)):
 | 
			
		||||
                    if edges[i, j] == 1 and nodes[j] == 'input':
 | 
			
		||||
                        return False
 | 
			
		||||
            for i in range(0, len(nodes)):
 | 
			
		||||
                for j in range(i, len(nodes)):
 | 
			
		||||
                    if edges[i, j] == 1 and nodes[i] == 'output':
 | 
			
		||||
                        return False
 | 
			
		||||
            flag = 0
 | 
			
		||||
            for i in range(0,len(nodes)):
 | 
			
		||||
                if edges[i,-1] == 1:
 | 
			
		||||
                    flag = 1
 | 
			
		||||
                    break
 | 
			
		||||
            if flag == 0: return False
 | 
			
		||||
            return True
 | 
			
		||||
 | 
			
		||||
        class Args:
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
        def get_score(sampled_s):
 | 
			
		||||
            x_list = sampled_s.X.unbind(dim=0)
 | 
			
		||||
            e_list = sampled_s.E.unbind(dim=0)
 | 
			
		||||
            valid_rlt = [check_valid_graph(x_list[i].cpu().numpy(), e_list[i].cpu().numpy()) for i in range(len(x_list))]
 | 
			
		||||
            from graph_dit.naswot.naswot.score_networks import get_nasbench201_nodes_score
 | 
			
		||||
            score = []
 | 
			
		||||
            
 | 
			
		||||
            for i in range(len(x_list)):
 | 
			
		||||
                if valid_rlt[i]:
 | 
			
		||||
                    nodes = [num_to_op[j] for j in x_list[i].cpu().numpy()]
 | 
			
		||||
                    # edges = e_list[i].cpu().numpy()
 | 
			
		||||
                    score.append(get_nasbench201_nodes_score(nodes,train_loader=self.train_loader,searchspace=self.searchspace,device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu") , args=self.args))
 | 
			
		||||
                else:
 | 
			
		||||
                    score.append(-1)
 | 
			
		||||
            return torch.tensor(score, dtype=torch.float32, requires_grad=True).to(x_list[0].device)
 | 
			
		||||
 | 
			
		||||
        sample_num = 10
 | 
			
		||||
        best_arch = None
 | 
			
		||||
        best_score = -1e8
 | 
			
		||||
        best_score_int = -1e8
 | 
			
		||||
        score = torch.ones(100, dtype=torch.float32, requires_grad=True) * -1e8
 | 
			
		||||
 | 
			
		||||
        for i in range(sample_num):
 | 
			
		||||
            sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask, step=s[0,0].item())
 | 
			
		||||
            score = get_score(sampled_s)
 | 
			
		||||
            if score > best_score:
 | 
			
		||||
            print(f'score: {score}')
 | 
			
		||||
            print(f'score.shape: {score.shape}')
 | 
			
		||||
            print(f'torch.sum(score): {torch.sum(score)}')
 | 
			
		||||
            sum_score = torch.sum(score)
 | 
			
		||||
            print(f'sum_score: {sum_score}')
 | 
			
		||||
            if sum_score > best_score_int:
 | 
			
		||||
                best_score_int = sum_score
 | 
			
		||||
                best_score = score
 | 
			
		||||
                best_arch = sampled_s
 | 
			
		||||
 | 
			
		||||
        X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).float()
 | 
			
		||||
        E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).float()
 | 
			
		||||
        # print(f'prob_X: {prob_X.shape}, prob_E: {prob_E.shape}')
 | 
			
		||||
        
 | 
			
		||||
        # best_arch = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask, step=s[0,0].item())
 | 
			
		||||
 | 
			
		||||
        # X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).float()
 | 
			
		||||
        # E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).float()
 | 
			
		||||
        print(f'best_arch.X: {best_arch.X.shape}, best_arch.E: {best_arch.E.shape}') # 100 8 8, bs n n, 100 8 8 2, bs n n 2
 | 
			
		||||
 | 
			
		||||
        print(f'best_arch.X: {best_arch.X}, best_arch.E: {best_arch.E}')
 | 
			
		||||
        X_s = F.one_hot(best_arch.X, num_classes=self.Xdim_output).float()
 | 
			
		||||
        E_s = F.one_hot(best_arch.E, num_classes=self.Edim_output).float()
 | 
			
		||||
        print(f'X_s: {X_s}, E_s: {E_s}')
 | 
			
		||||
 | 
			
		||||
        # NASWOT score
 | 
			
		||||
        target_score = torch.tensor([3000.0])
 | 
			
		||||
        target_score = torch.ones(100, requires_grad=True) * 2000.0
 | 
			
		||||
        target_score = target_score.to(X_s.device)
 | 
			
		||||
 | 
			
		||||
        # compute loss mse(cur_score - target_score)
 | 
			
		||||
        mse_loss = torch.nn.MSELoss()
 | 
			
		||||
        print(f'best_score: {best_score.shape}, target_score: {target_score.shape}')
 | 
			
		||||
        print(f'best_score.requires_grad: {best_score.requires_grad}, target_score.requires_grad: {target_score.requires_grad}')
 | 
			
		||||
        loss = mse_loss(best_score, target_score)
 | 
			
		||||
        loss.backward(retain_graph=True)
 | 
			
		||||
 | 
			
		||||
        # loss backward = gradient
 | 
			
		||||
 | 
			
		||||
        # get prob.X, prob_E gradient
 | 
			
		||||
        x_grad = pred.X.grad
 | 
			
		||||
        e_grad = pred.E.grad
 | 
			
		||||
 | 
			
		||||
        beta_ratio = 0.5
 | 
			
		||||
        # x_current = pred.X - beta_ratio * x_grad
 | 
			
		||||
        # e_current = pred.E - beta_ratio * e_grad
 | 
			
		||||
        E_s = pred.X - beta_ratio * x_grad
 | 
			
		||||
        X_s = pred.E - beta_ratio * e_grad
 | 
			
		||||
 | 
			
		||||
        # update prob.X prob_E with using gradient
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user