try to pack the naswot
This commit is contained in:
parent
13f77963d0
commit
fd4f0452f9
@ -11,7 +11,7 @@ __all__ = ['change_key', 'get_cell_based_tiny_net', 'get_search_spaces', 'get_ci
|
||||
]
|
||||
|
||||
# useful modules
|
||||
from config_utils import dict2config
|
||||
from naswot.config_utils import dict2config
|
||||
from .SharedUtils import change_key
|
||||
from .cell_searchs import CellStructure, CellArchitectures
|
||||
|
@ -1,16 +1,16 @@
|
||||
from models import get_cell_based_tiny_net, get_search_spaces
|
||||
from naswot.models import get_cell_based_tiny_net, get_search_spaces
|
||||
from nas_201_api import NASBench201API as API
|
||||
from nasbench import api as nasbench101api
|
||||
from nas_101_api.model import Network
|
||||
from nas_101_api.model_spec import ModelSpec
|
||||
from naswot.nas_101_api.model import Network
|
||||
from naswot.nas_101_api.model_spec import ModelSpec
|
||||
import itertools
|
||||
import random
|
||||
import numpy as np
|
||||
from models.cell_searchs.genotypes import Structure
|
||||
from naswot.models.cell_searchs.genotypes import Structure
|
||||
from copy import deepcopy
|
||||
from pycls.models.nas.nas import NetworkImageNet, NetworkCIFAR
|
||||
from pycls.models.anynet import AnyNet
|
||||
from pycls.models.nas.genotypes import GENOTYPES, Genotype
|
||||
from naswot.pycls.models.nas.nas import NetworkImageNet, NetworkCIFAR
|
||||
from naswot.pycls.models.anynet import AnyNet
|
||||
from naswot.pycls.models.nas.genotypes import GENOTYPES, Genotype
|
||||
import json
|
||||
import torch
|
||||
|
||||
@ -26,6 +26,7 @@ class Nasbench201:
|
||||
print(config)
|
||||
config['num_classes'] = 1
|
||||
network = get_cell_based_tiny_net(config)
|
||||
print(network)
|
||||
return network
|
||||
def __iter__(self):
|
||||
for uid in range(len(self)):
|
0
graph_dit/naswot/naswot/pycls/core/__init__.py
Normal file
0
graph_dit/naswot/naswot/pycls/core/__init__.py
Normal file
0
graph_dit/naswot/naswot/pycls/models/__init__.py
Normal file
0
graph_dit/naswot/naswot/pycls/models/__init__.py
Normal file
@ -1,16 +1,16 @@
|
||||
import argparse
|
||||
import nasspace
|
||||
from naswot import nasspace
|
||||
import datasets
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
import os
|
||||
from scores import get_score_func
|
||||
from naswot.scores import get_score_func
|
||||
from scipy import stats
|
||||
import time
|
||||
# from pycls.models.nas.nas import Cell
|
||||
from models import get_cell_based_tiny_net
|
||||
from utils import add_dropout, init_network
|
||||
from naswot.models import get_cell_based_tiny_net
|
||||
from naswot.utils import add_dropout, init_network
|
||||
|
||||
parser = argparse.ArgumentParser(description='NAS Without Training')
|
||||
parser.add_argument('--data_loc', default='../cifardata/', type=str, help='dataset folder')
|
||||
@ -57,11 +57,95 @@ def get_batch_jacobian(net, x, target, device, args=None):
|
||||
jacob = x.grad.detach()
|
||||
return jacob, target.detach(), y.detach(), out.detach()
|
||||
|
||||
def get_nasbench201_nodes_score(nodes, train_loader, searchspace, args, device):
|
||||
def get_config_by_nodes(nodes):
|
||||
num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output']
|
||||
arch_str = '|' + num_to_op[nodes[1]] + '~0|+|' + \
|
||||
num_to_op[nodes[2]] + '~0|' + num_to_op[nodes[3]] + '~1|+|' + \
|
||||
num_to_op[nodes[4]] + '~0|' + num_to_op[nodes[5]] + '~1|' + num_to_op[nodes[6]] + '~2|'
|
||||
config = {
|
||||
'name': 'infer.tiny',
|
||||
'C': 16,
|
||||
'N': 5,
|
||||
'arch_str': arch_str,
|
||||
'num_classes': 10,
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
def get_nasbench201_nodes_score(nodes, train_loader, searchspace, args, device):
|
||||
assert len(nodes) == 8
|
||||
network = get_cell_based_tiny_net(get_config_by_nodes(nodes))
|
||||
try:
|
||||
if args.dropout:
|
||||
add_dropout(network, args.sigma)
|
||||
if args.init != '':
|
||||
init_network(network, args.init)
|
||||
if 'hook_' in args.score:
|
||||
network.K = np.zeros((args.batch_size, args.batch_size))
|
||||
def counting_forward_hook(module, inp, out):
|
||||
try:
|
||||
if not module.visited_backwards:
|
||||
return
|
||||
if isinstance(inp, tuple):
|
||||
# print(len(inp))
|
||||
inp = inp[0]
|
||||
inp = inp.view(inp.size(0), -1)
|
||||
x = (inp > 0).float()
|
||||
K = x @ x.t()
|
||||
K2 = (1.-x) @ (1.-x.t())
|
||||
network.K = network.K + K.cpu().numpy() + K2.cpu().numpy()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def counting_backward_hook(module, inp, out):
|
||||
module.visited_backwards = True
|
||||
|
||||
|
||||
for name, module in network.named_modules():
|
||||
if 'ReLU' in str(type(module)):
|
||||
#hooks[name] = module.register_forward_hook(counting_hook)
|
||||
module.register_forward_hook(counting_forward_hook)
|
||||
module.register_backward_hook(counting_backward_hook)
|
||||
|
||||
network = network.to(device)
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
s = []
|
||||
for j in range(args.maxofn):
|
||||
data_iterator = iter(train_loader)
|
||||
x, target = next(data_iterator)
|
||||
x2 = torch.clone(x)
|
||||
x2 = x2.to(device)
|
||||
x, target = x.to(device), target.to(device)
|
||||
jacobs, labels, y, out = get_batch_jacobian(network, x, target, device, args)
|
||||
|
||||
|
||||
if 'hook_' in args.score:
|
||||
network(x2.to(device))
|
||||
s.append(get_score_func(args.score)(network.K, target))
|
||||
else:
|
||||
s.append(get_score_func(args.score)(jacobs, labels))
|
||||
return np.mean(s)
|
||||
scores[i] = np.mean(s)
|
||||
accs[i] = searchspace.get_final_accuracy(uid, acc_type, args.trainval)
|
||||
accs_ = accs[~np.isnan(scores)]
|
||||
scores_ = scores[~np.isnan(scores)]
|
||||
numnan = np.isnan(scores).sum()
|
||||
tau, p = stats.kendalltau(accs_[:max(i-numnan, 1)], scores_[:max(i-numnan, 1)])
|
||||
print(f'{tau}')
|
||||
if i % 1000 == 0:
|
||||
np.save(filename, scores)
|
||||
np.save(accfilename, accs)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print('final result')
|
||||
return np.nan
|
||||
|
||||
|
||||
|
||||
|
||||
def get_nasbench201_idx_score(idx, train_loader, searchspace, args, device):
|
||||
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
# searchspace = nasspace.get_search_space(args)
|
||||
@ -181,12 +265,19 @@ if 'valid' in args.dataset:
|
||||
args.dataset = args.dataset.replace('-valid', '')
|
||||
print('start to get search space')
|
||||
start_time = time.time()
|
||||
print(get_config_by_nodes(nodes=[0,2,2,3,4,2,4,6]))
|
||||
end_time = time.time()
|
||||
start_time = time.time()
|
||||
searchspace = nasspace.get_search_space(args)
|
||||
end_time = time.time()
|
||||
print(f'search space time: {end_time - start_time}')
|
||||
train_loader = datasets.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args)
|
||||
print('start to get score')
|
||||
print('5374')
|
||||
num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output']
|
||||
start_time = time.time()
|
||||
print(get_nasbench201_nodes_score(nodes=[0,2,2,3,4,2,4,6],train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")))
|
||||
end_time = time.time()
|
||||
start_time = time.time()
|
||||
print(get_nasbench201_idx_score(5374,train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")))
|
||||
end_time = time.time()
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user