try to pack the naswot
This commit is contained in:
		| @@ -11,7 +11,7 @@ __all__ = ['change_key', 'get_cell_based_tiny_net', 'get_search_spaces', 'get_ci | |||||||
|            ] |            ] | ||||||
| 
 | 
 | ||||||
| # useful modules | # useful modules | ||||||
| from config_utils import dict2config | from naswot.config_utils import dict2config | ||||||
| from .SharedUtils import change_key | from .SharedUtils import change_key | ||||||
| from .cell_searchs import CellStructure, CellArchitectures | from .cell_searchs import CellStructure, CellArchitectures | ||||||
| 
 | 
 | ||||||
| @@ -1,16 +1,16 @@ | |||||||
| from models import get_cell_based_tiny_net, get_search_spaces | from naswot.models import get_cell_based_tiny_net, get_search_spaces | ||||||
| from nas_201_api import NASBench201API as API | from nas_201_api import NASBench201API as API | ||||||
| from nasbench import api as nasbench101api | from nasbench import api as nasbench101api | ||||||
| from nas_101_api.model import Network | from naswot.nas_101_api.model import Network | ||||||
| from nas_101_api.model_spec import ModelSpec | from naswot.nas_101_api.model_spec import ModelSpec | ||||||
| import itertools | import itertools | ||||||
| import random | import random | ||||||
| import numpy as np | import numpy as np | ||||||
| from models.cell_searchs.genotypes import Structure | from naswot.models.cell_searchs.genotypes import Structure | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| from pycls.models.nas.nas import NetworkImageNet, NetworkCIFAR | from naswot.pycls.models.nas.nas import NetworkImageNet, NetworkCIFAR | ||||||
| from pycls.models.anynet import AnyNet | from naswot.pycls.models.anynet import AnyNet | ||||||
| from pycls.models.nas.genotypes import GENOTYPES, Genotype | from naswot.pycls.models.nas.genotypes import GENOTYPES, Genotype | ||||||
| import json | import json | ||||||
| import torch | import torch | ||||||
| 
 | 
 | ||||||
| @@ -26,6 +26,7 @@ class Nasbench201: | |||||||
|         print(config) |         print(config) | ||||||
|         config['num_classes'] = 1 |         config['num_classes'] = 1 | ||||||
|         network = get_cell_based_tiny_net(config) |         network = get_cell_based_tiny_net(config) | ||||||
|  |         print(network) | ||||||
|         return network |         return network | ||||||
|     def __iter__(self): |     def __iter__(self): | ||||||
|         for uid in range(len(self)): |         for uid in range(len(self)): | ||||||
							
								
								
									
										0
									
								
								graph_dit/naswot/naswot/pycls/core/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								graph_dit/naswot/naswot/pycls/core/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										0
									
								
								graph_dit/naswot/naswot/pycls/models/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								graph_dit/naswot/naswot/pycls/models/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -1,16 +1,16 @@ | |||||||
| import argparse | import argparse | ||||||
| import nasspace | from naswot import nasspace | ||||||
| import datasets | import datasets | ||||||
| import random | import random | ||||||
| import numpy as np | import numpy as np | ||||||
| import torch | import torch | ||||||
| import os | import os | ||||||
| from scores import get_score_func | from naswot.scores import get_score_func | ||||||
| from scipy import stats | from scipy import stats | ||||||
| import time | import time | ||||||
| # from pycls.models.nas.nas import Cell | # from pycls.models.nas.nas import Cell | ||||||
| from models import get_cell_based_tiny_net | from naswot.models import get_cell_based_tiny_net | ||||||
| from utils import add_dropout, init_network  | from naswot.utils import add_dropout, init_network  | ||||||
| 
 | 
 | ||||||
| parser = argparse.ArgumentParser(description='NAS Without Training') | parser = argparse.ArgumentParser(description='NAS Without Training') | ||||||
| parser.add_argument('--data_loc', default='../cifardata/', type=str, help='dataset folder') | parser.add_argument('--data_loc', default='../cifardata/', type=str, help='dataset folder') | ||||||
| @@ -57,8 +57,92 @@ def get_batch_jacobian(net, x, target, device, args=None): | |||||||
|     jacob = x.grad.detach() |     jacob = x.grad.detach() | ||||||
|     return jacob, target.detach(), y.detach(), out.detach() |     return jacob, target.detach(), y.detach(), out.detach() | ||||||
| 
 | 
 | ||||||
| def get_nasbench201_nodes_score(nodes, train_loader, searchspace, args, device): | def get_config_by_nodes(nodes): | ||||||
|     num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output'] |     num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output'] | ||||||
|  |     arch_str = '|' + num_to_op[nodes[1]] + '~0|+|' + \ | ||||||
|  |                 num_to_op[nodes[2]] + '~0|' + num_to_op[nodes[3]] + '~1|+|' + \ | ||||||
|  |                 num_to_op[nodes[4]] + '~0|' + num_to_op[nodes[5]] + '~1|' + num_to_op[nodes[6]] + '~2|' | ||||||
|  |     config = { | ||||||
|  |         'name': 'infer.tiny', | ||||||
|  |         'C': 16, | ||||||
|  |         'N': 5, | ||||||
|  |         'arch_str': arch_str, | ||||||
|  |         'num_classes': 10, | ||||||
|  |     } | ||||||
|  |     return config | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_nasbench201_nodes_score(nodes, train_loader, searchspace, args, device): | ||||||
|  |     assert len(nodes) == 8 | ||||||
|  |     network = get_cell_based_tiny_net(get_config_by_nodes(nodes)) | ||||||
|  |     try: | ||||||
|  |         if args.dropout: | ||||||
|  |             add_dropout(network, args.sigma) | ||||||
|  |         if args.init != '': | ||||||
|  |             init_network(network, args.init) | ||||||
|  |         if 'hook_' in args.score: | ||||||
|  |             network.K = np.zeros((args.batch_size, args.batch_size)) | ||||||
|  |             def counting_forward_hook(module, inp, out): | ||||||
|  |                 try: | ||||||
|  |                     if not module.visited_backwards: | ||||||
|  |                         return | ||||||
|  |                     if isinstance(inp, tuple): | ||||||
|  |                         # print(len(inp)) | ||||||
|  |                         inp = inp[0] | ||||||
|  |                     inp = inp.view(inp.size(0), -1) | ||||||
|  |                     x = (inp > 0).float() | ||||||
|  |                     K = x @ x.t() | ||||||
|  |                     K2 = (1.-x) @ (1.-x.t()) | ||||||
|  |                     network.K = network.K + K.cpu().numpy() + K2.cpu().numpy() | ||||||
|  |                 except: | ||||||
|  |                     pass | ||||||
|  | 
 | ||||||
|  |                  | ||||||
|  |             def counting_backward_hook(module, inp, out): | ||||||
|  |                 module.visited_backwards = True | ||||||
|  | 
 | ||||||
|  |                  | ||||||
|  |             for name, module in network.named_modules(): | ||||||
|  |                 if 'ReLU' in str(type(module)): | ||||||
|  |                     #hooks[name] = module.register_forward_hook(counting_hook) | ||||||
|  |                     module.register_forward_hook(counting_forward_hook) | ||||||
|  |                     module.register_backward_hook(counting_backward_hook) | ||||||
|  | 
 | ||||||
|  |         network = network.to(device) | ||||||
|  |         random.seed(args.seed) | ||||||
|  |         np.random.seed(args.seed) | ||||||
|  |         torch.manual_seed(args.seed) | ||||||
|  |         s = [] | ||||||
|  |         for j in range(args.maxofn): | ||||||
|  |             data_iterator = iter(train_loader) | ||||||
|  |             x, target = next(data_iterator) | ||||||
|  |             x2 = torch.clone(x) | ||||||
|  |             x2 = x2.to(device) | ||||||
|  |             x, target = x.to(device), target.to(device) | ||||||
|  |             jacobs, labels, y, out = get_batch_jacobian(network, x, target, device, args) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |             if 'hook_' in args.score: | ||||||
|  |                 network(x2.to(device)) | ||||||
|  |                 s.append(get_score_func(args.score)(network.K, target)) | ||||||
|  |             else: | ||||||
|  |                 s.append(get_score_func(args.score)(jacobs, labels)) | ||||||
|  |         return np.mean(s) | ||||||
|  |         scores[i] = np.mean(s) | ||||||
|  |         accs[i] = searchspace.get_final_accuracy(uid, acc_type, args.trainval) | ||||||
|  |         accs_ = accs[~np.isnan(scores)] | ||||||
|  |         scores_ = scores[~np.isnan(scores)] | ||||||
|  |         numnan = np.isnan(scores).sum() | ||||||
|  |         tau, p = stats.kendalltau(accs_[:max(i-numnan, 1)], scores_[:max(i-numnan, 1)]) | ||||||
|  |         print(f'{tau}') | ||||||
|  |         if i % 1000 == 0: | ||||||
|  |             np.save(filename, scores) | ||||||
|  |             np.save(accfilename, accs) | ||||||
|  |     except Exception as e: | ||||||
|  |         print(e) | ||||||
|  |     print('final result') | ||||||
|  |     return np.nan | ||||||
|  |      | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @@ -181,12 +265,19 @@ if 'valid' in args.dataset: | |||||||
|     args.dataset = args.dataset.replace('-valid', '') |     args.dataset = args.dataset.replace('-valid', '') | ||||||
| print('start to get search space') | print('start to get search space') | ||||||
| start_time = time.time() | start_time = time.time() | ||||||
|  | print(get_config_by_nodes(nodes=[0,2,2,3,4,2,4,6])) | ||||||
|  | end_time = time.time() | ||||||
|  | start_time = time.time() | ||||||
| searchspace = nasspace.get_search_space(args) | searchspace = nasspace.get_search_space(args) | ||||||
| end_time = time.time() | end_time = time.time() | ||||||
| print(f'search space time: {end_time - start_time}') | print(f'search space time: {end_time - start_time}') | ||||||
| train_loader = datasets.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args) | train_loader = datasets.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args) | ||||||
| print('start to get score') | print('start to get score') | ||||||
| print('5374') | print('5374') | ||||||
|  | num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output'] | ||||||
|  | start_time = time.time() | ||||||
|  | print(get_nasbench201_nodes_score(nodes=[0,2,2,3,4,2,4,6],train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))) | ||||||
|  | end_time = time.time() | ||||||
| start_time = time.time() | start_time = time.time() | ||||||
| print(get_nasbench201_idx_score(5374,train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))) | print(get_nasbench201_idx_score(5374,train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))) | ||||||
| end_time = time.time() | end_time = time.time() | ||||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user