add a datsets option to specify the datset you want, add a plot script
This commit is contained in:
		
							
								
								
									
										48
									
								
								analyze.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								analyze.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,48 @@ | |||||||
|  | import csv | ||||||
|  | import matplotlib.pyplot as plt | ||||||
|  | from scipy import stats | ||||||
|  | import pandas as pd | ||||||
|  |  | ||||||
|  | def plot(l): | ||||||
|  |     labels = ['0-10k', '10k-20k,', '20k-30k', '30k-40k', '40k-50k', '50k-60k', '60k-70k'] | ||||||
|  |     l = [i/15625 for i in l] | ||||||
|  |     l = l[:7] | ||||||
|  |     plt.bar(labels, l) | ||||||
|  |     plt.savefig('plot.png') | ||||||
|  |  | ||||||
|  | def analyse(filename): | ||||||
|  |     l = [0 for i in range(10)] | ||||||
|  |     scores = [] | ||||||
|  |     count = 0 | ||||||
|  |     best_value = -1 | ||||||
|  |     with open(filename) as file: | ||||||
|  |         reader = csv.reader(file) | ||||||
|  |         header = next(reader) | ||||||
|  |         data = [row for row in reader] | ||||||
|  |          | ||||||
|  |         for row in data: | ||||||
|  |             score = row[0] | ||||||
|  |             best_value = max(best_value, float(score)) | ||||||
|  |             # print(score) | ||||||
|  |             ind = float(score) // 10000 | ||||||
|  |             ind = int(ind) | ||||||
|  |             l[ind] += 1 | ||||||
|  |             acc = row[1] | ||||||
|  |             index = row[2] | ||||||
|  |             datas = list(zip(score, acc, index)) | ||||||
|  |             scores.append(score) | ||||||
|  |     print(max(scores)) | ||||||
|  |     results = pd.DataFrame(datas, columns=['swap_score', 'valid_acc', 'index']) | ||||||
|  |     print(results['swap_score'].max()) | ||||||
|  |     print(best_value) | ||||||
|  |     plot(l) | ||||||
|  |     return stats.spearmanr(results.swap_score, results.valid_acc)[0] | ||||||
|  |  | ||||||
|  | if __name__ == '__main__': | ||||||
|  |     print(analyse('output/swap_results.csv')) | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -39,6 +39,7 @@ parser.add_argument('--seed', default=0, type=int, help='random seed') | |||||||
| parser.add_argument('--device', default="cuda", type=str, nargs='?', help='setup device (cpu, mps or cuda)') | parser.add_argument('--device', default="cuda", type=str, nargs='?', help='setup device (cpu, mps or cuda)') | ||||||
| parser.add_argument('--repeats', default=32, type=int, nargs='?', help='times of calculating the training-free metric') | parser.add_argument('--repeats', default=32, type=int, nargs='?', help='times of calculating the training-free metric') | ||||||
| parser.add_argument('--input_samples', default=16, type=int, nargs='?', help='input batch size for training-free metric') | parser.add_argument('--input_samples', default=16, type=int, nargs='?', help='input batch size for training-free metric') | ||||||
|  | parser.add_argument('--datasets', default='cifar10', type=str, help='input datasets') | ||||||
|  |  | ||||||
| args = parser.parse_args() | args = parser.parse_args() | ||||||
|  |  | ||||||
| @@ -48,7 +49,7 @@ if __name__ == "__main__": | |||||||
|  |  | ||||||
|     # arch_info = pd.read_csv(args.data_path+'/DARTS_archs_CIFAR10.csv', names=['genotype', 'valid_acc'], sep=',') |     # arch_info = pd.read_csv(args.data_path+'/DARTS_archs_CIFAR10.csv', names=['genotype', 'valid_acc'], sep=',') | ||||||
|      |      | ||||||
|     train_data, _, _ = get_datasets('cifar10', args.data_path, (args.input_samples, 3, 32, 32), -1) |     train_data, _, _ = get_datasets(args.datasets, args.data_path, (args.input_samples, 3, 32, 32), -1) | ||||||
|     train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.input_samples, num_workers=0, pin_memory=True) |     train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.input_samples, num_workers=0, pin_memory=True) | ||||||
|     loader = iter(train_loader) |     loader = iter(train_loader) | ||||||
|     inputs, _ = next(loader)   |     inputs, _ = next(loader)   | ||||||
| @@ -63,11 +64,11 @@ if __name__ == "__main__": | |||||||
|         # print(f'Evaluating network: {index}') |         # print(f'Evaluating network: {index}') | ||||||
|         print(f'Evaluating network: {ind}') |         print(f'Evaluating network: {ind}') | ||||||
|  |  | ||||||
|         config = api.get_net_config(ind, 'cifar10') |         config = api.get_net_config(ind, args.datasets) | ||||||
|         network = get_cell_based_tiny_net(config) |         network = get_cell_based_tiny_net(config) | ||||||
|         # nas_results = api.query_by_index(i, 'cifar10') |         # nas_results = api.query_by_index(i, 'cifar10') | ||||||
|         # acc = nas_results[111].get_eval('ori-test') |         # acc = nas_results[111].get_eval('ori-test') | ||||||
|         nas_results = api.get_more_info(ind, 'cifar10', None, hp=200, is_random=False) |         nas_results = api.get_more_info(ind, args.datasets, None, hp=200, is_random=False) | ||||||
|         acc = nas_results['test-accuracy'] |         acc = nas_results['test-accuracy'] | ||||||
|  |  | ||||||
|         # print(type(network)) |         # print(type(network)) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user