wrote the get_nasbench201_idx_score
This commit is contained in:
		| @@ -9,6 +9,7 @@ from scores import get_score_func | ||||
| from scipy import stats | ||||
| import time | ||||
| # from pycls.models.nas.nas import Cell | ||||
| from models import get_cell_based_tiny_net | ||||
| from utils import add_dropout, init_network  | ||||
|  | ||||
| parser = argparse.ArgumentParser(description='NAS Without Training') | ||||
| @@ -56,11 +57,22 @@ def get_batch_jacobian(net, x, target, device, args=None): | ||||
|     jacob = x.grad.detach() | ||||
|     return jacob, target.detach(), y.detach(), out.detach() | ||||
|  | ||||
| def get_nasbench201_idx_score(idx, train_loader, searchspace, args): | ||||
|     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||||
| def get_nasbench201_nodes_score(nodes, train_loader, searchspace, args, device): | ||||
|     op_type = { | ||||
|     'input': 0, | ||||
|     'nor_conv_1x1': 1, | ||||
|     'nor_conv_3x3': 2, | ||||
|     'avg_pool_3x3': 3, | ||||
|     'skip_connect': 4, | ||||
|     'none': 5, | ||||
|     'output': 6, | ||||
|     } | ||||
|  | ||||
| def get_nasbench201_idx_score(idx, train_loader, searchspace, args, device): | ||||
|     # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||||
|     # searchspace = nasspace.get_search_space(args) | ||||
|     if 'valid' in args.dataset: | ||||
|         args.dataset = args.dataset.replace('-valid', '') | ||||
|     # if 'valid' in args.dataset: | ||||
|     #     args.dataset = args.dataset.replace('-valid', '') | ||||
|  | ||||
|     # train_loader = datasets.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args) | ||||
|     # os.makedirs(args.save_loc, exist_ok=True) | ||||
| @@ -182,17 +194,17 @@ train_loader = datasets.get_data(args.dataset, args.data_loc, args.trainval, arg | ||||
| print('start to get score') | ||||
| print('5374') | ||||
| start_time = time.time() | ||||
| print(get_nasbench201_idx_score(5374,train_loader=train_loader, searchspace=searchspace, args=args)) | ||||
| print(get_nasbench201_idx_score(5374,train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))) | ||||
| end_time = time.time() | ||||
| print(f'5374 time: {end_time - start_time}') | ||||
| print('5375') | ||||
| start_time = time.time() | ||||
| print(get_nasbench201_idx_score(5375,train_loader=train_loader, searchspace=searchspace, args=args)) | ||||
| print(get_nasbench201_idx_score(5375,train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))) | ||||
| end_time = time.time() | ||||
| print(f'5375 time: {end_time - start_time}') | ||||
| print('5376') | ||||
| start_time = time.time() | ||||
| print(get_nasbench201_idx_score(5376,train_loader=train_loader, searchspace=searchspace, args=args)) | ||||
| print(get_nasbench201_idx_score(5376,train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))) | ||||
| end_time = time.time() | ||||
| print(f'5376 time: {end_time - start_time}') | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user