try to pack the naswot
This commit is contained in:
		| @@ -11,7 +11,7 @@ __all__ = ['change_key', 'get_cell_based_tiny_net', 'get_search_spaces', 'get_ci | ||||
|            ] | ||||
| 
 | ||||
| # useful modules | ||||
| from config_utils import dict2config | ||||
| from naswot.config_utils import dict2config | ||||
| from .SharedUtils import change_key | ||||
| from .cell_searchs import CellStructure, CellArchitectures | ||||
| 
 | ||||
| @@ -1,16 +1,16 @@ | ||||
| from models import get_cell_based_tiny_net, get_search_spaces | ||||
| from naswot.models import get_cell_based_tiny_net, get_search_spaces | ||||
| from nas_201_api import NASBench201API as API | ||||
| from nasbench import api as nasbench101api | ||||
| from nas_101_api.model import Network | ||||
| from nas_101_api.model_spec import ModelSpec | ||||
| from naswot.nas_101_api.model import Network | ||||
| from naswot.nas_101_api.model_spec import ModelSpec | ||||
| import itertools | ||||
| import random | ||||
| import numpy as np | ||||
| from models.cell_searchs.genotypes import Structure | ||||
| from naswot.models.cell_searchs.genotypes import Structure | ||||
| from copy import deepcopy | ||||
| from pycls.models.nas.nas import NetworkImageNet, NetworkCIFAR | ||||
| from pycls.models.anynet import AnyNet | ||||
| from pycls.models.nas.genotypes import GENOTYPES, Genotype | ||||
| from naswot.pycls.models.nas.nas import NetworkImageNet, NetworkCIFAR | ||||
| from naswot.pycls.models.anynet import AnyNet | ||||
| from naswot.pycls.models.nas.genotypes import GENOTYPES, Genotype | ||||
| import json | ||||
| import torch | ||||
| 
 | ||||
| @@ -26,6 +26,7 @@ class Nasbench201: | ||||
|         print(config) | ||||
|         config['num_classes'] = 1 | ||||
|         network = get_cell_based_tiny_net(config) | ||||
|         print(network) | ||||
|         return network | ||||
|     def __iter__(self): | ||||
|         for uid in range(len(self)): | ||||
							
								
								
									
										0
									
								
								graph_dit/naswot/naswot/pycls/core/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								graph_dit/naswot/naswot/pycls/core/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										0
									
								
								graph_dit/naswot/naswot/pycls/models/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								graph_dit/naswot/naswot/pycls/models/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -1,16 +1,16 @@ | ||||
| import argparse | ||||
| import nasspace | ||||
| from naswot import nasspace | ||||
| import datasets | ||||
| import random | ||||
| import numpy as np | ||||
| import torch | ||||
| import os | ||||
| from scores import get_score_func | ||||
| from naswot.scores import get_score_func | ||||
| from scipy import stats | ||||
| import time | ||||
| # from pycls.models.nas.nas import Cell | ||||
| from models import get_cell_based_tiny_net | ||||
| from utils import add_dropout, init_network  | ||||
| from naswot.models import get_cell_based_tiny_net | ||||
| from naswot.utils import add_dropout, init_network  | ||||
| 
 | ||||
| parser = argparse.ArgumentParser(description='NAS Without Training') | ||||
| parser.add_argument('--data_loc', default='../cifardata/', type=str, help='dataset folder') | ||||
| @@ -57,11 +57,95 @@ def get_batch_jacobian(net, x, target, device, args=None): | ||||
|     jacob = x.grad.detach() | ||||
|     return jacob, target.detach(), y.detach(), out.detach() | ||||
| 
 | ||||
| def get_nasbench201_nodes_score(nodes, train_loader, searchspace, args, device): | ||||
| def get_config_by_nodes(nodes): | ||||
|     num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output'] | ||||
|     arch_str = '|' + num_to_op[nodes[1]] + '~0|+|' + \ | ||||
|                 num_to_op[nodes[2]] + '~0|' + num_to_op[nodes[3]] + '~1|+|' + \ | ||||
|                 num_to_op[nodes[4]] + '~0|' + num_to_op[nodes[5]] + '~1|' + num_to_op[nodes[6]] + '~2|' | ||||
|     config = { | ||||
|         'name': 'infer.tiny', | ||||
|         'C': 16, | ||||
|         'N': 5, | ||||
|         'arch_str': arch_str, | ||||
|         'num_classes': 10, | ||||
|     } | ||||
|     return config | ||||
| 
 | ||||
| 
 | ||||
| def get_nasbench201_nodes_score(nodes, train_loader, searchspace, args, device): | ||||
|     assert len(nodes) == 8 | ||||
|     network = get_cell_based_tiny_net(get_config_by_nodes(nodes)) | ||||
|     try: | ||||
|         if args.dropout: | ||||
|             add_dropout(network, args.sigma) | ||||
|         if args.init != '': | ||||
|             init_network(network, args.init) | ||||
|         if 'hook_' in args.score: | ||||
|             network.K = np.zeros((args.batch_size, args.batch_size)) | ||||
|             def counting_forward_hook(module, inp, out): | ||||
|                 try: | ||||
|                     if not module.visited_backwards: | ||||
|                         return | ||||
|                     if isinstance(inp, tuple): | ||||
|                         # print(len(inp)) | ||||
|                         inp = inp[0] | ||||
|                     inp = inp.view(inp.size(0), -1) | ||||
|                     x = (inp > 0).float() | ||||
|                     K = x @ x.t() | ||||
|                     K2 = (1.-x) @ (1.-x.t()) | ||||
|                     network.K = network.K + K.cpu().numpy() + K2.cpu().numpy() | ||||
|                 except: | ||||
|                     pass | ||||
| 
 | ||||
|                  | ||||
|             def counting_backward_hook(module, inp, out): | ||||
|                 module.visited_backwards = True | ||||
| 
 | ||||
|                  | ||||
|             for name, module in network.named_modules(): | ||||
|                 if 'ReLU' in str(type(module)): | ||||
|                     #hooks[name] = module.register_forward_hook(counting_hook) | ||||
|                     module.register_forward_hook(counting_forward_hook) | ||||
|                     module.register_backward_hook(counting_backward_hook) | ||||
| 
 | ||||
|         network = network.to(device) | ||||
|         random.seed(args.seed) | ||||
|         np.random.seed(args.seed) | ||||
|         torch.manual_seed(args.seed) | ||||
|         s = [] | ||||
|         for j in range(args.maxofn): | ||||
|             data_iterator = iter(train_loader) | ||||
|             x, target = next(data_iterator) | ||||
|             x2 = torch.clone(x) | ||||
|             x2 = x2.to(device) | ||||
|             x, target = x.to(device), target.to(device) | ||||
|             jacobs, labels, y, out = get_batch_jacobian(network, x, target, device, args) | ||||
| 
 | ||||
| 
 | ||||
|             if 'hook_' in args.score: | ||||
|                 network(x2.to(device)) | ||||
|                 s.append(get_score_func(args.score)(network.K, target)) | ||||
|             else: | ||||
|                 s.append(get_score_func(args.score)(jacobs, labels)) | ||||
|         return np.mean(s) | ||||
|         scores[i] = np.mean(s) | ||||
|         accs[i] = searchspace.get_final_accuracy(uid, acc_type, args.trainval) | ||||
|         accs_ = accs[~np.isnan(scores)] | ||||
|         scores_ = scores[~np.isnan(scores)] | ||||
|         numnan = np.isnan(scores).sum() | ||||
|         tau, p = stats.kendalltau(accs_[:max(i-numnan, 1)], scores_[:max(i-numnan, 1)]) | ||||
|         print(f'{tau}') | ||||
|         if i % 1000 == 0: | ||||
|             np.save(filename, scores) | ||||
|             np.save(accfilename, accs) | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|     print('final result') | ||||
|     return np.nan | ||||
|      | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| def get_nasbench201_idx_score(idx, train_loader, searchspace, args, device): | ||||
|     # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||||
|     # searchspace = nasspace.get_search_space(args) | ||||
| @@ -181,12 +265,19 @@ if 'valid' in args.dataset: | ||||
|     args.dataset = args.dataset.replace('-valid', '') | ||||
| print('start to get search space') | ||||
| start_time = time.time() | ||||
| print(get_config_by_nodes(nodes=[0,2,2,3,4,2,4,6])) | ||||
| end_time = time.time() | ||||
| start_time = time.time() | ||||
| searchspace = nasspace.get_search_space(args) | ||||
| end_time = time.time() | ||||
| print(f'search space time: {end_time - start_time}') | ||||
| train_loader = datasets.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args) | ||||
| print('start to get score') | ||||
| print('5374') | ||||
| num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output'] | ||||
| start_time = time.time() | ||||
| print(get_nasbench201_nodes_score(nodes=[0,2,2,3,4,2,4,6],train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))) | ||||
| end_time = time.time() | ||||
| start_time = time.time() | ||||
| print(get_nasbench201_idx_score(5374,train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))) | ||||
| end_time = time.time() | ||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user