Update NATS-Bench (sss version 1.1)

This commit is contained in:
D-X-Y 2020-08-28 10:21:33 +00:00
parent 2c86d6aa67
commit 469a207945

View File

@ -10,7 +10,7 @@
# Usage: # # Usage: #
# python exps/NATS-Bench/sss-collect.py # # python exps/NATS-Bench/sss-collect.py #
############################################################################## ##############################################################################
import os, re, sys, time, argparse, collections import os, re, sys, time, shutil, argparse, collections
import numpy as np import numpy as np
import torch import torch
from tqdm import tqdm from tqdm import tqdm
@ -26,6 +26,8 @@ from nas_201_api import ArchResults, ResultsCount
from procedures import bench_pure_evaluate as pure_evaluate, get_nas_bench_loaders from procedures import bench_pure_evaluate as pure_evaluate, get_nas_bench_loaders
from utils import get_md5_file from utils import get_md5_file
NATS_TSS_BASE_NAME = 'NATS-tss-v1_0' # 2020.08.28
def account_one_arch(arch_index: int, arch_str: Text, checkpoints: List[Text], datasets: List[Text]) -> ArchResults: def account_one_arch(arch_index: int, arch_str: Text, checkpoints: List[Text], datasets: List[Text]) -> ArchResults:
information = ArchResults(arch_index, arch_str) information = ArchResults(arch_index, arch_str)
@ -231,8 +233,16 @@ def simplify(save_dir, save_name, nets, total):
'evaluated_indexes': evaluated_indexes} 'evaluated_indexes': evaluated_indexes}
save_file_name = save_dir / '{:}.npy'.format(save_name) save_file_name = save_dir / '{:}.npy'.format(save_name)
np.save(str(save_file_name), final_infos) np.save(str(save_file_name), final_infos)
import pdb; pdb.set_trace() # move the benchmark file to a new path
print ('Save {:} / {:} architecture results into {:}.'.format(len(evaluated_indexes), total, save_file_name)) hd5sum = get_md5_file(save_file_name)
hd5_file_name = save_dir / '{:}-{:}.npy'.format(NATS_TSS_BASE_NAME, hd5sum)
shutil.move(save_file_name, hd5_file_name)
print('Save {:} / {:} architecture results into {:} -> {:}.'.format(len(evaluated_indexes), total, save_file_name, hd5_file_name))
# move the directory to a new path
hd5_full_save_dir = save_dir / '{:}-{:}-full'.format(NATS_TSS_BASE_NAME, hd5sum)
hd5_simple_save_dir = save_dir / '{:}-{:}-simple'.format(NATS_TSS_BASE_NAME, hd5sum)
shutil.move(full_save_dir, hd5_full_save_dir)
shutil.move(simple_save_dir, hd5_simple_save_dir)
def traverse_net(candidates: List[int], N: int): def traverse_net(candidates: List[int], N: int):