| 
									
										
										
										
											2020-11-30 00:48:10 +08:00
										 |  |  | ############################################################### | 
					
						
							| 
									
										
										
										
											2021-01-25 21:48:14 +08:00
										 |  |  | # NATS-Bench (arxiv.org/pdf/2009.00437.pdf), IEEE TPAMI 2021  # | 
					
						
							| 
									
										
										
										
											2020-11-30 00:48:10 +08:00
										 |  |  | ############################################################### | 
					
						
							|  |  |  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06           # | 
					
						
							|  |  |  | ############################################################### | 
					
						
							|  |  |  | # Usage: python exps/NATS-Bench/draw-correlations.py          # | 
					
						
							|  |  |  | ############################################################### | 
					
						
							|  |  |  | import os, gc, sys, time, scipy, torch, argparse | 
					
						
							|  |  |  | import numpy as np | 
					
						
							|  |  |  | from typing import List, Text, Dict, Any | 
					
						
							|  |  |  | from shutil import copyfile | 
					
						
							|  |  |  | from collections import defaultdict, OrderedDict | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  | from copy import deepcopy | 
					
						
							| 
									
										
										
										
											2020-11-30 00:48:10 +08:00
										 |  |  | from pathlib import Path | 
					
						
							|  |  |  | import matplotlib | 
					
						
							|  |  |  | import seaborn as sns | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | matplotlib.use("agg") | 
					
						
							| 
									
										
										
										
											2020-11-30 00:48:10 +08:00
										 |  |  | import matplotlib.pyplot as plt | 
					
						
							|  |  |  | import matplotlib.ticker as ticker | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  | lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | 
					
						
							|  |  |  | if str(lib_dir) not in sys.path: | 
					
						
							|  |  |  |     sys.path.insert(0, str(lib_dir)) | 
					
						
							| 
									
										
										
										
											2020-11-30 00:48:10 +08:00
										 |  |  | from config_utils import dict2config, load_config | 
					
						
							|  |  |  | from nats_bench import create | 
					
						
							|  |  |  | from log_utils import time_string | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def get_valid_test_acc(api, arch, dataset): | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     is_size_space = api.search_space_name == "size" | 
					
						
							|  |  |  |     if dataset == "cifar10": | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |         xinfo = api.get_more_info( | 
					
						
							|  |  |  |             arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |         test_acc = xinfo["test-accuracy"] | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |         xinfo = api.get_more_info( | 
					
						
							|  |  |  |             arch, | 
					
						
							|  |  |  |             dataset="cifar10-valid", | 
					
						
							|  |  |  |             hp=90 if is_size_space else 200, | 
					
						
							|  |  |  |             is_random=False, | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |         valid_acc = xinfo["valid-accuracy"] | 
					
						
							|  |  |  |     else: | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |         xinfo = api.get_more_info( | 
					
						
							|  |  |  |             arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |         valid_acc = xinfo["valid-accuracy"] | 
					
						
							|  |  |  |         test_acc = xinfo["test-accuracy"] | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |     return ( | 
					
						
							|  |  |  |         valid_acc, | 
					
						
							|  |  |  |         test_acc, | 
					
						
							|  |  |  |         "validation = {:.2f}, test = {:.2f}\n".format(valid_acc, test_acc), | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2020-11-30 00:48:10 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def compute_kendalltau(vectori, vectorj): | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     # indexes = list(range(len(vectori))) | 
					
						
							|  |  |  |     # rank_1 = sorted(indexes, key=lambda i: vectori[i]) | 
					
						
							|  |  |  |     # rank_2 = sorted(indexes, key=lambda i: vectorj[i]) | 
					
						
							|  |  |  |     # import pdb; pdb.set_trace() | 
					
						
							|  |  |  |     coef, p = scipy.stats.kendalltau(vectori, vectorj) | 
					
						
							|  |  |  |     return coef | 
					
						
							| 
									
										
										
										
											2020-11-30 00:48:10 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def compute_spearmanr(vectori, vectorj): | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     coef, p = scipy.stats.spearmanr(vectori, vectorj) | 
					
						
							|  |  |  |     return coef | 
					
						
							| 
									
										
										
										
											2020-11-30 00:48:10 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  | if __name__ == "__main__": | 
					
						
							|  |  |  |     parser = argparse.ArgumentParser( | 
					
						
							|  |  |  |         description="NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size", | 
					
						
							|  |  |  |         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |         "--save_dir", | 
					
						
							|  |  |  |         type=str, | 
					
						
							|  |  |  |         default="output/vis-nas-bench/nas-algos", | 
					
						
							|  |  |  |         help="Folder to save checkpoints and log.", | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--search_space", | 
					
						
							|  |  |  |         type=str, | 
					
						
							|  |  |  |         choices=["tss", "sss"], | 
					
						
							|  |  |  |         help="Choose the search space.", | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     ) | 
					
						
							|  |  |  |     args = parser.parse_args() | 
					
						
							| 
									
										
										
										
											2020-11-30 00:48:10 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     save_dir = Path(args.save_dir) | 
					
						
							| 
									
										
										
										
											2020-11-30 00:48:10 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     api = create(None, "tss", fast_mode=True, verbose=False) | 
					
						
							|  |  |  |     indexes = list(range(1, 10000, 300)) | 
					
						
							|  |  |  |     scores_1 = [] | 
					
						
							|  |  |  |     scores_2 = [] | 
					
						
							|  |  |  |     for index in indexes: | 
					
						
							|  |  |  |         valid_acc, test_acc, _ = get_valid_test_acc(api, index, "cifar10") | 
					
						
							|  |  |  |         scores_1.append(valid_acc) | 
					
						
							|  |  |  |         scores_2.append(test_acc) | 
					
						
							|  |  |  |     correlation = compute_kendalltau(scores_1, scores_2) | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |     print( | 
					
						
							|  |  |  |         "The kendall tau correlation of {:} samples : {:}".format( | 
					
						
							|  |  |  |             len(indexes), correlation | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     correlation = compute_spearmanr(scores_1, scores_2) | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |     print( | 
					
						
							|  |  |  |         "The spearmanr correlation of {:} samples : {:}".format( | 
					
						
							|  |  |  |             len(indexes), correlation | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     # scores_1 = ['{:.2f}'.format(x) for x in scores_1] | 
					
						
							|  |  |  |     # scores_2 = ['{:.2f}'.format(x) for x in scores_2] | 
					
						
							|  |  |  |     # print(', '.join(scores_1)) | 
					
						
							|  |  |  |     # print(', '.join(scores_2)) | 
					
						
							| 
									
										
										
										
											2020-11-30 00:48:10 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     dpi, width, height = 250, 1000, 1000 | 
					
						
							|  |  |  |     figsize = width / float(dpi), height / float(dpi) | 
					
						
							|  |  |  |     LabelSize, LegendFontsize = 14, 14 | 
					
						
							| 
									
										
										
										
											2020-11-30 00:48:10 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     fig, ax = plt.subplots(1, 1, figsize=figsize) | 
					
						
							|  |  |  |     ax.scatter(scores_1, scores_2, marker="^", s=0.5, c="tab:green", alpha=0.8) | 
					
						
							| 
									
										
										
										
											2020-11-30 00:48:10 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     save_path = "/Users/xuanyidong/Desktop/test-temp-rank.png" | 
					
						
							|  |  |  |     fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") | 
					
						
							|  |  |  |     plt.close("all") |