update the script to use nasbench-201 api
This commit is contained in:
		
							
								
								
									
										
											BIN
										
									
								
								graph_dit/exp_201/barplog.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								graph_dit/exp_201/barplog.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 30 KiB | 
| @@ -2,44 +2,45 @@ | ||||
| import matplotlib.pyplot as plt | ||||
| import pandas as pd | ||||
| from nas_201_api import NASBench201API as API | ||||
| from naswot.score_networks import get_nasbench201_idx_score | ||||
| from naswot import datasets as dt | ||||
| from naswot import nasspace | ||||
| # from naswot.score_networks import get_nasbench201_idx_score | ||||
| # from naswot import datasets as dt | ||||
| # from naswot import nasspace | ||||
|  | ||||
| class Args(): | ||||
|     pass | ||||
| args = Args() | ||||
| args.trainval = True | ||||
| args.augtype = 'none' | ||||
| args.repeat = 1 | ||||
| args.score = 'hook_logdet' | ||||
| args.sigma = 0.05 | ||||
| args.nasspace = 'nasbench201' | ||||
| args.batch_size = 128 | ||||
| args.GPU = '0' | ||||
| args.dataset = 'cifar10' | ||||
| args.api_loc = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' | ||||
| args.data_loc = '../cifardata/' | ||||
| args.seed = 777 | ||||
| args.init = '' | ||||
| args.save_loc = 'results' | ||||
| args.save_string = 'naswot' | ||||
| args.dropout = False | ||||
| args.maxofn = 1 | ||||
| args.n_samples = 100 | ||||
| args.n_runs = 500 | ||||
| args.stem_out_channels = 16 | ||||
| args.num_stacks = 3 | ||||
| args.num_modules_per_stack = 3 | ||||
| args.num_labels = 1 | ||||
| searchspace = nasspace.get_search_space(args) | ||||
| train_loader = dt.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args) | ||||
| device = torch.device('cuda:2') | ||||
| # class Args(): | ||||
| #     pass | ||||
| # args = Args() | ||||
| # args.trainval = True | ||||
| # args.augtype = 'none' | ||||
| # args.repeat = 1 | ||||
| # args.score = 'hook_logdet' | ||||
| # args.sigma = 0.05 | ||||
| # args.nasspace = 'nasbench201' | ||||
| # args.batch_size = 128 | ||||
| # args.GPU = '0' | ||||
| # args.dataset = 'cifar10' | ||||
| # args.api_loc = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' | ||||
| # args.data_loc = '../cifardata/' | ||||
| # args.seed = 777 | ||||
| # args.init = '' | ||||
| # args.save_loc = 'results' | ||||
| # args.save_string = 'naswot' | ||||
| # args.dropout = False | ||||
| # args.maxofn = 1 | ||||
| # args.n_samples = 100 | ||||
| # args.n_runs = 500 | ||||
| # args.stem_out_channels = 16 | ||||
| # args.num_stacks = 3 | ||||
| # args.num_modules_per_stack = 3 | ||||
| # args.num_labels = 1 | ||||
| # searchspace = nasspace.get_search_space(args) | ||||
| # train_loader = dt.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args) | ||||
| # device = torch.device('cuda:2') | ||||
|  | ||||
|  | ||||
| source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' | ||||
| api = API(source) | ||||
|  | ||||
|  | ||||
|           | ||||
| # source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth' | ||||
| # api = API(source) | ||||
|  | ||||
|  | ||||
|  | ||||
| @@ -50,8 +51,10 @@ percentages = [] | ||||
| len_201 = 15625 | ||||
|  | ||||
| for i in range(len_201): | ||||
|     percentage = get_nasbench201_idx_score(i, train_loader, searchspace, args, device) | ||||
|     percentages.append(percentage) | ||||
|     # percentage = get_nasbench201_idx_score(i, train_loader, searchspace, args, device) | ||||
|     results = api.query_by_index(i, 'cifar10') | ||||
|     result = results[111].get_eval('ori-test') | ||||
|     percentages.append(result) | ||||
|  | ||||
| # 定义10%区间 | ||||
| bins = [i for i in range(0, 101, 10)] | ||||
|   | ||||
		Reference in New Issue
	
	Block a user