Update NAS-Bench-201
This commit is contained in:
parent
076f9c2d41
commit
b464d9dc95
@ -7,7 +7,7 @@
|
||||
###############################################################
|
||||
import os, sys, time, torch, argparse
|
||||
from typing import List, Text, Dict, Any
|
||||
from tqdm import tqdm
|
||||
from shutil import copyfile
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
@ -32,17 +32,28 @@ def obtain_valid_ckp(save_dir: Text, total: int):
|
||||
seed2ckps[seed].append(i)
|
||||
else:
|
||||
miss2ckps[seed].append(i)
|
||||
"""
|
||||
ckps = [x for x in save_dir.glob('arch-{:06d}-seed-*.pth'.format(i))]
|
||||
for ckp in ckps:
|
||||
seed = ckp.name.split('-seed-')[-1].split('.pth')[0]
|
||||
seed2ckps[int(seed)].append(i)
|
||||
"""
|
||||
for seed, xlist in seed2ckps.items():
|
||||
print('[{:}] [seed={:}] has {:}/{:}'.format(save_dir, seed, len(xlist), total))
|
||||
return dict(seed2ckps), dict(miss2ckps)
|
||||
|
||||
|
||||
def copy_data(source_dir, target_dir, meta_path):
|
||||
target_dir = Path(target_dir)
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
miss2ckps = torch.load(meta_path)['miss2ckps']
|
||||
s2t = {}
|
||||
for seed, xlist in miss2ckps.items():
|
||||
for i in xlist:
|
||||
file_name = 'arch-{:06d}-seed-{:04d}.pth'.format(i, seed)
|
||||
source_path = os.path.join(source_dir, file_name)
|
||||
target_path = os.path.join(target_dir, file_name)
|
||||
if os.path.exists(source_path):
|
||||
s2t[source_path] = target_path
|
||||
print('Map from {:} to {:}, find {:} missed ckps.'.format(source_dir, target_dir, len(s2t)))
|
||||
for s, t in s2t.items():
|
||||
copyfile(s, t)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='NAS-Bench-X', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--mode', type=str, required=True, choices=['check', 'copy'], help='The script mode.')
|
||||
@ -56,4 +67,14 @@ if __name__ == '__main__':
|
||||
cur_save_dir = '{:}/raw-data-{:}'.format(args.save_dir, config)
|
||||
seed2ckps, miss2ckps = obtain_valid_ckp(cur_save_dir, args.check_N)
|
||||
torch.save(dict(seed2ckps=seed2ckps, miss2ckps=miss2ckps), '{:}/meta-{:}.pth'.format(args.save_dir, config))
|
||||
|
||||
elif args.mode == 'copy':
|
||||
for config in possible_configs:
|
||||
cur_save_dir = '{:}/raw-data-{:}'.format(args.save_dir, config)
|
||||
cur_copy_dir = '{:}/copy-{:}'.format(args.save_dir, config)
|
||||
cur_meta_path = '{:}/meta-{:}.pth'.format(args.save_dir, config)
|
||||
if os.path.exists(cur_meta_path):
|
||||
copy_data(cur_save_dir, cur_copy_dir, cur_meta_path)
|
||||
else:
|
||||
print('Do not find : {:}'.format(cur_meta_path))
|
||||
else:
|
||||
raise ValueError('invalid mode : {:}'.format(args.mode))
|
||||
|
Loading…
Reference in New Issue
Block a user