add autodl
This commit is contained in:
		
							
								
								
									
										53
									
								
								AutoDL-Projects/exps/NATS-Bench/Analyze-time.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								AutoDL-Projects/exps/NATS-Bench/Analyze-time.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,53 @@ | ||||
| ############################################################################## | ||||
| # NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size # | ||||
| ############################################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.07                          # | ||||
| ############################################################################## | ||||
| # python ./exps/NATS-Bench/Analyze-time.py                                   # | ||||
| ############################################################################## | ||||
| import os, sys, time, tqdm, argparse | ||||
| from pathlib import Path | ||||
|  | ||||
| from xautodl.config_utils import dict2config, load_config | ||||
| from xautodl.datasets import get_datasets | ||||
| from nats_bench import create | ||||
|  | ||||
|  | ||||
| def show_time(api, epoch=12): | ||||
|     print("Show the time for {:} with {:}-epoch-training".format(api, epoch)) | ||||
|     all_cifar10_time, all_cifar100_time, all_imagenet_time = 0, 0, 0 | ||||
|     for index in tqdm.tqdm(range(len(api))): | ||||
|         info = api.get_more_info(index, "ImageNet16-120", hp=epoch) | ||||
|         imagenet_time = info["train-all-time"] | ||||
|         info = api.get_more_info(index, "cifar10-valid", hp=epoch) | ||||
|         cifar10_time = info["train-all-time"] | ||||
|         info = api.get_more_info(index, "cifar100", hp=epoch) | ||||
|         cifar100_time = info["train-all-time"] | ||||
|         # accumulate the time | ||||
|         all_cifar10_time += cifar10_time | ||||
|         all_cifar100_time += cifar100_time | ||||
|         all_imagenet_time += imagenet_time | ||||
|     print( | ||||
|         "The total training time for CIFAR-10        (held-out train set) is {:} seconds".format( | ||||
|             all_cifar10_time | ||||
|         ) | ||||
|     ) | ||||
|     print( | ||||
|         "The total training time for CIFAR-100       (held-out train set) is {:} seconds, {:.2f} times longer than that on CIFAR-10".format( | ||||
|             all_cifar100_time, all_cifar100_time / all_cifar10_time | ||||
|         ) | ||||
|     ) | ||||
|     print( | ||||
|         "The total training time for ImageNet-16-120 (held-out train set) is {:} seconds, {:.2f} times longer than that on CIFAR-10".format( | ||||
|             all_imagenet_time, all_imagenet_time / all_cifar10_time | ||||
|         ) | ||||
|     ) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|  | ||||
|     api_nats_tss = create(None, "tss", fast_mode=True, verbose=False) | ||||
|     show_time(api_nats_tss, 12) | ||||
|  | ||||
|     api_nats_sss = create(None, "sss", fast_mode=True, verbose=False) | ||||
|     show_time(api_nats_sss, 12) | ||||
							
								
								
									
										123
									
								
								AutoDL-Projects/exps/NATS-Bench/draw-correlations.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										123
									
								
								AutoDL-Projects/exps/NATS-Bench/draw-correlations.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,123 @@ | ||||
| ############################################################### | ||||
| # NATS-Bench (arxiv.org/pdf/2009.00437.pdf), IEEE TPAMI 2021  # | ||||
| ############################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06           # | ||||
| ############################################################### | ||||
| # Usage: python exps/NATS-Bench/draw-correlations.py          # | ||||
| ############################################################### | ||||
| import os, gc, sys, time, scipy, torch, argparse | ||||
| import numpy as np | ||||
| from typing import List, Text, Dict, Any | ||||
| from shutil import copyfile | ||||
| from collections import defaultdict, OrderedDict | ||||
| from copy import deepcopy | ||||
| from pathlib import Path | ||||
| import matplotlib | ||||
| import seaborn as sns | ||||
|  | ||||
| matplotlib.use("agg") | ||||
| import matplotlib.pyplot as plt | ||||
| import matplotlib.ticker as ticker | ||||
|  | ||||
| from xautodl.config_utils import dict2config, load_config | ||||
| from xautodl.log_utils import time_string | ||||
| from nats_bench import create | ||||
|  | ||||
|  | ||||
| def get_valid_test_acc(api, arch, dataset): | ||||
|     is_size_space = api.search_space_name == "size" | ||||
|     if dataset == "cifar10": | ||||
|         xinfo = api.get_more_info( | ||||
|             arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False | ||||
|         ) | ||||
|         test_acc = xinfo["test-accuracy"] | ||||
|         xinfo = api.get_more_info( | ||||
|             arch, | ||||
|             dataset="cifar10-valid", | ||||
|             hp=90 if is_size_space else 200, | ||||
|             is_random=False, | ||||
|         ) | ||||
|         valid_acc = xinfo["valid-accuracy"] | ||||
|     else: | ||||
|         xinfo = api.get_more_info( | ||||
|             arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False | ||||
|         ) | ||||
|         valid_acc = xinfo["valid-accuracy"] | ||||
|         test_acc = xinfo["test-accuracy"] | ||||
|     return ( | ||||
|         valid_acc, | ||||
|         test_acc, | ||||
|         "validation = {:.2f}, test = {:.2f}\n".format(valid_acc, test_acc), | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def compute_kendalltau(vectori, vectorj): | ||||
|     # indexes = list(range(len(vectori))) | ||||
|     # rank_1 = sorted(indexes, key=lambda i: vectori[i]) | ||||
|     # rank_2 = sorted(indexes, key=lambda i: vectorj[i]) | ||||
|     # import pdb; pdb.set_trace() | ||||
|     coef, p = scipy.stats.kendalltau(vectori, vectorj) | ||||
|     return coef | ||||
|  | ||||
|  | ||||
| def compute_spearmanr(vectori, vectorj): | ||||
|     coef, p = scipy.stats.spearmanr(vectori, vectorj) | ||||
|     return coef | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="output/vis-nas-bench/nas-algos", | ||||
|         help="Folder to save checkpoints and log.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--search_space", | ||||
|         type=str, | ||||
|         choices=["tss", "sss"], | ||||
|         help="Choose the search space.", | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     save_dir = Path(args.save_dir) | ||||
|  | ||||
|     api = create(None, "tss", fast_mode=True, verbose=False) | ||||
|     indexes = list(range(1, 10000, 300)) | ||||
|     scores_1 = [] | ||||
|     scores_2 = [] | ||||
|     for index in indexes: | ||||
|         valid_acc, test_acc, _ = get_valid_test_acc(api, index, "cifar10") | ||||
|         scores_1.append(valid_acc) | ||||
|         scores_2.append(test_acc) | ||||
|     correlation = compute_kendalltau(scores_1, scores_2) | ||||
|     print( | ||||
|         "The kendall tau correlation of {:} samples : {:}".format( | ||||
|             len(indexes), correlation | ||||
|         ) | ||||
|     ) | ||||
|     correlation = compute_spearmanr(scores_1, scores_2) | ||||
|     print( | ||||
|         "The spearmanr correlation of {:} samples : {:}".format( | ||||
|             len(indexes), correlation | ||||
|         ) | ||||
|     ) | ||||
|     # scores_1 = ['{:.2f}'.format(x) for x in scores_1] | ||||
|     # scores_2 = ['{:.2f}'.format(x) for x in scores_2] | ||||
|     # print(', '.join(scores_1)) | ||||
|     # print(', '.join(scores_2)) | ||||
|  | ||||
|     dpi, width, height = 250, 1000, 1000 | ||||
|     figsize = width / float(dpi), height / float(dpi) | ||||
|     LabelSize, LegendFontsize = 14, 14 | ||||
|  | ||||
|     fig, ax = plt.subplots(1, 1, figsize=figsize) | ||||
|     ax.scatter(scores_1, scores_2, marker="^", s=0.5, c="tab:green", alpha=0.8) | ||||
|  | ||||
|     save_path = "/Users/xuanyidong/Desktop/test-temp-rank.png" | ||||
|     fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") | ||||
|     plt.close("all") | ||||
							
								
								
									
										651
									
								
								AutoDL-Projects/exps/NATS-Bench/draw-fig2_5.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										651
									
								
								AutoDL-Projects/exps/NATS-Bench/draw-fig2_5.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,651 @@ | ||||
| ############################################################### | ||||
| # NATS-Bench (arxiv.org/pdf/2009.00437.pdf), IEEE TPAMI 2021  # | ||||
| # The code to draw Figure 2 / 3 / 4 / 5 in our paper.         # | ||||
| ############################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06           # | ||||
| ############################################################### | ||||
| # Usage: python exps/NATS-Bench/draw-fig2_5.py                # | ||||
| ############################################################### | ||||
| import os, sys, time, torch, argparse | ||||
| import scipy | ||||
| import numpy as np | ||||
| from typing import List, Text, Dict, Any | ||||
| from shutil import copyfile | ||||
| from collections import defaultdict | ||||
| from copy import deepcopy | ||||
| from pathlib import Path | ||||
| import matplotlib | ||||
| import seaborn as sns | ||||
|  | ||||
| matplotlib.use("agg") | ||||
| import matplotlib.pyplot as plt | ||||
| import matplotlib.ticker as ticker | ||||
| from xautodl.config_utils import dict2config, load_config | ||||
| from xautodl.log_utils import time_string | ||||
| from xautodl.models import get_cell_based_tiny_net | ||||
| from nats_bench import create | ||||
|  | ||||
|  | ||||
| def visualize_relative_info(api, vis_save_dir, indicator): | ||||
|     vis_save_dir = vis_save_dir.resolve() | ||||
|     # print ('{:} start to visualize {:} information'.format(time_string(), api)) | ||||
|     vis_save_dir.mkdir(parents=True, exist_ok=True) | ||||
|  | ||||
|     cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( | ||||
|         "cifar10", indicator | ||||
|     ) | ||||
|     cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( | ||||
|         "cifar100", indicator | ||||
|     ) | ||||
|     imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( | ||||
|         "ImageNet16-120", indicator | ||||
|     ) | ||||
|     cifar010_info = torch.load(cifar010_cache_path) | ||||
|     cifar100_info = torch.load(cifar100_cache_path) | ||||
|     imagenet_info = torch.load(imagenet_cache_path) | ||||
|     indexes = list(range(len(cifar010_info["params"]))) | ||||
|  | ||||
|     print("{:} start to visualize relative ranking".format(time_string())) | ||||
|  | ||||
|     cifar010_ord_indexes = sorted(indexes, key=lambda i: cifar010_info["test_accs"][i]) | ||||
|     cifar100_ord_indexes = sorted(indexes, key=lambda i: cifar100_info["test_accs"][i]) | ||||
|     imagenet_ord_indexes = sorted(indexes, key=lambda i: imagenet_info["test_accs"][i]) | ||||
|  | ||||
|     cifar100_labels, imagenet_labels = [], [] | ||||
|     for idx in cifar010_ord_indexes: | ||||
|         cifar100_labels.append(cifar100_ord_indexes.index(idx)) | ||||
|         imagenet_labels.append(imagenet_ord_indexes.index(idx)) | ||||
|     print("{:} prepare data done.".format(time_string())) | ||||
|  | ||||
|     dpi, width, height = 200, 1400, 800 | ||||
|     figsize = width / float(dpi), height / float(dpi) | ||||
|     LabelSize, LegendFontsize = 18, 12 | ||||
|     resnet_scale, resnet_alpha = 120, 0.5 | ||||
|  | ||||
|     fig = plt.figure(figsize=figsize) | ||||
|     ax = fig.add_subplot(111) | ||||
|     plt.xlim(min(indexes), max(indexes)) | ||||
|     plt.ylim(min(indexes), max(indexes)) | ||||
|     # plt.ylabel('y').set_rotation(30) | ||||
|     plt.yticks( | ||||
|         np.arange(min(indexes), max(indexes), max(indexes) // 3), | ||||
|         fontsize=LegendFontsize, | ||||
|         rotation="vertical", | ||||
|     ) | ||||
|     plt.xticks( | ||||
|         np.arange(min(indexes), max(indexes), max(indexes) // 5), | ||||
|         fontsize=LegendFontsize, | ||||
|     ) | ||||
|     ax.scatter(indexes, cifar100_labels, marker="^", s=0.5, c="tab:green", alpha=0.8) | ||||
|     ax.scatter(indexes, imagenet_labels, marker="*", s=0.5, c="tab:red", alpha=0.8) | ||||
|     ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8) | ||||
|     ax.scatter([-1], [-1], marker="o", s=100, c="tab:blue", label="CIFAR-10") | ||||
|     ax.scatter([-1], [-1], marker="^", s=100, c="tab:green", label="CIFAR-100") | ||||
|     ax.scatter([-1], [-1], marker="*", s=100, c="tab:red", label="ImageNet-16-120") | ||||
|     plt.grid(zorder=0) | ||||
|     ax.set_axisbelow(True) | ||||
|     plt.legend(loc=0, fontsize=LegendFontsize) | ||||
|     ax.set_xlabel("architecture ranking in CIFAR-10", fontsize=LabelSize) | ||||
|     ax.set_ylabel("architecture ranking", fontsize=LabelSize) | ||||
|     save_path = (vis_save_dir / "{:}-relative-rank.pdf".format(indicator)).resolve() | ||||
|     fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf") | ||||
|     save_path = (vis_save_dir / "{:}-relative-rank.png".format(indicator)).resolve() | ||||
|     fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") | ||||
|     print("{:} save into {:}".format(time_string(), save_path)) | ||||
|  | ||||
|  | ||||
| def visualize_sss_info(api, dataset, vis_save_dir): | ||||
|     vis_save_dir = vis_save_dir.resolve() | ||||
|     print("{:} start to visualize {:} information".format(time_string(), dataset)) | ||||
|     vis_save_dir.mkdir(parents=True, exist_ok=True) | ||||
|     cache_file_path = vis_save_dir / "{:}-cache-sss-info.pth".format(dataset) | ||||
|     if not cache_file_path.exists(): | ||||
|         print("Do not find cache file : {:}".format(cache_file_path)) | ||||
|         params, flops, train_accs, valid_accs, test_accs = [], [], [], [], [] | ||||
|         for index in range(len(api)): | ||||
|             cost_info = api.get_cost_info(index, dataset, hp="90") | ||||
|             params.append(cost_info["params"]) | ||||
|             flops.append(cost_info["flops"]) | ||||
|             # accuracy | ||||
|             info = api.get_more_info(index, dataset, hp="90", is_random=False) | ||||
|             train_accs.append(info["train-accuracy"]) | ||||
|             test_accs.append(info["test-accuracy"]) | ||||
|             if dataset == "cifar10": | ||||
|                 info = api.get_more_info( | ||||
|                     index, "cifar10-valid", hp="90", is_random=False | ||||
|                 ) | ||||
|                 valid_accs.append(info["valid-accuracy"]) | ||||
|             else: | ||||
|                 valid_accs.append(info["valid-accuracy"]) | ||||
|         info = { | ||||
|             "params": params, | ||||
|             "flops": flops, | ||||
|             "train_accs": train_accs, | ||||
|             "valid_accs": valid_accs, | ||||
|             "test_accs": test_accs, | ||||
|         } | ||||
|         torch.save(info, cache_file_path) | ||||
|     else: | ||||
|         print("Find cache file : {:}".format(cache_file_path)) | ||||
|         info = torch.load(cache_file_path) | ||||
|         params, flops, train_accs, valid_accs, test_accs = ( | ||||
|             info["params"], | ||||
|             info["flops"], | ||||
|             info["train_accs"], | ||||
|             info["valid_accs"], | ||||
|             info["test_accs"], | ||||
|         ) | ||||
|     print("{:} collect data done.".format(time_string())) | ||||
|  | ||||
|     # pyramid = ['8:16:32:48:64', '8:8:16:32:48', '8:8:16:16:32', '8:8:16:16:48', '8:8:16:16:64', '16:16:32:32:64', '32:32:64:64:64'] | ||||
|     pyramid = ["8:16:24:32:40", "8:16:32:48:64", "32:40:48:56:64"] | ||||
|     pyramid_indexes = [api.query_index_by_arch(x) for x in pyramid] | ||||
|     largest_indexes = [api.query_index_by_arch("64:64:64:64:64")] | ||||
|  | ||||
|     indexes = list(range(len(params))) | ||||
|     dpi, width, height = 250, 8500, 1300 | ||||
|     figsize = width / float(dpi), height / float(dpi) | ||||
|     LabelSize, LegendFontsize = 24, 24 | ||||
|     # resnet_scale, resnet_alpha = 120, 0.5 | ||||
|     xscale, xalpha = 120, 0.8 | ||||
|  | ||||
|     fig, axs = plt.subplots(1, 4, figsize=figsize) | ||||
|     # ax1, ax2, ax3, ax4, ax5 = axs | ||||
|     for ax in axs: | ||||
|         for tick in ax.xaxis.get_major_ticks(): | ||||
|             tick.label.set_fontsize(LabelSize) | ||||
|         ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%.0f")) | ||||
|         for tick in ax.yaxis.get_major_ticks(): | ||||
|             tick.label.set_fontsize(LabelSize) | ||||
|     ax1, ax2, ax3, ax4 = axs | ||||
|  | ||||
|     ax1.scatter(params, train_accs, marker="o", s=0.5, c="tab:blue") | ||||
|     ax1.scatter( | ||||
|         [params[x] for x in pyramid_indexes], | ||||
|         [train_accs[x] for x in pyramid_indexes], | ||||
|         marker="*", | ||||
|         s=xscale, | ||||
|         c="tab:orange", | ||||
|         label="Pyramid Structure", | ||||
|         alpha=xalpha, | ||||
|     ) | ||||
|     ax1.scatter( | ||||
|         [params[x] for x in largest_indexes], | ||||
|         [train_accs[x] for x in largest_indexes], | ||||
|         marker="x", | ||||
|         s=xscale, | ||||
|         c="tab:green", | ||||
|         label="Largest Candidate", | ||||
|         alpha=xalpha, | ||||
|     ) | ||||
|     ax1.set_xlabel("#parameters (MB)", fontsize=LabelSize) | ||||
|     ax1.set_ylabel("train accuracy (%)", fontsize=LabelSize) | ||||
|     ax1.legend(loc=4, fontsize=LegendFontsize) | ||||
|  | ||||
|     ax2.scatter(flops, train_accs, marker="o", s=0.5, c="tab:blue") | ||||
|     ax2.scatter( | ||||
|         [flops[x] for x in pyramid_indexes], | ||||
|         [train_accs[x] for x in pyramid_indexes], | ||||
|         marker="*", | ||||
|         s=xscale, | ||||
|         c="tab:orange", | ||||
|         label="Pyramid Structure", | ||||
|         alpha=xalpha, | ||||
|     ) | ||||
|     ax2.scatter( | ||||
|         [flops[x] for x in largest_indexes], | ||||
|         [train_accs[x] for x in largest_indexes], | ||||
|         marker="x", | ||||
|         s=xscale, | ||||
|         c="tab:green", | ||||
|         label="Largest Candidate", | ||||
|         alpha=xalpha, | ||||
|     ) | ||||
|     ax2.set_xlabel("#FLOPs (M)", fontsize=LabelSize) | ||||
|     # ax2.set_ylabel('train accuracy (%)', fontsize=LabelSize) | ||||
|     ax2.legend(loc=4, fontsize=LegendFontsize) | ||||
|  | ||||
|     ax3.scatter(params, test_accs, marker="o", s=0.5, c="tab:blue") | ||||
|     ax3.scatter( | ||||
|         [params[x] for x in pyramid_indexes], | ||||
|         [test_accs[x] for x in pyramid_indexes], | ||||
|         marker="*", | ||||
|         s=xscale, | ||||
|         c="tab:orange", | ||||
|         label="Pyramid Structure", | ||||
|         alpha=xalpha, | ||||
|     ) | ||||
|     ax3.scatter( | ||||
|         [params[x] for x in largest_indexes], | ||||
|         [test_accs[x] for x in largest_indexes], | ||||
|         marker="x", | ||||
|         s=xscale, | ||||
|         c="tab:green", | ||||
|         label="Largest Candidate", | ||||
|         alpha=xalpha, | ||||
|     ) | ||||
|     ax3.set_xlabel("#parameters (MB)", fontsize=LabelSize) | ||||
|     ax3.set_ylabel("test accuracy (%)", fontsize=LabelSize) | ||||
|     ax3.legend(loc=4, fontsize=LegendFontsize) | ||||
|  | ||||
|     ax4.scatter(flops, test_accs, marker="o", s=0.5, c="tab:blue") | ||||
|     ax4.scatter( | ||||
|         [flops[x] for x in pyramid_indexes], | ||||
|         [test_accs[x] for x in pyramid_indexes], | ||||
|         marker="*", | ||||
|         s=xscale, | ||||
|         c="tab:orange", | ||||
|         label="Pyramid Structure", | ||||
|         alpha=xalpha, | ||||
|     ) | ||||
|     ax4.scatter( | ||||
|         [flops[x] for x in largest_indexes], | ||||
|         [test_accs[x] for x in largest_indexes], | ||||
|         marker="x", | ||||
|         s=xscale, | ||||
|         c="tab:green", | ||||
|         label="Largest Candidate", | ||||
|         alpha=xalpha, | ||||
|     ) | ||||
|     ax4.set_xlabel("#FLOPs (M)", fontsize=LabelSize) | ||||
|     # ax4.set_ylabel('test accuracy (%)', fontsize=LabelSize) | ||||
|     ax4.legend(loc=4, fontsize=LegendFontsize) | ||||
|  | ||||
|     save_path = vis_save_dir / "sss-{:}.png".format(dataset.lower()) | ||||
|     fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") | ||||
|     print("{:} save into {:}".format(time_string(), save_path)) | ||||
|     plt.close("all") | ||||
|  | ||||
|  | ||||
| def visualize_tss_info(api, dataset, vis_save_dir): | ||||
|     vis_save_dir = vis_save_dir.resolve() | ||||
|     print("{:} start to visualize {:} information".format(time_string(), dataset)) | ||||
|     vis_save_dir.mkdir(parents=True, exist_ok=True) | ||||
|     cache_file_path = vis_save_dir / "{:}-cache-tss-info.pth".format(dataset) | ||||
|     if not cache_file_path.exists(): | ||||
|         print("Do not find cache file : {:}".format(cache_file_path)) | ||||
|         params, flops, train_accs, valid_accs, test_accs = [], [], [], [], [] | ||||
|         for index in range(len(api)): | ||||
|             cost_info = api.get_cost_info(index, dataset, hp="12") | ||||
|             params.append(cost_info["params"]) | ||||
|             flops.append(cost_info["flops"]) | ||||
|             # accuracy | ||||
|             info = api.get_more_info(index, dataset, hp="200", is_random=False) | ||||
|             train_accs.append(info["train-accuracy"]) | ||||
|             test_accs.append(info["test-accuracy"]) | ||||
|             if dataset == "cifar10": | ||||
|                 info = api.get_more_info( | ||||
|                     index, "cifar10-valid", hp="200", is_random=False | ||||
|                 ) | ||||
|                 valid_accs.append(info["valid-accuracy"]) | ||||
|             else: | ||||
|                 valid_accs.append(info["valid-accuracy"]) | ||||
|             print("") | ||||
|         info = { | ||||
|             "params": params, | ||||
|             "flops": flops, | ||||
|             "train_accs": train_accs, | ||||
|             "valid_accs": valid_accs, | ||||
|             "test_accs": test_accs, | ||||
|         } | ||||
|         torch.save(info, cache_file_path) | ||||
|     else: | ||||
|         print("Find cache file : {:}".format(cache_file_path)) | ||||
|         info = torch.load(cache_file_path) | ||||
|         params, flops, train_accs, valid_accs, test_accs = ( | ||||
|             info["params"], | ||||
|             info["flops"], | ||||
|             info["train_accs"], | ||||
|             info["valid_accs"], | ||||
|             info["test_accs"], | ||||
|         ) | ||||
|     print("{:} collect data done.".format(time_string())) | ||||
|  | ||||
|     resnet = [ | ||||
|         "|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|" | ||||
|     ] | ||||
|     resnet_indexes = [api.query_index_by_arch(x) for x in resnet] | ||||
|     largest_indexes = [ | ||||
|         api.query_index_by_arch( | ||||
|             "|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|nor_conv_3x3~0|nor_conv_3x3~1|nor_conv_3x3~2|" | ||||
|         ) | ||||
|     ] | ||||
|  | ||||
|     indexes = list(range(len(params))) | ||||
|     dpi, width, height = 250, 8500, 1300 | ||||
|     figsize = width / float(dpi), height / float(dpi) | ||||
|     LabelSize, LegendFontsize = 24, 24 | ||||
|     # resnet_scale, resnet_alpha = 120, 0.5 | ||||
|     xscale, xalpha = 120, 0.8 | ||||
|  | ||||
|     fig, axs = plt.subplots(1, 4, figsize=figsize) | ||||
|     for ax in axs: | ||||
|         for tick in ax.xaxis.get_major_ticks(): | ||||
|             tick.label.set_fontsize(LabelSize) | ||||
|         ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%.0f")) | ||||
|         for tick in ax.yaxis.get_major_ticks(): | ||||
|             tick.label.set_fontsize(LabelSize) | ||||
|     ax1, ax2, ax3, ax4 = axs | ||||
|  | ||||
|     ax1.scatter(params, train_accs, marker="o", s=0.5, c="tab:blue") | ||||
|     ax1.scatter( | ||||
|         [params[x] for x in resnet_indexes], | ||||
|         [train_accs[x] for x in resnet_indexes], | ||||
|         marker="*", | ||||
|         s=xscale, | ||||
|         c="tab:orange", | ||||
|         label="ResNet", | ||||
|         alpha=xalpha, | ||||
|     ) | ||||
|     ax1.scatter( | ||||
|         [params[x] for x in largest_indexes], | ||||
|         [train_accs[x] for x in largest_indexes], | ||||
|         marker="x", | ||||
|         s=xscale, | ||||
|         c="tab:green", | ||||
|         label="Largest Candidate", | ||||
|         alpha=xalpha, | ||||
|     ) | ||||
|     ax1.set_xlabel("#parameters (MB)", fontsize=LabelSize) | ||||
|     ax1.set_ylabel("train accuracy (%)", fontsize=LabelSize) | ||||
|     ax1.legend(loc=4, fontsize=LegendFontsize) | ||||
|  | ||||
|     ax2.scatter(flops, train_accs, marker="o", s=0.5, c="tab:blue") | ||||
|     ax2.scatter( | ||||
|         [flops[x] for x in resnet_indexes], | ||||
|         [train_accs[x] for x in resnet_indexes], | ||||
|         marker="*", | ||||
|         s=xscale, | ||||
|         c="tab:orange", | ||||
|         label="ResNet", | ||||
|         alpha=xalpha, | ||||
|     ) | ||||
|     ax2.scatter( | ||||
|         [flops[x] for x in largest_indexes], | ||||
|         [train_accs[x] for x in largest_indexes], | ||||
|         marker="x", | ||||
|         s=xscale, | ||||
|         c="tab:green", | ||||
|         label="Largest Candidate", | ||||
|         alpha=xalpha, | ||||
|     ) | ||||
|     ax2.set_xlabel("#FLOPs (M)", fontsize=LabelSize) | ||||
|     # ax2.set_ylabel('train accuracy (%)', fontsize=LabelSize) | ||||
|     ax2.legend(loc=4, fontsize=LegendFontsize) | ||||
|  | ||||
|     ax3.scatter(params, test_accs, marker="o", s=0.5, c="tab:blue") | ||||
|     ax3.scatter( | ||||
|         [params[x] for x in resnet_indexes], | ||||
|         [test_accs[x] for x in resnet_indexes], | ||||
|         marker="*", | ||||
|         s=xscale, | ||||
|         c="tab:orange", | ||||
|         label="ResNet", | ||||
|         alpha=xalpha, | ||||
|     ) | ||||
|     ax3.scatter( | ||||
|         [params[x] for x in largest_indexes], | ||||
|         [test_accs[x] for x in largest_indexes], | ||||
|         marker="x", | ||||
|         s=xscale, | ||||
|         c="tab:green", | ||||
|         label="Largest Candidate", | ||||
|         alpha=xalpha, | ||||
|     ) | ||||
|     ax3.set_xlabel("#parameters (MB)", fontsize=LabelSize) | ||||
|     ax3.set_ylabel("test accuracy (%)", fontsize=LabelSize) | ||||
|     ax3.legend(loc=4, fontsize=LegendFontsize) | ||||
|  | ||||
|     ax4.scatter(flops, test_accs, marker="o", s=0.5, c="tab:blue") | ||||
|     ax4.scatter( | ||||
|         [flops[x] for x in resnet_indexes], | ||||
|         [test_accs[x] for x in resnet_indexes], | ||||
|         marker="*", | ||||
|         s=xscale, | ||||
|         c="tab:orange", | ||||
|         label="ResNet", | ||||
|         alpha=xalpha, | ||||
|     ) | ||||
|     ax4.scatter( | ||||
|         [flops[x] for x in largest_indexes], | ||||
|         [test_accs[x] for x in largest_indexes], | ||||
|         marker="x", | ||||
|         s=xscale, | ||||
|         c="tab:green", | ||||
|         label="Largest Candidate", | ||||
|         alpha=xalpha, | ||||
|     ) | ||||
|     ax4.set_xlabel("#FLOPs (M)", fontsize=LabelSize) | ||||
|     # ax4.set_ylabel('test accuracy (%)', fontsize=LabelSize) | ||||
|     ax4.legend(loc=4, fontsize=LegendFontsize) | ||||
|  | ||||
|     save_path = vis_save_dir / "tss-{:}.png".format(dataset.lower()) | ||||
|     fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") | ||||
|     print("{:} save into {:}".format(time_string(), save_path)) | ||||
|     plt.close("all") | ||||
|  | ||||
|  | ||||
| def visualize_rank_info(api, vis_save_dir, indicator): | ||||
|     vis_save_dir = vis_save_dir.resolve() | ||||
|     # print ('{:} start to visualize {:} information'.format(time_string(), api)) | ||||
|     vis_save_dir.mkdir(parents=True, exist_ok=True) | ||||
|  | ||||
|     cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( | ||||
|         "cifar10", indicator | ||||
|     ) | ||||
|     cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( | ||||
|         "cifar100", indicator | ||||
|     ) | ||||
|     imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( | ||||
|         "ImageNet16-120", indicator | ||||
|     ) | ||||
|     cifar010_info = torch.load(cifar010_cache_path) | ||||
|     cifar100_info = torch.load(cifar100_cache_path) | ||||
|     imagenet_info = torch.load(imagenet_cache_path) | ||||
|     indexes = list(range(len(cifar010_info["params"]))) | ||||
|  | ||||
|     print("{:} start to visualize relative ranking".format(time_string())) | ||||
|  | ||||
|     dpi, width, height = 250, 3800, 1200 | ||||
|     figsize = width / float(dpi), height / float(dpi) | ||||
|     LabelSize, LegendFontsize = 14, 14 | ||||
|  | ||||
|     fig, axs = plt.subplots(1, 3, figsize=figsize) | ||||
|     ax1, ax2, ax3 = axs | ||||
|  | ||||
|     def get_labels(info): | ||||
|         ord_test_indexes = sorted(indexes, key=lambda i: info["test_accs"][i]) | ||||
|         ord_valid_indexes = sorted(indexes, key=lambda i: info["valid_accs"][i]) | ||||
|         labels = [] | ||||
|         for idx in ord_test_indexes: | ||||
|             labels.append(ord_valid_indexes.index(idx)) | ||||
|         return labels | ||||
|  | ||||
|     def plot_ax(labels, ax, name): | ||||
|         for tick in ax.xaxis.get_major_ticks(): | ||||
|             tick.label.set_fontsize(LabelSize) | ||||
|         for tick in ax.yaxis.get_major_ticks(): | ||||
|             tick.label.set_fontsize(LabelSize) | ||||
|             tick.label.set_rotation(90) | ||||
|         ax.set_xlim(min(indexes), max(indexes)) | ||||
|         ax.set_ylim(min(indexes), max(indexes)) | ||||
|         ax.yaxis.set_ticks(np.arange(min(indexes), max(indexes), max(indexes) // 3)) | ||||
|         ax.xaxis.set_ticks(np.arange(min(indexes), max(indexes), max(indexes) // 5)) | ||||
|         ax.scatter(indexes, labels, marker="^", s=0.5, c="tab:green", alpha=0.8) | ||||
|         ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8) | ||||
|         ax.scatter( | ||||
|             [-1], [-1], marker="^", s=100, c="tab:green", label="{:} test".format(name) | ||||
|         ) | ||||
|         ax.scatter( | ||||
|             [-1], | ||||
|             [-1], | ||||
|             marker="o", | ||||
|             s=100, | ||||
|             c="tab:blue", | ||||
|             label="{:} validation".format(name), | ||||
|         ) | ||||
|         ax.legend(loc=4, fontsize=LegendFontsize) | ||||
|         ax.set_xlabel("ranking on the {:} validation".format(name), fontsize=LabelSize) | ||||
|         ax.set_ylabel("architecture ranking", fontsize=LabelSize) | ||||
|  | ||||
|     labels = get_labels(cifar010_info) | ||||
|     plot_ax(labels, ax1, "CIFAR-10") | ||||
|     labels = get_labels(cifar100_info) | ||||
|     plot_ax(labels, ax2, "CIFAR-100") | ||||
|     labels = get_labels(imagenet_info) | ||||
|     plot_ax(labels, ax3, "ImageNet-16-120") | ||||
|  | ||||
|     save_path = ( | ||||
|         vis_save_dir / "{:}-same-relative-rank.pdf".format(indicator) | ||||
|     ).resolve() | ||||
|     fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf") | ||||
|     save_path = ( | ||||
|         vis_save_dir / "{:}-same-relative-rank.png".format(indicator) | ||||
|     ).resolve() | ||||
|     fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") | ||||
|     print("{:} save into {:}".format(time_string(), save_path)) | ||||
|     plt.close("all") | ||||
|  | ||||
|  | ||||
| def compute_kendalltau(vectori, vectorj): | ||||
|     # indexes = list(range(len(vectori))) | ||||
|     # rank_1 = sorted(indexes, key=lambda i: vectori[i]) | ||||
|     # rank_2 = sorted(indexes, key=lambda i: vectorj[i]) | ||||
|     return scipy.stats.kendalltau(vectori, vectorj).correlation | ||||
|  | ||||
|  | ||||
| def calculate_correlation(*vectors): | ||||
|     matrix = [] | ||||
|     for i, vectori in enumerate(vectors): | ||||
|         x = [] | ||||
|         for j, vectorj in enumerate(vectors): | ||||
|             # x.append(np.corrcoef(vectori, vectorj)[0,1]) | ||||
|             x.append(compute_kendalltau(vectori, vectorj)) | ||||
|         matrix.append(x) | ||||
|     return np.array(matrix) | ||||
|  | ||||
|  | ||||
| def visualize_all_rank_info(api, vis_save_dir, indicator): | ||||
|     vis_save_dir = vis_save_dir.resolve() | ||||
|     # print ('{:} start to visualize {:} information'.format(time_string(), api)) | ||||
|     vis_save_dir.mkdir(parents=True, exist_ok=True) | ||||
|  | ||||
|     cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( | ||||
|         "cifar10", indicator | ||||
|     ) | ||||
|     cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( | ||||
|         "cifar100", indicator | ||||
|     ) | ||||
|     imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format( | ||||
|         "ImageNet16-120", indicator | ||||
|     ) | ||||
|     cifar010_info = torch.load(cifar010_cache_path) | ||||
|     cifar100_info = torch.load(cifar100_cache_path) | ||||
|     imagenet_info = torch.load(imagenet_cache_path) | ||||
|     indexes = list(range(len(cifar010_info["params"]))) | ||||
|  | ||||
|     print("{:} start to visualize relative ranking".format(time_string())) | ||||
|  | ||||
|     dpi, width, height = 250, 3200, 1400 | ||||
|     figsize = width / float(dpi), height / float(dpi) | ||||
|     LabelSize, LegendFontsize = 14, 14 | ||||
|  | ||||
|     fig, axs = plt.subplots(1, 2, figsize=figsize) | ||||
|     ax1, ax2 = axs | ||||
|  | ||||
|     sns_size, xformat = 15, ".2f" | ||||
|     CoRelMatrix = calculate_correlation( | ||||
|         cifar010_info["valid_accs"], | ||||
|         cifar010_info["test_accs"], | ||||
|         cifar100_info["valid_accs"], | ||||
|         cifar100_info["test_accs"], | ||||
|         imagenet_info["valid_accs"], | ||||
|         imagenet_info["test_accs"], | ||||
|     ) | ||||
|  | ||||
|     sns.heatmap( | ||||
|         CoRelMatrix, | ||||
|         annot=True, | ||||
|         annot_kws={"size": sns_size}, | ||||
|         fmt=xformat, | ||||
|         linewidths=0.5, | ||||
|         ax=ax1, | ||||
|         xticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"], | ||||
|         yticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"], | ||||
|     ) | ||||
|  | ||||
|     selected_indexes, acc_bar = [], 92 | ||||
|     for i, acc in enumerate(cifar010_info["test_accs"]): | ||||
|         if acc > acc_bar: | ||||
|             selected_indexes.append(i) | ||||
|     cifar010_valid_accs = np.array(cifar010_info["valid_accs"])[selected_indexes] | ||||
|     cifar010_test_accs = np.array(cifar010_info["test_accs"])[selected_indexes] | ||||
|     cifar100_valid_accs = np.array(cifar100_info["valid_accs"])[selected_indexes] | ||||
|     cifar100_test_accs = np.array(cifar100_info["test_accs"])[selected_indexes] | ||||
|     imagenet_valid_accs = np.array(imagenet_info["valid_accs"])[selected_indexes] | ||||
|     imagenet_test_accs = np.array(imagenet_info["test_accs"])[selected_indexes] | ||||
|     CoRelMatrix = calculate_correlation( | ||||
|         cifar010_valid_accs, | ||||
|         cifar010_test_accs, | ||||
|         cifar100_valid_accs, | ||||
|         cifar100_test_accs, | ||||
|         imagenet_valid_accs, | ||||
|         imagenet_test_accs, | ||||
|     ) | ||||
|  | ||||
|     sns.heatmap( | ||||
|         CoRelMatrix, | ||||
|         annot=True, | ||||
|         annot_kws={"size": sns_size}, | ||||
|         fmt=xformat, | ||||
|         linewidths=0.5, | ||||
|         ax=ax2, | ||||
|         xticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"], | ||||
|         yticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"], | ||||
|     ) | ||||
|     ax1.set_title("Correlation coefficient over ALL candidates") | ||||
|     ax2.set_title( | ||||
|         "Correlation coefficient over candidates with accuracy > {:}%".format(acc_bar) | ||||
|     ) | ||||
|     save_path = (vis_save_dir / "{:}-all-relative-rank.png".format(indicator)).resolve() | ||||
|     fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") | ||||
|     print("{:} save into {:}".format(time_string(), save_path)) | ||||
|     plt.close("all") | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="NATS-Bench", formatter_class=argparse.ArgumentDefaultsHelpFormatter | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="output/vis-nas-bench", | ||||
|         help="Folder to save checkpoints and log.", | ||||
|     ) | ||||
|     # use for train the model | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     to_save_dir = Path(args.save_dir) | ||||
|  | ||||
|     datasets = ["cifar10", "cifar100", "ImageNet16-120"] | ||||
|     # Figure 3 (a-c) | ||||
|     api_tss = create(None, "tss", verbose=True) | ||||
|     for xdata in datasets: | ||||
|         visualize_tss_info(api_tss, xdata, to_save_dir) | ||||
|     # Figure 3 (d-f) | ||||
|     api_sss = create(None, "size", verbose=True) | ||||
|     for xdata in datasets: | ||||
|         visualize_sss_info(api_sss, xdata, to_save_dir) | ||||
|  | ||||
|     # Figure 2 | ||||
|     visualize_relative_info(None, to_save_dir, "tss") | ||||
|     visualize_relative_info(None, to_save_dir, "sss") | ||||
|  | ||||
|     # Figure 4 | ||||
|     visualize_rank_info(None, to_save_dir, "tss") | ||||
|     visualize_rank_info(None, to_save_dir, "sss") | ||||
|  | ||||
|     # Figure 5 | ||||
|     visualize_all_rank_info(None, to_save_dir, "tss") | ||||
|     visualize_all_rank_info(None, to_save_dir, "sss") | ||||
							
								
								
									
										225
									
								
								AutoDL-Projects/exps/NATS-Bench/draw-fig6.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										225
									
								
								AutoDL-Projects/exps/NATS-Bench/draw-fig6.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,225 @@ | ||||
| ############################################################### | ||||
| # NATS-Bench (arxiv.org/pdf/2009.00437.pdf), IEEE TPAMI 2021  # | ||||
| # The code to draw Figure 6 in our paper.                     # | ||||
| ############################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06           # | ||||
| ############################################################### | ||||
| # Usage: python exps/NATS-Bench/draw-fig6.py --search_space tss | ||||
| # Usage: python exps/NATS-Bench/draw-fig6.py --search_space sss | ||||
| ############################################################### | ||||
| import os, gc, sys, time, torch, argparse | ||||
| import numpy as np | ||||
| from typing import List, Text, Dict, Any | ||||
| from shutil import copyfile | ||||
| from collections import defaultdict, OrderedDict | ||||
| from copy import deepcopy | ||||
| from pathlib import Path | ||||
| import matplotlib | ||||
| import seaborn as sns | ||||
|  | ||||
| matplotlib.use("agg") | ||||
| import matplotlib.pyplot as plt | ||||
| import matplotlib.ticker as ticker | ||||
|  | ||||
| from xautodl.config_utils import dict2config, load_config | ||||
| from xautodl.log_utils import time_string | ||||
| from nats_bench import create | ||||
|  | ||||
|  | ||||
| def fetch_data(root_dir="./output/search", search_space="tss", dataset=None): | ||||
|     ss_dir = "{:}-{:}".format(root_dir, search_space) | ||||
|     alg2name, alg2path = OrderedDict(), OrderedDict() | ||||
|     alg2name["REA"] = "R-EA-SS3" | ||||
|     alg2name["REINFORCE"] = "REINFORCE-0.01" | ||||
|     alg2name["RANDOM"] = "RANDOM" | ||||
|     alg2name["BOHB"] = "BOHB" | ||||
|     for alg, name in alg2name.items(): | ||||
|         alg2path[alg] = os.path.join(ss_dir, dataset, name, "results.pth") | ||||
|         assert os.path.isfile(alg2path[alg]), "invalid path : {:}".format(alg2path[alg]) | ||||
|     alg2data = OrderedDict() | ||||
|     for alg, path in alg2path.items(): | ||||
|         data = torch.load(path) | ||||
|         for index, info in data.items(): | ||||
|             info["time_w_arch"] = [ | ||||
|                 (x, y) for x, y in zip(info["all_total_times"], info["all_archs"]) | ||||
|             ] | ||||
|             for j, arch in enumerate(info["all_archs"]): | ||||
|                 assert arch != -1, "invalid arch from {:} {:} {:} ({:}, {:})".format( | ||||
|                     alg, search_space, dataset, index, j | ||||
|                 ) | ||||
|         alg2data[alg] = data | ||||
|     return alg2data | ||||
|  | ||||
|  | ||||
| def query_performance(api, data, dataset, ticket): | ||||
|     results, is_size_space = [], api.search_space_name == "size" | ||||
|     for i, info in data.items(): | ||||
|         time_w_arch = sorted(info["time_w_arch"], key=lambda x: abs(x[0] - ticket)) | ||||
|         time_a, arch_a = time_w_arch[0] | ||||
|         time_b, arch_b = time_w_arch[1] | ||||
|         info_a = api.get_more_info( | ||||
|             arch_a, dataset=dataset, hp=90 if is_size_space else 200, is_random=False | ||||
|         ) | ||||
|         info_b = api.get_more_info( | ||||
|             arch_b, dataset=dataset, hp=90 if is_size_space else 200, is_random=False | ||||
|         ) | ||||
|         accuracy_a, accuracy_b = info_a["test-accuracy"], info_b["test-accuracy"] | ||||
|         interplate = (time_b - ticket) / (time_b - time_a) * accuracy_a + ( | ||||
|             ticket - time_a | ||||
|         ) / (time_b - time_a) * accuracy_b | ||||
|         results.append(interplate) | ||||
|     # return sum(results) / len(results) | ||||
|     return np.mean(results), np.std(results) | ||||
|  | ||||
|  | ||||
| def show_valid_test(api, data, dataset): | ||||
|     valid_accs, test_accs, is_size_space = [], [], api.search_space_name == "size" | ||||
|     for i, info in data.items(): | ||||
|         time, arch = info["time_w_arch"][-1] | ||||
|         if dataset == "cifar10": | ||||
|             xinfo = api.get_more_info( | ||||
|                 arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False | ||||
|             ) | ||||
|             test_accs.append(xinfo["test-accuracy"]) | ||||
|             xinfo = api.get_more_info( | ||||
|                 arch, | ||||
|                 dataset="cifar10-valid", | ||||
|                 hp=90 if is_size_space else 200, | ||||
|                 is_random=False, | ||||
|             ) | ||||
|             valid_accs.append(xinfo["valid-accuracy"]) | ||||
|         else: | ||||
|             xinfo = api.get_more_info( | ||||
|                 arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False | ||||
|             ) | ||||
|             valid_accs.append(xinfo["valid-accuracy"]) | ||||
|             test_accs.append(xinfo["test-accuracy"]) | ||||
|     valid_str = "{:.2f}$\pm${:.2f}".format(np.mean(valid_accs), np.std(valid_accs)) | ||||
|     test_str = "{:.2f}$\pm${:.2f}".format(np.mean(test_accs), np.std(test_accs)) | ||||
|     return valid_str, test_str | ||||
|  | ||||
|  | ||||
| y_min_s = { | ||||
|     ("cifar10", "tss"): 90, | ||||
|     ("cifar10", "sss"): 92, | ||||
|     ("cifar100", "tss"): 65, | ||||
|     ("cifar100", "sss"): 65, | ||||
|     ("ImageNet16-120", "tss"): 36, | ||||
|     ("ImageNet16-120", "sss"): 40, | ||||
| } | ||||
|  | ||||
| y_max_s = { | ||||
|     ("cifar10", "tss"): 94.3, | ||||
|     ("cifar10", "sss"): 93.3, | ||||
|     ("cifar100", "tss"): 72.5, | ||||
|     ("cifar100", "sss"): 70.5, | ||||
|     ("ImageNet16-120", "tss"): 46, | ||||
|     ("ImageNet16-120", "sss"): 46, | ||||
| } | ||||
|  | ||||
| x_axis_s = { | ||||
|     ("cifar10", "tss"): 200, | ||||
|     ("cifar10", "sss"): 200, | ||||
|     ("cifar100", "tss"): 400, | ||||
|     ("cifar100", "sss"): 400, | ||||
|     ("ImageNet16-120", "tss"): 1200, | ||||
|     ("ImageNet16-120", "sss"): 600, | ||||
| } | ||||
|  | ||||
| name2label = { | ||||
|     "cifar10": "CIFAR-10", | ||||
|     "cifar100": "CIFAR-100", | ||||
|     "ImageNet16-120": "ImageNet-16-120", | ||||
| } | ||||
|  | ||||
|  | ||||
| def visualize_curve(api, vis_save_dir, search_space): | ||||
|     vis_save_dir = vis_save_dir.resolve() | ||||
|     vis_save_dir.mkdir(parents=True, exist_ok=True) | ||||
|  | ||||
|     dpi, width, height = 250, 5200, 1400 | ||||
|     figsize = width / float(dpi), height / float(dpi) | ||||
|     LabelSize, LegendFontsize = 16, 16 | ||||
|  | ||||
|     def sub_plot_fn(ax, dataset): | ||||
|         xdataset, max_time = dataset.split("-T") | ||||
|         alg2data = fetch_data(search_space=search_space, dataset=dataset) | ||||
|         alg2accuracies = OrderedDict() | ||||
|         total_tickets = 150 | ||||
|         time_tickets = [ | ||||
|             float(i) / total_tickets * int(max_time) for i in range(total_tickets) | ||||
|         ] | ||||
|         colors = ["b", "g", "c", "m", "y"] | ||||
|         ax.set_xlim(0, x_axis_s[(xdataset, search_space)]) | ||||
|         ax.set_ylim( | ||||
|             y_min_s[(xdataset, search_space)], y_max_s[(xdataset, search_space)] | ||||
|         ) | ||||
|         for idx, (alg, data) in enumerate(alg2data.items()): | ||||
|             accuracies = [] | ||||
|             for ticket in time_tickets: | ||||
|                 accuracy, accuracy_std = query_performance(api, data, xdataset, ticket) | ||||
|                 accuracies.append(accuracy) | ||||
|             valid_str, test_str = show_valid_test(api, data, xdataset) | ||||
|             # print('{:} plot alg : {:10s}, final accuracy = {:.2f}$\pm${:.2f}'.format(time_string(), alg, accuracy, accuracy_std)) | ||||
|             print( | ||||
|                 "{:} plot alg : {:10s}  | validation = {:} | test = {:}".format( | ||||
|                     time_string(), alg, valid_str, test_str | ||||
|                 ) | ||||
|             ) | ||||
|             alg2accuracies[alg] = accuracies | ||||
|             ax.plot( | ||||
|                 [x / 100 for x in time_tickets], | ||||
|                 accuracies, | ||||
|                 c=colors[idx], | ||||
|                 label="{:}".format(alg), | ||||
|             ) | ||||
|             ax.set_xlabel("Estimated wall-clock time (1e2 seconds)", fontsize=LabelSize) | ||||
|             ax.set_ylabel( | ||||
|                 "Test accuracy on {:}".format(name2label[xdataset]), fontsize=LabelSize | ||||
|             ) | ||||
|             ax.set_title( | ||||
|                 "Searching results on {:}".format(name2label[xdataset]), | ||||
|                 fontsize=LabelSize + 4, | ||||
|             ) | ||||
|         ax.legend(loc=4, fontsize=LegendFontsize) | ||||
|  | ||||
|     fig, axs = plt.subplots(1, 3, figsize=figsize) | ||||
|     # datasets = ['cifar10', 'cifar100', 'ImageNet16-120'] | ||||
|     if search_space == "tss": | ||||
|         datasets = ["cifar10-T20000", "cifar100-T40000", "ImageNet16-120-T120000"] | ||||
|     elif search_space == "sss": | ||||
|         datasets = ["cifar10-T20000", "cifar100-T40000", "ImageNet16-120-T60000"] | ||||
|     else: | ||||
|         raise ValueError("Unknown search space: {:}".format(search_space)) | ||||
|     for dataset, ax in zip(datasets, axs): | ||||
|         sub_plot_fn(ax, dataset) | ||||
|         print("sub-plot {:} on {:} done.".format(dataset, search_space)) | ||||
|     save_path = (vis_save_dir / "{:}-curve.png".format(search_space)).resolve() | ||||
|     fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") | ||||
|     print("{:} save into {:}".format(time_string(), save_path)) | ||||
|     plt.close("all") | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="output/vis-nas-bench/nas-algos", | ||||
|         help="Folder to save checkpoints and log.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--search_space", | ||||
|         type=str, | ||||
|         choices=["tss", "sss"], | ||||
|         help="Choose the search space.", | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     save_dir = Path(args.save_dir) | ||||
|  | ||||
|     api = create(None, args.search_space, fast_mode=True, verbose=False) | ||||
|     visualize_curve(api, save_dir, args.search_space) | ||||
							
								
								
									
										250
									
								
								AutoDL-Projects/exps/NATS-Bench/draw-fig7.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										250
									
								
								AutoDL-Projects/exps/NATS-Bench/draw-fig7.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,250 @@ | ||||
| ############################################################### | ||||
| # NATS-Bench (arxiv.org/pdf/2009.00437.pdf), IEEE TPAMI 2021  # | ||||
| # The code to draw Figure 7 in our paper.                     # | ||||
| ############################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06           # | ||||
| ############################################################### | ||||
| # Usage: python exps/NATS-Bench/draw-fig7.py                  # | ||||
| ############################################################### | ||||
| import os, gc, sys, time, torch, argparse | ||||
| import numpy as np | ||||
| from typing import List, Text, Dict, Any | ||||
| from shutil import copyfile | ||||
| from collections import defaultdict, OrderedDict | ||||
| from copy import deepcopy | ||||
| from pathlib import Path | ||||
| import matplotlib | ||||
| import seaborn as sns | ||||
|  | ||||
| matplotlib.use("agg") | ||||
| import matplotlib.pyplot as plt | ||||
| import matplotlib.ticker as ticker | ||||
|  | ||||
| from xautodl.config_utils import dict2config, load_config | ||||
| from xautodl.log_utils import time_string | ||||
| from nats_bench import create | ||||
|  | ||||
|  | ||||
| def get_valid_test_acc(api, arch, dataset): | ||||
|     is_size_space = api.search_space_name == "size" | ||||
|     if dataset == "cifar10": | ||||
|         xinfo = api.get_more_info( | ||||
|             arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False | ||||
|         ) | ||||
|         test_acc = xinfo["test-accuracy"] | ||||
|         xinfo = api.get_more_info( | ||||
|             arch, | ||||
|             dataset="cifar10-valid", | ||||
|             hp=90 if is_size_space else 200, | ||||
|             is_random=False, | ||||
|         ) | ||||
|         valid_acc = xinfo["valid-accuracy"] | ||||
|     else: | ||||
|         xinfo = api.get_more_info( | ||||
|             arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False | ||||
|         ) | ||||
|         valid_acc = xinfo["valid-accuracy"] | ||||
|         test_acc = xinfo["test-accuracy"] | ||||
|     return ( | ||||
|         valid_acc, | ||||
|         test_acc, | ||||
|         "validation = {:.2f}, test = {:.2f}\n".format(valid_acc, test_acc), | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def fetch_data( | ||||
|     root_dir="./output/search", search_space="tss", dataset=None, suffix="-WARM0.3" | ||||
| ): | ||||
|     ss_dir = "{:}-{:}".format(root_dir, search_space) | ||||
|     alg2name, alg2path = OrderedDict(), OrderedDict() | ||||
|     seeds = [777, 888, 999] | ||||
|     print("\n[fetch data] from {:} on {:}".format(search_space, dataset)) | ||||
|     if search_space == "tss": | ||||
|         alg2name["GDAS"] = "gdas-affine0_BN0-None" | ||||
|         alg2name["RSPS"] = "random-affine0_BN0-None" | ||||
|         alg2name["DARTS (1st)"] = "darts-v1-affine0_BN0-None" | ||||
|         alg2name["DARTS (2nd)"] = "darts-v2-affine0_BN0-None" | ||||
|         alg2name["ENAS"] = "enas-affine0_BN0-None" | ||||
|         alg2name["SETN"] = "setn-affine0_BN0-None" | ||||
|     else: | ||||
|         alg2name["channel-wise interpolation"] = "tas-affine0_BN0-AWD0.001{:}".format( | ||||
|             suffix | ||||
|         ) | ||||
|         alg2name[ | ||||
|             "masking + Gumbel-Softmax" | ||||
|         ] = "mask_gumbel-affine0_BN0-AWD0.001{:}".format(suffix) | ||||
|         alg2name["masking + sampling"] = "mask_rl-affine0_BN0-AWD0.0{:}".format(suffix) | ||||
|     for alg, name in alg2name.items(): | ||||
|         alg2path[alg] = os.path.join(ss_dir, dataset, name, "seed-{:}-last-info.pth") | ||||
|     alg2data = OrderedDict() | ||||
|     for alg, path in alg2path.items(): | ||||
|         alg2data[alg], ok_num = [], 0 | ||||
|         for seed in seeds: | ||||
|             xpath = path.format(seed) | ||||
|             if os.path.isfile(xpath): | ||||
|                 ok_num += 1 | ||||
|             else: | ||||
|                 print("This is an invalid path : {:}".format(xpath)) | ||||
|                 continue | ||||
|             data = torch.load(xpath, map_location=torch.device("cpu")) | ||||
|             try: | ||||
|                 data = torch.load( | ||||
|                     data["last_checkpoint"], map_location=torch.device("cpu") | ||||
|                 ) | ||||
|             except: | ||||
|                 xpath = str(data["last_checkpoint"]).split("E100-") | ||||
|                 if len(xpath) == 2 and os.path.isfile(xpath[0] + xpath[1]): | ||||
|                     xpath = xpath[0] + xpath[1] | ||||
|                 elif "fbv2" in str(data["last_checkpoint"]): | ||||
|                     xpath = str(data["last_checkpoint"]).replace("fbv2", "mask_gumbel") | ||||
|                 elif "tunas" in str(data["last_checkpoint"]): | ||||
|                     xpath = str(data["last_checkpoint"]).replace("tunas", "mask_rl") | ||||
|                 else: | ||||
|                     raise ValueError( | ||||
|                         "Invalid path: {:}".format(data["last_checkpoint"]) | ||||
|                     ) | ||||
|                 data = torch.load(xpath, map_location=torch.device("cpu")) | ||||
|             alg2data[alg].append(data["genotypes"]) | ||||
|         print("This algorithm : {:} has {:} valid ckps.".format(alg, ok_num)) | ||||
|         assert ok_num > 0, "Must have at least 1 valid ckps." | ||||
|     return alg2data | ||||
|  | ||||
|  | ||||
| y_min_s = { | ||||
|     ("cifar10", "tss"): 90, | ||||
|     ("cifar10", "sss"): 92, | ||||
|     ("cifar100", "tss"): 65, | ||||
|     ("cifar100", "sss"): 65, | ||||
|     ("ImageNet16-120", "tss"): 36, | ||||
|     ("ImageNet16-120", "sss"): 40, | ||||
| } | ||||
|  | ||||
| y_max_s = { | ||||
|     ("cifar10", "tss"): 94.5, | ||||
|     ("cifar10", "sss"): 93.3, | ||||
|     ("cifar100", "tss"): 72, | ||||
|     ("cifar100", "sss"): 70, | ||||
|     ("ImageNet16-120", "tss"): 44, | ||||
|     ("ImageNet16-120", "sss"): 46, | ||||
| } | ||||
|  | ||||
| name2label = { | ||||
|     "cifar10": "CIFAR-10", | ||||
|     "cifar100": "CIFAR-100", | ||||
|     "ImageNet16-120": "ImageNet-16-120", | ||||
| } | ||||
|  | ||||
| name2suffix = { | ||||
|     ("sss", "warm"): "-WARM0.3", | ||||
|     ("sss", "none"): "-WARMNone", | ||||
|     ("tss", "none"): None, | ||||
|     ("tss", None): None, | ||||
| } | ||||
|  | ||||
|  | ||||
| def visualize_curve(api, vis_save_dir, search_space, suffix): | ||||
|     vis_save_dir = vis_save_dir.resolve() | ||||
|     vis_save_dir.mkdir(parents=True, exist_ok=True) | ||||
|  | ||||
|     dpi, width, height = 250, 5200, 1400 | ||||
|     figsize = width / float(dpi), height / float(dpi) | ||||
|     LabelSize, LegendFontsize = 16, 16 | ||||
|  | ||||
|     def sub_plot_fn(ax, dataset): | ||||
|         print("{:} plot {:10s}".format(time_string(), dataset)) | ||||
|         alg2data = fetch_data( | ||||
|             search_space=search_space, | ||||
|             dataset=dataset, | ||||
|             suffix=name2suffix[(search_space, suffix)], | ||||
|         ) | ||||
|         alg2accuracies = OrderedDict() | ||||
|         epochs = 100 | ||||
|         colors = ["b", "g", "c", "m", "y", "r"] | ||||
|         ax.set_xlim(0, epochs) | ||||
|         # ax.set_ylim(y_min_s[(dataset, search_space)], y_max_s[(dataset, search_space)]) | ||||
|         for idx, (alg, data) in enumerate(alg2data.items()): | ||||
|             xs, accuracies = [], [] | ||||
|             for iepoch in range(epochs + 1): | ||||
|                 try: | ||||
|                     structures, accs = [_[iepoch - 1] for _ in data], [] | ||||
|                 except: | ||||
|                     raise ValueError( | ||||
|                         "This alg {:} on {:} has invalid checkpoints.".format( | ||||
|                             alg, dataset | ||||
|                         ) | ||||
|                     ) | ||||
|                 for structure in structures: | ||||
|                     info = api.get_more_info( | ||||
|                         structure, | ||||
|                         dataset=dataset, | ||||
|                         hp=90 if api.search_space_name == "size" else 200, | ||||
|                         is_random=False, | ||||
|                     ) | ||||
|                     accs.append(info["test-accuracy"]) | ||||
|                 accuracies.append(sum(accs) / len(accs)) | ||||
|                 xs.append(iepoch) | ||||
|             alg2accuracies[alg] = accuracies | ||||
|             ax.plot(xs, accuracies, c=colors[idx], label="{:}".format(alg)) | ||||
|             ax.set_xlabel("The searching epoch", fontsize=LabelSize) | ||||
|             ax.set_ylabel( | ||||
|                 "Test accuracy on {:}".format(name2label[dataset]), fontsize=LabelSize | ||||
|             ) | ||||
|             ax.set_title( | ||||
|                 "Searching results on {:}".format(name2label[dataset]), | ||||
|                 fontsize=LabelSize + 4, | ||||
|             ) | ||||
|             structures, valid_accs, test_accs = [_[epochs - 1] for _ in data], [], [] | ||||
|             print( | ||||
|                 "{:} plot alg : {:} -- final {:} architectures.".format( | ||||
|                     time_string(), alg, len(structures) | ||||
|                 ) | ||||
|             ) | ||||
|             for arch in structures: | ||||
|                 valid_acc, test_acc, _ = get_valid_test_acc(api, arch, dataset) | ||||
|                 test_accs.append(test_acc) | ||||
|                 valid_accs.append(valid_acc) | ||||
|             print( | ||||
|                 "{:} plot alg : {:} -- validation: {:.2f}$\pm${:.2f} -- test: {:.2f}$\pm${:.2f}".format( | ||||
|                     time_string(), | ||||
|                     alg, | ||||
|                     np.mean(valid_accs), | ||||
|                     np.std(valid_accs), | ||||
|                     np.mean(test_accs), | ||||
|                     np.std(test_accs), | ||||
|                 ) | ||||
|             ) | ||||
|         ax.legend(loc=4, fontsize=LegendFontsize) | ||||
|  | ||||
|     fig, axs = plt.subplots(1, 3, figsize=figsize) | ||||
|     datasets = ["cifar10", "cifar100", "ImageNet16-120"] | ||||
|     for dataset, ax in zip(datasets, axs): | ||||
|         sub_plot_fn(ax, dataset) | ||||
|         print("sub-plot {:} on {:} done.".format(dataset, search_space)) | ||||
|     save_path = ( | ||||
|         vis_save_dir / "{:}-ws-{:}-curve.png".format(search_space, suffix) | ||||
|     ).resolve() | ||||
|     fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") | ||||
|     print("{:} save into {:}".format(time_string(), save_path)) | ||||
|     plt.close("all") | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="NATS-Bench", formatter_class=argparse.ArgumentDefaultsHelpFormatter | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="output/vis-nas-bench/nas-algos", | ||||
|         help="Folder to save checkpoints and log.", | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     save_dir = Path(args.save_dir) | ||||
|  | ||||
|     api_tss = create(None, "tss", fast_mode=True, verbose=False) | ||||
|     visualize_curve(api_tss, save_dir, "tss", None) | ||||
|  | ||||
|     api_sss = create(None, "sss", fast_mode=True, verbose=False) | ||||
|     visualize_curve(api_sss, save_dir, "sss", "warm") | ||||
|     visualize_curve(api_sss, save_dir, "sss", "none") | ||||
							
								
								
									
										232
									
								
								AutoDL-Projects/exps/NATS-Bench/draw-fig8.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										232
									
								
								AutoDL-Projects/exps/NATS-Bench/draw-fig8.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,232 @@ | ||||
| ############################################################### | ||||
| # NATS-Bench (arxiv.org/pdf/2009.00437.pdf), IEEE TPAMI 2021  # | ||||
| # The code to draw Figure 6 in our paper.                     # | ||||
| ############################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06           # | ||||
| ############################################################### | ||||
| # Usage: python exps/NATS-Bench/draw-fig8.py                  # | ||||
| ############################################################### | ||||
| import os, gc, sys, time, torch, argparse | ||||
| import numpy as np | ||||
| from typing import List, Text, Dict, Any | ||||
| from shutil import copyfile | ||||
| from collections import defaultdict, OrderedDict | ||||
| from copy import deepcopy | ||||
| from pathlib import Path | ||||
| import matplotlib | ||||
| import seaborn as sns | ||||
|  | ||||
| matplotlib.use("agg") | ||||
| import matplotlib.pyplot as plt | ||||
| import matplotlib.ticker as ticker | ||||
|  | ||||
| from xautodl.config_utils import dict2config, load_config | ||||
| from xautodl.log_utils import time_string | ||||
| from nats_bench import create | ||||
|  | ||||
|  | ||||
| plt.rcParams.update( | ||||
|     {"text.usetex": True, "font.family": "sans-serif", "font.sans-serif": ["Helvetica"]} | ||||
| ) | ||||
| ## for Palatino and other serif fonts use: | ||||
| plt.rcParams.update( | ||||
|     { | ||||
|         "text.usetex": True, | ||||
|         "font.family": "serif", | ||||
|         "font.serif": ["Palatino"], | ||||
|     } | ||||
| ) | ||||
|  | ||||
|  | ||||
| def fetch_data(root_dir="./output/search", search_space="tss", dataset=None): | ||||
|     ss_dir = "{:}-{:}".format(root_dir, search_space) | ||||
|     alg2all = OrderedDict() | ||||
|     # alg2name['REINFORCE'] = 'REINFORCE-0.01' | ||||
|     # alg2name['RANDOM'] = 'RANDOM' | ||||
|     # alg2name['BOHB'] = 'BOHB' | ||||
|     if search_space == "tss": | ||||
|         hp = "$\mathcal{H}^{1}$" | ||||
|         if dataset == "cifar10": | ||||
|             suffixes = ["-T1200000", "-T1200000-FULL"] | ||||
|     elif search_space == "sss": | ||||
|         hp = "$\mathcal{H}^{2}$" | ||||
|         if dataset == "cifar10": | ||||
|             suffixes = ["-T200000", "-T200000-FULL"] | ||||
|     else: | ||||
|         raise ValueError("Unkonwn search space: {:}".format(search_space)) | ||||
|  | ||||
|     alg2all[r"REA ($\mathcal{H}^{0}$)"] = dict( | ||||
|         path=os.path.join(ss_dir, dataset + suffixes[0], "R-EA-SS3", "results.pth"), | ||||
|         color="b", | ||||
|         linestyle="-", | ||||
|     ) | ||||
|     alg2all[r"REA ({:})".format(hp)] = dict( | ||||
|         path=os.path.join(ss_dir, dataset + suffixes[1], "R-EA-SS3", "results.pth"), | ||||
|         color="b", | ||||
|         linestyle="--", | ||||
|     ) | ||||
|  | ||||
|     for alg, xdata in alg2all.items(): | ||||
|         data = torch.load(xdata["path"]) | ||||
|         for index, info in data.items(): | ||||
|             info["time_w_arch"] = [ | ||||
|                 (x, y) for x, y in zip(info["all_total_times"], info["all_archs"]) | ||||
|             ] | ||||
|             for j, arch in enumerate(info["all_archs"]): | ||||
|                 assert arch != -1, "invalid arch from {:} {:} {:} ({:}, {:})".format( | ||||
|                     alg, search_space, dataset, index, j | ||||
|                 ) | ||||
|         xdata["data"] = data | ||||
|     return alg2all | ||||
|  | ||||
|  | ||||
| def query_performance(api, data, dataset, ticket): | ||||
|     results, is_size_space = [], api.search_space_name == "size" | ||||
|     for i, info in data.items(): | ||||
|         time_w_arch = sorted(info["time_w_arch"], key=lambda x: abs(x[0] - ticket)) | ||||
|         time_a, arch_a = time_w_arch[0] | ||||
|         time_b, arch_b = time_w_arch[1] | ||||
|         info_a = api.get_more_info( | ||||
|             arch_a, dataset=dataset, hp=90 if is_size_space else 200, is_random=False | ||||
|         ) | ||||
|         info_b = api.get_more_info( | ||||
|             arch_b, dataset=dataset, hp=90 if is_size_space else 200, is_random=False | ||||
|         ) | ||||
|         accuracy_a, accuracy_b = info_a["test-accuracy"], info_b["test-accuracy"] | ||||
|         interplate = (time_b - ticket) / (time_b - time_a) * accuracy_a + ( | ||||
|             ticket - time_a | ||||
|         ) / (time_b - time_a) * accuracy_b | ||||
|         results.append(interplate) | ||||
|     # return sum(results) / len(results) | ||||
|     return np.mean(results), np.std(results) | ||||
|  | ||||
|  | ||||
| y_min_s = { | ||||
|     ("cifar10", "tss"): 91, | ||||
|     ("cifar10", "sss"): 91, | ||||
|     ("cifar100", "tss"): 65, | ||||
|     ("cifar100", "sss"): 65, | ||||
|     ("ImageNet16-120", "tss"): 36, | ||||
|     ("ImageNet16-120", "sss"): 40, | ||||
| } | ||||
|  | ||||
| y_max_s = { | ||||
|     ("cifar10", "tss"): 94.5, | ||||
|     ("cifar10", "sss"): 93.5, | ||||
|     ("cifar100", "tss"): 72.5, | ||||
|     ("cifar100", "sss"): 70.5, | ||||
|     ("ImageNet16-120", "tss"): 46, | ||||
|     ("ImageNet16-120", "sss"): 46, | ||||
| } | ||||
|  | ||||
| x_axis_s = { | ||||
|     ("cifar10", "tss"): 1200000, | ||||
|     ("cifar10", "sss"): 200000, | ||||
|     ("cifar100", "tss"): 400, | ||||
|     ("cifar100", "sss"): 400, | ||||
|     ("ImageNet16-120", "tss"): 1200, | ||||
|     ("ImageNet16-120", "sss"): 600, | ||||
| } | ||||
|  | ||||
| name2label = { | ||||
|     "cifar10": "CIFAR-10", | ||||
|     "cifar100": "CIFAR-100", | ||||
|     "ImageNet16-120": "ImageNet-16-120", | ||||
| } | ||||
|  | ||||
| spaces2latex = { | ||||
|     "tss": r"$\mathcal{S}_{t}$", | ||||
|     "sss": r"$\mathcal{S}_{s}$", | ||||
| } | ||||
|  | ||||
|  | ||||
| # FuncFormatter can be used as a decorator | ||||
| @ticker.FuncFormatter | ||||
| def major_formatter(x, pos): | ||||
|     if x == 0: | ||||
|         return "0" | ||||
|     else: | ||||
|         return "{:.2f}e5".format(x / 1e5) | ||||
|  | ||||
|  | ||||
| def visualize_curve(api_dict, vis_save_dir): | ||||
|     vis_save_dir = vis_save_dir.resolve() | ||||
|     vis_save_dir.mkdir(parents=True, exist_ok=True) | ||||
|  | ||||
|     dpi, width, height = 250, 5000, 2000 | ||||
|     figsize = width / float(dpi), height / float(dpi) | ||||
|     LabelSize, LegendFontsize = 28, 28 | ||||
|  | ||||
|     def sub_plot_fn(ax, search_space, dataset): | ||||
|         max_time = x_axis_s[(dataset, search_space)] | ||||
|         alg2data = fetch_data(search_space=search_space, dataset=dataset) | ||||
|         alg2accuracies = OrderedDict() | ||||
|         total_tickets = 200 | ||||
|         time_tickets = [ | ||||
|             float(i) / total_tickets * int(max_time) for i in range(total_tickets) | ||||
|         ] | ||||
|         ax.set_xlim(0, x_axis_s[(dataset, search_space)]) | ||||
|         ax.set_ylim(y_min_s[(dataset, search_space)], y_max_s[(dataset, search_space)]) | ||||
|         for tick in ax.get_xticklabels(): | ||||
|             tick.set_rotation(25) | ||||
|             tick.set_fontsize(LabelSize - 6) | ||||
|         for tick in ax.get_yticklabels(): | ||||
|             tick.set_fontsize(LabelSize - 6) | ||||
|         ax.xaxis.set_major_formatter(major_formatter) | ||||
|         for idx, (alg, xdata) in enumerate(alg2data.items()): | ||||
|             accuracies = [] | ||||
|             for ticket in time_tickets: | ||||
|                 # import pdb; pdb.set_trace() | ||||
|                 accuracy, accuracy_std = query_performance( | ||||
|                     api_dict[search_space], xdata["data"], dataset, ticket | ||||
|                 ) | ||||
|                 accuracies.append(accuracy) | ||||
|             # print('{:} plot alg : {:10s}, final accuracy = {:.2f}$\pm${:.2f}'.format(time_string(), alg, accuracy, accuracy_std)) | ||||
|             print( | ||||
|                 "{:} plot alg : {:10s} on {:}".format(time_string(), alg, search_space) | ||||
|             ) | ||||
|             alg2accuracies[alg] = accuracies | ||||
|             ax.plot( | ||||
|                 time_tickets, | ||||
|                 accuracies, | ||||
|                 c=xdata["color"], | ||||
|                 linestyle=xdata["linestyle"], | ||||
|                 label="{:}".format(alg), | ||||
|             ) | ||||
|             ax.set_xlabel("Estimated wall-clock time", fontsize=LabelSize) | ||||
|             ax.set_ylabel("Test accuracy", fontsize=LabelSize) | ||||
|             ax.set_title( | ||||
|                 r"Results on {:} over {:}".format( | ||||
|                     name2label[dataset], spaces2latex[search_space] | ||||
|                 ), | ||||
|                 fontsize=LabelSize, | ||||
|             ) | ||||
|         ax.legend(loc=4, fontsize=LegendFontsize) | ||||
|  | ||||
|     fig, axs = plt.subplots(1, 2, figsize=figsize) | ||||
|     sub_plot_fn(axs[0], "tss", "cifar10") | ||||
|     sub_plot_fn(axs[1], "sss", "cifar10") | ||||
|     save_path = (vis_save_dir / "full-curve.png").resolve() | ||||
|     fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") | ||||
|     print("{:} save into {:}".format(time_string(), save_path)) | ||||
|     plt.close("all") | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="output/vis-nas-bench/nas-algos-vs-h", | ||||
|         help="Folder to save checkpoints and log.", | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     save_dir = Path(args.save_dir) | ||||
|  | ||||
|     api_tss = create(None, "tss", fast_mode=True, verbose=False) | ||||
|     api_sss = create(None, "sss", fast_mode=True, verbose=False) | ||||
|     visualize_curve(dict(tss=api_tss, sss=api_sss), save_dir) | ||||
							
								
								
									
										185
									
								
								AutoDL-Projects/exps/NATS-Bench/draw-ranks.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										185
									
								
								AutoDL-Projects/exps/NATS-Bench/draw-ranks.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,185 @@ | ||||
| ############################################################### | ||||
| # NATS-Bench (arxiv.org/pdf/2009.00437.pdf), IEEE TPAMI 2021  # | ||||
| # The code to draw Figure 2 / 3 / 4 / 5 in our paper.         # | ||||
| ############################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06           # | ||||
| ############################################################### | ||||
| # Usage: python exps/NATS-Bench/draw-ranks.py                 # | ||||
| ############################################################### | ||||
| import os, sys, time, torch, argparse | ||||
| import scipy | ||||
| import numpy as np | ||||
| from typing import List, Text, Dict, Any | ||||
| from shutil import copyfile | ||||
| from collections import defaultdict, OrderedDict | ||||
| from copy import deepcopy | ||||
| from pathlib import Path | ||||
| import matplotlib | ||||
| import seaborn as sns | ||||
|  | ||||
| matplotlib.use("agg") | ||||
| import matplotlib.pyplot as plt | ||||
| import matplotlib.ticker as ticker | ||||
|  | ||||
| from xautodl.config_utils import dict2config, load_config | ||||
| from xautodl.log_utils import time_string | ||||
| from xautodl.models import get_cell_based_tiny_net | ||||
| from nats_bench import create | ||||
|  | ||||
|  | ||||
| name2label = { | ||||
|     "cifar10": "CIFAR-10", | ||||
|     "cifar100": "CIFAR-100", | ||||
|     "ImageNet16-120": "ImageNet-16-120", | ||||
| } | ||||
|  | ||||
|  | ||||
| def visualize_relative_info(vis_save_dir, search_space, indicator, topk): | ||||
|     vis_save_dir = vis_save_dir.resolve() | ||||
|     print( | ||||
|         "{:} start to visualize {:} with top-{:} information".format( | ||||
|             time_string(), search_space, topk | ||||
|         ) | ||||
|     ) | ||||
|     vis_save_dir.mkdir(parents=True, exist_ok=True) | ||||
|     cache_file_path = vis_save_dir / "cache-{:}-info.pth".format(search_space) | ||||
|     datasets = ["cifar10", "cifar100", "ImageNet16-120"] | ||||
|     if not cache_file_path.exists(): | ||||
|         api = create(None, search_space, fast_mode=False, verbose=False) | ||||
|         all_infos = OrderedDict() | ||||
|         for index in range(len(api)): | ||||
|             all_info = OrderedDict() | ||||
|             for dataset in datasets: | ||||
|                 info_less = api.get_more_info(index, dataset, hp="12", is_random=False) | ||||
|                 info_more = api.get_more_info( | ||||
|                     index, dataset, hp=api.full_train_epochs, is_random=False | ||||
|                 ) | ||||
|                 all_info[dataset] = dict( | ||||
|                     less=info_less["test-accuracy"], more=info_more["test-accuracy"] | ||||
|                 ) | ||||
|             all_infos[index] = all_info | ||||
|         torch.save(all_infos, cache_file_path) | ||||
|         print("{:} save all cache data into {:}".format(time_string(), cache_file_path)) | ||||
|     else: | ||||
|         api = create(None, search_space, fast_mode=True, verbose=False) | ||||
|         all_infos = torch.load(cache_file_path) | ||||
|  | ||||
|     dpi, width, height = 250, 5000, 1300 | ||||
|     figsize = width / float(dpi), height / float(dpi) | ||||
|     LabelSize, LegendFontsize = 16, 16 | ||||
|  | ||||
|     fig, axs = plt.subplots(1, 3, figsize=figsize) | ||||
|     datasets = ["cifar10", "cifar100", "ImageNet16-120"] | ||||
|  | ||||
|     def sub_plot_fn(ax, dataset, indicator): | ||||
|         performances = [] | ||||
|         # pickup top 10% architectures | ||||
|         for _index in range(len(api)): | ||||
|             performances.append((all_infos[_index][dataset][indicator], _index)) | ||||
|         performances = sorted(performances, reverse=True) | ||||
|         performances = performances[: int(len(api) * topk * 0.01)] | ||||
|         selected_indexes = [x[1] for x in performances] | ||||
|         print( | ||||
|             "{:} plot {:10s} with {:}, {:} architectures".format( | ||||
|                 time_string(), dataset, indicator, len(selected_indexes) | ||||
|             ) | ||||
|         ) | ||||
|         standard_scores = [] | ||||
|         random_scores = [] | ||||
|         for idx in selected_indexes: | ||||
|             standard_scores.append( | ||||
|                 api.get_more_info( | ||||
|                     idx, | ||||
|                     dataset, | ||||
|                     hp=api.full_train_epochs if indicator == "more" else "12", | ||||
|                     is_random=False, | ||||
|                 )["test-accuracy"] | ||||
|             ) | ||||
|             random_scores.append( | ||||
|                 api.get_more_info( | ||||
|                     idx, | ||||
|                     dataset, | ||||
|                     hp=api.full_train_epochs if indicator == "more" else "12", | ||||
|                     is_random=True, | ||||
|                 )["test-accuracy"] | ||||
|             ) | ||||
|         indexes = list(range(len(selected_indexes))) | ||||
|         standard_indexes = sorted(indexes, key=lambda i: standard_scores[i]) | ||||
|         random_indexes = sorted(indexes, key=lambda i: random_scores[i]) | ||||
|         random_labels = [] | ||||
|         for idx in standard_indexes: | ||||
|             random_labels.append(random_indexes.index(idx)) | ||||
|         for tick in ax.get_xticklabels(): | ||||
|             tick.set_fontsize(LabelSize - 3) | ||||
|         for tick in ax.get_yticklabels(): | ||||
|             tick.set_rotation(25) | ||||
|             tick.set_fontsize(LabelSize - 3) | ||||
|         ax.set_xlim(0, len(indexes)) | ||||
|         ax.set_ylim(0, len(indexes)) | ||||
|         ax.set_yticks(np.arange(min(indexes), max(indexes), max(indexes) // 3)) | ||||
|         ax.set_xticks(np.arange(min(indexes), max(indexes), max(indexes) // 5)) | ||||
|         ax.scatter(indexes, random_labels, marker="^", s=0.5, c="tab:green", alpha=0.8) | ||||
|         ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8) | ||||
|         ax.scatter( | ||||
|             [-1], | ||||
|             [-1], | ||||
|             marker="o", | ||||
|             s=100, | ||||
|             c="tab:blue", | ||||
|             label="Average Over Multi-Trials", | ||||
|         ) | ||||
|         ax.scatter( | ||||
|             [-1], | ||||
|             [-1], | ||||
|             marker="^", | ||||
|             s=100, | ||||
|             c="tab:green", | ||||
|             label="Randomly Selected Trial", | ||||
|         ) | ||||
|  | ||||
|         coef, p = scipy.stats.kendalltau(standard_scores, random_scores) | ||||
|         ax.set_xlabel( | ||||
|             "architecture ranking in {:}".format(name2label[dataset]), | ||||
|             fontsize=LabelSize, | ||||
|         ) | ||||
|         if dataset == "cifar10": | ||||
|             ax.set_ylabel("architecture ranking", fontsize=LabelSize) | ||||
|         ax.legend(loc=4, fontsize=LegendFontsize) | ||||
|         return coef | ||||
|  | ||||
|     for dataset, ax in zip(datasets, axs): | ||||
|         rank_coef = sub_plot_fn(ax, dataset, indicator) | ||||
|         print( | ||||
|             "sub-plot {:} on {:} done, the ranking coefficient is {:.4f}.".format( | ||||
|                 dataset, search_space, rank_coef | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|     save_path = ( | ||||
|         vis_save_dir / "{:}-rank-{:}-top{:}.pdf".format(search_space, indicator, topk) | ||||
|     ).resolve() | ||||
|     fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf") | ||||
|     save_path = ( | ||||
|         vis_save_dir / "{:}-rank-{:}-top{:}.png".format(search_space, indicator, topk) | ||||
|     ).resolve() | ||||
|     fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") | ||||
|     print("Save into {:}".format(save_path)) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="NATS-Bench", formatter_class=argparse.ArgumentDefaultsHelpFormatter | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="output/vis-nas-bench/rank-stability", | ||||
|         help="Folder to save checkpoints and log.", | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|     to_save_dir = Path(args.save_dir) | ||||
|  | ||||
|     for topk in [1, 5, 10, 20]: | ||||
|         visualize_relative_info(to_save_dir, "tss", "more", topk) | ||||
|         visualize_relative_info(to_save_dir, "sss", "less", topk) | ||||
|     print("{:} : complete running this file : {:}".format(time_string(), __file__)) | ||||
							
								
								
									
										191
									
								
								AutoDL-Projects/exps/NATS-Bench/draw-table.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										191
									
								
								AutoDL-Projects/exps/NATS-Bench/draw-table.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,191 @@ | ||||
| ############################################################### | ||||
| # NATS-Bench (arxiv.org/pdf/2009.00437.pdf), IEEE TPAMI 2021  # | ||||
| # The code to draw some results in Table 4 in our paper.      # | ||||
| ############################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06           # | ||||
| ############################################################### | ||||
| # Usage: python exps/NATS-Bench/draw-table.py                 # | ||||
| ############################################################### | ||||
| import os, gc, sys, time, torch, argparse | ||||
| import numpy as np | ||||
| from typing import List, Text, Dict, Any | ||||
| from shutil import copyfile | ||||
| from collections import defaultdict, OrderedDict | ||||
| from copy import deepcopy | ||||
| from pathlib import Path | ||||
| import matplotlib | ||||
| import seaborn as sns | ||||
|  | ||||
| matplotlib.use("agg") | ||||
| import matplotlib.pyplot as plt | ||||
| import matplotlib.ticker as ticker | ||||
|  | ||||
| from xautodl.config_utils import dict2config, load_config | ||||
| from xautodl.log_utils import time_string | ||||
| from nats_bench import create | ||||
|  | ||||
|  | ||||
| def fetch_data(root_dir="./output/search", search_space="tss", dataset=None): | ||||
|     ss_dir = "{:}-{:}".format(root_dir, search_space) | ||||
|     alg2name, alg2path = OrderedDict(), OrderedDict() | ||||
|     alg2name["REA"] = "R-EA-SS3" | ||||
|     alg2name["REINFORCE"] = "REINFORCE-0.01" | ||||
|     alg2name["RANDOM"] = "RANDOM" | ||||
|     alg2name["BOHB"] = "BOHB" | ||||
|     for alg, name in alg2name.items(): | ||||
|         alg2path[alg] = os.path.join(ss_dir, dataset, name, "results.pth") | ||||
|         assert os.path.isfile(alg2path[alg]), "invalid path : {:}".format(alg2path[alg]) | ||||
|     alg2data = OrderedDict() | ||||
|     for alg, path in alg2path.items(): | ||||
|         data = torch.load(path) | ||||
|         for index, info in data.items(): | ||||
|             info["time_w_arch"] = [ | ||||
|                 (x, y) for x, y in zip(info["all_total_times"], info["all_archs"]) | ||||
|             ] | ||||
|             for j, arch in enumerate(info["all_archs"]): | ||||
|                 assert arch != -1, "invalid arch from {:} {:} {:} ({:}, {:})".format( | ||||
|                     alg, search_space, dataset, index, j | ||||
|                 ) | ||||
|         alg2data[alg] = data | ||||
|     return alg2data | ||||
|  | ||||
|  | ||||
| def get_valid_test_acc(api, arch, dataset): | ||||
|     is_size_space = api.search_space_name == "size" | ||||
|     if dataset == "cifar10": | ||||
|         xinfo = api.get_more_info( | ||||
|             arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False | ||||
|         ) | ||||
|         test_acc = xinfo["test-accuracy"] | ||||
|         xinfo = api.get_more_info( | ||||
|             arch, | ||||
|             dataset="cifar10-valid", | ||||
|             hp=90 if is_size_space else 200, | ||||
|             is_random=False, | ||||
|         ) | ||||
|         valid_acc = xinfo["valid-accuracy"] | ||||
|     else: | ||||
|         xinfo = api.get_more_info( | ||||
|             arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False | ||||
|         ) | ||||
|         valid_acc = xinfo["valid-accuracy"] | ||||
|         test_acc = xinfo["test-accuracy"] | ||||
|     return ( | ||||
|         valid_acc, | ||||
|         test_acc, | ||||
|         "validation = {:.2f}, test = {:.2f}\n".format(valid_acc, test_acc), | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def show_valid_test(api, arch): | ||||
|     is_size_space = api.search_space_name == "size" | ||||
|     final_str = "" | ||||
|     for dataset in ["cifar10", "cifar100", "ImageNet16-120"]: | ||||
|         valid_acc, test_acc, perf_str = get_valid_test_acc(api, arch, dataset) | ||||
|         final_str += "{:} : {:}\n".format(dataset, perf_str) | ||||
|     return final_str | ||||
|  | ||||
|  | ||||
| def find_best_valid(api, dataset): | ||||
|     all_valid_accs, all_test_accs = [], [] | ||||
|     for index, arch in enumerate(api): | ||||
|         valid_acc, test_acc, perf_str = get_valid_test_acc(api, index, dataset) | ||||
|         all_valid_accs.append((index, valid_acc)) | ||||
|         all_test_accs.append((index, test_acc)) | ||||
|     best_valid_index = sorted(all_valid_accs, key=lambda x: -x[1])[0][0] | ||||
|     best_test_index = sorted(all_test_accs, key=lambda x: -x[1])[0][0] | ||||
|  | ||||
|     print("-" * 50 + "{:10s}".format(dataset) + "-" * 50) | ||||
|     print( | ||||
|         "Best ({:}) architecture on validation: {:}".format( | ||||
|             best_valid_index, api[best_valid_index] | ||||
|         ) | ||||
|     ) | ||||
|     print( | ||||
|         "Best ({:}) architecture on       test: {:}".format( | ||||
|             best_test_index, api[best_test_index] | ||||
|         ) | ||||
|     ) | ||||
|     _, _, perf_str = get_valid_test_acc(api, best_valid_index, dataset) | ||||
|     print("using validation ::: {:}".format(perf_str)) | ||||
|     _, _, perf_str = get_valid_test_acc(api, best_test_index, dataset) | ||||
|     print("using test       ::: {:}".format(perf_str)) | ||||
|  | ||||
|  | ||||
| def interplate_fn(xpair1, xpair2, x): | ||||
|     (x1, y1) = xpair1 | ||||
|     (x2, y2) = xpair2 | ||||
|     return (x2 - x) / (x2 - x1) * y1 + (x - x1) / (x2 - x1) * y2 | ||||
|  | ||||
|  | ||||
| def query_performance(api, info, dataset, ticket): | ||||
|     info = deepcopy(info) | ||||
|     results, is_size_space = [], api.search_space_name == "size" | ||||
|     time_w_arch = sorted(info["time_w_arch"], key=lambda x: abs(x[0] - ticket)) | ||||
|     time_a, arch_a = time_w_arch[0] | ||||
|     time_b, arch_b = time_w_arch[1] | ||||
|  | ||||
|     v_acc_a, t_acc_a, _ = get_valid_test_acc(api, arch_a, dataset) | ||||
|     v_acc_b, t_acc_b, _ = get_valid_test_acc(api, arch_b, dataset) | ||||
|     v_acc = interplate_fn((time_a, v_acc_a), (time_b, v_acc_b), ticket) | ||||
|     t_acc = interplate_fn((time_a, t_acc_a), (time_b, t_acc_b), ticket) | ||||
|     # if True: | ||||
|     #   interplate = (time_b-ticket) / (time_b-time_a) * accuracy_a + (ticket-time_a) / (time_b-time_a) * accuracy_b | ||||
|     #   results.append(interplate) | ||||
|     # return sum(results) / len(results) | ||||
|     return v_acc, t_acc | ||||
|  | ||||
|  | ||||
| def show_multi_trial(search_space): | ||||
|     api = create(None, search_space, fast_mode=True, verbose=False) | ||||
|  | ||||
|     def show(dataset): | ||||
|         print("show {:} on {:} done.".format(dataset, search_space)) | ||||
|         xdataset, max_time = dataset.split("-T") | ||||
|         alg2data = fetch_data(search_space=search_space, dataset=dataset) | ||||
|         for idx, (alg, data) in enumerate(alg2data.items()): | ||||
|  | ||||
|             valid_accs, test_accs = [], [] | ||||
|             for _, x in data.items(): | ||||
|                 v_acc, t_acc = query_performance(api, x, xdataset, float(max_time)) | ||||
|                 valid_accs.append(v_acc) | ||||
|                 test_accs.append(t_acc) | ||||
|             valid_str = "{:.2f}$\pm${:.2f}".format( | ||||
|                 np.mean(valid_accs), np.std(valid_accs) | ||||
|             ) | ||||
|             test_str = "{:.2f}$\pm${:.2f}".format(np.mean(test_accs), np.std(test_accs)) | ||||
|             print( | ||||
|                 "{:} plot alg : {:10s}  | validation = {:} | test = {:}".format( | ||||
|                     time_string(), alg, valid_str, test_str | ||||
|                 ) | ||||
|             ) | ||||
|  | ||||
|     if search_space == "tss": | ||||
|         datasets = ["cifar10-T20000", "cifar100-T40000", "ImageNet16-120-T120000"] | ||||
|     elif search_space == "sss": | ||||
|         datasets = ["cifar10-T20000", "cifar100-T40000", "ImageNet16-120-T60000"] | ||||
|     else: | ||||
|         raise ValueError("Unknown search space: {:}".format(search_space)) | ||||
|     for dataset in datasets: | ||||
|         show(dataset) | ||||
|     print("{:} complete show multi-trial results.\n".format(time_string())) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|  | ||||
|     show_multi_trial("tss") | ||||
|     show_multi_trial("sss") | ||||
|  | ||||
|     api_tss = create(None, "tss", fast_mode=False, verbose=False) | ||||
|     resnet = "|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|" | ||||
|     resnet_index = api_tss.query_index_by_arch(resnet) | ||||
|     print(show_valid_test(api_tss, resnet_index)) | ||||
|  | ||||
|     for dataset in ["cifar10", "cifar100", "ImageNet16-120"]: | ||||
|         find_best_valid(api_tss, dataset) | ||||
|  | ||||
|     largest = "64:64:64:64:64" | ||||
|     largest_index = api_sss.query_index_by_arch(largest) | ||||
|     print(show_valid_test(api_sss, largest_index)) | ||||
|     for dataset in ["cifar10", "cifar100", "ImageNet16-120"]: | ||||
|         find_best_valid(api_sss, dataset) | ||||
							
								
								
									
										486
									
								
								AutoDL-Projects/exps/NATS-Bench/main-sss.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										486
									
								
								AutoDL-Projects/exps/NATS-Bench/main-sss.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,486 @@ | ||||
| ############################################################################## | ||||
| # NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size # | ||||
| ############################################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.07                          # | ||||
| ############################################################################## | ||||
| # This file is used to train (all) architecture candidate in the size search # | ||||
| # space in NATS-Bench (sss) with different hyper-parameters.                 # | ||||
| # When use mode=new, it will automatically detect whether the checkpoint of  # | ||||
| # a trial exists, if so, it will skip this trial. When use mode=cover, it    # | ||||
| # will ignore the (possible) existing checkpoint, run each trial, and save.  # | ||||
| # (NOTE): the topology for all candidates in sss is fixed as:                ###################### | ||||
| # |nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2| # | ||||
| ################################################################################################### | ||||
| # Please use the script of scripts/NATS-Bench/train-shapes.sh to run.        # | ||||
| ############################################################################## | ||||
| import os, sys, time, torch, argparse | ||||
| from typing import List, Text, Dict, Any | ||||
| from PIL import ImageFile | ||||
|  | ||||
| ImageFile.LOAD_TRUNCATED_IMAGES = True | ||||
| from copy import deepcopy | ||||
| from pathlib import Path | ||||
|  | ||||
| from xautodl.config_utils import dict2config, load_config | ||||
| from xautodl.procedures import bench_evaluate_for_seed | ||||
| from xautodl.procedures import get_machine_info | ||||
| from xautodl.datasets import get_datasets | ||||
| from xautodl.log_utils import Logger, AverageMeter, time_string, convert_secs2time | ||||
| from xautodl.utils import split_str2indexes | ||||
|  | ||||
|  | ||||
| def evaluate_all_datasets( | ||||
|     channels: Text, | ||||
|     datasets: List[Text], | ||||
|     xpaths: List[Text], | ||||
|     splits: List[Text], | ||||
|     config_path: Text, | ||||
|     seed: int, | ||||
|     workers: int, | ||||
|     logger, | ||||
| ): | ||||
|     machine_info = get_machine_info() | ||||
|     all_infos = {"info": machine_info} | ||||
|     all_dataset_keys = [] | ||||
|     # look all the dataset | ||||
|     for dataset, xpath, split in zip(datasets, xpaths, splits): | ||||
|         # the train and valid data | ||||
|         train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1) | ||||
|         # load the configuration | ||||
|         if dataset == "cifar10" or dataset == "cifar100": | ||||
|             split_info = load_config( | ||||
|                 "configs/nas-benchmark/cifar-split.txt", None, None | ||||
|             ) | ||||
|         elif dataset.startswith("ImageNet16"): | ||||
|             split_info = load_config( | ||||
|                 "configs/nas-benchmark/{:}-split.txt".format(dataset), None, None | ||||
|             ) | ||||
|         else: | ||||
|             raise ValueError("invalid dataset : {:}".format(dataset)) | ||||
|         config = load_config( | ||||
|             config_path, dict(class_num=class_num, xshape=xshape), logger | ||||
|         ) | ||||
|         # check whether use the splitted validation set | ||||
|         if bool(split): | ||||
|             assert dataset == "cifar10" | ||||
|             ValLoaders = { | ||||
|                 "ori-test": torch.utils.data.DataLoader( | ||||
|                     valid_data, | ||||
|                     batch_size=config.batch_size, | ||||
|                     shuffle=False, | ||||
|                     num_workers=workers, | ||||
|                     pin_memory=True, | ||||
|                 ) | ||||
|             } | ||||
|             assert len(train_data) == len(split_info.train) + len( | ||||
|                 split_info.valid | ||||
|             ), "invalid length : {:} vs {:} + {:}".format( | ||||
|                 len(train_data), len(split_info.train), len(split_info.valid) | ||||
|             ) | ||||
|             train_data_v2 = deepcopy(train_data) | ||||
|             train_data_v2.transform = valid_data.transform | ||||
|             valid_data = train_data_v2 | ||||
|             # data loader | ||||
|             train_loader = torch.utils.data.DataLoader( | ||||
|                 train_data, | ||||
|                 batch_size=config.batch_size, | ||||
|                 sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train), | ||||
|                 num_workers=workers, | ||||
|                 pin_memory=True, | ||||
|             ) | ||||
|             valid_loader = torch.utils.data.DataLoader( | ||||
|                 valid_data, | ||||
|                 batch_size=config.batch_size, | ||||
|                 sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid), | ||||
|                 num_workers=workers, | ||||
|                 pin_memory=True, | ||||
|             ) | ||||
|             ValLoaders["x-valid"] = valid_loader | ||||
|         else: | ||||
|             # data loader | ||||
|             train_loader = torch.utils.data.DataLoader( | ||||
|                 train_data, | ||||
|                 batch_size=config.batch_size, | ||||
|                 shuffle=True, | ||||
|                 num_workers=workers, | ||||
|                 pin_memory=True, | ||||
|             ) | ||||
|             valid_loader = torch.utils.data.DataLoader( | ||||
|                 valid_data, | ||||
|                 batch_size=config.batch_size, | ||||
|                 shuffle=False, | ||||
|                 num_workers=workers, | ||||
|                 pin_memory=True, | ||||
|             ) | ||||
|             if dataset == "cifar10": | ||||
|                 ValLoaders = {"ori-test": valid_loader} | ||||
|             elif dataset == "cifar100": | ||||
|                 cifar100_splits = load_config( | ||||
|                     "configs/nas-benchmark/cifar100-test-split.txt", None, None | ||||
|                 ) | ||||
|                 ValLoaders = { | ||||
|                     "ori-test": valid_loader, | ||||
|                     "x-valid": torch.utils.data.DataLoader( | ||||
|                         valid_data, | ||||
|                         batch_size=config.batch_size, | ||||
|                         sampler=torch.utils.data.sampler.SubsetRandomSampler( | ||||
|                             cifar100_splits.xvalid | ||||
|                         ), | ||||
|                         num_workers=workers, | ||||
|                         pin_memory=True, | ||||
|                     ), | ||||
|                     "x-test": torch.utils.data.DataLoader( | ||||
|                         valid_data, | ||||
|                         batch_size=config.batch_size, | ||||
|                         sampler=torch.utils.data.sampler.SubsetRandomSampler( | ||||
|                             cifar100_splits.xtest | ||||
|                         ), | ||||
|                         num_workers=workers, | ||||
|                         pin_memory=True, | ||||
|                     ), | ||||
|                 } | ||||
|             elif dataset == "ImageNet16-120": | ||||
|                 imagenet16_splits = load_config( | ||||
|                     "configs/nas-benchmark/imagenet-16-120-test-split.txt", None, None | ||||
|                 ) | ||||
|                 ValLoaders = { | ||||
|                     "ori-test": valid_loader, | ||||
|                     "x-valid": torch.utils.data.DataLoader( | ||||
|                         valid_data, | ||||
|                         batch_size=config.batch_size, | ||||
|                         sampler=torch.utils.data.sampler.SubsetRandomSampler( | ||||
|                             imagenet16_splits.xvalid | ||||
|                         ), | ||||
|                         num_workers=workers, | ||||
|                         pin_memory=True, | ||||
|                     ), | ||||
|                     "x-test": torch.utils.data.DataLoader( | ||||
|                         valid_data, | ||||
|                         batch_size=config.batch_size, | ||||
|                         sampler=torch.utils.data.sampler.SubsetRandomSampler( | ||||
|                             imagenet16_splits.xtest | ||||
|                         ), | ||||
|                         num_workers=workers, | ||||
|                         pin_memory=True, | ||||
|                     ), | ||||
|                 } | ||||
|             else: | ||||
|                 raise ValueError("invalid dataset : {:}".format(dataset)) | ||||
|  | ||||
|         dataset_key = "{:}".format(dataset) | ||||
|         if bool(split): | ||||
|             dataset_key = dataset_key + "-valid" | ||||
|         logger.log( | ||||
|             "Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format( | ||||
|                 dataset_key, | ||||
|                 len(train_data), | ||||
|                 len(valid_data), | ||||
|                 len(train_loader), | ||||
|                 len(valid_loader), | ||||
|                 config.batch_size, | ||||
|             ) | ||||
|         ) | ||||
|         logger.log( | ||||
|             "Evaluate ||||||| {:10s} ||||||| Config={:}".format(dataset_key, config) | ||||
|         ) | ||||
|         for key, value in ValLoaders.items(): | ||||
|             logger.log( | ||||
|                 "Evaluate ---->>>> {:10s} with {:} batchs".format(key, len(value)) | ||||
|             ) | ||||
|         # arch-index= 9930, arch=|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2| | ||||
|         # this genotype is the architecture with the highest accuracy on CIFAR-100 validation set | ||||
|         genotype = "|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|" | ||||
|         arch_config = dict2config( | ||||
|             dict( | ||||
|                 name="infer.shape.tiny", | ||||
|                 channels=channels, | ||||
|                 genotype=genotype, | ||||
|                 num_classes=class_num, | ||||
|             ), | ||||
|             None, | ||||
|         ) | ||||
|         results = bench_evaluate_for_seed( | ||||
|             arch_config, config, train_loader, ValLoaders, seed, logger | ||||
|         ) | ||||
|         all_infos[dataset_key] = results | ||||
|         all_dataset_keys.append(dataset_key) | ||||
|     all_infos["all_dataset_keys"] = all_dataset_keys | ||||
|     return all_infos | ||||
|  | ||||
|  | ||||
| def main( | ||||
|     save_dir: Path, | ||||
|     workers: int, | ||||
|     datasets: List[Text], | ||||
|     xpaths: List[Text], | ||||
|     splits: List[int], | ||||
|     seeds: List[int], | ||||
|     nets: List[str], | ||||
|     opt_config: Dict[Text, Any], | ||||
|     to_evaluate_indexes: tuple, | ||||
|     cover_mode: bool, | ||||
| ): | ||||
|  | ||||
|     log_dir = save_dir / "logs" | ||||
|     log_dir.mkdir(parents=True, exist_ok=True) | ||||
|     logger = Logger(str(log_dir), os.getpid(), False) | ||||
|  | ||||
|     logger.log("xargs : seeds      = {:}".format(seeds)) | ||||
|     logger.log("xargs : cover_mode = {:}".format(cover_mode)) | ||||
|     logger.log("-" * 100) | ||||
|     logger.log( | ||||
|         "Start evaluating range =: {:06d} - {:06d}".format( | ||||
|             min(to_evaluate_indexes), max(to_evaluate_indexes) | ||||
|         ) | ||||
|         + "({:} in total) / {:06d} with cover-mode={:}".format( | ||||
|             len(to_evaluate_indexes), len(nets), cover_mode | ||||
|         ) | ||||
|     ) | ||||
|     for i, (dataset, xpath, split) in enumerate(zip(datasets, xpaths, splits)): | ||||
|         logger.log( | ||||
|             "--->>> Evaluate {:}/{:} : dataset={:9s}, path={:}, split={:}".format( | ||||
|                 i, len(datasets), dataset, xpath, split | ||||
|             ) | ||||
|         ) | ||||
|     logger.log("--->>> optimization config : {:}".format(opt_config)) | ||||
|  | ||||
|     start_time, epoch_time = time.time(), AverageMeter() | ||||
|     for i, index in enumerate(to_evaluate_indexes): | ||||
|         channelstr = nets[index] | ||||
|         logger.log( | ||||
|             "\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}] {:}".format( | ||||
|                 time_string(), | ||||
|                 i, | ||||
|                 len(to_evaluate_indexes), | ||||
|                 index, | ||||
|                 len(nets), | ||||
|                 seeds, | ||||
|                 "-" * 15, | ||||
|             ) | ||||
|         ) | ||||
|         logger.log("{:} {:} {:}".format("-" * 15, channelstr, "-" * 15)) | ||||
|  | ||||
|         # test this arch on different datasets with different seeds | ||||
|         has_continue = False | ||||
|         for seed in seeds: | ||||
|             to_save_name = save_dir / "arch-{:06d}-seed-{:04d}.pth".format(index, seed) | ||||
|             if to_save_name.exists(): | ||||
|                 if cover_mode: | ||||
|                     logger.log( | ||||
|                         "Find existing file : {:}, remove it before evaluation".format( | ||||
|                             to_save_name | ||||
|                         ) | ||||
|                     ) | ||||
|                     os.remove(str(to_save_name)) | ||||
|                 else: | ||||
|                     logger.log( | ||||
|                         "Find existing file : {:}, skip this evaluation".format( | ||||
|                             to_save_name | ||||
|                         ) | ||||
|                     ) | ||||
|                     has_continue = True | ||||
|                     continue | ||||
|             results = evaluate_all_datasets( | ||||
|                 channelstr, datasets, xpaths, splits, opt_config, seed, workers, logger | ||||
|             ) | ||||
|             torch.save(results, to_save_name) | ||||
|             logger.log( | ||||
|                 "\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}] ===>>> {:}".format( | ||||
|                     time_string(), | ||||
|                     i, | ||||
|                     len(to_evaluate_indexes), | ||||
|                     index, | ||||
|                     len(nets), | ||||
|                     seeds, | ||||
|                     to_save_name, | ||||
|                 ) | ||||
|             ) | ||||
|         # measure elapsed time | ||||
|         if not has_continue: | ||||
|             epoch_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|         need_time = "Time Left: {:}".format( | ||||
|             convert_secs2time(epoch_time.avg * (len(to_evaluate_indexes) - i - 1), True) | ||||
|         ) | ||||
|         logger.log( | ||||
|             "This arch costs : {:}".format(convert_secs2time(epoch_time.val, True)) | ||||
|         ) | ||||
|         logger.log("{:}".format("*" * 100)) | ||||
|         logger.log( | ||||
|             "{:}   {:74s}   {:}".format( | ||||
|                 "*" * 10, | ||||
|                 "{:06d}/{:06d} ({:06d}/{:06d})-th done, left {:}".format( | ||||
|                     i, len(to_evaluate_indexes), index, len(nets), need_time | ||||
|                 ), | ||||
|                 "*" * 10, | ||||
|             ) | ||||
|         ) | ||||
|         logger.log("{:}".format("*" * 100)) | ||||
|  | ||||
|     logger.close() | ||||
|  | ||||
|  | ||||
| def traverse_net(candidates: List[int], N: int): | ||||
|     nets = [""] | ||||
|     for i in range(N): | ||||
|         new_nets = [] | ||||
|         for net in nets: | ||||
|             for C in candidates: | ||||
|                 new_nets.append(str(C) if net == "" else "{:}:{:}".format(net, C)) | ||||
|         nets = new_nets | ||||
|     return nets | ||||
|  | ||||
|  | ||||
| def filter_indexes(xlist, mode, save_dir, seeds): | ||||
|     all_indexes = [] | ||||
|     for index in xlist: | ||||
|         if mode == "cover": | ||||
|             all_indexes.append(index) | ||||
|         else: | ||||
|             for seed in seeds: | ||||
|                 temp_path = save_dir / "arch-{:06d}-seed-{:04d}.pth".format(index, seed) | ||||
|                 if not temp_path.exists(): | ||||
|                     all_indexes.append(index) | ||||
|                     break | ||||
|     print( | ||||
|         "{:} [FILTER-INDEXES] : there are {:}/{:} architectures in total".format( | ||||
|             time_string(), len(all_indexes), len(xlist) | ||||
|         ) | ||||
|     ) | ||||
|  | ||||
|     SLURM_PROCID, SLURM_NTASKS = "SLURM_PROCID", "SLURM_NTASKS" | ||||
|     if SLURM_PROCID in os.environ and SLURM_NTASKS in os.environ:  # run on the slurm | ||||
|         proc_id, ntasks = int(os.environ[SLURM_PROCID]), int(os.environ[SLURM_NTASKS]) | ||||
|         assert 0 <= proc_id < ntasks, "invalid proc_id {:} vs ntasks {:}".format( | ||||
|             proc_id, ntasks | ||||
|         ) | ||||
|         scales = [int(float(i) / ntasks * len(all_indexes)) for i in range(ntasks)] + [ | ||||
|             len(all_indexes) | ||||
|         ] | ||||
|         per_job = [] | ||||
|         for i in range(ntasks): | ||||
|             xs, xe = min(max(scales[i], 0), len(all_indexes) - 1), min( | ||||
|                 max(scales[i + 1] - 1, 0), len(all_indexes) - 1 | ||||
|             ) | ||||
|             per_job.append((xs, xe)) | ||||
|         for i, srange in enumerate(per_job): | ||||
|             print("  -->> {:2d}/{:02d} : {:}".format(i, ntasks, srange)) | ||||
|         current_range = per_job[proc_id] | ||||
|         all_indexes = [ | ||||
|             all_indexes[i] for i in range(current_range[0], current_range[1] + 1) | ||||
|         ] | ||||
|         # set the device id | ||||
|         device = proc_id % torch.cuda.device_count() | ||||
|         torch.cuda.set_device(device) | ||||
|         print("  set the device id = {:}".format(device)) | ||||
|     print( | ||||
|         "{:} [FILTER-INDEXES] : after filtering there are {:} architectures in total".format( | ||||
|             time_string(), len(all_indexes) | ||||
|         ) | ||||
|     ) | ||||
|     return all_indexes | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="NATS-Bench (size search space)", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--mode", | ||||
|         type=str, | ||||
|         required=True, | ||||
|         choices=["new", "cover"], | ||||
|         help="The script mode.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="output/NATS-Bench-size", | ||||
|         help="Folder to save checkpoints and log.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--candidateC", | ||||
|         type=int, | ||||
|         nargs="+", | ||||
|         default=[8, 16, 24, 32, 40, 48, 56, 64], | ||||
|         help=".", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--num_layers", type=int, default=5, help="The number of layers in a network." | ||||
|     ) | ||||
|     parser.add_argument("--check_N", type=int, default=32768, help="For safety.") | ||||
|     # use for train the model | ||||
|     parser.add_argument( | ||||
|         "--workers", | ||||
|         type=int, | ||||
|         default=8, | ||||
|         help="The number of data loading workers (default: 2)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--srange", type=str, required=True, help="The range of models to be evaluated" | ||||
|     ) | ||||
|     parser.add_argument("--datasets", type=str, nargs="+", help="The applied datasets.") | ||||
|     parser.add_argument( | ||||
|         "--xpaths", type=str, nargs="+", help="The root path for this dataset." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--splits", type=int, nargs="+", help="The root path for this dataset." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--hyper", | ||||
|         type=str, | ||||
|         default="12", | ||||
|         choices=["01", "12", "90"], | ||||
|         help="The tag for hyper-parameters.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--seeds", type=int, nargs="+", help="The range of models to be evaluated" | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     nets = traverse_net(args.candidateC, args.num_layers) | ||||
|     if len(nets) != args.check_N: | ||||
|         raise ValueError( | ||||
|             "Pre-num-check failed : {:} vs {:}".format(len(nets), args.check_N) | ||||
|         ) | ||||
|  | ||||
|     opt_config = "./configs/nas-benchmark/hyper-opts/{:}E.config".format(args.hyper) | ||||
|     if not os.path.isfile(opt_config): | ||||
|         raise ValueError("{:} is not a file.".format(opt_config)) | ||||
|     save_dir = Path(args.save_dir) / "raw-data-{:}".format(args.hyper) | ||||
|     save_dir.mkdir(parents=True, exist_ok=True) | ||||
|     to_evaluate_indexes = split_str2indexes(args.srange, args.check_N, 5) | ||||
|  | ||||
|     if not len(args.seeds): | ||||
|         raise ValueError("invalid length of seeds args: {:}".format(args.seeds)) | ||||
|     if not (len(args.datasets) == len(args.xpaths) == len(args.splits)): | ||||
|         raise ValueError( | ||||
|             "invalid infos : {:} vs {:} vs {:}".format( | ||||
|                 len(args.datasets), len(args.xpaths), len(args.splits) | ||||
|             ) | ||||
|         ) | ||||
|     if args.workers <= 0: | ||||
|         raise ValueError("invalid number of workers : {:}".format(args.workers)) | ||||
|  | ||||
|     target_indexes = filter_indexes( | ||||
|         to_evaluate_indexes, args.mode, save_dir, args.seeds | ||||
|     ) | ||||
|  | ||||
|     assert torch.cuda.is_available(), "CUDA is not available." | ||||
|     torch.backends.cudnn.enabled = True | ||||
|     torch.backends.cudnn.deterministic = True | ||||
|     # torch.set_num_threads(args.workers) | ||||
|  | ||||
|     main( | ||||
|         save_dir, | ||||
|         args.workers, | ||||
|         args.datasets, | ||||
|         args.xpaths, | ||||
|         args.splits, | ||||
|         tuple(args.seeds), | ||||
|         nets, | ||||
|         opt_config, | ||||
|         target_indexes, | ||||
|         args.mode == "cover", | ||||
|     ) | ||||
							
								
								
									
										696
									
								
								AutoDL-Projects/exps/NATS-Bench/main-tss.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										696
									
								
								AutoDL-Projects/exps/NATS-Bench/main-tss.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,696 @@ | ||||
| ############################################################################## | ||||
| # NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size # | ||||
| ############################################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.07                          # | ||||
| ############################################################################## | ||||
| # This file is used to train (all) architecture candidate in the topology    # | ||||
| # search space in NATS-Bench (tss) with different hyper-parameters.          # | ||||
| # When use mode=new, it will automatically detect whether the checkpoint of  # | ||||
| # a trial exists, if so, it will skip this trial. When use mode=cover, it    # | ||||
| # will ignore the (possible) existing checkpoint, run each trial, and save.  # | ||||
| ############################################################################## | ||||
| # Please use the script of scripts/NATS-Bench/train-topology.sh to run.      # | ||||
| # bash scripts/NATS-Bench/train-topology.sh 00000-15624 12 777               # | ||||
| # bash scripts/NATS-Bench/train-topology.sh 00000-15624 200 '777 888 999'    # | ||||
| #                                                                            # | ||||
| ################                                                             # | ||||
| # [Deprecated Function: Generate the meta information]                       # | ||||
| # python ./exps/NATS-Bench/main-tss.py --mode meta                           # | ||||
| ############################################################################## | ||||
| import os, sys, time, torch, random, argparse | ||||
| from typing import List, Text, Dict, Any | ||||
| from PIL import ImageFile | ||||
|  | ||||
| ImageFile.LOAD_TRUNCATED_IMAGES = True | ||||
| from copy import deepcopy | ||||
| from pathlib import Path | ||||
|  | ||||
| from xautodl.config_utils import dict2config, load_config | ||||
| from xautodl.procedures import bench_evaluate_for_seed | ||||
| from xautodl.procedures import get_machine_info | ||||
| from xautodl.datasets import get_datasets | ||||
| from xautodl.log_utils import Logger, AverageMeter, time_string, convert_secs2time | ||||
| from xautodl.models import CellStructure, CellArchitectures, get_search_spaces | ||||
| from xautodl.utils import split_str2indexes | ||||
|  | ||||
|  | ||||
| def evaluate_all_datasets( | ||||
|     arch: Text, | ||||
|     datasets: List[Text], | ||||
|     xpaths: List[Text], | ||||
|     splits: List[Text], | ||||
|     config_path: Text, | ||||
|     seed: int, | ||||
|     raw_arch_config, | ||||
|     workers, | ||||
|     logger, | ||||
| ): | ||||
|     machine_info, raw_arch_config = get_machine_info(), deepcopy(raw_arch_config) | ||||
|     all_infos = {"info": machine_info} | ||||
|     all_dataset_keys = [] | ||||
|     # look all the datasets | ||||
|     for dataset, xpath, split in zip(datasets, xpaths, splits): | ||||
|         # train valid data | ||||
|         train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1) | ||||
|         # load the configuration | ||||
|         if dataset == "cifar10" or dataset == "cifar100": | ||||
|             split_info = load_config( | ||||
|                 "configs/nas-benchmark/cifar-split.txt", None, None | ||||
|             ) | ||||
|         elif dataset.startswith("ImageNet16"): | ||||
|             split_info = load_config( | ||||
|                 "configs/nas-benchmark/{:}-split.txt".format(dataset), None, None | ||||
|             ) | ||||
|         else: | ||||
|             raise ValueError("invalid dataset : {:}".format(dataset)) | ||||
|         config = load_config( | ||||
|             config_path, dict(class_num=class_num, xshape=xshape), logger | ||||
|         ) | ||||
|         # check whether use splited validation set | ||||
|         if bool(split): | ||||
|             assert dataset == "cifar10" | ||||
|             ValLoaders = { | ||||
|                 "ori-test": torch.utils.data.DataLoader( | ||||
|                     valid_data, | ||||
|                     batch_size=config.batch_size, | ||||
|                     shuffle=False, | ||||
|                     num_workers=workers, | ||||
|                     pin_memory=True, | ||||
|                 ) | ||||
|             } | ||||
|             assert len(train_data) == len(split_info.train) + len( | ||||
|                 split_info.valid | ||||
|             ), "invalid length : {:} vs {:} + {:}".format( | ||||
|                 len(train_data), len(split_info.train), len(split_info.valid) | ||||
|             ) | ||||
|             train_data_v2 = deepcopy(train_data) | ||||
|             train_data_v2.transform = valid_data.transform | ||||
|             valid_data = train_data_v2 | ||||
|             # data loader | ||||
|             train_loader = torch.utils.data.DataLoader( | ||||
|                 train_data, | ||||
|                 batch_size=config.batch_size, | ||||
|                 sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train), | ||||
|                 num_workers=workers, | ||||
|                 pin_memory=True, | ||||
|             ) | ||||
|             valid_loader = torch.utils.data.DataLoader( | ||||
|                 valid_data, | ||||
|                 batch_size=config.batch_size, | ||||
|                 sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid), | ||||
|                 num_workers=workers, | ||||
|                 pin_memory=True, | ||||
|             ) | ||||
|             ValLoaders["x-valid"] = valid_loader | ||||
|         else: | ||||
|             # data loader | ||||
|             train_loader = torch.utils.data.DataLoader( | ||||
|                 train_data, | ||||
|                 batch_size=config.batch_size, | ||||
|                 shuffle=True, | ||||
|                 num_workers=workers, | ||||
|                 pin_memory=True, | ||||
|             ) | ||||
|             valid_loader = torch.utils.data.DataLoader( | ||||
|                 valid_data, | ||||
|                 batch_size=config.batch_size, | ||||
|                 shuffle=False, | ||||
|                 num_workers=workers, | ||||
|                 pin_memory=True, | ||||
|             ) | ||||
|             if dataset == "cifar10": | ||||
|                 ValLoaders = {"ori-test": valid_loader} | ||||
|             elif dataset == "cifar100": | ||||
|                 cifar100_splits = load_config( | ||||
|                     "configs/nas-benchmark/cifar100-test-split.txt", None, None | ||||
|                 ) | ||||
|                 ValLoaders = { | ||||
|                     "ori-test": valid_loader, | ||||
|                     "x-valid": torch.utils.data.DataLoader( | ||||
|                         valid_data, | ||||
|                         batch_size=config.batch_size, | ||||
|                         sampler=torch.utils.data.sampler.SubsetRandomSampler( | ||||
|                             cifar100_splits.xvalid | ||||
|                         ), | ||||
|                         num_workers=workers, | ||||
|                         pin_memory=True, | ||||
|                     ), | ||||
|                     "x-test": torch.utils.data.DataLoader( | ||||
|                         valid_data, | ||||
|                         batch_size=config.batch_size, | ||||
|                         sampler=torch.utils.data.sampler.SubsetRandomSampler( | ||||
|                             cifar100_splits.xtest | ||||
|                         ), | ||||
|                         num_workers=workers, | ||||
|                         pin_memory=True, | ||||
|                     ), | ||||
|                 } | ||||
|             elif dataset == "ImageNet16-120": | ||||
|                 imagenet16_splits = load_config( | ||||
|                     "configs/nas-benchmark/imagenet-16-120-test-split.txt", None, None | ||||
|                 ) | ||||
|                 ValLoaders = { | ||||
|                     "ori-test": valid_loader, | ||||
|                     "x-valid": torch.utils.data.DataLoader( | ||||
|                         valid_data, | ||||
|                         batch_size=config.batch_size, | ||||
|                         sampler=torch.utils.data.sampler.SubsetRandomSampler( | ||||
|                             imagenet16_splits.xvalid | ||||
|                         ), | ||||
|                         num_workers=workers, | ||||
|                         pin_memory=True, | ||||
|                     ), | ||||
|                     "x-test": torch.utils.data.DataLoader( | ||||
|                         valid_data, | ||||
|                         batch_size=config.batch_size, | ||||
|                         sampler=torch.utils.data.sampler.SubsetRandomSampler( | ||||
|                             imagenet16_splits.xtest | ||||
|                         ), | ||||
|                         num_workers=workers, | ||||
|                         pin_memory=True, | ||||
|                     ), | ||||
|                 } | ||||
|             else: | ||||
|                 raise ValueError("invalid dataset : {:}".format(dataset)) | ||||
|  | ||||
|         dataset_key = "{:}".format(dataset) | ||||
|         if bool(split): | ||||
|             dataset_key = dataset_key + "-valid" | ||||
|         logger.log( | ||||
|             "Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format( | ||||
|                 dataset_key, | ||||
|                 len(train_data), | ||||
|                 len(valid_data), | ||||
|                 len(train_loader), | ||||
|                 len(valid_loader), | ||||
|                 config.batch_size, | ||||
|             ) | ||||
|         ) | ||||
|         logger.log( | ||||
|             "Evaluate ||||||| {:10s} ||||||| Config={:}".format(dataset_key, config) | ||||
|         ) | ||||
|         for key, value in ValLoaders.items(): | ||||
|             logger.log( | ||||
|                 "Evaluate ---->>>> {:10s} with {:} batchs".format(key, len(value)) | ||||
|             ) | ||||
|         arch_config = dict2config( | ||||
|             dict( | ||||
|                 name="infer.tiny", | ||||
|                 C=raw_arch_config["channel"], | ||||
|                 N=raw_arch_config["num_cells"], | ||||
|                 genotype=arch, | ||||
|                 num_classes=config.class_num, | ||||
|             ), | ||||
|             None, | ||||
|         ) | ||||
|         results = bench_evaluate_for_seed( | ||||
|             arch_config, config, train_loader, ValLoaders, seed, logger | ||||
|         ) | ||||
|         all_infos[dataset_key] = results | ||||
|         all_dataset_keys.append(dataset_key) | ||||
|     all_infos["all_dataset_keys"] = all_dataset_keys | ||||
|     return all_infos | ||||
|  | ||||
|  | ||||
| def main( | ||||
|     save_dir: Path, | ||||
|     workers: int, | ||||
|     datasets: List[Text], | ||||
|     xpaths: List[Text], | ||||
|     splits: List[int], | ||||
|     seeds: List[int], | ||||
|     nets: List[str], | ||||
|     opt_config: Dict[Text, Any], | ||||
|     to_evaluate_indexes: tuple, | ||||
|     cover_mode: bool, | ||||
|     arch_config: Dict[Text, Any], | ||||
| ): | ||||
|  | ||||
|     log_dir = save_dir / "logs" | ||||
|     log_dir.mkdir(parents=True, exist_ok=True) | ||||
|     logger = Logger(str(log_dir), os.getpid(), False) | ||||
|  | ||||
|     logger.log("xargs : seeds      = {:}".format(seeds)) | ||||
|     logger.log("xargs : cover_mode = {:}".format(cover_mode)) | ||||
|     logger.log("-" * 100) | ||||
|     logger.log( | ||||
|         "Start evaluating range =: {:06d} - {:06d}".format( | ||||
|             min(to_evaluate_indexes), max(to_evaluate_indexes) | ||||
|         ) | ||||
|         + "({:} in total) / {:06d} with cover-mode={:}".format( | ||||
|             len(to_evaluate_indexes), len(nets), cover_mode | ||||
|         ) | ||||
|     ) | ||||
|     for i, (dataset, xpath, split) in enumerate(zip(datasets, xpaths, splits)): | ||||
|         logger.log( | ||||
|             "--->>> Evaluate {:}/{:} : dataset={:9s}, path={:}, split={:}".format( | ||||
|                 i, len(datasets), dataset, xpath, split | ||||
|             ) | ||||
|         ) | ||||
|     logger.log("--->>> optimization config : {:}".format(opt_config)) | ||||
|  | ||||
|     start_time, epoch_time = time.time(), AverageMeter() | ||||
|     for i, index in enumerate(to_evaluate_indexes): | ||||
|         arch = nets[index] | ||||
|         logger.log( | ||||
|             "\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}] {:}".format( | ||||
|                 time_string(), | ||||
|                 i, | ||||
|                 len(to_evaluate_indexes), | ||||
|                 index, | ||||
|                 len(nets), | ||||
|                 seeds, | ||||
|                 "-" * 15, | ||||
|             ) | ||||
|         ) | ||||
|         logger.log("{:} {:} {:}".format("-" * 15, arch, "-" * 15)) | ||||
|  | ||||
|         # test this arch on different datasets with different seeds | ||||
|         has_continue = False | ||||
|         for seed in seeds: | ||||
|             to_save_name = save_dir / "arch-{:06d}-seed-{:04d}.pth".format(index, seed) | ||||
|             if to_save_name.exists(): | ||||
|                 if cover_mode: | ||||
|                     logger.log( | ||||
|                         "Find existing file : {:}, remove it before evaluation".format( | ||||
|                             to_save_name | ||||
|                         ) | ||||
|                     ) | ||||
|                     os.remove(str(to_save_name)) | ||||
|                 else: | ||||
|                     logger.log( | ||||
|                         "Find existing file : {:}, skip this evaluation".format( | ||||
|                             to_save_name | ||||
|                         ) | ||||
|                     ) | ||||
|                     has_continue = True | ||||
|                     continue | ||||
|             results = evaluate_all_datasets( | ||||
|                 CellStructure.str2structure(arch), | ||||
|                 datasets, | ||||
|                 xpaths, | ||||
|                 splits, | ||||
|                 opt_config, | ||||
|                 seed, | ||||
|                 arch_config, | ||||
|                 workers, | ||||
|                 logger, | ||||
|             ) | ||||
|             torch.save(results, to_save_name) | ||||
|             logger.log( | ||||
|                 "\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}] ===>>> {:}".format( | ||||
|                     time_string(), | ||||
|                     i, | ||||
|                     len(to_evaluate_indexes), | ||||
|                     index, | ||||
|                     len(nets), | ||||
|                     seeds, | ||||
|                     to_save_name, | ||||
|                 ) | ||||
|             ) | ||||
|         # measure elapsed time | ||||
|         if not has_continue: | ||||
|             epoch_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|         need_time = "Time Left: {:}".format( | ||||
|             convert_secs2time(epoch_time.avg * (len(to_evaluate_indexes) - i - 1), True) | ||||
|         ) | ||||
|         logger.log( | ||||
|             "This arch costs : {:}".format(convert_secs2time(epoch_time.val, True)) | ||||
|         ) | ||||
|         logger.log("{:}".format("*" * 100)) | ||||
|         logger.log( | ||||
|             "{:}   {:74s}   {:}".format( | ||||
|                 "*" * 10, | ||||
|                 "{:06d}/{:06d} ({:06d}/{:06d})-th done, left {:}".format( | ||||
|                     i, len(to_evaluate_indexes), index, len(nets), need_time | ||||
|                 ), | ||||
|                 "*" * 10, | ||||
|             ) | ||||
|         ) | ||||
|         logger.log("{:}".format("*" * 100)) | ||||
|  | ||||
|     logger.close() | ||||
|  | ||||
|  | ||||
| def train_single_model( | ||||
|     save_dir, workers, datasets, xpaths, splits, use_less, seeds, model_str, arch_config | ||||
| ): | ||||
|     assert torch.cuda.is_available(), "CUDA is not available." | ||||
|     torch.backends.cudnn.enabled = True | ||||
|     torch.backends.cudnn.deterministic = True | ||||
|     # torch.backends.cudnn.benchmark = True | ||||
|     # torch.set_num_threads(workers) | ||||
|  | ||||
|     save_dir = ( | ||||
|         Path(save_dir) | ||||
|         / "specifics" | ||||
|         / "{:}-{:}-{:}-{:}".format( | ||||
|             "LESS" if use_less else "FULL", | ||||
|             model_str, | ||||
|             arch_config["channel"], | ||||
|             arch_config["num_cells"], | ||||
|         ) | ||||
|     ) | ||||
|     logger = Logger(str(save_dir), 0, False) | ||||
|     if model_str in CellArchitectures: | ||||
|         arch = CellArchitectures[model_str] | ||||
|         logger.log( | ||||
|             "The model string is found in pre-defined architecture dict : {:}".format( | ||||
|                 model_str | ||||
|             ) | ||||
|         ) | ||||
|     else: | ||||
|         try: | ||||
|             arch = CellStructure.str2structure(model_str) | ||||
|         except: | ||||
|             raise ValueError( | ||||
|                 "Invalid model string : {:}. It can not be found or parsed.".format( | ||||
|                     model_str | ||||
|                 ) | ||||
|             ) | ||||
|     assert arch.check_valid_op( | ||||
|         get_search_spaces("cell", "full") | ||||
|     ), "{:} has the invalid op.".format(arch) | ||||
|     logger.log("Start train-evaluate {:}".format(arch.tostr())) | ||||
|     logger.log("arch_config : {:}".format(arch_config)) | ||||
|  | ||||
|     start_time, seed_time = time.time(), AverageMeter() | ||||
|     for _is, seed in enumerate(seeds): | ||||
|         logger.log( | ||||
|             "\nThe {:02d}/{:02d}-th seed is {:} ----------------------<.>----------------------".format( | ||||
|                 _is, len(seeds), seed | ||||
|             ) | ||||
|         ) | ||||
|         to_save_name = save_dir / "seed-{:04d}.pth".format(seed) | ||||
|         if to_save_name.exists(): | ||||
|             logger.log( | ||||
|                 "Find the existing file {:}, directly load!".format(to_save_name) | ||||
|             ) | ||||
|             checkpoint = torch.load(to_save_name) | ||||
|         else: | ||||
|             logger.log( | ||||
|                 "Does not find the existing file {:}, train and evaluate!".format( | ||||
|                     to_save_name | ||||
|                 ) | ||||
|             ) | ||||
|             checkpoint = evaluate_all_datasets( | ||||
|                 arch, | ||||
|                 datasets, | ||||
|                 xpaths, | ||||
|                 splits, | ||||
|                 use_less, | ||||
|                 seed, | ||||
|                 arch_config, | ||||
|                 workers, | ||||
|                 logger, | ||||
|             ) | ||||
|             torch.save(checkpoint, to_save_name) | ||||
|         # log information | ||||
|         logger.log("{:}".format(checkpoint["info"])) | ||||
|         all_dataset_keys = checkpoint["all_dataset_keys"] | ||||
|         for dataset_key in all_dataset_keys: | ||||
|             logger.log( | ||||
|                 "\n{:} dataset : {:} {:}".format("-" * 15, dataset_key, "-" * 15) | ||||
|             ) | ||||
|             dataset_info = checkpoint[dataset_key] | ||||
|             # logger.log('Network ==>\n{:}'.format( dataset_info['net_string'] )) | ||||
|             logger.log( | ||||
|                 "Flops = {:} MB, Params = {:} MB".format( | ||||
|                     dataset_info["flop"], dataset_info["param"] | ||||
|                 ) | ||||
|             ) | ||||
|             logger.log("config : {:}".format(dataset_info["config"])) | ||||
|             logger.log( | ||||
|                 "Training State (finish) = {:}".format(dataset_info["finish-train"]) | ||||
|             ) | ||||
|             last_epoch = dataset_info["total_epoch"] - 1 | ||||
|             train_acc1es, train_acc5es = ( | ||||
|                 dataset_info["train_acc1es"], | ||||
|                 dataset_info["train_acc5es"], | ||||
|             ) | ||||
|             valid_acc1es, valid_acc5es = ( | ||||
|                 dataset_info["valid_acc1es"], | ||||
|                 dataset_info["valid_acc5es"], | ||||
|             ) | ||||
|             logger.log( | ||||
|                 "Last Info : Train = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%, Test = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%".format( | ||||
|                     train_acc1es[last_epoch], | ||||
|                     train_acc5es[last_epoch], | ||||
|                     100 - train_acc1es[last_epoch], | ||||
|                     valid_acc1es[last_epoch], | ||||
|                     valid_acc5es[last_epoch], | ||||
|                     100 - valid_acc1es[last_epoch], | ||||
|                 ) | ||||
|             ) | ||||
|         # measure elapsed time | ||||
|         seed_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|         need_time = "Time Left: {:}".format( | ||||
|             convert_secs2time(seed_time.avg * (len(seeds) - _is - 1), True) | ||||
|         ) | ||||
|         logger.log( | ||||
|             "\n<<<***>>> The {:02d}/{:02d}-th seed is {:} <finish> other procedures need {:}".format( | ||||
|                 _is, len(seeds), seed, need_time | ||||
|             ) | ||||
|         ) | ||||
|     logger.close() | ||||
|  | ||||
|  | ||||
| def generate_meta_info(save_dir, max_node, divide=40): | ||||
|     aa_nas_bench_ss = get_search_spaces("cell", "nas-bench-201") | ||||
|     archs = CellStructure.gen_all(aa_nas_bench_ss, max_node, False) | ||||
|     print( | ||||
|         "There are {:} archs vs {:}.".format( | ||||
|             len(archs), len(aa_nas_bench_ss) ** ((max_node - 1) * max_node / 2) | ||||
|         ) | ||||
|     ) | ||||
|  | ||||
|     random.seed(88)  # please do not change this line for reproducibility | ||||
|     random.shuffle(archs) | ||||
|     # to test fixed-random shuffle | ||||
|     # print ('arch [0] : {:}\n---->>>>   {:}'.format( archs[0], archs[0].tostr() )) | ||||
|     # print ('arch [9] : {:}\n---->>>>   {:}'.format( archs[9], archs[9].tostr() )) | ||||
|     assert ( | ||||
|         archs[0].tostr() | ||||
|         == "|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|" | ||||
|     ), "please check the 0-th architecture : {:}".format(archs[0]) | ||||
|     assert ( | ||||
|         archs[9].tostr() | ||||
|         == "|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|" | ||||
|     ), "please check the 9-th architecture : {:}".format(archs[9]) | ||||
|     assert ( | ||||
|         archs[123].tostr() | ||||
|         == "|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|" | ||||
|     ), "please check the 123-th architecture : {:}".format(archs[123]) | ||||
|     total_arch = len(archs) | ||||
|  | ||||
|     num = 50000 | ||||
|     indexes_5W = list(range(num)) | ||||
|     random.seed(1021) | ||||
|     random.shuffle(indexes_5W) | ||||
|     train_split = sorted(list(set(indexes_5W[: num // 2]))) | ||||
|     valid_split = sorted(list(set(indexes_5W[num // 2 :]))) | ||||
|     assert len(train_split) + len(valid_split) == num | ||||
|     assert ( | ||||
|         train_split[0] == 0 | ||||
|         and train_split[10] == 26 | ||||
|         and train_split[111] == 203 | ||||
|         and valid_split[0] == 1 | ||||
|         and valid_split[10] == 18 | ||||
|         and valid_split[111] == 242 | ||||
|     ), "{:} {:} {:} - {:} {:} {:}".format( | ||||
|         train_split[0], | ||||
|         train_split[10], | ||||
|         train_split[111], | ||||
|         valid_split[0], | ||||
|         valid_split[10], | ||||
|         valid_split[111], | ||||
|     ) | ||||
|     splits = {num: {"train": train_split, "valid": valid_split}} | ||||
|  | ||||
|     info = { | ||||
|         "archs": [x.tostr() for x in archs], | ||||
|         "total": total_arch, | ||||
|         "max_node": max_node, | ||||
|         "splits": splits, | ||||
|     } | ||||
|  | ||||
|     save_dir = Path(save_dir) | ||||
|     save_dir.mkdir(parents=True, exist_ok=True) | ||||
|     save_name = save_dir / "meta-node-{:}.pth".format(max_node) | ||||
|     assert not save_name.exists(), "{:} already exist".format(save_name) | ||||
|     torch.save(info, save_name) | ||||
|     print("save the meta file into {:}".format(save_name)) | ||||
|  | ||||
|  | ||||
| def traverse_net(max_node): | ||||
|     aa_nas_bench_ss = get_search_spaces("cell", "nats-bench") | ||||
|     archs = CellStructure.gen_all(aa_nas_bench_ss, max_node, False) | ||||
|     print( | ||||
|         "There are {:} archs vs {:}.".format( | ||||
|             len(archs), len(aa_nas_bench_ss) ** ((max_node - 1) * max_node / 2) | ||||
|         ) | ||||
|     ) | ||||
|  | ||||
|     random.seed(88)  # please do not change this line for reproducibility | ||||
|     random.shuffle(archs) | ||||
|     assert ( | ||||
|         archs[0].tostr() | ||||
|         == "|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|" | ||||
|     ), "please check the 0-th architecture : {:}".format(archs[0]) | ||||
|     assert ( | ||||
|         archs[9].tostr() | ||||
|         == "|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|" | ||||
|     ), "please check the 9-th architecture : {:}".format(archs[9]) | ||||
|     assert ( | ||||
|         archs[123].tostr() | ||||
|         == "|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|" | ||||
|     ), "please check the 123-th architecture : {:}".format(archs[123]) | ||||
|     return [x.tostr() for x in archs] | ||||
|  | ||||
|  | ||||
| def filter_indexes(xlist, mode, save_dir, seeds): | ||||
|     all_indexes = [] | ||||
|     for index in xlist: | ||||
|         if mode == "cover": | ||||
|             all_indexes.append(index) | ||||
|         else: | ||||
|             for seed in seeds: | ||||
|                 temp_path = save_dir / "arch-{:06d}-seed-{:04d}.pth".format(index, seed) | ||||
|                 if not temp_path.exists(): | ||||
|                     all_indexes.append(index) | ||||
|                     break | ||||
|     print( | ||||
|         "{:} [FILTER-INDEXES] : there are {:}/{:} architectures in total".format( | ||||
|             time_string(), len(all_indexes), len(xlist) | ||||
|         ) | ||||
|     ) | ||||
|     return all_indexes | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     # mode_choices = ['meta', 'new', 'cover'] + ['specific-{:}'.format(_) for _ in CellArchitectures.keys()] | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="NATS-Bench (topology search space)", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument("--mode", type=str, required=True, help="The script mode.") | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="output/NATS-Bench-topology", | ||||
|         help="Folder to save checkpoints and log.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--max_node", | ||||
|         type=int, | ||||
|         default=4, | ||||
|         help="The maximum node in a cell (please do not change it).", | ||||
|     ) | ||||
|     # use for train the model | ||||
|     parser.add_argument( | ||||
|         "--workers", | ||||
|         type=int, | ||||
|         default=8, | ||||
|         help="number of data loading workers (default: 2)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--srange", type=str, required=True, help="The range of models to be evaluated" | ||||
|     ) | ||||
|     parser.add_argument("--datasets", type=str, nargs="+", help="The applied datasets.") | ||||
|     parser.add_argument( | ||||
|         "--xpaths", type=str, nargs="+", help="The root path for this dataset." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--splits", type=int, nargs="+", help="The root path for this dataset." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--hyper", | ||||
|         type=str, | ||||
|         default="12", | ||||
|         choices=["01", "12", "200"], | ||||
|         help="The tag for hyper-parameters.", | ||||
|     ) | ||||
|  | ||||
|     parser.add_argument( | ||||
|         "--seeds", type=int, nargs="+", help="The range of models to be evaluated" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--channel", type=int, default=16, help="The number of channels." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--num_cells", type=int, default=5, help="The number of cells in one stage." | ||||
|     ) | ||||
|     parser.add_argument("--check_N", type=int, default=15625, help="For safety.") | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     assert args.mode in ["meta", "new", "cover"] or args.mode.startswith( | ||||
|         "specific-" | ||||
|     ), "invalid mode : {:}".format(args.mode) | ||||
|  | ||||
|     if args.mode == "meta": | ||||
|         generate_meta_info(args.save_dir, args.max_node) | ||||
|     elif args.mode.startswith("specific"): | ||||
|         assert len(args.mode.split("-")) == 2, "invalid mode : {:}".format(args.mode) | ||||
|         model_str = args.mode.split("-")[1] | ||||
|         train_single_model( | ||||
|             args.save_dir, | ||||
|             args.workers, | ||||
|             args.datasets, | ||||
|             args.xpaths, | ||||
|             args.splits, | ||||
|             args.use_less > 0, | ||||
|             tuple(args.seeds), | ||||
|             model_str, | ||||
|             {"channel": args.channel, "num_cells": args.num_cells}, | ||||
|         ) | ||||
|     else: | ||||
|         nets = traverse_net(args.max_node) | ||||
|         if len(nets) != args.check_N: | ||||
|             raise ValueError( | ||||
|                 "Pre-num-check failed : {:} vs {:}".format(len(nets), args.check_N) | ||||
|             ) | ||||
|         opt_config = "./configs/nas-benchmark/hyper-opts/{:}E.config".format(args.hyper) | ||||
|         if not os.path.isfile(opt_config): | ||||
|             raise ValueError("{:} is not a file.".format(opt_config)) | ||||
|         save_dir = Path(args.save_dir) / "raw-data-{:}".format(args.hyper) | ||||
|         save_dir.mkdir(parents=True, exist_ok=True) | ||||
|         to_evaluate_indexes = split_str2indexes(args.srange, args.check_N, 5) | ||||
|         if not len(args.seeds): | ||||
|             raise ValueError("invalid length of seeds args: {:}".format(args.seeds)) | ||||
|         if not (len(args.datasets) == len(args.xpaths) == len(args.splits)): | ||||
|             raise ValueError( | ||||
|                 "invalid infos : {:} vs {:} vs {:}".format( | ||||
|                     len(args.datasets), len(args.xpaths), len(args.splits) | ||||
|                 ) | ||||
|             ) | ||||
|         if args.workers < 0: | ||||
|             raise ValueError("invalid number of workers : {:}".format(args.workers)) | ||||
|  | ||||
|         target_indexes = filter_indexes( | ||||
|             to_evaluate_indexes, args.mode, save_dir, args.seeds | ||||
|         ) | ||||
|  | ||||
|         assert torch.cuda.is_available(), "CUDA is not available." | ||||
|         torch.backends.cudnn.enabled = True | ||||
|         torch.backends.cudnn.deterministic = True | ||||
|         # torch.set_num_threads(args.workers if args.workers > 0 else 1) | ||||
|  | ||||
|         main( | ||||
|             save_dir, | ||||
|             args.workers, | ||||
|             args.datasets, | ||||
|             args.xpaths, | ||||
|             args.splits, | ||||
|             tuple(args.seeds), | ||||
|             nets, | ||||
|             opt_config, | ||||
|             target_indexes, | ||||
|             args.mode == "cover", | ||||
|             { | ||||
|                 "name": "infer.tiny", | ||||
|                 "channel": args.channel, | ||||
|                 "num_cells": args.num_cells, | ||||
|             }, | ||||
|         ) | ||||
							
								
								
									
										59
									
								
								AutoDL-Projects/exps/NATS-Bench/show-dataset.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										59
									
								
								AutoDL-Projects/exps/NATS-Bench/show-dataset.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,59 @@ | ||||
| ############################################################################## | ||||
| # NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size # | ||||
| ############################################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.07                          # | ||||
| ############################################################################## | ||||
| # python ./exps/NATS-Bench/show-dataset.py                                   # | ||||
| ############################################################################## | ||||
| import os, sys, time, torch, random, argparse | ||||
| from typing import List, Text, Dict, Any | ||||
| from PIL import ImageFile | ||||
|  | ||||
| ImageFile.LOAD_TRUNCATED_IMAGES = True | ||||
| from copy import deepcopy | ||||
|  | ||||
| from xautodl.config_utils import dict2config, load_config | ||||
| from xautodl.datasets import get_datasets | ||||
| from nats_bench import create | ||||
|  | ||||
|  | ||||
| def show_imagenet_16_120(dataset_dir=None): | ||||
|     if dataset_dir is None: | ||||
|         torch_home_dir = ( | ||||
|             os.environ["TORCH_HOME"] | ||||
|             if "TORCH_HOME" in os.environ | ||||
|             else os.path.join(os.environ["HOME"], ".torch") | ||||
|         ) | ||||
|         dataset_dir = os.path.join(torch_home_dir, "cifar.python", "ImageNet16") | ||||
|     train_data, valid_data, xshape, class_num = get_datasets( | ||||
|         "ImageNet16-120", dataset_dir, -1 | ||||
|     ) | ||||
|     split_info = load_config( | ||||
|         "configs/nas-benchmark/ImageNet16-120-split.txt", None, None | ||||
|     ) | ||||
|     print("=" * 10 + " ImageNet-16-120 " + "=" * 10) | ||||
|     print("Training Data: {:}".format(train_data)) | ||||
|     print("Evaluation Data: {:}".format(valid_data)) | ||||
|     print("Hold-out training: {:} images.".format(len(split_info.train))) | ||||
|     print("Hold-out valid   : {:} images.".format(len(split_info.valid))) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     # show_imagenet_16_120() | ||||
|     api_nats_tss = create(None, "tss", fast_mode=True, verbose=True) | ||||
|  | ||||
|     valid_acc_12e = [] | ||||
|     test_acc_12e = [] | ||||
|     test_acc_200e = [] | ||||
|     for index in range(10000): | ||||
|         info = api_nats_tss.get_more_info(index, "ImageNet16-120", hp="12") | ||||
|         valid_acc_12e.append( | ||||
|             info["valid-accuracy"] | ||||
|         )  # the validation accuracy after training the model by 12 epochs | ||||
|         test_acc_12e.append( | ||||
|             info["test-accuracy"] | ||||
|         )  # the test accuracy after training the model by 12 epochs | ||||
|         info = api_nats_tss.get_more_info(index, "ImageNet16-120", hp="200") | ||||
|         test_acc_200e.append( | ||||
|             info["test-accuracy"] | ||||
|         )  # the test accuracy after training the model by 200 epochs (which I reported in the paper) | ||||
							
								
								
									
										389
									
								
								AutoDL-Projects/exps/NATS-Bench/sss-collect.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										389
									
								
								AutoDL-Projects/exps/NATS-Bench/sss-collect.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,389 @@ | ||||
| ############################################################################## | ||||
| # NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size # | ||||
| ############################################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08                          # | ||||
| ############################################################################## | ||||
| # This file is used to re-orangize all checkpoints (created by main-sss.py)  # | ||||
| # into a single benchmark file. Besides, for each trial, we will merge the   # | ||||
| # information of all its trials into a single file.                          # | ||||
| #                                                                            # | ||||
| # Usage:                                                                     # | ||||
| # python exps/NATS-Bench/sss-collect.py                                      # | ||||
| ############################################################################## | ||||
| import os, re, sys, time, shutil, argparse, collections | ||||
| import torch | ||||
| from tqdm import tqdm | ||||
| from pathlib import Path | ||||
| from collections import defaultdict, OrderedDict | ||||
| from typing import Dict, Any, Text, List | ||||
|  | ||||
| from xautodl.log_utils import AverageMeter, time_string, convert_secs2time | ||||
| from xautodl.config_utils import dict2config | ||||
| from xautodl.models import CellStructure, get_cell_based_tiny_net | ||||
| from xautodl.procedures import ( | ||||
|     bench_pure_evaluate as pure_evaluate, | ||||
|     get_nas_bench_loaders, | ||||
| ) | ||||
| from xautodl.utils import get_md5_file | ||||
|  | ||||
| from nats_bench import pickle_save, pickle_load, ArchResults, ResultsCount | ||||
|  | ||||
|  | ||||
| NATS_SSS_BASE_NAME = "NATS-sss-v1_0"  # 2020.08.28 | ||||
|  | ||||
|  | ||||
| def account_one_arch( | ||||
|     arch_index: int, arch_str: Text, checkpoints: List[Text], datasets: List[Text] | ||||
| ) -> ArchResults: | ||||
|     information = ArchResults(arch_index, arch_str) | ||||
|  | ||||
|     for checkpoint_path in checkpoints: | ||||
|         try: | ||||
|             checkpoint = torch.load(checkpoint_path, map_location="cpu") | ||||
|         except: | ||||
|             raise ValueError( | ||||
|                 "This checkpoint failed to be loaded : {:}".format(checkpoint_path) | ||||
|             ) | ||||
|         used_seed = checkpoint_path.name.split("-")[-1].split(".")[0] | ||||
|         ok_dataset = 0 | ||||
|         for dataset in datasets: | ||||
|             if dataset not in checkpoint: | ||||
|                 print( | ||||
|                     "Can not find {:} in arch-{:} from {:}".format( | ||||
|                         dataset, arch_index, checkpoint_path | ||||
|                     ) | ||||
|                 ) | ||||
|                 continue | ||||
|             else: | ||||
|                 ok_dataset += 1 | ||||
|             results = checkpoint[dataset] | ||||
|             assert results[ | ||||
|                 "finish-train" | ||||
|             ], "This {:} arch seed={:} does not finish train on {:} ::: {:}".format( | ||||
|                 arch_index, used_seed, dataset, checkpoint_path | ||||
|             ) | ||||
|             arch_config = { | ||||
|                 "name": "infer.shape.tiny", | ||||
|                 "channels": arch_str, | ||||
|                 "arch_str": arch_str, | ||||
|                 "genotype": results["arch_config"]["genotype"], | ||||
|                 "class_num": results["arch_config"]["num_classes"], | ||||
|             } | ||||
|             xresult = ResultsCount( | ||||
|                 dataset, | ||||
|                 results["net_state_dict"], | ||||
|                 results["train_acc1es"], | ||||
|                 results["train_losses"], | ||||
|                 results["param"], | ||||
|                 results["flop"], | ||||
|                 arch_config, | ||||
|                 used_seed, | ||||
|                 results["total_epoch"], | ||||
|                 None, | ||||
|             ) | ||||
|             xresult.update_train_info( | ||||
|                 results["train_acc1es"], | ||||
|                 results["train_acc5es"], | ||||
|                 results["train_losses"], | ||||
|                 results["train_times"], | ||||
|             ) | ||||
|             xresult.update_eval( | ||||
|                 results["valid_acc1es"], results["valid_losses"], results["valid_times"] | ||||
|             ) | ||||
|             information.update(dataset, int(used_seed), xresult) | ||||
|         if ok_dataset < len(datasets): | ||||
|             raise ValueError( | ||||
|                 "{:} does find enought data : {:} vs {:}".format( | ||||
|                     checkpoint_path, ok_dataset, len(datasets) | ||||
|                 ) | ||||
|             ) | ||||
|     return information | ||||
|  | ||||
|  | ||||
| def correct_time_related_info(hp2info: Dict[Text, ArchResults]): | ||||
|     # calibrate the latency based on the number of epochs = 01, since they are trained on the same machine. | ||||
|     x1 = hp2info["01"].get_metrics("cifar10-valid", "x-valid")["all_time"] / 98 | ||||
|     x2 = hp2info["01"].get_metrics("cifar10-valid", "ori-test")["all_time"] / 40 | ||||
|     cifar010_latency = (x1 + x2) / 2 | ||||
|     for hp, arch_info in hp2info.items(): | ||||
|         arch_info.reset_latency("cifar10-valid", None, cifar010_latency) | ||||
|         arch_info.reset_latency("cifar10", None, cifar010_latency) | ||||
|     # hp2info['01'].get_latency('cifar10') | ||||
|  | ||||
|     x1 = hp2info["01"].get_metrics("cifar100", "ori-test")["all_time"] / 40 | ||||
|     x2 = hp2info["01"].get_metrics("cifar100", "x-test")["all_time"] / 20 | ||||
|     x3 = hp2info["01"].get_metrics("cifar100", "x-valid")["all_time"] / 20 | ||||
|     cifar100_latency = (x1 + x2 + x3) / 3 | ||||
|     for hp, arch_info in hp2info.items(): | ||||
|         arch_info.reset_latency("cifar100", None, cifar100_latency) | ||||
|  | ||||
|     x1 = hp2info["01"].get_metrics("ImageNet16-120", "ori-test")["all_time"] / 24 | ||||
|     x2 = hp2info["01"].get_metrics("ImageNet16-120", "x-test")["all_time"] / 12 | ||||
|     x3 = hp2info["01"].get_metrics("ImageNet16-120", "x-valid")["all_time"] / 12 | ||||
|     image_latency = (x1 + x2 + x3) / 3 | ||||
|     for hp, arch_info in hp2info.items(): | ||||
|         arch_info.reset_latency("ImageNet16-120", None, image_latency) | ||||
|  | ||||
|     # CIFAR10 VALID | ||||
|     train_per_epoch_time = list( | ||||
|         hp2info["01"].query("cifar10-valid", 777).train_times.values() | ||||
|     ) | ||||
|     train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time) | ||||
|     eval_ori_test_time, eval_x_valid_time = [], [] | ||||
|     for key, value in hp2info["01"].query("cifar10-valid", 777).eval_times.items(): | ||||
|         if key.startswith("ori-test@"): | ||||
|             eval_ori_test_time.append(value) | ||||
|         elif key.startswith("x-valid@"): | ||||
|             eval_x_valid_time.append(value) | ||||
|         else: | ||||
|             raise ValueError("-- {:} --".format(key)) | ||||
|     eval_ori_test_time = sum(eval_ori_test_time) / len(eval_ori_test_time) | ||||
|     eval_x_valid_time = sum(eval_x_valid_time) / len(eval_x_valid_time) | ||||
|     for hp, arch_info in hp2info.items(): | ||||
|         arch_info.reset_pseudo_train_times("cifar10-valid", None, train_per_epoch_time) | ||||
|         arch_info.reset_pseudo_eval_times( | ||||
|             "cifar10-valid", None, "x-valid", eval_x_valid_time | ||||
|         ) | ||||
|         arch_info.reset_pseudo_eval_times( | ||||
|             "cifar10-valid", None, "ori-test", eval_ori_test_time | ||||
|         ) | ||||
|  | ||||
|     # CIFAR10 | ||||
|     train_per_epoch_time = list( | ||||
|         hp2info["01"].query("cifar10", 777).train_times.values() | ||||
|     ) | ||||
|     train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time) | ||||
|     eval_ori_test_time = [] | ||||
|     for key, value in hp2info["01"].query("cifar10", 777).eval_times.items(): | ||||
|         if key.startswith("ori-test@"): | ||||
|             eval_ori_test_time.append(value) | ||||
|         else: | ||||
|             raise ValueError("-- {:} --".format(key)) | ||||
|     eval_ori_test_time = sum(eval_ori_test_time) / len(eval_ori_test_time) | ||||
|     for hp, arch_info in hp2info.items(): | ||||
|         arch_info.reset_pseudo_train_times("cifar10", None, train_per_epoch_time) | ||||
|         arch_info.reset_pseudo_eval_times( | ||||
|             "cifar10", None, "ori-test", eval_ori_test_time | ||||
|         ) | ||||
|  | ||||
|     # CIFAR100 | ||||
|     train_per_epoch_time = list( | ||||
|         hp2info["01"].query("cifar100", 777).train_times.values() | ||||
|     ) | ||||
|     train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time) | ||||
|     eval_ori_test_time, eval_x_valid_time, eval_x_test_time = [], [], [] | ||||
|     for key, value in hp2info["01"].query("cifar100", 777).eval_times.items(): | ||||
|         if key.startswith("ori-test@"): | ||||
|             eval_ori_test_time.append(value) | ||||
|         elif key.startswith("x-valid@"): | ||||
|             eval_x_valid_time.append(value) | ||||
|         elif key.startswith("x-test@"): | ||||
|             eval_x_test_time.append(value) | ||||
|         else: | ||||
|             raise ValueError("-- {:} --".format(key)) | ||||
|     eval_ori_test_time = sum(eval_ori_test_time) / len(eval_ori_test_time) | ||||
|     eval_x_valid_time = sum(eval_x_valid_time) / len(eval_x_valid_time) | ||||
|     eval_x_test_time = sum(eval_x_test_time) / len(eval_x_test_time) | ||||
|     for hp, arch_info in hp2info.items(): | ||||
|         arch_info.reset_pseudo_train_times("cifar100", None, train_per_epoch_time) | ||||
|         arch_info.reset_pseudo_eval_times( | ||||
|             "cifar100", None, "x-valid", eval_x_valid_time | ||||
|         ) | ||||
|         arch_info.reset_pseudo_eval_times("cifar100", None, "x-test", eval_x_test_time) | ||||
|         arch_info.reset_pseudo_eval_times( | ||||
|             "cifar100", None, "ori-test", eval_ori_test_time | ||||
|         ) | ||||
|  | ||||
|     # ImageNet16-120 | ||||
|     train_per_epoch_time = list( | ||||
|         hp2info["01"].query("ImageNet16-120", 777).train_times.values() | ||||
|     ) | ||||
|     train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time) | ||||
|     eval_ori_test_time, eval_x_valid_time, eval_x_test_time = [], [], [] | ||||
|     for key, value in hp2info["01"].query("ImageNet16-120", 777).eval_times.items(): | ||||
|         if key.startswith("ori-test@"): | ||||
|             eval_ori_test_time.append(value) | ||||
|         elif key.startswith("x-valid@"): | ||||
|             eval_x_valid_time.append(value) | ||||
|         elif key.startswith("x-test@"): | ||||
|             eval_x_test_time.append(value) | ||||
|         else: | ||||
|             raise ValueError("-- {:} --".format(key)) | ||||
|     eval_ori_test_time = sum(eval_ori_test_time) / len(eval_ori_test_time) | ||||
|     eval_x_valid_time = sum(eval_x_valid_time) / len(eval_x_valid_time) | ||||
|     eval_x_test_time = sum(eval_x_test_time) / len(eval_x_test_time) | ||||
|     for hp, arch_info in hp2info.items(): | ||||
|         arch_info.reset_pseudo_train_times("ImageNet16-120", None, train_per_epoch_time) | ||||
|         arch_info.reset_pseudo_eval_times( | ||||
|             "ImageNet16-120", None, "x-valid", eval_x_valid_time | ||||
|         ) | ||||
|         arch_info.reset_pseudo_eval_times( | ||||
|             "ImageNet16-120", None, "x-test", eval_x_test_time | ||||
|         ) | ||||
|         arch_info.reset_pseudo_eval_times( | ||||
|             "ImageNet16-120", None, "ori-test", eval_ori_test_time | ||||
|         ) | ||||
|     return hp2info | ||||
|  | ||||
|  | ||||
| def simplify(save_dir, save_name, nets, total): | ||||
|  | ||||
|     hps, seeds = ["01", "12", "90"], set() | ||||
|     for hp in hps: | ||||
|         sub_save_dir = save_dir / "raw-data-{:}".format(hp) | ||||
|         ckps = sorted(list(sub_save_dir.glob("arch-*-seed-*.pth"))) | ||||
|         seed2names = defaultdict(list) | ||||
|         for ckp in ckps: | ||||
|             parts = re.split("-|\.", ckp.name) | ||||
|             seed2names[parts[3]].append(ckp.name) | ||||
|         print("DIR : {:}".format(sub_save_dir)) | ||||
|         nums = [] | ||||
|         for seed, xlist in seed2names.items(): | ||||
|             seeds.add(seed) | ||||
|             nums.append(len(xlist)) | ||||
|             print("  [seed={:}] there are {:} checkpoints.".format(seed, len(xlist))) | ||||
|         assert ( | ||||
|             len(nets) == total == max(nums) | ||||
|         ), "there are some missed files : {:} vs {:}".format(max(nums), total) | ||||
|     print("{:} start simplify the checkpoint.".format(time_string())) | ||||
|  | ||||
|     datasets = ("cifar10-valid", "cifar10", "cifar100", "ImageNet16-120") | ||||
|  | ||||
|     # Create the directory to save the processed data | ||||
|     # full_save_dir contains all benchmark files with trained weights. | ||||
|     # simplify_save_dir contains all benchmark files without trained weights. | ||||
|     full_save_dir = save_dir / (save_name + "-FULL") | ||||
|     simple_save_dir = save_dir / (save_name + "-SIMPLIFY") | ||||
|     full_save_dir.mkdir(parents=True, exist_ok=True) | ||||
|     simple_save_dir.mkdir(parents=True, exist_ok=True) | ||||
|     # all data in memory | ||||
|     arch2infos, evaluated_indexes = dict(), set() | ||||
|     end_time, arch_time = time.time(), AverageMeter() | ||||
|  | ||||
|     for index in tqdm(range(total)): | ||||
|         arch_str = nets[index] | ||||
|         hp2info = OrderedDict() | ||||
|  | ||||
|         full_save_path = full_save_dir / "{:06d}.pickle".format(index) | ||||
|         simple_save_path = simple_save_dir / "{:06d}.pickle".format(index) | ||||
|  | ||||
|         for hp in hps: | ||||
|             sub_save_dir = save_dir / "raw-data-{:}".format(hp) | ||||
|             ckps = [ | ||||
|                 sub_save_dir / "arch-{:06d}-seed-{:}.pth".format(index, seed) | ||||
|                 for seed in seeds | ||||
|             ] | ||||
|             ckps = [x for x in ckps if x.exists()] | ||||
|             if len(ckps) == 0: | ||||
|                 raise ValueError("Invalid data : index={:}, hp={:}".format(index, hp)) | ||||
|  | ||||
|             arch_info = account_one_arch(index, arch_str, ckps, datasets) | ||||
|             hp2info[hp] = arch_info | ||||
|  | ||||
|         hp2info = correct_time_related_info(hp2info) | ||||
|         evaluated_indexes.add(index) | ||||
|  | ||||
|         hp2info["01"].clear_params()  # to save some spaces... | ||||
|         to_save_data = OrderedDict( | ||||
|             { | ||||
|                 "01": hp2info["01"].state_dict(), | ||||
|                 "12": hp2info["12"].state_dict(), | ||||
|                 "90": hp2info["90"].state_dict(), | ||||
|             } | ||||
|         ) | ||||
|         pickle_save(to_save_data, str(full_save_path)) | ||||
|  | ||||
|         for hp in hps: | ||||
|             hp2info[hp].clear_params() | ||||
|         to_save_data = OrderedDict( | ||||
|             { | ||||
|                 "01": hp2info["01"].state_dict(), | ||||
|                 "12": hp2info["12"].state_dict(), | ||||
|                 "90": hp2info["90"].state_dict(), | ||||
|             } | ||||
|         ) | ||||
|         pickle_save(to_save_data, str(simple_save_path)) | ||||
|         arch2infos[index] = to_save_data | ||||
|         # measure elapsed time | ||||
|         arch_time.update(time.time() - end_time) | ||||
|         end_time = time.time() | ||||
|         need_time = "{:}".format( | ||||
|             convert_secs2time(arch_time.avg * (total - index - 1), True) | ||||
|         ) | ||||
|         # print('{:} {:06d}/{:06d} : still need {:}'.format(time_string(), index, total, need_time)) | ||||
|     print("{:} {:} done.".format(time_string(), save_name)) | ||||
|     final_infos = { | ||||
|         "meta_archs": nets, | ||||
|         "total_archs": total, | ||||
|         "arch2infos": arch2infos, | ||||
|         "evaluated_indexes": evaluated_indexes, | ||||
|     } | ||||
|     save_file_name = save_dir / "{:}.pickle".format(save_name) | ||||
|     pickle_save(final_infos, str(save_file_name)) | ||||
|     # move the benchmark file to a new path | ||||
|     hd5sum = get_md5_file(str(save_file_name) + ".pbz2") | ||||
|     hd5_file_name = save_dir / "{:}-{:}.pickle.pbz2".format(NATS_SSS_BASE_NAME, hd5sum) | ||||
|     shutil.move(str(save_file_name) + ".pbz2", hd5_file_name) | ||||
|     print( | ||||
|         "Save {:} / {:} architecture results into {:} -> {:}.".format( | ||||
|             len(evaluated_indexes), total, save_file_name, hd5_file_name | ||||
|         ) | ||||
|     ) | ||||
|     # move the directory to a new path | ||||
|     hd5_full_save_dir = save_dir / "{:}-{:}-full".format(NATS_SSS_BASE_NAME, hd5sum) | ||||
|     hd5_simple_save_dir = save_dir / "{:}-{:}-simple".format(NATS_SSS_BASE_NAME, hd5sum) | ||||
|     shutil.move(full_save_dir, hd5_full_save_dir) | ||||
|     shutil.move(simple_save_dir, hd5_simple_save_dir) | ||||
|     # save the meta information for simple and full | ||||
|     final_infos["arch2infos"] = None | ||||
|     final_infos["evaluated_indexes"] = set() | ||||
|     pickle_save(final_infos, str(hd5_full_save_dir / "meta.pickle")) | ||||
|     pickle_save(final_infos, str(hd5_simple_save_dir / "meta.pickle")) | ||||
|  | ||||
|  | ||||
| def traverse_net(candidates: List[int], N: int): | ||||
|     nets = [""] | ||||
|     for i in range(N): | ||||
|         new_nets = [] | ||||
|         for net in nets: | ||||
|             for C in candidates: | ||||
|                 new_nets.append(str(C) if net == "" else "{:}:{:}".format(net, C)) | ||||
|         nets = new_nets | ||||
|     return nets | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="NATS-Bench (size search space)", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--base_save_dir", | ||||
|         type=str, | ||||
|         default="./output/NATS-Bench-size", | ||||
|         help="The base-name of folder to save checkpoints and log.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--candidateC", | ||||
|         type=int, | ||||
|         nargs="+", | ||||
|         default=[8, 16, 24, 32, 40, 48, 56, 64], | ||||
|         help=".", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--num_layers", type=int, default=5, help="The number of layers in a network." | ||||
|     ) | ||||
|     parser.add_argument("--check_N", type=int, default=32768, help="For safety.") | ||||
|     parser.add_argument( | ||||
|         "--save_name", type=str, default="process", help="The save directory." | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     nets = traverse_net(args.candidateC, args.num_layers) | ||||
|     if len(nets) != args.check_N: | ||||
|         raise ValueError( | ||||
|             "Pre-num-check failed : {:} vs {:}".format(len(nets), args.check_N) | ||||
|         ) | ||||
|  | ||||
|     save_dir = Path(args.base_save_dir) | ||||
|     simplify(save_dir, args.save_name, nets, args.check_N) | ||||
							
								
								
									
										103
									
								
								AutoDL-Projects/exps/NATS-Bench/sss-file-manager.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										103
									
								
								AutoDL-Projects/exps/NATS-Bench/sss-file-manager.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,103 @@ | ||||
| ############################################################################## | ||||
| # NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size # | ||||
| ############################################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08                          # | ||||
| ############################################################################## | ||||
| # Usage: python exps/NATS-Bench/sss-file-manager.py --mode check             # | ||||
| ############################################################################## | ||||
| import os, sys, time, torch, argparse | ||||
| from typing import List, Text, Dict, Any | ||||
| from shutil import copyfile | ||||
| from collections import defaultdict | ||||
| from copy import deepcopy | ||||
| from pathlib import Path | ||||
|  | ||||
| from xautodl.config_utils import dict2config, load_config | ||||
| from xautodl.procedures import bench_evaluate_for_seed | ||||
| from xautodl.procedures import get_machine_info | ||||
| from xautodl.datasets import get_datasets | ||||
| from xautodl.log_utils import Logger, AverageMeter, time_string, convert_secs2time | ||||
|  | ||||
|  | ||||
| def obtain_valid_ckp(save_dir: Text, total: int): | ||||
|     possible_seeds = [777, 888, 999] | ||||
|     seed2ckps = defaultdict(list) | ||||
|     miss2ckps = defaultdict(list) | ||||
|     for i in range(total): | ||||
|         for seed in possible_seeds: | ||||
|             path = os.path.join(save_dir, "arch-{:06d}-seed-{:04d}.pth".format(i, seed)) | ||||
|             if os.path.exists(path): | ||||
|                 seed2ckps[seed].append(i) | ||||
|             else: | ||||
|                 miss2ckps[seed].append(i) | ||||
|     for seed, xlist in seed2ckps.items(): | ||||
|         print( | ||||
|             "[{:}] [seed={:}] has {:5d}/{:5d} | miss {:5d}/{:5d}".format( | ||||
|                 save_dir, seed, len(xlist), total, total - len(xlist), total | ||||
|             ) | ||||
|         ) | ||||
|     return dict(seed2ckps), dict(miss2ckps) | ||||
|  | ||||
|  | ||||
| def copy_data(source_dir, target_dir, meta_path): | ||||
|     target_dir = Path(target_dir) | ||||
|     target_dir.mkdir(parents=True, exist_ok=True) | ||||
|     miss2ckps = torch.load(meta_path)["miss2ckps"] | ||||
|     s2t = {} | ||||
|     for seed, xlist in miss2ckps.items(): | ||||
|         for i in xlist: | ||||
|             file_name = "arch-{:06d}-seed-{:04d}.pth".format(i, seed) | ||||
|             source_path = os.path.join(source_dir, file_name) | ||||
|             target_path = os.path.join(target_dir, file_name) | ||||
|             if os.path.exists(source_path): | ||||
|                 s2t[source_path] = target_path | ||||
|     print( | ||||
|         "Map from {:} to {:}, find {:} missed ckps.".format( | ||||
|             source_dir, target_dir, len(s2t) | ||||
|         ) | ||||
|     ) | ||||
|     for s, t in s2t.items(): | ||||
|         copyfile(s, t) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="NATS-Bench (size search space) file manager.", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--mode", | ||||
|         type=str, | ||||
|         required=True, | ||||
|         choices=["check", "copy"], | ||||
|         help="The script mode.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="output/NATS-Bench-size", | ||||
|         help="Folder to save checkpoints and log.", | ||||
|     ) | ||||
|     parser.add_argument("--check_N", type=int, default=32768, help="For safety.") | ||||
|     # use for train the model | ||||
|     args = parser.parse_args() | ||||
|     possible_configs = ["01", "12", "90"] | ||||
|     if args.mode == "check": | ||||
|         for config in possible_configs: | ||||
|             cur_save_dir = "{:}/raw-data-{:}".format(args.save_dir, config) | ||||
|             seed2ckps, miss2ckps = obtain_valid_ckp(cur_save_dir, args.check_N) | ||||
|             torch.save( | ||||
|                 dict(seed2ckps=seed2ckps, miss2ckps=miss2ckps), | ||||
|                 "{:}/meta-{:}.pth".format(args.save_dir, config), | ||||
|             ) | ||||
|     elif args.mode == "copy": | ||||
|         for config in possible_configs: | ||||
|             cur_save_dir = "{:}/raw-data-{:}".format(args.save_dir, config) | ||||
|             cur_copy_dir = "{:}/copy-{:}".format(args.save_dir, config) | ||||
|             cur_meta_path = "{:}/meta-{:}.pth".format(args.save_dir, config) | ||||
|             if os.path.exists(cur_meta_path): | ||||
|                 copy_data(cur_save_dir, cur_copy_dir, cur_meta_path) | ||||
|             else: | ||||
|                 print("Do not find : {:}".format(cur_meta_path)) | ||||
|     else: | ||||
|         raise ValueError("invalid mode : {:}".format(args.mode)) | ||||
							
								
								
									
										111
									
								
								AutoDL-Projects/exps/NATS-Bench/test-nats-api.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										111
									
								
								AutoDL-Projects/exps/NATS-Bench/test-nats-api.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,111 @@ | ||||
| ############################################################################## | ||||
| # NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size # | ||||
| ############################################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08                          # | ||||
| ############################################################################## | ||||
| # Usage: python exps/NATS-Bench/test-nats-api.py                             # | ||||
| ############################################################################## | ||||
| import os, gc, sys, time, torch, argparse | ||||
| import numpy as np | ||||
| from typing import List, Text, Dict, Any | ||||
| from shutil import copyfile | ||||
| from collections import defaultdict | ||||
| from copy import deepcopy | ||||
| from pathlib import Path | ||||
| import matplotlib | ||||
| import seaborn as sns | ||||
|  | ||||
| matplotlib.use("agg") | ||||
| import matplotlib.pyplot as plt | ||||
| import matplotlib.ticker as ticker | ||||
|  | ||||
| from xautodl.config_utils import dict2config, load_config | ||||
| from xautodl.log_utils import time_string | ||||
| from xautodl.models import get_cell_based_tiny_net, CellStructure | ||||
| from nats_bench import create | ||||
|  | ||||
|  | ||||
| def test_api(api, sss_or_tss=True): | ||||
|     print("{:} start testing the api : {:}".format(time_string(), api)) | ||||
|     api.clear_params(12) | ||||
|     api.reload(index=12) | ||||
|  | ||||
|     # Query the informations of 1113-th architecture | ||||
|     info_strs = api.query_info_str_by_arch(1113) | ||||
|     print(info_strs) | ||||
|     info = api.query_by_index(113) | ||||
|     print("{:}\n".format(info)) | ||||
|     info = api.query_by_index(113, "cifar100") | ||||
|     print("{:}\n".format(info)) | ||||
|  | ||||
|     info = api.query_meta_info_by_index(115, "90" if sss_or_tss else "200") | ||||
|     print("{:}\n".format(info)) | ||||
|  | ||||
|     for dataset in ["cifar10", "cifar100", "ImageNet16-120"]: | ||||
|         for xset in ["train", "test", "valid"]: | ||||
|             best_index, highest_accuracy = api.find_best(dataset, xset) | ||||
|         print("") | ||||
|     params = api.get_net_param(12, "cifar10", None) | ||||
|  | ||||
|     # Obtain the config and create the network | ||||
|     config = api.get_net_config(12, "cifar10") | ||||
|     print("{:}\n".format(config)) | ||||
|     network = get_cell_based_tiny_net(config) | ||||
|     network.load_state_dict(next(iter(params.values()))) | ||||
|  | ||||
|     # Obtain the cost information | ||||
|     info = api.get_cost_info(12, "cifar10") | ||||
|     print("{:}\n".format(info)) | ||||
|     info = api.get_latency(12, "cifar10") | ||||
|     print("{:}\n".format(info)) | ||||
|     for index in [13, 15, 19, 200]: | ||||
|         info = api.get_latency(index, "cifar10") | ||||
|  | ||||
|     # Count the number of architectures | ||||
|     info = api.statistics("cifar100", "12") | ||||
|     print("{:} statistics results : {:}\n".format(time_string(), info)) | ||||
|  | ||||
|     # Show the information of the 123-th architecture | ||||
|     api.show(123) | ||||
|  | ||||
|     # Obtain both cost and performance information | ||||
|     info = api.get_more_info(1234, "cifar10") | ||||
|     print("{:}\n".format(info)) | ||||
|     print("{:} finish testing the api : {:}".format(time_string(), api)) | ||||
|  | ||||
|     if not sss_or_tss: | ||||
|         arch_str = "|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|" | ||||
|         matrix = api.str2matrix(arch_str) | ||||
|         print("Compute the adjacency matrix of {:}".format(arch_str)) | ||||
|         print(matrix) | ||||
|     info = api.simulate_train_eval(123, "cifar10") | ||||
|     print("simulate_train_eval : {:}\n\n".format(info)) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|  | ||||
|     # api201 = create('./output/NATS-Bench-topology/process-FULL', 'topology', fast_mode=True, verbose=True) | ||||
|     for fast_mode in [True, False]: | ||||
|         for verbose in [True, False]: | ||||
|             api_nats_tss = create(None, "tss", fast_mode=fast_mode, verbose=True) | ||||
|             print( | ||||
|                 "{:} create with fast_mode={:} and verbose={:}".format( | ||||
|                     time_string(), fast_mode, verbose | ||||
|                 ) | ||||
|             ) | ||||
|             test_api(api_nats_tss, False) | ||||
|             del api_nats_tss | ||||
|             gc.collect() | ||||
|  | ||||
|     for fast_mode in [True, False]: | ||||
|         for verbose in [True, False]: | ||||
|             print( | ||||
|                 "{:} create with fast_mode={:} and verbose={:}".format( | ||||
|                     time_string(), fast_mode, verbose | ||||
|                 ) | ||||
|             ) | ||||
|             api_nats_sss = create(None, "size", fast_mode=fast_mode, verbose=True) | ||||
|             print("{:} --->>> {:}".format(time_string(), api_nats_sss)) | ||||
|             test_api(api_nats_sss, True) | ||||
|             del api_nats_sss | ||||
|             gc.collect() | ||||
							
								
								
									
										179
									
								
								AutoDL-Projects/exps/NATS-Bench/tss-collect-patcher.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										179
									
								
								AutoDL-Projects/exps/NATS-Bench/tss-collect-patcher.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,179 @@ | ||||
| ############################################################################## | ||||
| # NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size # | ||||
| ############################################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08                          # | ||||
| ############################################################################## | ||||
| # This file is used to re-orangize all checkpoints (created by main-tss.py)  # | ||||
| # into a single benchmark file. Besides, for each trial, we will merge the   # | ||||
| # information of all its trials into a single file.                          # | ||||
| #                                                                            # | ||||
| # Usage:                                                                     # | ||||
| # python exps/NATS-Bench/tss-collect-patcher.py                              # | ||||
| ############################################################################## | ||||
| import os, re, sys, time, shutil, random, argparse, collections | ||||
| import numpy as np | ||||
| from copy import deepcopy | ||||
| import torch | ||||
| from tqdm import tqdm | ||||
| from pathlib import Path | ||||
| from collections import defaultdict, OrderedDict | ||||
| from typing import Dict, Any, Text, List | ||||
|  | ||||
| from xautodl.log_utils import AverageMeter, time_string, convert_secs2time | ||||
| from xautodl.config_utils import load_config, dict2config | ||||
| from xautodl.datasets import get_datasets | ||||
| from xautodl.models import CellStructure, get_cell_based_tiny_net, get_search_spaces | ||||
| from xautodl.procedures import ( | ||||
|     bench_pure_evaluate as pure_evaluate, | ||||
|     get_nas_bench_loaders, | ||||
| ) | ||||
| from xautodl.utils import get_md5_file | ||||
|  | ||||
| from nats_bench import pickle_save, pickle_load, ArchResults, ResultsCount | ||||
| from nas_201_api import NASBench201API | ||||
|  | ||||
|  | ||||
| NATS_TSS_BASE_NAME = "NATS-tss-v1_0"  # 2020.08.28 | ||||
|  | ||||
|  | ||||
| def simplify(save_dir, save_name, nets, total, sup_config): | ||||
|     hps, seeds = ["12", "200"], set() | ||||
|     for hp in hps: | ||||
|         sub_save_dir = save_dir / "raw-data-{:}".format(hp) | ||||
|         ckps = sorted(list(sub_save_dir.glob("arch-*-seed-*.pth"))) | ||||
|         seed2names = defaultdict(list) | ||||
|         for ckp in ckps: | ||||
|             parts = re.split("-|\.", ckp.name) | ||||
|             seed2names[parts[3]].append(ckp.name) | ||||
|         print("DIR : {:}".format(sub_save_dir)) | ||||
|         nums = [] | ||||
|         for seed, xlist in seed2names.items(): | ||||
|             seeds.add(seed) | ||||
|             nums.append(len(xlist)) | ||||
|             print("  [seed={:}] there are {:} checkpoints.".format(seed, len(xlist))) | ||||
|         assert ( | ||||
|             len(nets) == total == max(nums) | ||||
|         ), "there are some missed files : {:} vs {:}".format(max(nums), total) | ||||
|     print("{:} start simplify the checkpoint.".format(time_string())) | ||||
|  | ||||
|     datasets = ("cifar10-valid", "cifar10", "cifar100", "ImageNet16-120") | ||||
|  | ||||
|     # Create the directory to save the processed data | ||||
|     # full_save_dir contains all benchmark files with trained weights. | ||||
|     # simplify_save_dir contains all benchmark files without trained weights. | ||||
|     full_save_dir = save_dir / (save_name + "-FULL") | ||||
|     simple_save_dir = save_dir / (save_name + "-SIMPLIFY") | ||||
|     full_save_dir.mkdir(parents=True, exist_ok=True) | ||||
|     simple_save_dir.mkdir(parents=True, exist_ok=True) | ||||
|     # all data in memory | ||||
|     arch2infos, evaluated_indexes = dict(), set() | ||||
|     end_time, arch_time = time.time(), AverageMeter() | ||||
|     # save the meta information | ||||
|     for index in tqdm(range(total)): | ||||
|         arch_str = nets[index] | ||||
|         hp2info = OrderedDict() | ||||
|  | ||||
|         simple_save_path = simple_save_dir / "{:06d}.pickle".format(index) | ||||
|  | ||||
|         arch2infos[index] = pickle_load(simple_save_path) | ||||
|         evaluated_indexes.add(index) | ||||
|  | ||||
|         # measure elapsed time | ||||
|         arch_time.update(time.time() - end_time) | ||||
|         end_time = time.time() | ||||
|         need_time = "{:}".format( | ||||
|             convert_secs2time(arch_time.avg * (total - index - 1), True) | ||||
|         ) | ||||
|         # print('{:} {:06d}/{:06d} : still need {:}'.format(time_string(), index, total, need_time)) | ||||
|     print("{:} {:} done.".format(time_string(), save_name)) | ||||
|     final_infos = { | ||||
|         "meta_archs": nets, | ||||
|         "total_archs": total, | ||||
|         "arch2infos": arch2infos, | ||||
|         "evaluated_indexes": evaluated_indexes, | ||||
|     } | ||||
|     save_file_name = save_dir / "{:}.pickle".format(save_name) | ||||
|     pickle_save(final_infos, str(save_file_name)) | ||||
|     # move the benchmark file to a new path | ||||
|     hd5sum = get_md5_file(str(save_file_name) + ".pbz2") | ||||
|     hd5_file_name = save_dir / "{:}-{:}.pickle.pbz2".format(NATS_TSS_BASE_NAME, hd5sum) | ||||
|     shutil.move(str(save_file_name) + ".pbz2", hd5_file_name) | ||||
|     print( | ||||
|         "Save {:} / {:} architecture results into {:} -> {:}.".format( | ||||
|             len(evaluated_indexes), total, save_file_name, hd5_file_name | ||||
|         ) | ||||
|     ) | ||||
|     # move the directory to a new path | ||||
|     hd5_full_save_dir = save_dir / "{:}-{:}-full".format(NATS_TSS_BASE_NAME, hd5sum) | ||||
|     hd5_simple_save_dir = save_dir / "{:}-{:}-simple".format(NATS_TSS_BASE_NAME, hd5sum) | ||||
|     shutil.move(full_save_dir, hd5_full_save_dir) | ||||
|     shutil.move(simple_save_dir, hd5_simple_save_dir) | ||||
|  | ||||
|  | ||||
| def traverse_net(max_node): | ||||
|     aa_nas_bench_ss = get_search_spaces("cell", "nats-bench") | ||||
|     archs = CellStructure.gen_all(aa_nas_bench_ss, max_node, False) | ||||
|     print( | ||||
|         "There are {:} archs vs {:}.".format( | ||||
|             len(archs), len(aa_nas_bench_ss) ** ((max_node - 1) * max_node / 2) | ||||
|         ) | ||||
|     ) | ||||
|  | ||||
|     random.seed(88)  # please do not change this line for reproducibility | ||||
|     random.shuffle(archs) | ||||
|     assert ( | ||||
|         archs[0].tostr() | ||||
|         == "|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|" | ||||
|     ), "please check the 0-th architecture : {:}".format(archs[0]) | ||||
|     assert ( | ||||
|         archs[9].tostr() | ||||
|         == "|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|" | ||||
|     ), "please check the 9-th architecture : {:}".format(archs[9]) | ||||
|     assert ( | ||||
|         archs[123].tostr() | ||||
|         == "|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|" | ||||
|     ), "please check the 123-th architecture : {:}".format(archs[123]) | ||||
|     return [x.tostr() for x in archs] | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|  | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="NATS-Bench (topology search space)", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--base_save_dir", | ||||
|         type=str, | ||||
|         default="./output/NATS-Bench-topology", | ||||
|         help="The base-name of folder to save checkpoints and log.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--max_node", type=int, default=4, help="The maximum node in a cell." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--channel", type=int, default=16, help="The number of channels." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--num_cells", type=int, default=5, help="The number of cells in one stage." | ||||
|     ) | ||||
|     parser.add_argument("--check_N", type=int, default=15625, help="For safety.") | ||||
|     parser.add_argument( | ||||
|         "--save_name", type=str, default="process", help="The save directory." | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     nets = traverse_net(args.max_node) | ||||
|     if len(nets) != args.check_N: | ||||
|         raise ValueError( | ||||
|             "Pre-num-check failed : {:} vs {:}".format(len(nets), args.check_N) | ||||
|         ) | ||||
|  | ||||
|     save_dir = Path(args.base_save_dir) | ||||
|     simplify( | ||||
|         save_dir, | ||||
|         args.save_name, | ||||
|         nets, | ||||
|         args.check_N, | ||||
|         {"name": "infer.tiny", "channel": args.channel, "num_cells": args.num_cells}, | ||||
|     ) | ||||
							
								
								
									
										461
									
								
								AutoDL-Projects/exps/NATS-Bench/tss-collect.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										461
									
								
								AutoDL-Projects/exps/NATS-Bench/tss-collect.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,461 @@ | ||||
| ############################################################################## | ||||
| # NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size # | ||||
| ############################################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08                          # | ||||
| ############################################################################## | ||||
| # This file is used to re-orangize all checkpoints (created by main-tss.py)  # | ||||
| # into a single benchmark file. Besides, for each trial, we will merge the   # | ||||
| # information of all its trials into a single file.                          # | ||||
| #                                                                            # | ||||
| # Usage:                                                                     # | ||||
| # python exps/NATS-Bench/tss-collect.py                                      # | ||||
| ############################################################################## | ||||
| import os, re, sys, time, shutil, random, argparse, collections | ||||
| import numpy as np | ||||
| from copy import deepcopy | ||||
| import torch | ||||
| from tqdm import tqdm | ||||
| from pathlib import Path | ||||
| from collections import defaultdict, OrderedDict | ||||
| from typing import Dict, Any, Text, List | ||||
|  | ||||
| from xautodl.log_utils import AverageMeter, time_string, convert_secs2time | ||||
| from xautodl.config_utils import load_config, dict2config | ||||
| from xautodl.datasets import get_datasets | ||||
| from xautodl.models import CellStructure, get_cell_based_tiny_net, get_search_spaces | ||||
| from xautodl.procedures import ( | ||||
|     bench_pure_evaluate as pure_evaluate, | ||||
|     get_nas_bench_loaders, | ||||
| ) | ||||
| from xautodl.utils import get_md5_file | ||||
| from nats_bench import pickle_save, pickle_load, ArchResults, ResultsCount | ||||
| from nas_201_api import NASBench201API | ||||
|  | ||||
|  | ||||
| api = NASBench201API( | ||||
|     "{:}/.torch/NAS-Bench-201-v1_0-e61699.pth".format(os.environ["HOME"]) | ||||
| ) | ||||
|  | ||||
| NATS_TSS_BASE_NAME = "NATS-tss-v1_0"  # 2020.08.28 | ||||
|  | ||||
|  | ||||
| def create_result_count( | ||||
|     used_seed: int, | ||||
|     dataset: Text, | ||||
|     arch_config: Dict[Text, Any], | ||||
|     results: Dict[Text, Any], | ||||
|     dataloader_dict: Dict[Text, Any], | ||||
| ) -> ResultsCount: | ||||
|     xresult = ResultsCount( | ||||
|         dataset, | ||||
|         results["net_state_dict"], | ||||
|         results["train_acc1es"], | ||||
|         results["train_losses"], | ||||
|         results["param"], | ||||
|         results["flop"], | ||||
|         arch_config, | ||||
|         used_seed, | ||||
|         results["total_epoch"], | ||||
|         None, | ||||
|     ) | ||||
|     net_config = dict2config( | ||||
|         { | ||||
|             "name": "infer.tiny", | ||||
|             "C": arch_config["channel"], | ||||
|             "N": arch_config["num_cells"], | ||||
|             "genotype": CellStructure.str2structure(arch_config["arch_str"]), | ||||
|             "num_classes": arch_config["class_num"], | ||||
|         }, | ||||
|         None, | ||||
|     ) | ||||
|     if "train_times" in results:  # new version | ||||
|         xresult.update_train_info( | ||||
|             results["train_acc1es"], | ||||
|             results["train_acc5es"], | ||||
|             results["train_losses"], | ||||
|             results["train_times"], | ||||
|         ) | ||||
|         xresult.update_eval( | ||||
|             results["valid_acc1es"], results["valid_losses"], results["valid_times"] | ||||
|         ) | ||||
|     else: | ||||
|         network = get_cell_based_tiny_net(net_config) | ||||
|         network.load_state_dict(xresult.get_net_param()) | ||||
|         if dataset == "cifar10-valid": | ||||
|             xresult.update_OLD_eval( | ||||
|                 "x-valid", results["valid_acc1es"], results["valid_losses"] | ||||
|             ) | ||||
|             loss, top1, top5, latencies = pure_evaluate( | ||||
|                 dataloader_dict["{:}@{:}".format("cifar10", "test")], network.cuda() | ||||
|             ) | ||||
|             xresult.update_OLD_eval( | ||||
|                 "ori-test", | ||||
|                 {results["total_epoch"] - 1: top1}, | ||||
|                 {results["total_epoch"] - 1: loss}, | ||||
|             ) | ||||
|             xresult.update_latency(latencies) | ||||
|         elif dataset == "cifar10": | ||||
|             xresult.update_OLD_eval( | ||||
|                 "ori-test", results["valid_acc1es"], results["valid_losses"] | ||||
|             ) | ||||
|             loss, top1, top5, latencies = pure_evaluate( | ||||
|                 dataloader_dict["{:}@{:}".format(dataset, "test")], network.cuda() | ||||
|             ) | ||||
|             xresult.update_latency(latencies) | ||||
|         elif dataset == "cifar100" or dataset == "ImageNet16-120": | ||||
|             xresult.update_OLD_eval( | ||||
|                 "ori-test", results["valid_acc1es"], results["valid_losses"] | ||||
|             ) | ||||
|             loss, top1, top5, latencies = pure_evaluate( | ||||
|                 dataloader_dict["{:}@{:}".format(dataset, "valid")], network.cuda() | ||||
|             ) | ||||
|             xresult.update_OLD_eval( | ||||
|                 "x-valid", | ||||
|                 {results["total_epoch"] - 1: top1}, | ||||
|                 {results["total_epoch"] - 1: loss}, | ||||
|             ) | ||||
|             loss, top1, top5, latencies = pure_evaluate( | ||||
|                 dataloader_dict["{:}@{:}".format(dataset, "test")], network.cuda() | ||||
|             ) | ||||
|             xresult.update_OLD_eval( | ||||
|                 "x-test", | ||||
|                 {results["total_epoch"] - 1: top1}, | ||||
|                 {results["total_epoch"] - 1: loss}, | ||||
|             ) | ||||
|             xresult.update_latency(latencies) | ||||
|         else: | ||||
|             raise ValueError("invalid dataset name : {:}".format(dataset)) | ||||
|     return xresult | ||||
|  | ||||
|  | ||||
| def account_one_arch(arch_index, arch_str, checkpoints, datasets, dataloader_dict): | ||||
|     information = ArchResults(arch_index, arch_str) | ||||
|  | ||||
|     for checkpoint_path in checkpoints: | ||||
|         checkpoint = torch.load(checkpoint_path, map_location="cpu") | ||||
|         used_seed = checkpoint_path.name.split("-")[-1].split(".")[0] | ||||
|         ok_dataset = 0 | ||||
|         for dataset in datasets: | ||||
|             if dataset not in checkpoint: | ||||
|                 print( | ||||
|                     "Can not find {:} in arch-{:} from {:}".format( | ||||
|                         dataset, arch_index, checkpoint_path | ||||
|                     ) | ||||
|                 ) | ||||
|                 continue | ||||
|             else: | ||||
|                 ok_dataset += 1 | ||||
|             results = checkpoint[dataset] | ||||
|             assert results[ | ||||
|                 "finish-train" | ||||
|             ], "This {:} arch seed={:} does not finish train on {:} ::: {:}".format( | ||||
|                 arch_index, used_seed, dataset, checkpoint_path | ||||
|             ) | ||||
|             arch_config = { | ||||
|                 "channel": results["channel"], | ||||
|                 "num_cells": results["num_cells"], | ||||
|                 "arch_str": arch_str, | ||||
|                 "class_num": results["config"]["class_num"], | ||||
|             } | ||||
|  | ||||
|             xresult = create_result_count( | ||||
|                 used_seed, dataset, arch_config, results, dataloader_dict | ||||
|             ) | ||||
|             information.update(dataset, int(used_seed), xresult) | ||||
|         if ok_dataset == 0: | ||||
|             raise ValueError("{:} does not find any data".format(checkpoint_path)) | ||||
|     return information | ||||
|  | ||||
|  | ||||
| def correct_time_related_info(arch_index: int, arch_infos: Dict[Text, ArchResults]): | ||||
|     # calibrate the latency based on NAS-Bench-201-v1_0-e61699.pth | ||||
|     cifar010_latency = ( | ||||
|         api.get_latency(arch_index, "cifar10-valid", hp="200") | ||||
|         + api.get_latency(arch_index, "cifar10", hp="200") | ||||
|     ) / 2 | ||||
|     cifar100_latency = api.get_latency(arch_index, "cifar100", hp="200") | ||||
|     image_latency = api.get_latency(arch_index, "ImageNet16-120", hp="200") | ||||
|     for hp, arch_info in arch_infos.items(): | ||||
|         arch_info.reset_latency("cifar10-valid", None, cifar010_latency) | ||||
|         arch_info.reset_latency("cifar10", None, cifar010_latency) | ||||
|         arch_info.reset_latency("cifar100", None, cifar100_latency) | ||||
|         arch_info.reset_latency("ImageNet16-120", None, image_latency) | ||||
|  | ||||
|     train_per_epoch_time = list( | ||||
|         arch_infos["12"].query("cifar10-valid", 777).train_times.values() | ||||
|     ) | ||||
|     train_per_epoch_time = sum(train_per_epoch_time) / len(train_per_epoch_time) | ||||
|     eval_ori_test_time, eval_x_valid_time = [], [] | ||||
|     for key, value in arch_infos["12"].query("cifar10-valid", 777).eval_times.items(): | ||||
|         if key.startswith("ori-test@"): | ||||
|             eval_ori_test_time.append(value) | ||||
|         elif key.startswith("x-valid@"): | ||||
|             eval_x_valid_time.append(value) | ||||
|         else: | ||||
|             raise ValueError("-- {:} --".format(key)) | ||||
|     eval_ori_test_time, eval_x_valid_time = float(np.mean(eval_ori_test_time)), float( | ||||
|         np.mean(eval_x_valid_time) | ||||
|     ) | ||||
|     nums = { | ||||
|         "ImageNet16-120-train": 151700, | ||||
|         "ImageNet16-120-valid": 3000, | ||||
|         "ImageNet16-120-test": 6000, | ||||
|         "cifar10-valid-train": 25000, | ||||
|         "cifar10-valid-valid": 25000, | ||||
|         "cifar10-train": 50000, | ||||
|         "cifar10-test": 10000, | ||||
|         "cifar100-train": 50000, | ||||
|         "cifar100-test": 10000, | ||||
|         "cifar100-valid": 5000, | ||||
|     } | ||||
|     eval_per_sample = (eval_ori_test_time + eval_x_valid_time) / ( | ||||
|         nums["cifar10-valid-valid"] + nums["cifar10-test"] | ||||
|     ) | ||||
|     for hp, arch_info in arch_infos.items(): | ||||
|         arch_info.reset_pseudo_train_times( | ||||
|             "cifar10-valid", | ||||
|             None, | ||||
|             train_per_epoch_time | ||||
|             / nums["cifar10-valid-train"] | ||||
|             * nums["cifar10-valid-train"], | ||||
|         ) | ||||
|         arch_info.reset_pseudo_train_times( | ||||
|             "cifar10", | ||||
|             None, | ||||
|             train_per_epoch_time / nums["cifar10-valid-train"] * nums["cifar10-train"], | ||||
|         ) | ||||
|         arch_info.reset_pseudo_train_times( | ||||
|             "cifar100", | ||||
|             None, | ||||
|             train_per_epoch_time / nums["cifar10-valid-train"] * nums["cifar100-train"], | ||||
|         ) | ||||
|         arch_info.reset_pseudo_train_times( | ||||
|             "ImageNet16-120", | ||||
|             None, | ||||
|             train_per_epoch_time | ||||
|             / nums["cifar10-valid-train"] | ||||
|             * nums["ImageNet16-120-train"], | ||||
|         ) | ||||
|         arch_info.reset_pseudo_eval_times( | ||||
|             "cifar10-valid", | ||||
|             None, | ||||
|             "x-valid", | ||||
|             eval_per_sample * nums["cifar10-valid-valid"], | ||||
|         ) | ||||
|         arch_info.reset_pseudo_eval_times( | ||||
|             "cifar10-valid", None, "ori-test", eval_per_sample * nums["cifar10-test"] | ||||
|         ) | ||||
|         arch_info.reset_pseudo_eval_times( | ||||
|             "cifar10", None, "ori-test", eval_per_sample * nums["cifar10-test"] | ||||
|         ) | ||||
|         arch_info.reset_pseudo_eval_times( | ||||
|             "cifar100", None, "x-valid", eval_per_sample * nums["cifar100-valid"] | ||||
|         ) | ||||
|         arch_info.reset_pseudo_eval_times( | ||||
|             "cifar100", None, "x-test", eval_per_sample * nums["cifar100-valid"] | ||||
|         ) | ||||
|         arch_info.reset_pseudo_eval_times( | ||||
|             "cifar100", None, "ori-test", eval_per_sample * nums["cifar100-test"] | ||||
|         ) | ||||
|         arch_info.reset_pseudo_eval_times( | ||||
|             "ImageNet16-120", | ||||
|             None, | ||||
|             "x-valid", | ||||
|             eval_per_sample * nums["ImageNet16-120-valid"], | ||||
|         ) | ||||
|         arch_info.reset_pseudo_eval_times( | ||||
|             "ImageNet16-120", | ||||
|             None, | ||||
|             "x-test", | ||||
|             eval_per_sample * nums["ImageNet16-120-valid"], | ||||
|         ) | ||||
|         arch_info.reset_pseudo_eval_times( | ||||
|             "ImageNet16-120", | ||||
|             None, | ||||
|             "ori-test", | ||||
|             eval_per_sample * nums["ImageNet16-120-test"], | ||||
|         ) | ||||
|     return arch_infos | ||||
|  | ||||
|  | ||||
| def simplify(save_dir, save_name, nets, total, sup_config): | ||||
|     dataloader_dict = get_nas_bench_loaders(6) | ||||
|     hps, seeds = ["12", "200"], set() | ||||
|     for hp in hps: | ||||
|         sub_save_dir = save_dir / "raw-data-{:}".format(hp) | ||||
|         ckps = sorted(list(sub_save_dir.glob("arch-*-seed-*.pth"))) | ||||
|         seed2names = defaultdict(list) | ||||
|         for ckp in ckps: | ||||
|             parts = re.split("-|\.", ckp.name) | ||||
|             seed2names[parts[3]].append(ckp.name) | ||||
|         print("DIR : {:}".format(sub_save_dir)) | ||||
|         nums = [] | ||||
|         for seed, xlist in seed2names.items(): | ||||
|             seeds.add(seed) | ||||
|             nums.append(len(xlist)) | ||||
|             print("  [seed={:}] there are {:} checkpoints.".format(seed, len(xlist))) | ||||
|         assert ( | ||||
|             len(nets) == total == max(nums) | ||||
|         ), "there are some missed files : {:} vs {:}".format(max(nums), total) | ||||
|     print("{:} start simplify the checkpoint.".format(time_string())) | ||||
|  | ||||
|     datasets = ("cifar10-valid", "cifar10", "cifar100", "ImageNet16-120") | ||||
|  | ||||
|     # Create the directory to save the processed data | ||||
|     # full_save_dir contains all benchmark files with trained weights. | ||||
|     # simplify_save_dir contains all benchmark files without trained weights. | ||||
|     full_save_dir = save_dir / (save_name + "-FULL") | ||||
|     simple_save_dir = save_dir / (save_name + "-SIMPLIFY") | ||||
|     full_save_dir.mkdir(parents=True, exist_ok=True) | ||||
|     simple_save_dir.mkdir(parents=True, exist_ok=True) | ||||
|     # all data in memory | ||||
|     arch2infos, evaluated_indexes = dict(), set() | ||||
|     end_time, arch_time = time.time(), AverageMeter() | ||||
|     # save the meta information | ||||
|     temp_final_infos = { | ||||
|         "meta_archs": nets, | ||||
|         "total_archs": total, | ||||
|         "arch2infos": None, | ||||
|         "evaluated_indexes": set(), | ||||
|     } | ||||
|     pickle_save(temp_final_infos, str(full_save_dir / "meta.pickle")) | ||||
|     pickle_save(temp_final_infos, str(simple_save_dir / "meta.pickle")) | ||||
|  | ||||
|     for index in tqdm(range(total)): | ||||
|         arch_str = nets[index] | ||||
|         hp2info = OrderedDict() | ||||
|  | ||||
|         full_save_path = full_save_dir / "{:06d}.pickle".format(index) | ||||
|         simple_save_path = simple_save_dir / "{:06d}.pickle".format(index) | ||||
|         for hp in hps: | ||||
|             sub_save_dir = save_dir / "raw-data-{:}".format(hp) | ||||
|             ckps = [ | ||||
|                 sub_save_dir / "arch-{:06d}-seed-{:}.pth".format(index, seed) | ||||
|                 for seed in seeds | ||||
|             ] | ||||
|             ckps = [x for x in ckps if x.exists()] | ||||
|             if len(ckps) == 0: | ||||
|                 raise ValueError("Invalid data : index={:}, hp={:}".format(index, hp)) | ||||
|  | ||||
|             arch_info = account_one_arch( | ||||
|                 index, arch_str, ckps, datasets, dataloader_dict | ||||
|             ) | ||||
|             hp2info[hp] = arch_info | ||||
|  | ||||
|         hp2info = correct_time_related_info(index, hp2info) | ||||
|         evaluated_indexes.add(index) | ||||
|  | ||||
|         to_save_data = OrderedDict( | ||||
|             {"12": hp2info["12"].state_dict(), "200": hp2info["200"].state_dict()} | ||||
|         ) | ||||
|         pickle_save(to_save_data, str(full_save_path)) | ||||
|  | ||||
|         for hp in hps: | ||||
|             hp2info[hp].clear_params() | ||||
|         to_save_data = OrderedDict( | ||||
|             {"12": hp2info["12"].state_dict(), "200": hp2info["200"].state_dict()} | ||||
|         ) | ||||
|         pickle_save(to_save_data, str(simple_save_path)) | ||||
|         arch2infos[index] = to_save_data | ||||
|         # measure elapsed time | ||||
|         arch_time.update(time.time() - end_time) | ||||
|         end_time = time.time() | ||||
|         need_time = "{:}".format( | ||||
|             convert_secs2time(arch_time.avg * (total - index - 1), True) | ||||
|         ) | ||||
|         # print('{:} {:06d}/{:06d} : still need {:}'.format(time_string(), index, total, need_time)) | ||||
|     print("{:} {:} done.".format(time_string(), save_name)) | ||||
|     final_infos = { | ||||
|         "meta_archs": nets, | ||||
|         "total_archs": total, | ||||
|         "arch2infos": arch2infos, | ||||
|         "evaluated_indexes": evaluated_indexes, | ||||
|     } | ||||
|     save_file_name = save_dir / "{:}.pickle".format(save_name) | ||||
|     pickle_save(final_infos, str(save_file_name)) | ||||
|     # move the benchmark file to a new path | ||||
|     hd5sum = get_md5_file(str(save_file_name) + ".pbz2") | ||||
|     hd5_file_name = save_dir / "{:}-{:}.pickle.pbz2".format(NATS_TSS_BASE_NAME, hd5sum) | ||||
|     shutil.move(str(save_file_name) + ".pbz2", hd5_file_name) | ||||
|     print( | ||||
|         "Save {:} / {:} architecture results into {:} -> {:}.".format( | ||||
|             len(evaluated_indexes), total, save_file_name, hd5_file_name | ||||
|         ) | ||||
|     ) | ||||
|     # move the directory to a new path | ||||
|     hd5_full_save_dir = save_dir / "{:}-{:}-full".format(NATS_TSS_BASE_NAME, hd5sum) | ||||
|     hd5_simple_save_dir = save_dir / "{:}-{:}-simple".format(NATS_TSS_BASE_NAME, hd5sum) | ||||
|     shutil.move(full_save_dir, hd5_full_save_dir) | ||||
|     shutil.move(simple_save_dir, hd5_simple_save_dir) | ||||
|     # save the meta information for simple and full | ||||
|     # final_infos['arch2infos'] = None | ||||
|     # final_infos['evaluated_indexes'] = set() | ||||
|  | ||||
|  | ||||
| def traverse_net(max_node): | ||||
|     aa_nas_bench_ss = get_search_spaces("cell", "nats-bench") | ||||
|     archs = CellStructure.gen_all(aa_nas_bench_ss, max_node, False) | ||||
|     print( | ||||
|         "There are {:} archs vs {:}.".format( | ||||
|             len(archs), len(aa_nas_bench_ss) ** ((max_node - 1) * max_node / 2) | ||||
|         ) | ||||
|     ) | ||||
|  | ||||
|     random.seed(88)  # please do not change this line for reproducibility | ||||
|     random.shuffle(archs) | ||||
|     assert ( | ||||
|         archs[0].tostr() | ||||
|         == "|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|" | ||||
|     ), "please check the 0-th architecture : {:}".format(archs[0]) | ||||
|     assert ( | ||||
|         archs[9].tostr() | ||||
|         == "|avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|" | ||||
|     ), "please check the 9-th architecture : {:}".format(archs[9]) | ||||
|     assert ( | ||||
|         archs[123].tostr() | ||||
|         == "|avg_pool_3x3~0|+|avg_pool_3x3~0|nor_conv_1x1~1|+|none~0|avg_pool_3x3~1|nor_conv_3x3~2|" | ||||
|     ), "please check the 123-th architecture : {:}".format(archs[123]) | ||||
|     return [x.tostr() for x in archs] | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|  | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="NATS-Bench (topology search space)", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--base_save_dir", | ||||
|         type=str, | ||||
|         default="./output/NATS-Bench-topology", | ||||
|         help="The base-name of folder to save checkpoints and log.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--max_node", type=int, default=4, help="The maximum node in a cell." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--channel", type=int, default=16, help="The number of channels." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--num_cells", type=int, default=5, help="The number of cells in one stage." | ||||
|     ) | ||||
|     parser.add_argument("--check_N", type=int, default=15625, help="For safety.") | ||||
|     parser.add_argument( | ||||
|         "--save_name", type=str, default="process", help="The save directory." | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     nets = traverse_net(args.max_node) | ||||
|     if len(nets) != args.check_N: | ||||
|         raise ValueError( | ||||
|             "Pre-num-check failed : {:} vs {:}".format(len(nets), args.check_N) | ||||
|         ) | ||||
|  | ||||
|     save_dir = Path(args.base_save_dir) | ||||
|     simplify( | ||||
|         save_dir, | ||||
|         args.save_name, | ||||
|         nets, | ||||
|         args.check_N, | ||||
|         {"name": "infer.tiny", "channel": args.channel, "num_cells": args.num_cells}, | ||||
|     ) | ||||
							
								
								
									
										105
									
								
								AutoDL-Projects/exps/NATS-Bench/tss-file-manager.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										105
									
								
								AutoDL-Projects/exps/NATS-Bench/tss-file-manager.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,105 @@ | ||||
| ############################################################################## | ||||
| # NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size # | ||||
| ############################################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08                          # | ||||
| ############################################################################## | ||||
| # Usage: python exps/NATS-Bench/tss-file-manager.py --mode check             # | ||||
| ############################################################################## | ||||
| import os, sys, time, torch, argparse | ||||
| from typing import List, Text, Dict, Any | ||||
| from shutil import copyfile | ||||
| from collections import defaultdict | ||||
| from copy import deepcopy | ||||
| from pathlib import Path | ||||
|  | ||||
| from xautodl.config_utils import dict2config, load_config | ||||
| from xautodl.procedures import bench_evaluate_for_seed | ||||
| from xautodl.procedures import get_machine_info | ||||
| from xautodl.datasets import get_datasets | ||||
| from xautodl.log_utils import Logger, AverageMeter, time_string, convert_secs2time | ||||
|  | ||||
|  | ||||
| def obtain_valid_ckp(save_dir: Text, total: int, possible_seeds: List[int]): | ||||
|     seed2ckps = defaultdict(list) | ||||
|     miss2ckps = defaultdict(list) | ||||
|     for i in range(total): | ||||
|         for seed in possible_seeds: | ||||
|             path = os.path.join(save_dir, "arch-{:06d}-seed-{:04d}.pth".format(i, seed)) | ||||
|             if os.path.exists(path): | ||||
|                 seed2ckps[seed].append(i) | ||||
|             else: | ||||
|                 miss2ckps[seed].append(i) | ||||
|     for seed, xlist in seed2ckps.items(): | ||||
|         print( | ||||
|             "[{:}] [seed={:}] has {:5d}/{:5d} | miss {:5d}/{:5d}".format( | ||||
|                 save_dir, seed, len(xlist), total, total - len(xlist), total | ||||
|             ) | ||||
|         ) | ||||
|     return dict(seed2ckps), dict(miss2ckps) | ||||
|  | ||||
|  | ||||
| def copy_data(source_dir, target_dir, meta_path): | ||||
|     target_dir = Path(target_dir) | ||||
|     target_dir.mkdir(parents=True, exist_ok=True) | ||||
|     miss2ckps = torch.load(meta_path)["miss2ckps"] | ||||
|     s2t = {} | ||||
|     for seed, xlist in miss2ckps.items(): | ||||
|         for i in xlist: | ||||
|             file_name = "arch-{:06d}-seed-{:04d}.pth".format(i, seed) | ||||
|             source_path = os.path.join(source_dir, file_name) | ||||
|             target_path = os.path.join(target_dir, file_name) | ||||
|             if os.path.exists(source_path): | ||||
|                 s2t[source_path] = target_path | ||||
|     print( | ||||
|         "Map from {:} to {:}, find {:} missed ckps.".format( | ||||
|             source_dir, target_dir, len(s2t) | ||||
|         ) | ||||
|     ) | ||||
|     for s, t in s2t.items(): | ||||
|         copyfile(s, t) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="NATS-Bench (topology search space) file manager.", | ||||
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--mode", | ||||
|         type=str, | ||||
|         required=True, | ||||
|         choices=["check", "copy"], | ||||
|         help="The script mode.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="output/NATS-Bench-topology", | ||||
|         help="Folder to save checkpoints and log.", | ||||
|     ) | ||||
|     parser.add_argument("--check_N", type=int, default=15625, help="For safety.") | ||||
|     # use for train the model | ||||
|     args = parser.parse_args() | ||||
|     possible_configs = ["12", "200"] | ||||
|     possible_seedss = [[111, 777], [777, 888, 999]] | ||||
|     if args.mode == "check": | ||||
|         for config, possible_seeds in zip(possible_configs, possible_seedss): | ||||
|             cur_save_dir = "{:}/raw-data-{:}".format(args.save_dir, config) | ||||
|             seed2ckps, miss2ckps = obtain_valid_ckp( | ||||
|                 cur_save_dir, args.check_N, possible_seeds | ||||
|             ) | ||||
|             torch.save( | ||||
|                 dict(seed2ckps=seed2ckps, miss2ckps=miss2ckps), | ||||
|                 "{:}/meta-{:}.pth".format(args.save_dir, config), | ||||
|             ) | ||||
|     elif args.mode == "copy": | ||||
|         for config in possible_configs: | ||||
|             cur_save_dir = "{:}/raw-data-{:}".format(args.save_dir, config) | ||||
|             cur_copy_dir = "{:}/copy-{:}".format(args.save_dir, config) | ||||
|             cur_meta_path = "{:}/meta-{:}.pth".format(args.save_dir, config) | ||||
|             if os.path.exists(cur_meta_path): | ||||
|                 copy_data(cur_save_dir, cur_copy_dir, cur_meta_path) | ||||
|             else: | ||||
|                 print("Do not find : {:}".format(cur_meta_path)) | ||||
|     else: | ||||
|         raise ValueError("invalid mode : {:}".format(args.mode)) | ||||
		Reference in New Issue
	
	Block a user