wrote the get_nasbench201_idx_score

This commit is contained in:
mhz 2024-07-27 17:47:40 +02:00
parent 93ced7700d
commit caf462f310

View File

@ -9,6 +9,7 @@ from scores import get_score_func
from scipy import stats
import time
# from pycls.models.nas.nas import Cell
from models import get_cell_based_tiny_net
from utils import add_dropout, init_network
parser = argparse.ArgumentParser(description='NAS Without Training')
@ -56,11 +57,22 @@ def get_batch_jacobian(net, x, target, device, args=None):
jacob = x.grad.detach()
return jacob, target.detach(), y.detach(), out.detach()
def get_nasbench201_idx_score(idx, train_loader, searchspace, args):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def get_nasbench201_nodes_score(nodes, train_loader, searchspace, args, device):
op_type = {
'input': 0,
'nor_conv_1x1': 1,
'nor_conv_3x3': 2,
'avg_pool_3x3': 3,
'skip_connect': 4,
'none': 5,
'output': 6,
}
def get_nasbench201_idx_score(idx, train_loader, searchspace, args, device):
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# searchspace = nasspace.get_search_space(args)
if 'valid' in args.dataset:
args.dataset = args.dataset.replace('-valid', '')
# if 'valid' in args.dataset:
# args.dataset = args.dataset.replace('-valid', '')
# train_loader = datasets.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args)
# os.makedirs(args.save_loc, exist_ok=True)
@ -182,17 +194,17 @@ train_loader = datasets.get_data(args.dataset, args.data_loc, args.trainval, arg
print('start to get score')
print('5374')
start_time = time.time()
print(get_nasbench201_idx_score(5374,train_loader=train_loader, searchspace=searchspace, args=args))
print(get_nasbench201_idx_score(5374,train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")))
end_time = time.time()
print(f'5374 time: {end_time - start_time}')
print('5375')
start_time = time.time()
print(get_nasbench201_idx_score(5375,train_loader=train_loader, searchspace=searchspace, args=args))
print(get_nasbench201_idx_score(5375,train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")))
end_time = time.time()
print(f'5375 time: {end_time - start_time}')
print('5376')
start_time = time.time()
print(get_nasbench201_idx_score(5376,train_loader=train_loader, searchspace=searchspace, args=args))
print(get_nasbench201_idx_score(5376,train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")))
end_time = time.time()
print(f'5376 time: {end_time - start_time}')