Compare commits
	
		
			2 Commits
		
	
	
		
			33452adc3b
			...
			4df5615380
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 4df5615380 | |||
| 968157b657 | 
							
								
								
									
										20
									
								
								analyze.py
									
									
									
									
									
								
							
							
						
						
									
										20
									
								
								analyze.py
									
									
									
									
									
								
							| @@ -4,16 +4,26 @@ from scipy import stats | |||||||
| import pandas as pd | import pandas as pd | ||||||
| import argparse | import argparse | ||||||
|  |  | ||||||
| def plot(l,filename): | def plot(l, thousands, filename): | ||||||
|     lenth = len(l) |     lenth = len(l) | ||||||
|     threshold = [0, 10000, 20000, 30000, 40000, 50000, 60000, 70000] |     threshold = [0, 10000, 20000, 30000, 40000, 50000, 60000, 70000] | ||||||
|     labels = ['0-10k', '10k-20k,', '20k-30k', '30k-40k', '40k-50k', '50k-60k', '60k-70k'] |     labels = ['0-10k', '10k-20k,', '20k-30k', '30k-40k', '40k-50k', '50k-60k', '60k-70k'] | ||||||
|     l = [i/15625 for i in l] |     l = [i/lenth for i in l] | ||||||
|     l = l[:7] |     l = l[:7] | ||||||
|  |     thousands = thousands[60:] | ||||||
|  |     thousands_labels = [str(i) + 'k' for i in range(60, 70)] | ||||||
|  |     plt.figure(figsize=(8, 6)) | ||||||
|  |     plt.subplots_adjust(top=0.85) | ||||||
|  |     plt.title('Distribution of Swap Scores over 60k') | ||||||
|  |     plt.bar(thousands_labels, thousands) | ||||||
|  |     for i, v in enumerate(thousands): | ||||||
|  |         plt.text(i, v + 0.01, str(v), ha='center', va='bottom') | ||||||
|  |     plt.savefig(filename + '_60k.png')  | ||||||
|  |  | ||||||
|     datasets = filename.split('_')[-1].split('.')[0] |     datasets = filename.split('_')[-1].split('.')[0] | ||||||
|     plt.figure(figsize=(8, 6)) |     plt.figure(figsize=(8, 6)) | ||||||
|     plt.subplots_adjust(top=0.85) |     plt.subplots_adjust(top=0.85) | ||||||
|     plt.ylim(0,0.3) |     # plt.ylim(0,0.3) | ||||||
|     plt.title('Distribution of Swap Scores in ' + datasets) |     plt.title('Distribution of Swap Scores in ' + datasets) | ||||||
|     plt.bar(labels, l) |     plt.bar(labels, l) | ||||||
|     for i, v in enumerate(l): |     for i, v in enumerate(l): | ||||||
| @@ -29,6 +39,7 @@ def analyse(filename): | |||||||
|         reader = csv.reader(file) |         reader = csv.reader(file) | ||||||
|         header = next(reader) |         header = next(reader) | ||||||
|         data = [row for row in reader] |         data = [row for row in reader] | ||||||
|  |         thousands = [0 for i in range(70)] | ||||||
|          |          | ||||||
|         for row in data: |         for row in data: | ||||||
|             score = row[0] |             score = row[0] | ||||||
| @@ -37,6 +48,7 @@ def analyse(filename): | |||||||
|             ind = float(score) // 10000 |             ind = float(score) // 10000 | ||||||
|             ind = int(ind) |             ind = int(ind) | ||||||
|             l[ind] += 1 |             l[ind] += 1 | ||||||
|  |             thousands[int(float(score) // 1000)] += 1 | ||||||
|             acc = row[1] |             acc = row[1] | ||||||
|             index = row[2] |             index = row[2] | ||||||
|             datas = list(zip(score, acc, index)) |             datas = list(zip(score, acc, index)) | ||||||
| @@ -45,7 +57,7 @@ def analyse(filename): | |||||||
|     results = pd.DataFrame(datas, columns=['swap_score', 'valid_acc', 'index']) |     results = pd.DataFrame(datas, columns=['swap_score', 'valid_acc', 'index']) | ||||||
|     print(results['swap_score'].max()) |     print(results['swap_score'].max()) | ||||||
|     print(best_value) |     print(best_value) | ||||||
|     plot(l, filename + '.png') |     plot(l, thousands, filename + '.png') | ||||||
|     return stats.spearmanr(results.swap_score, results.valid_acc)[0] |     return stats.spearmanr(results.swap_score, results.valid_acc)[0] | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|   | |||||||
| @@ -4,7 +4,7 @@ | |||||||
|  |  | ||||||
| # # 加载CIFAR-10数据集 | # # 加载CIFAR-10数据集 | ||||||
| # transform = transforms.Compose([transforms.ToTensor()]) | # transform = transforms.Compose([transforms.ToTensor()]) | ||||||
| # trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) | # trainset = torchvision.datasets.CIFAR10(root='./datasets', train=True, download=True, transform=transform) | ||||||
| # trainloader = torch.utils.data.DataLoader(trainset, batch_size=10000, shuffle=False, num_workers=2) | # trainloader = torch.utils.data.DataLoader(trainset, batch_size=10000, shuffle=False, num_workers=2) | ||||||
|  |  | ||||||
| # # 将所有数据加载到内存中 | # # 将所有数据加载到内存中 | ||||||
| @@ -18,6 +18,10 @@ | |||||||
| # print(f'Mean: {mean}') | # print(f'Mean: {mean}') | ||||||
| # print(f'Std: {std}') | # print(f'Std: {std}') | ||||||
|  |  | ||||||
|  | # results: | ||||||
|  | # Mean: tensor([0.4935, 0.4834, 0.4472]) | ||||||
|  | # Std: tensor([0.2476, 0.2446, 0.2626])   | ||||||
|  |  | ||||||
| import torch | import torch | ||||||
| from torchvision import datasets, transforms | from torchvision import datasets, transforms | ||||||
| from torch.utils.data import DataLoader | from torch.utils.data import DataLoader | ||||||
| @@ -35,6 +39,7 @@ dataset_name = args.dataset | |||||||
|  |  | ||||||
| # 设置数据集的transform(这里只使用了ToTensor) | # 设置数据集的transform(这里只使用了ToTensor) | ||||||
| transform = transforms.Compose([ | transform = transforms.Compose([ | ||||||
|  |     transforms.Resize((224, 224)), | ||||||
|     transforms.ToTensor() |     transforms.ToTensor() | ||||||
| ]) | ]) | ||||||
|  |  | ||||||
| @@ -47,7 +52,10 @@ mean = torch.zeros(3) | |||||||
| std = torch.zeros(3) | std = torch.zeros(3) | ||||||
| nb_samples = 0 | nb_samples = 0 | ||||||
|  |  | ||||||
|  | count = 0 | ||||||
| for data in dataloader: | for data in dataloader: | ||||||
|  |     count += 1 | ||||||
|  |     print(f'Processing batch {count}/{len(dataloader)}', end='\r') | ||||||
|     batch_samples = data[0].size(0) |     batch_samples = data[0].size(0) | ||||||
|     data = data[0].view(batch_samples, data[0].size(1), -1) |     data = data[0].view(batch_samples, data[0].size(1), -1) | ||||||
|     mean += data.mean(2).sum(0) |     mean += data.mean(2).sum(0) | ||||||
|   | |||||||
| @@ -40,9 +40,9 @@ parser.add_argument('--device', default="cuda", type=str, nargs='?', help='setup | |||||||
| parser.add_argument('--repeats', default=32, type=int, nargs='?', help='times of calculating the training-free metric') | parser.add_argument('--repeats', default=32, type=int, nargs='?', help='times of calculating the training-free metric') | ||||||
| parser.add_argument('--input_samples', default=16, type=int, nargs='?', help='input batch size for training-free metric') | parser.add_argument('--input_samples', default=16, type=int, nargs='?', help='input batch size for training-free metric') | ||||||
| parser.add_argument('--datasets', default='cifar10', type=str, help='input datasets') | parser.add_argument('--datasets', default='cifar10', type=str, help='input datasets') | ||||||
|  | parser.add_argument('--start_index', default=0, type=int, help='start index of the networks to evaluate') | ||||||
|  |  | ||||||
| args = parser.parse_args() | args = parser.parse_args() | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|      |      | ||||||
|     device = torch.device(args.device) |     device = torch.device(args.device) | ||||||
| @@ -58,18 +58,21 @@ if __name__ == "__main__": | |||||||
|  |  | ||||||
|     # nasbench_len = 15625 |     # nasbench_len = 15625 | ||||||
|     nasbench_len = 15625 |     nasbench_len = 15625 | ||||||
|  |     filename = f'output/swap_results_{args.datasets}.csv' | ||||||
|  |     if args.datasets == 'aircraft': | ||||||
|  |         api_datasets = 'cifar10' | ||||||
|      |      | ||||||
|     # for index, i in arch_info.iterrows(): |     # for index, i in arch_info.iterrows(): | ||||||
|     for ind in range(nasbench_len): |     for ind in range(args.start_index,nasbench_len): | ||||||
|         # print(f'Evaluating network: {index}') |         # print(f'Evaluating network: {index}') | ||||||
|         print(f'Evaluating network: {ind}') |         print(f'Evaluating network: {ind}') | ||||||
|  |         config = api.get_net_config(ind, api_datasets) | ||||||
|         config = api.get_net_config(ind, args.datasets) |  | ||||||
|         network = get_cell_based_tiny_net(config) |         network = get_cell_based_tiny_net(config) | ||||||
|         # nas_results = api.query_by_index(i, 'cifar10') |         # nas_results = api.query_by_index(i, 'cifar10') | ||||||
|         # acc = nas_results[111].get_eval('ori-test') |         # acc = nas_results[111].get_eval('ori-test') | ||||||
|         nas_results = api.get_more_info(ind, args.datasets, None, hp=200, is_random=False) |         # nas_results = api.get_more_info(ind, api_datasets, None, hp=200, is_random=False) | ||||||
|         acc = nas_results['test-accuracy'] |         # acc = nas_results['test-accuracy'] | ||||||
|  |         acc = 99 | ||||||
|  |  | ||||||
|         # print(type(network)) |         # print(type(network)) | ||||||
|         start_time = time.time() |         start_time = time.time() | ||||||
| @@ -98,6 +101,8 @@ if __name__ == "__main__": | |||||||
|         print(f'Elapsed time: {end_time - start_time:.2f} seconds') |         print(f'Elapsed time: {end_time - start_time:.2f} seconds') | ||||||
|  |  | ||||||
|         results.append([np.mean(swap_score), acc, ind]) |         results.append([np.mean(swap_score), acc, ind]) | ||||||
|  |         with open(filename, 'a') as f: | ||||||
|  |             f.write(f'{np.mean(swap_score)},{acc},{ind}\n') | ||||||
|  |  | ||||||
|     results = pd.DataFrame(results, columns=['swap_score', 'valid_acc', 'index']) |     results = pd.DataFrame(results, columns=['swap_score', 'valid_acc', 'index']) | ||||||
|     results.to_csv('output/swap_results.csv', float_format='%.4f', index=False) |     results.to_csv('output/swap_results.csv', float_format='%.4f', index=False) | ||||||
|   | |||||||
| @@ -3,21 +3,28 @@ import shutil | |||||||
|  |  | ||||||
| # 数据集路径 | # 数据集路径 | ||||||
| dataset_path = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/images' | dataset_path = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/images' | ||||||
| output_path = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/sorted_images' | test_output_path = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/test_sorted_images' | ||||||
|  | train_output_path = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/train_sorted_images' | ||||||
|  |  | ||||||
| # 类别文件,例如 'images_variant_trainval.txt' | # 类别文件,例如 'images_variant_trainval.txt' | ||||||
| labels_file = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/images_variant_test.txt' | # 有两个文件,一个是训练集和验证集,一个是测试集 | ||||||
|  | test_labels_file = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/images_variant_test.txt' | ||||||
|  | train_labels_file = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data/images_variant_train.txt' | ||||||
|  |  | ||||||
| # 创建输出文件夹 | # 创建输出文件夹 | ||||||
| if not os.path.exists(output_path): | if not os.path.exists(test_output_path): | ||||||
|     os.makedirs(output_path) |     os.makedirs(test_output_path) | ||||||
|  | if not os.path.exists(train_output_path): | ||||||
|  |     os.makedirs(train_output_path) | ||||||
|  |  | ||||||
| # 读取类别文件 | # 读取类别文件 | ||||||
| with open(labels_file, 'r') as f: | with open(test_labels_file, 'r') as f: | ||||||
|     lines = f.readlines() |     test_lines = f.readlines() | ||||||
|  | with open(train_labels_file, 'r') as f: | ||||||
|  |     train_lines = f.readlines() | ||||||
|  |  | ||||||
|  | def sort_images(lines, output_path): | ||||||
|     count = 0 |     count = 0 | ||||||
|  |  | ||||||
|     for line in lines: |     for line in lines: | ||||||
|         count += 1 |         count += 1 | ||||||
|         print(f'Processing image {count}/{len(lines)}', end='\r') |         print(f'Processing image {count}/{len(lines)}', end='\r') | ||||||
| @@ -38,4 +45,9 @@ for line in lines: | |||||||
|         else: |         else: | ||||||
|             print(f'Image {image_name} not found!') |             print(f'Image {image_name} not found!') | ||||||
|  |  | ||||||
|  | print("Sorting test images into folders by category...") | ||||||
|  | sort_images(test_lines, test_output_path) | ||||||
|  | print("Sorting train images into folders by category...") | ||||||
|  | sort_images(train_lines, train_output_path) | ||||||
|  |  | ||||||
| print("Images have been sorted into folders by category.") | print("Images have been sorted into folders by category.") | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user