From b464d9dc9537aa224086ad8445378c4abb31b0d7 Mon Sep 17 00:00:00 2001
From: D-X-Y <280835372@qq.com>
Date: Fri, 10 Apr 2020 10:02:13 +0000
Subject: [PATCH] Update NAS-Bench-201

---
 exps/NAS-Bench-201/xshape-file.py | 37 ++++++++++++++++++++++++-------
 1 file changed, 29 insertions(+), 8 deletions(-)

diff --git a/exps/NAS-Bench-201/xshape-file.py b/exps/NAS-Bench-201/xshape-file.py
index 0eee246..16754fa 100644
--- a/exps/NAS-Bench-201/xshape-file.py
+++ b/exps/NAS-Bench-201/xshape-file.py
@@ -7,7 +7,7 @@
 ###############################################################
 import os, sys, time, torch, argparse
 from typing import List, Text, Dict, Any
-from tqdm import tqdm
+from shutil import copyfile
 from collections import defaultdict
 from copy    import deepcopy
 from pathlib import Path
@@ -32,17 +32,28 @@ def obtain_valid_ckp(save_dir: Text, total: int):
         seed2ckps[seed].append(i)
       else:
         miss2ckps[seed].append(i)
-    """
-    ckps = [x for x in save_dir.glob('arch-{:06d}-seed-*.pth'.format(i))]
-    for ckp in ckps:
-      seed = ckp.name.split('-seed-')[-1].split('.pth')[0]
-      seed2ckps[int(seed)].append(i)
-    """
   for seed, xlist in seed2ckps.items():
     print('[{:}] [seed={:}] has {:}/{:}'.format(save_dir, seed, len(xlist), total))
   return dict(seed2ckps), dict(miss2ckps)
     
 
+def copy_data(source_dir, target_dir, meta_path):
+  target_dir = Path(target_dir)
+  target_dir.mkdir(parents=True, exist_ok=True)
+  miss2ckps = torch.load(meta_path)['miss2ckps']
+  s2t = {}
+  for seed, xlist in miss2ckps.items():
+    for i in xlist:
+      file_name = 'arch-{:06d}-seed-{:04d}.pth'.format(i, seed)
+      source_path = os.path.join(source_dir, file_name)
+      target_path = os.path.join(target_dir, file_name)
+      if os.path.exists(source_path):
+        s2t[source_path] = target_path
+  print('Map from {:} to {:}, find {:} missed ckps.'.format(source_dir, target_dir, len(s2t)))
+  for s, t in s2t.items():
+    copyfile(s, t)
+
+
 if __name__ == '__main__':
   parser = argparse.ArgumentParser(description='NAS-Bench-X', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
   parser.add_argument('--mode',        type=str, required=True, choices=['check', 'copy'], help='The script mode.')
@@ -56,4 +67,14 @@ if __name__ == '__main__':
       cur_save_dir = '{:}/raw-data-{:}'.format(args.save_dir, config)
       seed2ckps, miss2ckps = obtain_valid_ckp(cur_save_dir, args.check_N)
       torch.save(dict(seed2ckps=seed2ckps, miss2ckps=miss2ckps), '{:}/meta-{:}.pth'.format(args.save_dir, config))
-  
+  elif args.mode == 'copy':
+    for config in possible_configs:
+      cur_save_dir = '{:}/raw-data-{:}'.format(args.save_dir, config)
+      cur_copy_dir = '{:}/copy-{:}'.format(args.save_dir, config)
+      cur_meta_path = '{:}/meta-{:}.pth'.format(args.save_dir, config)
+      if os.path.exists(cur_meta_path):
+        copy_data(cur_save_dir, cur_copy_dir, cur_meta_path)
+      else:
+        print('Do not find : {:}'.format(cur_meta_path))
+  else:
+    raise ValueError('invalid mode : {:}'.format(args.mode))