# python exps/prepare.py --name cifar10     --root $TORCH_HOME/cifar.python --save ./data/cifar10.split.pth
# python exps/prepare.py --name cifar100    --root $TORCH_HOME/cifar.python --save ./data/cifar100.split.pth
# python exps/prepare.py --name imagenet-1k --root $TORCH_HOME/ILSVRC2012   --save ./data/imagenet-1k.split.pth
import sys, time, torch, random, argparse
from collections import defaultdict
import os.path as osp
from PIL     import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from copy    import deepcopy
from pathlib import Path
import torchvision
import torchvision.datasets as dset

lib_dir = (Path(__file__).parent / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
parser = argparse.ArgumentParser(description='Prepare splits for searching', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--name' , type=str,    help='The dataset name.')
parser.add_argument('--root' , type=str,    help='The directory to the dataset.')
parser.add_argument('--save' , type=str,    help='The save path.')
parser.add_argument('--ratio', type=float,  help='The save path.')
args = parser.parse_args()

def main():
  save_path = Path(args.save)
  save_dir  = save_path.parent
  name      = args.name
  save_dir.mkdir(parents=True, exist_ok=True)
  assert not save_path.exists(), '{:} already exists'.format(save_path)
  print ('torchvision version : {:}'.format(torchvision.__version__))

  if name == 'cifar10':
    dataset = dset.CIFAR10 (args.root, train=True)
  elif name == 'cifar100':
    dataset = dset.CIFAR100(args.root, train=True)
  elif name == 'imagenet-1k':
    dataset = dset.ImageFolder(osp.join(args.root, 'train'))
  else: raise TypeError("Unknow dataset : {:}".format(name))

  if hasattr(dataset, 'targets'):
    targets = dataset.targets
  elif hasattr(dataset, 'train_labels'):
    targets = dataset.train_labels
  elif hasattr(dataset, 'imgs'):
    targets = [x[1] for x in dataset.imgs]
  else:
    raise ValueError('invalid pattern')
  print ('There are {:} samples in this dataset.'.format( len(targets) ))

  class2index = defaultdict(list)
  train, valid = [], []
  random.seed(111)
  for index, cls in enumerate(targets):
    class2index[cls].append( index )
  classes = sorted( list(class2index.keys()) )
  for cls in classes:
    xlist = class2index[cls]
    xtrain = random.sample(xlist, int(len(xlist)*args.ratio))
    xvalid = list(set(xlist) - set(xtrain))
    train += xtrain
    valid += xvalid
  train.sort()
  valid.sort()
  ## for statistics
  class2numT, class2numV = defaultdict(int), defaultdict(int)
  for index in train:
    class2numT[ targets[index] ] += 1
  for index in valid:
    class2numV[ targets[index] ] += 1
  class2numT, class2numV = dict(class2numT), dict(class2numV)
  torch.save({'train': train,
              'valid': valid,
              'class2numTrain': class2numT,
              'class2numValid': class2numV}, save_path)
  print ('-'*80)

if __name__ == '__main__':
  main()