try to pack the naswot

This commit is contained in:
mhz 2024-07-28 23:45:02 +02:00
parent 13f77963d0
commit fd4f0452f9
102 changed files with 112 additions and 13 deletions

View File

@ -11,7 +11,7 @@ __all__ = ['change_key', 'get_cell_based_tiny_net', 'get_search_spaces', 'get_ci
]
# useful modules
from config_utils import dict2config
from naswot.config_utils import dict2config
from .SharedUtils import change_key
from .cell_searchs import CellStructure, CellArchitectures

View File

@ -1,16 +1,16 @@
from models import get_cell_based_tiny_net, get_search_spaces
from naswot.models import get_cell_based_tiny_net, get_search_spaces
from nas_201_api import NASBench201API as API
from nasbench import api as nasbench101api
from nas_101_api.model import Network
from nas_101_api.model_spec import ModelSpec
from naswot.nas_101_api.model import Network
from naswot.nas_101_api.model_spec import ModelSpec
import itertools
import random
import numpy as np
from models.cell_searchs.genotypes import Structure
from naswot.models.cell_searchs.genotypes import Structure
from copy import deepcopy
from pycls.models.nas.nas import NetworkImageNet, NetworkCIFAR
from pycls.models.anynet import AnyNet
from pycls.models.nas.genotypes import GENOTYPES, Genotype
from naswot.pycls.models.nas.nas import NetworkImageNet, NetworkCIFAR
from naswot.pycls.models.anynet import AnyNet
from naswot.pycls.models.nas.genotypes import GENOTYPES, Genotype
import json
import torch
@ -26,6 +26,7 @@ class Nasbench201:
print(config)
config['num_classes'] = 1
network = get_cell_based_tiny_net(config)
print(network)
return network
def __iter__(self):
for uid in range(len(self)):

View File

@ -1,16 +1,16 @@
import argparse
import nasspace
from naswot import nasspace
import datasets
import random
import numpy as np
import torch
import os
from scores import get_score_func
from naswot.scores import get_score_func
from scipy import stats
import time
# from pycls.models.nas.nas import Cell
from models import get_cell_based_tiny_net
from utils import add_dropout, init_network
from naswot.models import get_cell_based_tiny_net
from naswot.utils import add_dropout, init_network
parser = argparse.ArgumentParser(description='NAS Without Training')
parser.add_argument('--data_loc', default='../cifardata/', type=str, help='dataset folder')
@ -57,11 +57,95 @@ def get_batch_jacobian(net, x, target, device, args=None):
jacob = x.grad.detach()
return jacob, target.detach(), y.detach(), out.detach()
def get_nasbench201_nodes_score(nodes, train_loader, searchspace, args, device):
def get_config_by_nodes(nodes):
num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output']
arch_str = '|' + num_to_op[nodes[1]] + '~0|+|' + \
num_to_op[nodes[2]] + '~0|' + num_to_op[nodes[3]] + '~1|+|' + \
num_to_op[nodes[4]] + '~0|' + num_to_op[nodes[5]] + '~1|' + num_to_op[nodes[6]] + '~2|'
config = {
'name': 'infer.tiny',
'C': 16,
'N': 5,
'arch_str': arch_str,
'num_classes': 10,
}
return config
def get_nasbench201_nodes_score(nodes, train_loader, searchspace, args, device):
assert len(nodes) == 8
network = get_cell_based_tiny_net(get_config_by_nodes(nodes))
try:
if args.dropout:
add_dropout(network, args.sigma)
if args.init != '':
init_network(network, args.init)
if 'hook_' in args.score:
network.K = np.zeros((args.batch_size, args.batch_size))
def counting_forward_hook(module, inp, out):
try:
if not module.visited_backwards:
return
if isinstance(inp, tuple):
# print(len(inp))
inp = inp[0]
inp = inp.view(inp.size(0), -1)
x = (inp > 0).float()
K = x @ x.t()
K2 = (1.-x) @ (1.-x.t())
network.K = network.K + K.cpu().numpy() + K2.cpu().numpy()
except:
pass
def counting_backward_hook(module, inp, out):
module.visited_backwards = True
for name, module in network.named_modules():
if 'ReLU' in str(type(module)):
#hooks[name] = module.register_forward_hook(counting_hook)
module.register_forward_hook(counting_forward_hook)
module.register_backward_hook(counting_backward_hook)
network = network.to(device)
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
s = []
for j in range(args.maxofn):
data_iterator = iter(train_loader)
x, target = next(data_iterator)
x2 = torch.clone(x)
x2 = x2.to(device)
x, target = x.to(device), target.to(device)
jacobs, labels, y, out = get_batch_jacobian(network, x, target, device, args)
if 'hook_' in args.score:
network(x2.to(device))
s.append(get_score_func(args.score)(network.K, target))
else:
s.append(get_score_func(args.score)(jacobs, labels))
return np.mean(s)
scores[i] = np.mean(s)
accs[i] = searchspace.get_final_accuracy(uid, acc_type, args.trainval)
accs_ = accs[~np.isnan(scores)]
scores_ = scores[~np.isnan(scores)]
numnan = np.isnan(scores).sum()
tau, p = stats.kendalltau(accs_[:max(i-numnan, 1)], scores_[:max(i-numnan, 1)])
print(f'{tau}')
if i % 1000 == 0:
np.save(filename, scores)
np.save(accfilename, accs)
except Exception as e:
print(e)
print('final result')
return np.nan
def get_nasbench201_idx_score(idx, train_loader, searchspace, args, device):
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# searchspace = nasspace.get_search_space(args)
@ -181,12 +265,19 @@ if 'valid' in args.dataset:
args.dataset = args.dataset.replace('-valid', '')
print('start to get search space')
start_time = time.time()
print(get_config_by_nodes(nodes=[0,2,2,3,4,2,4,6]))
end_time = time.time()
start_time = time.time()
searchspace = nasspace.get_search_space(args)
end_time = time.time()
print(f'search space time: {end_time - start_time}')
train_loader = datasets.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args)
print('start to get score')
print('5374')
num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output']
start_time = time.time()
print(get_nasbench201_nodes_score(nodes=[0,2,2,3,4,2,4,6],train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")))
end_time = time.time()
start_time = time.time()
print(get_nasbench201_idx_score(5374,train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")))
end_time = time.time()

Some files were not shown because too many files have changed in this diff Show More