update the naswot code
This commit is contained in:
		
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							| @@ -49,7 +49,7 @@ def get_data(dataset, data_loc, trainval, batch_size, augtype, repeat, args, pin | ||||
|         val_acc_type = 'x-valid' | ||||
|      | ||||
|     if trainval and 'cifar10' in dataset: | ||||
|         cifar_split = load_config('config_utils/cifar-split.txt', None, None) | ||||
|         cifar_split = load_config('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/naswot/naswot/config_utils/cifardata/cifar-split.txt', None, None) | ||||
|         train_split, valid_split = cifar_split.train, cifar_split.valid | ||||
|         if repeat > 0: | ||||
|             train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, | ||||
|   | ||||
							
								
								
									
										1
									
								
								graph_dit/naswot/naswot/nasbench
									
									
									
									
									
										Submodule
									
								
							
							
								
								
								
								
								
							
						
						
									
										1
									
								
								graph_dit/naswot/naswot/nasbench
									
									
									
									
									
										Submodule
									
								
							 Submodule graph_dit/naswot/naswot/nasbench added at b94247037e
									
								
							| @@ -12,41 +12,41 @@ import time | ||||
| from naswot.models import get_cell_based_tiny_net | ||||
| from naswot.utils import add_dropout, init_network  | ||||
|  | ||||
| parser = argparse.ArgumentParser(description='NAS Without Training') | ||||
| parser.add_argument('--data_loc', default='../cifardata/', type=str, help='dataset folder') | ||||
| parser.add_argument('--api_loc', default='../NAS-Bench-201-v1_0-e61699.pth', | ||||
|                     type=str, help='path to API') | ||||
| parser.add_argument('--save_loc', default='results', type=str, help='folder to save results') | ||||
| parser.add_argument('--save_string', default='naswot', type=str, help='prefix of results file') | ||||
| parser.add_argument('--score', default='hook_logdet', type=str, help='the score to evaluate') | ||||
| parser.add_argument('--nasspace', default='nasbench201', type=str, help='the nas search space to use') | ||||
| parser.add_argument('--batch_size', default=128, type=int) | ||||
| parser.add_argument('--repeat', default=1, type=int, help='how often to repeat a single image with a batch') | ||||
| parser.add_argument('--augtype', default='none', type=str, help='which perturbations to use') | ||||
| parser.add_argument('--sigma', default=0.05, type=float, help='noise level if augtype is "gaussnoise"') | ||||
| parser.add_argument('--GPU', default='0', type=str) | ||||
| parser.add_argument('--seed', default=1, type=int) | ||||
| parser.add_argument('--init', default='', type=str) | ||||
| parser.add_argument('--trainval', action='store_true') | ||||
| parser.add_argument('--dropout', action='store_true') | ||||
| parser.add_argument('--dataset', default='cifar10', type=str) | ||||
| parser.add_argument('--maxofn', default=1, type=int, help='score is the max of this many evaluations of the network') | ||||
| parser.add_argument('--n_samples', default=100, type=int) | ||||
| parser.add_argument('--n_runs', default=500, type=int) | ||||
| parser.add_argument('--stem_out_channels', default=16, type=int, help='output channels of stem convolution (nasbench101)') | ||||
| parser.add_argument('--num_stacks', default=3, type=int, help='#stacks of modules (nasbench101)') | ||||
| parser.add_argument('--num_modules_per_stack', default=3, type=int, help='#modules per stack (nasbench101)') | ||||
| parser.add_argument('--num_labels', default=1, type=int, help='#classes (nasbench101)') | ||||
| # parser = argparse.ArgumentParser(description='NAS Without Training') | ||||
| # parser.add_argument('--data_loc', default='../cifardata/', type=str, help='dataset folder') | ||||
| # parser.add_argument('--api_loc', default='../NAS-Bench-201-v1_0-e61699.pth', | ||||
| #                     type=str, help='path to API') | ||||
| # parser.add_argument('--save_loc', default='results', type=str, help='folder to save results') | ||||
| # parser.add_argument('--save_string', default='naswot', type=str, help='prefix of results file') | ||||
| # parser.add_argument('--score', default='hook_logdet', type=str, help='the score to evaluate') | ||||
| # parser.add_argument('--nasspace', default='nasbench201', type=str, help='the nas search space to use') | ||||
| # parser.add_argument('--batch_size', default=128, type=int) | ||||
| # parser.add_argument('--repeat', default=1, type=int, help='how often to repeat a single image with a batch') | ||||
| # parser.add_argument('--augtype', default='none', type=str, help='which perturbations to use') | ||||
| # parser.add_argument('--sigma', default=0.05, type=float, help='noise level if augtype is "gaussnoise"') | ||||
| # parser.add_argument('--GPU', default='0', type=str) | ||||
| # parser.add_argument('--seed', default=1, type=int) | ||||
| # parser.add_argument('--init', default='', type=str) | ||||
| # parser.add_argument('--trainval', action='store_true') | ||||
| # parser.add_argument('--dropout', action='store_true') | ||||
| # parser.add_argument('--dataset', default='cifar10', type=str) | ||||
| # parser.add_argument('--maxofn', default=1, type=int, help='score is the max of this many evaluations of the network') | ||||
| # parser.add_argument('--n_samples', default=100, type=int) | ||||
| # parser.add_argument('--n_runs', default=500, type=int) | ||||
| # parser.add_argument('--stem_out_channels', default=16, type=int, help='output channels of stem convolution (nasbench101)') | ||||
| # parser.add_argument('--num_stacks', default=3, type=int, help='#stacks of modules (nasbench101)') | ||||
| # parser.add_argument('--num_modules_per_stack', default=3, type=int, help='#modules per stack (nasbench101)') | ||||
| # parser.add_argument('--num_labels', default=1, type=int, help='#classes (nasbench101)') | ||||
|  | ||||
| args = parser.parse_args() | ||||
| os.environ['CUDA_VISIBLE_DEVICES'] = args.GPU | ||||
| # args = parser.parse_args() | ||||
| # os.environ['CUDA_VISIBLE_DEVICES'] = args.GPU | ||||
|  | ||||
| # Reproducibility | ||||
| torch.backends.cudnn.deterministic = True | ||||
| torch.backends.cudnn.benchmark = False | ||||
| random.seed(args.seed) | ||||
| np.random.seed(args.seed) | ||||
| torch.manual_seed(args.seed) | ||||
| # # Reproducibility | ||||
| # torch.backends.cudnn.deterministic = True | ||||
| # torch.backends.cudnn.benchmark = False | ||||
| # random.seed(args.seed) | ||||
| # np.random.seed(args.seed) | ||||
| # torch.manual_seed(args.seed) | ||||
|  | ||||
|  | ||||
| def get_batch_jacobian(net, x, target, device, args=None): | ||||
| @@ -58,10 +58,16 @@ def get_batch_jacobian(net, x, target, device, args=None): | ||||
|     return jacob, target.detach(), y.detach(), out.detach() | ||||
|  | ||||
| def get_config_by_nodes(nodes): | ||||
|     num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output'] | ||||
|     arch_str = '|' + num_to_op[nodes[1]] + '~0|+|' + \ | ||||
|     # check if the nodes[0] is a number | ||||
|     if isinstance(nodes[0], int): | ||||
|         num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output'] | ||||
|         arch_str = '|' + num_to_op[nodes[1]] + '~0|+|' + \ | ||||
|                 num_to_op[nodes[2]] + '~0|' + num_to_op[nodes[3]] + '~1|+|' + \ | ||||
|                 num_to_op[nodes[4]] + '~0|' + num_to_op[nodes[5]] + '~1|' + num_to_op[nodes[6]] + '~2|' | ||||
|     else: | ||||
|         arch_str = '|' + nodes[1] + '~0|+|' + \ | ||||
|                 nodes[2] + '~0|' + nodes[3] + '~1|+|' + \ | ||||
|                 nodes[4] + '~0|' + nodes[5] + '~1|' + nodes[6] + '~2|' | ||||
|     config = { | ||||
|         'name': 'infer.tiny', | ||||
|         'C': 16, | ||||
| @@ -234,64 +240,64 @@ def get_nasbench201_idx_score(idx, train_loader, searchspace, args, device): | ||||
|     print('final result') | ||||
|     return np.nan | ||||
|  | ||||
| class Args: | ||||
|     pass | ||||
| args = Args() | ||||
| args.trainval = True | ||||
| args.augtype = 'none' | ||||
| args.repeat = 1 | ||||
| args.score = 'hook_logdet' | ||||
| args.sigma = 0.05 | ||||
| args.nasspace = 'nasbench201' | ||||
| args.batch_size = 128 | ||||
| args.GPU = '0' | ||||
| args.dataset = 'cifar10-valid' | ||||
| args.api_loc = '/home/stud/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' | ||||
| args.data_loc = '../cifardata/' | ||||
| args.seed = 777 | ||||
| args.init = '' | ||||
| args.save_loc = 'results' | ||||
| args.save_string = 'naswot' | ||||
| args.dropout = False | ||||
| args.maxofn = 1 | ||||
| args.n_samples = 100 | ||||
| args.n_runs = 500 | ||||
| args.stem_out_channels = 16 | ||||
| args.num_stacks = 3 | ||||
| args.num_modules_per_stack = 3 | ||||
| args.num_labels = 1 | ||||
| # class Args: | ||||
| #     pass | ||||
| # args = Args() | ||||
| # args.trainval = True | ||||
| # args.augtype = 'none' | ||||
| # args.repeat = 1 | ||||
| # args.score = 'hook_logdet' | ||||
| # args.sigma = 0.05 | ||||
| # args.nasspace = 'nasbench201' | ||||
| # args.batch_size = 128 | ||||
| # args.GPU = '0' | ||||
| # args.dataset = 'cifar10-valid' | ||||
| # args.api_loc = '/home/stud/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' | ||||
| # args.data_loc = '../cifardata/' | ||||
| # args.seed = 777 | ||||
| # args.init = '' | ||||
| # args.save_loc = 'results' | ||||
| # args.save_string = 'naswot' | ||||
| # args.dropout = False | ||||
| # args.maxofn = 1 | ||||
| # args.n_samples = 100 | ||||
| # args.n_runs = 500 | ||||
| # args.stem_out_channels = 16 | ||||
| # args.num_stacks = 3 | ||||
| # args.num_modules_per_stack = 3 | ||||
| # args.num_labels = 1 | ||||
|  | ||||
| if 'valid' in args.dataset: | ||||
|     args.dataset = args.dataset.replace('-valid', '') | ||||
| print('start to get search space') | ||||
| start_time = time.time() | ||||
| print(get_config_by_nodes(nodes=[0,2,2,3,4,2,4,6])) | ||||
| end_time = time.time() | ||||
| start_time = time.time() | ||||
| searchspace = nasspace.get_search_space(args) | ||||
| end_time = time.time() | ||||
| print(f'search space time: {end_time - start_time}') | ||||
| train_loader = datasets.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args) | ||||
| print('start to get score') | ||||
| print('5374') | ||||
| num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output'] | ||||
| start_time = time.time() | ||||
| print(get_nasbench201_nodes_score(nodes=[0,2,2,3,4,2,4,6],train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))) | ||||
| end_time = time.time() | ||||
| start_time = time.time() | ||||
| print(get_nasbench201_idx_score(5374,train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))) | ||||
| end_time = time.time() | ||||
| print(f'5374 time: {end_time - start_time}') | ||||
| print('5375') | ||||
| start_time = time.time() | ||||
| print(get_nasbench201_idx_score(5375,train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))) | ||||
| end_time = time.time() | ||||
| print(f'5375 time: {end_time - start_time}') | ||||
| print('5376') | ||||
| start_time = time.time() | ||||
| print(get_nasbench201_idx_score(5376,train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))) | ||||
| end_time = time.time() | ||||
| print(f'5376 time: {end_time - start_time}') | ||||
| # if 'valid' in args.dataset: | ||||
| #     args.dataset = args.dataset.replace('-valid', '') | ||||
| # print('start to get search space') | ||||
| # start_time = time.time() | ||||
| # print(get_config_by_nodes(nodes=[0,2,2,3,4,2,4,6])) | ||||
| # end_time = time.time() | ||||
| # start_time = time.time() | ||||
| # searchspace = nasspace.get_search_space(args) | ||||
| # end_time = time.time() | ||||
| # print(f'search space time: {end_time - start_time}') | ||||
| # train_loader = datasets.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args) | ||||
| # print('start to get score') | ||||
| # print('5374') | ||||
| # num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output'] | ||||
| # start_time = time.time() | ||||
| # print(get_nasbench201_nodes_score(nodes=[0,2,2,3,4,2,4,6],train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))) | ||||
| # end_time = time.time() | ||||
| # start_time = time.time() | ||||
| # print(get_nasbench201_idx_score(5374,train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))) | ||||
| # end_time = time.time() | ||||
| # print(f'5374 time: {end_time - start_time}') | ||||
| # print('5375') | ||||
| # start_time = time.time() | ||||
| # print(get_nasbench201_idx_score(5375,train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))) | ||||
| # end_time = time.time() | ||||
| # print(f'5375 time: {end_time - start_time}') | ||||
| # print('5376') | ||||
| # start_time = time.time() | ||||
| # print(get_nasbench201_idx_score(5376,train_loader=train_loader, searchspace=searchspace, args=args, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))) | ||||
| # end_time = time.time() | ||||
| # print(f'5376 time: {end_time - start_time}') | ||||
|  | ||||
| # device = "cuda:0" | ||||
| # dataset = dataset | ||||
|   | ||||
| @@ -3,5 +3,8 @@ from setuptools import setup, find_packages | ||||
| setup( | ||||
|     name='naswot', | ||||
|     version='0.1', | ||||
|     packages=find_packages() | ||||
|     packages=find_packages(), | ||||
|     package_data={ | ||||
|         'naswot': ['config_utils/cifardata/*'] | ||||
|     } | ||||
| ) | ||||
		Reference in New Issue
	
	Block a user