Update NAS-Bench-201

This commit is contained in:
D-X-Y 2020-04-10 10:02:13 +00:00
parent 076f9c2d41
commit b464d9dc95

View File

@ -7,7 +7,7 @@
###############################################################
import os, sys, time, torch, argparse
from typing import List, Text, Dict, Any
from tqdm import tqdm
from shutil import copyfile
from collections import defaultdict
from copy import deepcopy
from pathlib import Path
@ -32,17 +32,28 @@ def obtain_valid_ckp(save_dir: Text, total: int):
seed2ckps[seed].append(i)
else:
miss2ckps[seed].append(i)
"""
ckps = [x for x in save_dir.glob('arch-{:06d}-seed-*.pth'.format(i))]
for ckp in ckps:
seed = ckp.name.split('-seed-')[-1].split('.pth')[0]
seed2ckps[int(seed)].append(i)
"""
for seed, xlist in seed2ckps.items():
print('[{:}] [seed={:}] has {:}/{:}'.format(save_dir, seed, len(xlist), total))
return dict(seed2ckps), dict(miss2ckps)
def copy_data(source_dir, target_dir, meta_path):
target_dir = Path(target_dir)
target_dir.mkdir(parents=True, exist_ok=True)
miss2ckps = torch.load(meta_path)['miss2ckps']
s2t = {}
for seed, xlist in miss2ckps.items():
for i in xlist:
file_name = 'arch-{:06d}-seed-{:04d}.pth'.format(i, seed)
source_path = os.path.join(source_dir, file_name)
target_path = os.path.join(target_dir, file_name)
if os.path.exists(source_path):
s2t[source_path] = target_path
print('Map from {:} to {:}, find {:} missed ckps.'.format(source_dir, target_dir, len(s2t)))
for s, t in s2t.items():
copyfile(s, t)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='NAS-Bench-X', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--mode', type=str, required=True, choices=['check', 'copy'], help='The script mode.')
@ -56,4 +67,14 @@ if __name__ == '__main__':
cur_save_dir = '{:}/raw-data-{:}'.format(args.save_dir, config)
seed2ckps, miss2ckps = obtain_valid_ckp(cur_save_dir, args.check_N)
torch.save(dict(seed2ckps=seed2ckps, miss2ckps=miss2ckps), '{:}/meta-{:}.pth'.format(args.save_dir, config))
elif args.mode == 'copy':
for config in possible_configs:
cur_save_dir = '{:}/raw-data-{:}'.format(args.save_dir, config)
cur_copy_dir = '{:}/copy-{:}'.format(args.save_dir, config)
cur_meta_path = '{:}/meta-{:}.pth'.format(args.save_dir, config)
if os.path.exists(cur_meta_path):
copy_data(cur_save_dir, cur_copy_dir, cur_meta_path)
else:
print('Do not find : {:}'.format(cur_meta_path))
else:
raise ValueError('invalid mode : {:}'.format(args.mode))