From 469a20794510c1f2e72b8946ed965451245aba4a Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Fri, 28 Aug 2020 10:21:33 +0000 Subject: [PATCH] Update NATS-Bench (sss version 1.1) --- exps/NATS-Bench/sss-collect.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/exps/NATS-Bench/sss-collect.py b/exps/NATS-Bench/sss-collect.py index 37afe15..a17f8b1 100644 --- a/exps/NATS-Bench/sss-collect.py +++ b/exps/NATS-Bench/sss-collect.py @@ -10,7 +10,7 @@ # Usage: # # python exps/NATS-Bench/sss-collect.py # ############################################################################## -import os, re, sys, time, argparse, collections +import os, re, sys, time, shutil, argparse, collections import numpy as np import torch from tqdm import tqdm @@ -26,6 +26,8 @@ from nas_201_api import ArchResults, ResultsCount from procedures import bench_pure_evaluate as pure_evaluate, get_nas_bench_loaders from utils import get_md5_file +NATS_TSS_BASE_NAME = 'NATS-tss-v1_0' # 2020.08.28 + def account_one_arch(arch_index: int, arch_str: Text, checkpoints: List[Text], datasets: List[Text]) -> ArchResults: information = ArchResults(arch_index, arch_str) @@ -231,8 +233,16 @@ def simplify(save_dir, save_name, nets, total): 'evaluated_indexes': evaluated_indexes} save_file_name = save_dir / '{:}.npy'.format(save_name) np.save(str(save_file_name), final_infos) - import pdb; pdb.set_trace() - print ('Save {:} / {:} architecture results into {:}.'.format(len(evaluated_indexes), total, save_file_name)) + # move the benchmark file to a new path + hd5sum = get_md5_file(save_file_name) + hd5_file_name = save_dir / '{:}-{:}.npy'.format(NATS_TSS_BASE_NAME, hd5sum) + shutil.move(save_file_name, hd5_file_name) + print('Save {:} / {:} architecture results into {:} -> {:}.'.format(len(evaluated_indexes), total, save_file_name, hd5_file_name)) + # move the directory to a new path + hd5_full_save_dir = save_dir / '{:}-{:}-full'.format(NATS_TSS_BASE_NAME, hd5sum) + hd5_simple_save_dir = save_dir / '{:}-{:}-simple'.format(NATS_TSS_BASE_NAME, hd5sum) + shutil.move(full_save_dir, hd5_full_save_dir) + shutil.move(simple_save_dir, hd5_simple_save_dir) def traverse_net(candidates: List[int], N: int):