wrote the get_nasbench201_idx_score
This commit is contained in:
parent
93ced7700d
commit
caf462f310
@ -9,6 +9,7 @@ from scores import get_score_func
|
||||
from scipy import stats
|
||||
import time
|
||||
# from pycls.models.nas.nas import Cell
|
||||
from models import get_cell_based_tiny_net
|
||||
from utils import add_dropout, init_network
|
||||
|
||||
parser = argparse.ArgumentParser(description='NAS Without Training')
|
||||
@ -56,11 +57,22 @@ def get_batch_jacobian(net, x, target, device, args=None):
|
||||
jacob = x.grad.detach()
|
||||
return jacob, target.detach(), y.detach(), out.detach()
|
||||
|
||||
def get_nasbench201_idx_score(idx, train_loader, searchspace, args):
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
def get_nasbench201_nodes_score(nodes, train_loader, searchspace, args, device):
|
||||
op_type = {
|
||||
'input': 0,
|
||||
'nor_conv_1x1': 1,
|
||||
'nor_conv_3x3': 2,
|
||||
'avg_pool_3x3': 3,
|
||||
'skip_connect': 4,
|
||||
'none': 5,
|
||||
'output': 6,
|
||||
}
|
||||
|
||||
def get_nasbench201_idx_score(idx, train_loader, searchspace, args, device):
|
||||
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
# searchspace = nasspace.get_search_space(args)
|
||||
if 'valid' in args.dataset:
|
||||
args.dataset = args.dataset.replace('-valid', '')
|
||||
# if 'valid' in args.dataset:
|
||||
# args.dataset = args.dataset.replace('-valid', '')
|
||||
|
||||
# train_loader = datasets.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args)
|
||||
# os.makedirs(args.save_loc, exist_ok=True)
|
||||
@ -182,17 +194,17 @@ train_loader = datasets.get_data(args.dataset, args.data_loc, args.trainval, arg
|
||||
print('start to get score')
|
||||
print('5374')
|
||||
start_time = time.time()
|
||||
print(get_nasbench201_idx_score(5374,train_loader=train_loader, searchspace=searchspace, args=args))
|
||||
print(get_nasbench201_idx_score(5374,train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")))
|
||||
end_time = time.time()
|
||||
print(f'5374 time: {end_time - start_time}')
|
||||
print('5375')
|
||||
start_time = time.time()
|
||||
print(get_nasbench201_idx_score(5375,train_loader=train_loader, searchspace=searchspace, args=args))
|
||||
print(get_nasbench201_idx_score(5375,train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")))
|
||||
end_time = time.time()
|
||||
print(f'5375 time: {end_time - start_time}')
|
||||
print('5376')
|
||||
start_time = time.time()
|
||||
print(get_nasbench201_idx_score(5376,train_loader=train_loader, searchspace=searchspace, args=args))
|
||||
print(get_nasbench201_idx_score(5376,train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")))
|
||||
end_time = time.time()
|
||||
print(f'5376 time: {end_time - start_time}')
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user