simplify baselines

This commit is contained in:
D-X-Y
2019-12-31 22:02:11 +11:00
parent f8f44bfb31
commit 9ec25663f1
12 changed files with 338 additions and 124 deletions

View File

@@ -370,17 +370,17 @@ def write_video(save_dir):
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='NAS-Bench-102', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--save_dir', type=str, default='./output/search-cell-nas-bench-102/visual', help='The base-name of folder to save checkpoints and log.')
parser.add_argument('--api_path', type=str, default=None, help='The path to the NAS-Bench-102 benchmark file.')
parser.add_argument('--save_dir', type=str, default='./output/search-cell-nas-bench-102/visuals', help='The base-name of folder to save checkpoints and log.')
parser.add_argument('--api_path', type=str, default=None, help='The path to the NAS-Bench-102 benchmark file.')
args = parser.parse_args()
vis_save_dir = Path(args.save_dir) / 'visuals'
vis_save_dir = Path(args.save_dir)
vis_save_dir.mkdir(parents=True, exist_ok=True)
meta_file = Path(args.api_path)
assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file)
visualize_rank_over_time(str(meta_file), vis_save_dir / 'over-time')
write_video(vis_save_dir / 'over-time')
visualize_info(str(meta_file), 'cifar10' , vis_save_dir)
visualize_info(str(meta_file), 'cifar100', vis_save_dir)
visualize_info(str(meta_file), 'ImageNet16-120', vis_save_dir)
#visualize_rank_over_time(str(meta_file), vis_save_dir / 'over-time')
#write_video(vis_save_dir / 'over-time')
#visualize_info(str(meta_file), 'cifar10' , vis_save_dir)
#visualize_info(str(meta_file), 'cifar100', vis_save_dir)
#visualize_info(str(meta_file), 'ImageNet16-120', vis_save_dir)
visualize_relative_ranking(vis_save_dir)