| 
									
										
										
										
											2020-02-23 10:30:37 +11:00
										 |  |  | ##################################################### | 
					
						
							|  |  |  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 # | 
					
						
							|  |  |  | ##################################################### | 
					
						
							| 
									
										
										
										
											2020-03-10 19:08:56 +11:00
										 |  |  | # python exps/NAS-Bench-201/check.py --base_str C16-N5-LESS | 
					
						
							| 
									
										
										
										
											2020-02-23 10:30:37 +11:00
										 |  |  | ##################################################### | 
					
						
							| 
									
										
										
										
											2020-03-09 19:38:00 +11:00
										 |  |  | import sys, time, argparse, collections | 
					
						
							| 
									
										
										
										
											2019-12-26 23:29:36 +11:00
										 |  |  | import torch | 
					
						
							|  |  |  | from pathlib import Path | 
					
						
							|  |  |  | from collections import defaultdict | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | 
					
						
							|  |  |  | if str(lib_dir) not in sys.path: | 
					
						
							|  |  |  |     sys.path.insert(0, str(lib_dir)) | 
					
						
							|  |  |  | from log_utils import AverageMeter, time_string, convert_secs2time | 
					
						
							| 
									
										
										
										
											2019-12-26 23:29:36 +11:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def check_files(save_dir, meta_file, basestr): | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     meta_infos = torch.load(meta_file, map_location="cpu") | 
					
						
							|  |  |  |     meta_archs = meta_infos["archs"] | 
					
						
							|  |  |  |     meta_num_archs = meta_infos["total"] | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |     assert meta_num_archs == len( | 
					
						
							|  |  |  |         meta_archs | 
					
						
							|  |  |  |     ), "invalid number of archs : {:} vs {:}".format(meta_num_archs, len(meta_archs)) | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |     sub_model_dirs = sorted(list(save_dir.glob("*-*-{:}".format(basestr)))) | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |     print( | 
					
						
							|  |  |  |         "{:} find {:} directories used to save checkpoints".format( | 
					
						
							|  |  |  |             time_string(), len(sub_model_dirs) | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2019-12-26 23:29:36 +11:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     subdir2archs, num_evaluated_arch = collections.OrderedDict(), 0 | 
					
						
							|  |  |  |     num_seeds = defaultdict(lambda: 0) | 
					
						
							|  |  |  |     for index, sub_dir in enumerate(sub_model_dirs): | 
					
						
							|  |  |  |         xcheckpoints = list(sub_dir.glob("arch-*-seed-*.pth")) | 
					
						
							|  |  |  |         # xcheckpoints = list(sub_dir.glob('arch-*-seed-0777.pth')) + list(sub_dir.glob('arch-*-seed-0888.pth')) + list(sub_dir.glob('arch-*-seed-0999.pth')) | 
					
						
							|  |  |  |         arch_indexes = set() | 
					
						
							|  |  |  |         for checkpoint in xcheckpoints: | 
					
						
							|  |  |  |             temp_names = checkpoint.name.split("-") | 
					
						
							|  |  |  |             assert ( | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |                 len(temp_names) == 4 | 
					
						
							|  |  |  |                 and temp_names[0] == "arch" | 
					
						
							|  |  |  |                 and temp_names[2] == "seed" | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |             ), "invalid checkpoint name : {:}".format(checkpoint.name) | 
					
						
							|  |  |  |             arch_indexes.add(temp_names[1]) | 
					
						
							|  |  |  |         subdir2archs[sub_dir] = sorted(list(arch_indexes)) | 
					
						
							|  |  |  |         num_evaluated_arch += len(arch_indexes) | 
					
						
							|  |  |  |         # count number of seeds for each architecture | 
					
						
							|  |  |  |         for arch_index in arch_indexes: | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |             num_seeds[ | 
					
						
							|  |  |  |                 len(list(sub_dir.glob("arch-{:}-seed-*.pth".format(arch_index)))) | 
					
						
							|  |  |  |             ] += 1 | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     print( | 
					
						
							|  |  |  |         "There are {:5d} architectures that have been evaluated ({:} in total, {:} ckps in total).".format( | 
					
						
							|  |  |  |             num_evaluated_arch, meta_num_archs, sum(k * v for k, v in num_seeds.items()) | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     for key in sorted(list(num_seeds.keys())): | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |         print( | 
					
						
							|  |  |  |             "There are {:5d} architectures that are evaluated {:} times.".format( | 
					
						
							|  |  |  |                 num_seeds[key], key | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2019-12-26 23:29:36 +11:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     dir2ckps, dir2ckp_exists = dict(), dict() | 
					
						
							|  |  |  |     start_time, epoch_time = time.time(), AverageMeter() | 
					
						
							|  |  |  |     for IDX, (sub_dir, arch_indexes) in enumerate(subdir2archs.items()): | 
					
						
							|  |  |  |         if basestr == "C16-N5": | 
					
						
							|  |  |  |             seeds = [777, 888, 999] | 
					
						
							|  |  |  |         elif basestr == "C16-N5-LESS": | 
					
						
							|  |  |  |             seeds = [111, 777] | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             raise ValueError("Invalid base str : {:}".format(basestr)) | 
					
						
							|  |  |  |         numrs = defaultdict(lambda: 0) | 
					
						
							|  |  |  |         all_checkpoints, all_ckp_exists = [], [] | 
					
						
							|  |  |  |         for arch_index in arch_indexes: | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |             checkpoints = [ | 
					
						
							|  |  |  |                 "arch-{:}-seed-{:04d}.pth".format(arch_index, seed) for seed in seeds | 
					
						
							|  |  |  |             ] | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |             ckp_exists = [(sub_dir / x).exists() for x in checkpoints] | 
					
						
							|  |  |  |             arch_index = int(arch_index) | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |             assert ( | 
					
						
							|  |  |  |                 0 <= arch_index < len(meta_archs) | 
					
						
							|  |  |  |             ), "invalid arch-index {:} (not found in meta_archs)".format(arch_index) | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |             all_checkpoints += checkpoints | 
					
						
							|  |  |  |             all_ckp_exists += ckp_exists | 
					
						
							|  |  |  |             numrs[sum(ckp_exists)] += 1 | 
					
						
							|  |  |  |         dir2ckps[str(sub_dir)] = all_checkpoints | 
					
						
							|  |  |  |         dir2ckp_exists[str(sub_dir)] = all_ckp_exists | 
					
						
							|  |  |  |         # measure time | 
					
						
							|  |  |  |         epoch_time.update(time.time() - start_time) | 
					
						
							|  |  |  |         start_time = time.time() | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |         numrstr = ", ".join( | 
					
						
							|  |  |  |             ["{:}: {:03d}".format(x, numrs[x]) for x in sorted(numrs.keys())] | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |         print( | 
					
						
							|  |  |  |             "{:} load [{:2d}/{:2d}] [{:03d} archs] [{:04d}->{:04d} ckps] {:} done, need {:}. {:}".format( | 
					
						
							|  |  |  |                 time_string(), | 
					
						
							|  |  |  |                 IDX + 1, | 
					
						
							|  |  |  |                 len(subdir2archs), | 
					
						
							|  |  |  |                 len(arch_indexes), | 
					
						
							|  |  |  |                 len(all_checkpoints), | 
					
						
							|  |  |  |                 sum(all_ckp_exists), | 
					
						
							|  |  |  |                 sub_dir, | 
					
						
							|  |  |  |                 convert_secs2time(epoch_time.avg * (len(subdir2archs) - IDX - 1), True), | 
					
						
							|  |  |  |                 numrstr, | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2019-12-26 23:29:36 +11:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  | if __name__ == "__main__": | 
					
						
							| 
									
										
										
										
											2019-12-26 23:29:36 +11:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     parser = argparse.ArgumentParser( | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |         description="NAS Benchmark 201", | 
					
						
							|  |  |  |         formatter_class=argparse.ArgumentDefaultsHelpFormatter, | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--base_save_dir", | 
					
						
							|  |  |  |         type=str, | 
					
						
							|  |  |  |         default="./output/NAS-BENCH-201-4", | 
					
						
							|  |  |  |         help="The base-name of folder to save checkpoints and log.", | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							| 
									
										
										
										
											2021-03-18 16:02:55 +08:00
										 |  |  |         "--meta_path", | 
					
						
							|  |  |  |         type=str, | 
					
						
							|  |  |  |         default="./output/NAS-BENCH-201-4/meta-node-4.pth", | 
					
						
							|  |  |  |         help="The meta file path.", | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--base_str", type=str, default="C16-N5", help="The basic string." | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     ) | 
					
						
							|  |  |  |     args = parser.parse_args() | 
					
						
							| 
									
										
										
										
											2020-03-10 19:08:56 +11:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     save_dir = Path(args.base_save_dir) | 
					
						
							|  |  |  |     meta_path = Path(args.meta_path) | 
					
						
							|  |  |  |     assert save_dir.exists(), "invalid save dir path : {:}".format(save_dir) | 
					
						
							|  |  |  |     assert meta_path.exists(), "invalid saved meta path : {:}".format(meta_path) | 
					
						
							|  |  |  |     print("check NAS-Bench-201 in {:}".format(save_dir)) | 
					
						
							| 
									
										
										
										
											2019-12-26 23:29:36 +11:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-17 09:25:58 +00:00
										 |  |  |     check_files(save_dir, meta_path, args.base_str) |