add oxford and aircraft
Some checks failed
Test Spaces / build (macos-latest, 3.6) (push) Has been cancelled
Test Spaces / build (macos-latest, 3.7) (push) Has been cancelled
Test Spaces / build (macos-latest, 3.8) (push) Has been cancelled
Test Spaces / build (macos-latest, 3.9) (push) Has been cancelled
Test Spaces / build (ubuntu-18.04, 3.6) (push) Has been cancelled
Test Spaces / build (ubuntu-18.04, 3.7) (push) Has been cancelled
Test Spaces / build (ubuntu-18.04, 3.8) (push) Has been cancelled
Test Spaces / build (ubuntu-18.04, 3.9) (push) Has been cancelled
Test Spaces / build (ubuntu-20.04, 3.6) (push) Has been cancelled
Test Spaces / build (ubuntu-20.04, 3.7) (push) Has been cancelled
Test Spaces / build (ubuntu-20.04, 3.8) (push) Has been cancelled
Test Spaces / build (ubuntu-20.04, 3.9) (push) Has been cancelled
Test Xmisc / build (macos-latest, 3.6) (push) Has been cancelled
Test Xmisc / build (macos-latest, 3.7) (push) Has been cancelled
Test Xmisc / build (macos-latest, 3.8) (push) Has been cancelled
Test Xmisc / build (macos-latest, 3.9) (push) Has been cancelled
Test Xmisc / build (ubuntu-18.04, 3.6) (push) Has been cancelled
Test Xmisc / build (ubuntu-18.04, 3.7) (push) Has been cancelled
Test Xmisc / build (ubuntu-18.04, 3.8) (push) Has been cancelled
Test Xmisc / build (ubuntu-18.04, 3.9) (push) Has been cancelled
Test Xmisc / build (ubuntu-20.04, 3.6) (push) Has been cancelled
Test Xmisc / build (ubuntu-20.04, 3.7) (push) Has been cancelled
Test Xmisc / build (ubuntu-20.04, 3.8) (push) Has been cancelled
Test Xmisc / build (ubuntu-20.04, 3.9) (push) Has been cancelled
Test Super Model / build (macos-latest, 3.6) (push) Has been cancelled
Test Super Model / build (macos-latest, 3.7) (push) Has been cancelled
Test Super Model / build (macos-latest, 3.8) (push) Has been cancelled
Test Super Model / build (macos-latest, 3.9) (push) Has been cancelled
Test Super Model / build (ubuntu-18.04, 3.6) (push) Has been cancelled
Test Super Model / build (ubuntu-18.04, 3.7) (push) Has been cancelled
Test Super Model / build (ubuntu-18.04, 3.8) (push) Has been cancelled
Test Super Model / build (ubuntu-18.04, 3.9) (push) Has been cancelled
Test Super Model / build (ubuntu-20.04, 3.6) (push) Has been cancelled
Test Super Model / build (ubuntu-20.04, 3.7) (push) Has been cancelled
Test Super Model / build (ubuntu-20.04, 3.8) (push) Has been cancelled
Test Super Model / build (ubuntu-20.04, 3.9) (push) Has been cancelled
Some checks failed
Test Spaces / build (macos-latest, 3.6) (push) Has been cancelled
Test Spaces / build (macos-latest, 3.7) (push) Has been cancelled
Test Spaces / build (macos-latest, 3.8) (push) Has been cancelled
Test Spaces / build (macos-latest, 3.9) (push) Has been cancelled
Test Spaces / build (ubuntu-18.04, 3.6) (push) Has been cancelled
Test Spaces / build (ubuntu-18.04, 3.7) (push) Has been cancelled
Test Spaces / build (ubuntu-18.04, 3.8) (push) Has been cancelled
Test Spaces / build (ubuntu-18.04, 3.9) (push) Has been cancelled
Test Spaces / build (ubuntu-20.04, 3.6) (push) Has been cancelled
Test Spaces / build (ubuntu-20.04, 3.7) (push) Has been cancelled
Test Spaces / build (ubuntu-20.04, 3.8) (push) Has been cancelled
Test Spaces / build (ubuntu-20.04, 3.9) (push) Has been cancelled
Test Xmisc / build (macos-latest, 3.6) (push) Has been cancelled
Test Xmisc / build (macos-latest, 3.7) (push) Has been cancelled
Test Xmisc / build (macos-latest, 3.8) (push) Has been cancelled
Test Xmisc / build (macos-latest, 3.9) (push) Has been cancelled
Test Xmisc / build (ubuntu-18.04, 3.6) (push) Has been cancelled
Test Xmisc / build (ubuntu-18.04, 3.7) (push) Has been cancelled
Test Xmisc / build (ubuntu-18.04, 3.8) (push) Has been cancelled
Test Xmisc / build (ubuntu-18.04, 3.9) (push) Has been cancelled
Test Xmisc / build (ubuntu-20.04, 3.6) (push) Has been cancelled
Test Xmisc / build (ubuntu-20.04, 3.7) (push) Has been cancelled
Test Xmisc / build (ubuntu-20.04, 3.8) (push) Has been cancelled
Test Xmisc / build (ubuntu-20.04, 3.9) (push) Has been cancelled
Test Super Model / build (macos-latest, 3.6) (push) Has been cancelled
Test Super Model / build (macos-latest, 3.7) (push) Has been cancelled
Test Super Model / build (macos-latest, 3.8) (push) Has been cancelled
Test Super Model / build (macos-latest, 3.9) (push) Has been cancelled
Test Super Model / build (ubuntu-18.04, 3.6) (push) Has been cancelled
Test Super Model / build (ubuntu-18.04, 3.7) (push) Has been cancelled
Test Super Model / build (ubuntu-18.04, 3.8) (push) Has been cancelled
Test Super Model / build (ubuntu-18.04, 3.9) (push) Has been cancelled
Test Super Model / build (ubuntu-20.04, 3.6) (push) Has been cancelled
Test Super Model / build (ubuntu-20.04, 3.7) (push) Has been cancelled
Test Super Model / build (ubuntu-20.04, 3.8) (push) Has been cancelled
Test Super Model / build (ubuntu-20.04, 3.9) (push) Has been cancelled
This commit is contained in:
parent
889bd1974c
commit
4612cd198b
@ -20,7 +20,92 @@ from functions import evaluate_for_seed
|
|||||||
|
|
||||||
from torchvision import datasets, transforms
|
from torchvision import datasets, transforms
|
||||||
|
|
||||||
def evaluate_all_datasets(
|
# NASBENCH201_CONFIG_PATH = os.path.join( os.getcwd(), 'main_exp', 'transfer_nag')
|
||||||
|
|
||||||
|
NASBENCH201_CONFIG_PATH = '/lustre/hpe/ws11/ws11.1/ws/xmuhanma-nbdit/autodl-projects/configs/nas-benchmark'
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed,
|
||||||
|
arch_config, workers, logger):
|
||||||
|
machine_info, arch_config = get_machine_info(), deepcopy(arch_config)
|
||||||
|
all_infos = {'info': machine_info}
|
||||||
|
all_dataset_keys = []
|
||||||
|
# look all the datasets
|
||||||
|
for dataset, xpath, split in zip(datasets, xpaths, splits):
|
||||||
|
# train valid data
|
||||||
|
task = None
|
||||||
|
train_data, valid_data, xshape, class_num = get_datasets(
|
||||||
|
dataset, xpath, -1, task)
|
||||||
|
|
||||||
|
# load the configuration
|
||||||
|
if dataset in ['mnist', 'svhn', 'aircraft', 'oxford']:
|
||||||
|
if use_less:
|
||||||
|
# config_path = os.path.join(
|
||||||
|
# NASBENCH201_CONFIG_PATH, 'nas_bench_201/configs/nas-benchmark/LESS.config')
|
||||||
|
config_path = os.path.join(
|
||||||
|
NASBENCH201_CONFIG_PATH, 'LESS.config')
|
||||||
|
else:
|
||||||
|
# config_path = os.path.join(
|
||||||
|
# NASBENCH201_CONFIG_PATH, 'nas_bench_201/configs/nas-benchmark/{}.config'.format(dataset))
|
||||||
|
config_path = os.path.join(
|
||||||
|
NASBENCH201_CONFIG_PATH, '{}.config'.format(dataset))
|
||||||
|
|
||||||
|
|
||||||
|
p = os.path.join(
|
||||||
|
NASBENCH201_CONFIG_PATH, '{:}-split.txt'.format(dataset))
|
||||||
|
if not os.path.exists(p):
|
||||||
|
import json
|
||||||
|
label_list = list(range(len(train_data)))
|
||||||
|
random.shuffle(label_list)
|
||||||
|
strlist = [str(label_list[i]) for i in range(len(label_list))]
|
||||||
|
splited = {'train': ["int", strlist[:len(train_data) // 2]],
|
||||||
|
'valid': ["int", strlist[len(train_data) // 2:]]}
|
||||||
|
with open(p, 'w') as f:
|
||||||
|
f.write(json.dumps(splited))
|
||||||
|
split_info = load_config(os.path.join(
|
||||||
|
NASBENCH201_CONFIG_PATH, '{:}-split.txt'.format(dataset)), None, None)
|
||||||
|
else:
|
||||||
|
raise ValueError('invalid dataset : {:}'.format(dataset))
|
||||||
|
|
||||||
|
config = load_config(
|
||||||
|
config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
||||||
|
# data loader
|
||||||
|
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size,
|
||||||
|
shuffle=True, num_workers=workers, pin_memory=True)
|
||||||
|
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size,
|
||||||
|
shuffle=False, num_workers=workers, pin_memory=True)
|
||||||
|
splits = load_config(os.path.join(
|
||||||
|
NASBENCH201_CONFIG_PATH, '{}-test-split.txt'.format(dataset)), None, None)
|
||||||
|
ValLoaders = {'ori-test': valid_loader,
|
||||||
|
'x-valid': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size,
|
||||||
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(
|
||||||
|
splits.xvalid),
|
||||||
|
num_workers=workers, pin_memory=True),
|
||||||
|
'x-test': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size,
|
||||||
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(
|
||||||
|
splits.xtest),
|
||||||
|
num_workers=workers, pin_memory=True)
|
||||||
|
}
|
||||||
|
dataset_key = '{:}'.format(dataset)
|
||||||
|
if bool(split):
|
||||||
|
dataset_key = dataset_key + '-valid'
|
||||||
|
logger.log(
|
||||||
|
'Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.
|
||||||
|
format(dataset_key, len(train_data), len(valid_data), len(train_loader), len(valid_loader), config.batch_size))
|
||||||
|
logger.log('Evaluate ||||||| {:10s} ||||||| Config={:}'.format(
|
||||||
|
dataset_key, config))
|
||||||
|
for key, value in ValLoaders.items():
|
||||||
|
logger.log(
|
||||||
|
'Evaluate ---->>>> {:10s} with {:} batchs'.format(key, len(value)))
|
||||||
|
|
||||||
|
results = evaluate_for_seed(
|
||||||
|
arch_config, config, arch, train_loader, ValLoaders, seed, logger)
|
||||||
|
all_infos[dataset_key] = results
|
||||||
|
all_dataset_keys.append(dataset_key)
|
||||||
|
all_infos['all_dataset_keys'] = all_dataset_keys
|
||||||
|
return all_infos
|
||||||
|
|
||||||
|
def evaluate_all_datasets1(
|
||||||
arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger
|
arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger
|
||||||
):
|
):
|
||||||
machine_info, arch_config = get_machine_info(), deepcopy(arch_config)
|
machine_info, arch_config = get_machine_info(), deepcopy(arch_config)
|
||||||
@ -55,7 +140,14 @@ def evaluate_all_datasets(
|
|||||||
split_info = load_config(
|
split_info = load_config(
|
||||||
"configs/nas-benchmark/{:}-split.txt".format(dataset), None, None
|
"configs/nas-benchmark/{:}-split.txt".format(dataset), None, None
|
||||||
)
|
)
|
||||||
|
elif dataset.startswith("oxford"):
|
||||||
|
if use_less:
|
||||||
|
config_path = "configs/nas-benchmark/LESS.config"
|
||||||
|
else:
|
||||||
|
config_path = "configs/nas-benchmark/oxford.config"
|
||||||
|
split_info = load_config(
|
||||||
|
"configs/nas-benchmark/{:}-split.txt".format(dataset), None, None
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("invalid dataset : {:}".format(dataset))
|
raise ValueError("invalid dataset : {:}".format(dataset))
|
||||||
config = load_config(
|
config = load_config(
|
||||||
@ -126,6 +218,31 @@ def evaluate_all_datasets(
|
|||||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid),
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid),
|
||||||
num_workers=workers,
|
num_workers=workers,
|
||||||
pin_memory=True)
|
pin_memory=True)
|
||||||
|
elif dataset == "oxford":
|
||||||
|
ValLoaders = {
|
||||||
|
"ori-test": torch.utils.data.DataLoader(
|
||||||
|
valid_data,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=workers,
|
||||||
|
pin_memory=True
|
||||||
|
)
|
||||||
|
}
|
||||||
|
# train_data_v2 = deepcopy(train_data)
|
||||||
|
# train_data_v2.transform = valid_data.transform
|
||||||
|
train_loader = torch.utils.data.DataLoader(
|
||||||
|
train_data,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train),
|
||||||
|
num_workers=workers,
|
||||||
|
pin_memory=True)
|
||||||
|
valid_loader = torch.utils.data.DataLoader(
|
||||||
|
valid_data,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid),
|
||||||
|
num_workers=workers,
|
||||||
|
pin_memory=True)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# data loader
|
# data loader
|
||||||
train_loader = torch.utils.data.DataLoader(
|
train_loader = torch.utils.data.DataLoader(
|
||||||
@ -142,7 +259,7 @@ def evaluate_all_datasets(
|
|||||||
num_workers=workers,
|
num_workers=workers,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
)
|
)
|
||||||
if dataset == "cifar10" or dataset == "aircraft":
|
if dataset == "cifar10" or dataset == "aircraft" or dataset == "oxford":
|
||||||
ValLoaders = {"ori-test": valid_loader}
|
ValLoaders = {"ori-test": valid_loader}
|
||||||
elif dataset == "cifar100":
|
elif dataset == "cifar100":
|
||||||
cifar100_splits = load_config(
|
cifar100_splits = load_config(
|
||||||
|
@ -46,7 +46,7 @@ OMP_NUM_THREADS=4 python ./exps/NAS-Bench-201/main.py \
|
|||||||
--mode ${mode} --save_dir ${save_dir} --max_node 4 \
|
--mode ${mode} --save_dir ${save_dir} --max_node 4 \
|
||||||
--use_less ${use_less} \
|
--use_less ${use_less} \
|
||||||
--datasets aircraft \
|
--datasets aircraft \
|
||||||
--xpaths /lustre/hpe/ws11/ws11.1/ws/xmuhanma-SWAP/train_datasets/datasets/fgvc-aircraft-2013b/data/ \
|
--xpaths /lustre/hpe/ws11/ws11.1/ws/xmuhanma-SWAP/train_datasets/datasets/ \
|
||||||
--channel 16 \
|
--channel 16 \
|
||||||
--splits 1 \
|
--splits 1 \
|
||||||
--num_cells 5 \
|
--num_cells 5 \
|
||||||
@ -54,4 +54,15 @@ OMP_NUM_THREADS=4 python ./exps/NAS-Bench-201/main.py \
|
|||||||
--srange ${xstart} ${xend} --arch_index ${arch_index} \
|
--srange ${xstart} ${xend} --arch_index ${arch_index} \
|
||||||
--seeds ${all_seeds}
|
--seeds ${all_seeds}
|
||||||
|
|
||||||
|
# OMP_NUM_THREADS=4 python ./exps/NAS-Bench-201/main.py \
|
||||||
|
# --mode ${mode} --save_dir ${save_dir} --max_node 4 \
|
||||||
|
# --use_less ${use_less} \
|
||||||
|
# --datasets oxford\
|
||||||
|
# --xpaths /lustre/hpe/ws11/ws11.1/ws/xmuhanma-SWAP/train_datasets/datasets/ \
|
||||||
|
# --channel 16 \
|
||||||
|
# --splits 1 \
|
||||||
|
# --num_cells 5 \
|
||||||
|
# --workers 4 \
|
||||||
|
# --srange ${xstart} ${xend} --arch_index ${arch_index} \
|
||||||
|
# --seeds ${all_seeds}
|
||||||
|
|
||||||
|
@ -1,42 +1,39 @@
|
|||||||
##################################################
|
##################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||||
|
# Modified by Hayeon Lee, Eunyoung Hyung 2021. 03.
|
||||||
##################################################
|
##################################################
|
||||||
import os, sys, torch
|
import os
|
||||||
|
import sys
|
||||||
|
import torch
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torchvision.datasets as dset
|
import torchvision.datasets as dset
|
||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from xautodl.config_utils import load_config
|
|
||||||
|
|
||||||
from .DownsampledImageNet import ImageNet16
|
|
||||||
from .SearchDatasetWrap import SearchDataset
|
from .SearchDatasetWrap import SearchDataset
|
||||||
|
|
||||||
|
# from PIL import Image
|
||||||
|
import random
|
||||||
|
import pdb
|
||||||
|
from .aircraft import FGVCAircraft
|
||||||
|
from .pets import PetDataset
|
||||||
|
from config_utils import load_config
|
||||||
|
|
||||||
Dataset2Class = {
|
Dataset2Class = {'cifar10': 10,
|
||||||
"cifar10": 10,
|
'cifar100': 100,
|
||||||
"cifar100": 100,
|
'mnist': 10,
|
||||||
"imagenet-1k-s": 1000,
|
'svhn': 10,
|
||||||
"imagenet-1k": 1000,
|
'aircraft': 30,
|
||||||
"ImageNet16": 1000,
|
'oxford': 37}
|
||||||
"ImageNet16-150": 150,
|
|
||||||
"ImageNet16-120": 120,
|
|
||||||
"ImageNet16-200": 200,
|
|
||||||
"aircraft": 100,
|
|
||||||
"oxford": 102
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class CUTOUT(object):
|
class CUTOUT(object):
|
||||||
|
|
||||||
def __init__(self, length):
|
def __init__(self, length):
|
||||||
self.length = length
|
self.length = length
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "{name}(length={length})".format(
|
return ('{name}(length={length})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||||
name=self.__class__.__name__, **self.__dict__
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(self, img):
|
def __call__(self, img):
|
||||||
h, w = img.size(1), img.size(2)
|
h, w = img.size(1), img.size(2)
|
||||||
@ -49,7 +46,7 @@ class CUTOUT(object):
|
|||||||
x1 = np.clip(x - self.length // 2, 0, w)
|
x1 = np.clip(x - self.length // 2, 0, w)
|
||||||
x2 = np.clip(x + self.length // 2, 0, w)
|
x2 = np.clip(x + self.length // 2, 0, w)
|
||||||
|
|
||||||
mask[y1:y2, x1:x2] = 0.0
|
mask[y1: y2, x1: x2] = 0.
|
||||||
mask = torch.from_numpy(mask)
|
mask = torch.from_numpy(mask)
|
||||||
mask = mask.expand_as(img)
|
mask = mask.expand_as(img)
|
||||||
img *= mask
|
img *= mask
|
||||||
@ -57,21 +54,19 @@ class CUTOUT(object):
|
|||||||
|
|
||||||
|
|
||||||
imagenet_pca = {
|
imagenet_pca = {
|
||||||
"eigval": np.asarray([0.2175, 0.0188, 0.0045]),
|
'eigval': np.asarray([0.2175, 0.0188, 0.0045]),
|
||||||
"eigvec": np.asarray(
|
'eigvec': np.asarray([
|
||||||
[
|
|
||||||
[-0.5675, 0.7192, 0.4009],
|
[-0.5675, 0.7192, 0.4009],
|
||||||
[-0.5808, -0.0045, -0.8140],
|
[-0.5808, -0.0045, -0.8140],
|
||||||
[-0.5836, -0.6948, 0.4203],
|
[-0.5836, -0.6948, 0.4203],
|
||||||
]
|
])
|
||||||
),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class Lighting(object):
|
class Lighting(object):
|
||||||
def __init__(
|
def __init__(self, alphastd,
|
||||||
self, alphastd, eigval=imagenet_pca["eigval"], eigvec=imagenet_pca["eigvec"]
|
eigval=imagenet_pca['eigval'],
|
||||||
):
|
eigvec=imagenet_pca['eigvec']):
|
||||||
self.alphastd = alphastd
|
self.alphastd = alphastd
|
||||||
assert eigval.shape == (3,)
|
assert eigval.shape == (3,)
|
||||||
assert eigvec.shape == (3, 3)
|
assert eigvec.shape == (3, 3)
|
||||||
@ -79,10 +74,10 @@ class Lighting(object):
|
|||||||
self.eigvec = eigvec
|
self.eigvec = eigvec
|
||||||
|
|
||||||
def __call__(self, img):
|
def __call__(self, img):
|
||||||
if self.alphastd == 0.0:
|
if self.alphastd == 0.:
|
||||||
return img
|
return img
|
||||||
rnd = np.random.randn(3) * self.alphastd
|
rnd = np.random.randn(3) * self.alphastd
|
||||||
rnd = rnd.astype("float32")
|
rnd = rnd.astype('float32')
|
||||||
v = rnd
|
v = rnd
|
||||||
old_dtype = np.asarray(img).dtype
|
old_dtype = np.asarray(img).dtype
|
||||||
v = v * self.eigval
|
v = v * self.eigval
|
||||||
@ -91,292 +86,222 @@ class Lighting(object):
|
|||||||
img = np.add(img, inc)
|
img = np.add(img, inc)
|
||||||
if old_dtype == np.uint8:
|
if old_dtype == np.uint8:
|
||||||
img = np.clip(img, 0, 255)
|
img = np.clip(img, 0, 255)
|
||||||
img = Image.fromarray(img.astype(old_dtype), "RGB")
|
img = Image.fromarray(img.astype(old_dtype), 'RGB')
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return self.__class__.__name__ + "()"
|
return self.__class__.__name__ + '()'
|
||||||
|
|
||||||
|
|
||||||
def get_datasets(name, root, cutout):
|
def get_datasets(name, root, cutout, use_num_cls=None):
|
||||||
|
if name == 'cifar10':
|
||||||
if name == "cifar10":
|
|
||||||
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
|
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
|
||||||
std = [x / 255 for x in [63.0, 62.1, 66.7]]
|
std = [x / 255 for x in [63.0, 62.1, 66.7]]
|
||||||
elif name == "cifar100":
|
elif name == 'cifar100':
|
||||||
mean = [x / 255 for x in [129.3, 124.1, 112.4]]
|
mean = [x / 255 for x in [129.3, 124.1, 112.4]]
|
||||||
std = [x / 255 for x in [68.2, 65.4, 70.4]]
|
std = [x / 255 for x in [68.2, 65.4, 70.4]]
|
||||||
elif name.startswith("imagenet-1k"):
|
elif name.startswith('mnist'):
|
||||||
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
|
mean, std = [0.1307, 0.1307, 0.1307], [0.3081, 0.3081, 0.3081]
|
||||||
elif name.startswith("ImageNet16"):
|
elif name.startswith('svhn'):
|
||||||
mean = [x / 255 for x in [122.68, 116.66, 104.01]]
|
mean, std = [0.4376821, 0.4437697, 0.47280442], [ 0.19803012, 0.20101562, 0.19703614]
|
||||||
std = [x / 255 for x in [63.22, 61.26, 65.09]]
|
elif name.startswith('aircraft'):
|
||||||
elif name == 'aircraft':
|
mean = [0.48933587508932375, 0.5183537408957618, 0.5387914411673883]
|
||||||
mean = [0.4785, 0.5100, 0.5338]
|
std = [0.22388883112804625, 0.21641635409388751, 0.24615605842636115]
|
||||||
std = [0.1845, 0.1830, 0.2060]
|
elif name.startswith('oxford'):
|
||||||
elif name == 'oxford':
|
mean = [0.4828895122298728, 0.4448394893850807, 0.39566558230789783]
|
||||||
mean = [0.4811, 0.4492, 0.3957]
|
std = [0.25925664613996574, 0.2532760018681693, 0.25981017205097917]
|
||||||
std = [0.2260, 0.2231, 0.2249]
|
|
||||||
else:
|
else:
|
||||||
raise TypeError("Unknow dataset : {:}".format(name))
|
raise TypeError("Unknow dataset : {:}".format(name))
|
||||||
|
|
||||||
# Data Argumentation
|
# Data Argumentation
|
||||||
if name == "cifar10" or name == "cifar100":
|
if name == 'cifar10' or name == 'cifar100':
|
||||||
lists = [
|
lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(),
|
||||||
transforms.RandomHorizontalFlip(),
|
transforms.Normalize(mean, std)]
|
||||||
transforms.RandomCrop(32, padding=4),
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize(mean, std),
|
|
||||||
]
|
|
||||||
if cutout > 0:
|
if cutout > 0:
|
||||||
lists += [CUTOUT(cutout)]
|
lists += [CUTOUT(cutout)]
|
||||||
train_transform = transforms.Compose(lists)
|
train_transform = transforms.Compose(lists)
|
||||||
test_transform = transforms.Compose(
|
test_transform = transforms.Compose(
|
||||||
[transforms.ToTensor(), transforms.Normalize(mean, std)]
|
[transforms.ToTensor(), transforms.Normalize(mean, std)])
|
||||||
)
|
|
||||||
xshape = (1, 3, 32, 32)
|
xshape = (1, 3, 32, 32)
|
||||||
elif name.startswith("aircraft") or name.startswith("oxford"):
|
elif name.startswith('cub200'):
|
||||||
lists = [transforms.RandomCrop(16, padding=0), transforms.ToTensor(), transforms.Normalize(mean, std)]
|
train_transform = transforms.Compose([
|
||||||
if cutout > 0:
|
transforms.Resize((32, 32)),
|
||||||
lists += [CUTOUT(cutout)]
|
|
||||||
train_transform = transforms.Compose(lists)
|
|
||||||
test_transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean, std)])
|
|
||||||
xshape = (1, 3, 16, 16)
|
|
||||||
elif name.startswith("ImageNet16"):
|
|
||||||
lists = [
|
|
||||||
transforms.RandomHorizontalFlip(),
|
|
||||||
transforms.RandomCrop(16, padding=2),
|
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize(mean, std),
|
transforms.Normalize(mean=mean, std=std)
|
||||||
]
|
])
|
||||||
if cutout > 0:
|
test_transform = transforms.Compose([
|
||||||
lists += [CUTOUT(cutout)]
|
transforms.Resize((32, 32)),
|
||||||
train_transform = transforms.Compose(lists)
|
|
||||||
test_transform = transforms.Compose(
|
|
||||||
[transforms.ToTensor(), transforms.Normalize(mean, std)]
|
|
||||||
)
|
|
||||||
xshape = (1, 3, 16, 16)
|
|
||||||
elif name == "tiered":
|
|
||||||
lists = [
|
|
||||||
transforms.RandomHorizontalFlip(),
|
|
||||||
transforms.RandomCrop(80, padding=4),
|
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize(mean, std),
|
transforms.Normalize(mean=mean, std=std)
|
||||||
]
|
])
|
||||||
if cutout > 0:
|
|
||||||
lists += [CUTOUT(cutout)]
|
|
||||||
train_transform = transforms.Compose(lists)
|
|
||||||
test_transform = transforms.Compose(
|
|
||||||
[
|
|
||||||
transforms.CenterCrop(80),
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize(mean, std),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
xshape = (1, 3, 32, 32)
|
xshape = (1, 3, 32, 32)
|
||||||
elif name.startswith("imagenet-1k"):
|
elif name.startswith('mnist'):
|
||||||
normalize = transforms.Normalize(
|
train_transform = transforms.Compose([
|
||||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
transforms.Resize((32, 32)),
|
||||||
)
|
|
||||||
if name == "imagenet-1k":
|
|
||||||
xlists = [transforms.RandomResizedCrop(224)]
|
|
||||||
xlists.append(
|
|
||||||
transforms.ColorJitter(
|
|
||||||
brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2
|
|
||||||
)
|
|
||||||
)
|
|
||||||
xlists.append(Lighting(0.1))
|
|
||||||
elif name == "imagenet-1k-s":
|
|
||||||
xlists = [transforms.RandomResizedCrop(224, scale=(0.2, 1.0))]
|
|
||||||
else:
|
|
||||||
raise ValueError("invalid name : {:}".format(name))
|
|
||||||
xlists.append(transforms.RandomHorizontalFlip(p=0.5))
|
|
||||||
xlists.append(transforms.ToTensor())
|
|
||||||
xlists.append(normalize)
|
|
||||||
train_transform = transforms.Compose(xlists)
|
|
||||||
test_transform = transforms.Compose(
|
|
||||||
[
|
|
||||||
transforms.Resize(256),
|
|
||||||
transforms.CenterCrop(224),
|
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
normalize,
|
transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
|
||||||
]
|
transforms.Normalize(mean, std),
|
||||||
)
|
])
|
||||||
xshape = (1, 3, 224, 224)
|
test_transform = transforms.Compose([
|
||||||
|
transforms.Resize((32, 32)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
|
||||||
|
transforms.Normalize(mean, std)
|
||||||
|
])
|
||||||
|
xshape = (1, 3, 32, 32)
|
||||||
|
elif name.startswith('svhn'):
|
||||||
|
train_transform = transforms.Compose([
|
||||||
|
transforms.Resize((32, 32)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=mean, std=std)
|
||||||
|
])
|
||||||
|
test_transform = transforms.Compose([
|
||||||
|
transforms.Resize((32, 32)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=mean, std=std)
|
||||||
|
])
|
||||||
|
xshape = (1, 3, 32, 32)
|
||||||
|
elif name.startswith('aircraft'):
|
||||||
|
train_transform = transforms.Compose([
|
||||||
|
transforms.Resize((32, 32)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=mean, std=std)
|
||||||
|
])
|
||||||
|
test_transform = transforms.Compose([
|
||||||
|
transforms.Resize((32, 32)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=mean, std=std),
|
||||||
|
])
|
||||||
|
xshape = (1, 3, 32, 32)
|
||||||
|
elif name.startswith('oxford'):
|
||||||
|
train_transform = transforms.Compose([
|
||||||
|
transforms.Resize((32, 32)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=mean, std=std)
|
||||||
|
])
|
||||||
|
test_transform = transforms.Compose([
|
||||||
|
transforms.Resize((32, 32)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=mean, std=std),
|
||||||
|
])
|
||||||
|
xshape = (1, 3, 32, 32)
|
||||||
else:
|
else:
|
||||||
raise TypeError("Unknow dataset : {:}".format(name))
|
raise TypeError("Unknow dataset : {:}".format(name))
|
||||||
|
|
||||||
if name == "cifar10":
|
if name == 'cifar10':
|
||||||
train_data = dset.CIFAR10(
|
train_data = dset.CIFAR10(
|
||||||
root, train=True, transform=train_transform, download=True
|
root, train=True, transform=train_transform, download=True)
|
||||||
)
|
|
||||||
test_data = dset.CIFAR10(
|
test_data = dset.CIFAR10(
|
||||||
root, train=False, transform=test_transform, download=True
|
root, train=False, transform=test_transform, download=True)
|
||||||
)
|
|
||||||
assert len(train_data) == 50000 and len(test_data) == 10000
|
assert len(train_data) == 50000 and len(test_data) == 10000
|
||||||
elif name == "cifar100":
|
elif name == 'cifar100':
|
||||||
train_data = dset.CIFAR100(
|
train_data = dset.CIFAR100(
|
||||||
root, train=True, transform=train_transform, download=True
|
root, train=True, transform=train_transform, download=True)
|
||||||
)
|
|
||||||
test_data = dset.CIFAR100(
|
test_data = dset.CIFAR100(
|
||||||
root, train=False, transform=test_transform, download=True
|
root, train=False, transform=test_transform, download=True)
|
||||||
)
|
|
||||||
assert len(train_data) == 50000 and len(test_data) == 10000
|
assert len(train_data) == 50000 and len(test_data) == 10000
|
||||||
elif name == "aircraft":
|
elif name == 'mnist':
|
||||||
train_data = dset.ImageFolder(root='/lustre/hpe/ws11/ws11.1/ws/xmuhanma-SWAP/train_datasets/datasets/fgvc-aircraft-2013b/data/train_sorted_image', transform=train_transform)
|
train_data = dset.MNIST(
|
||||||
test_data = dset.ImageFolder(root='/lustre/hpe/ws11/ws11.1/ws/xmuhanma-SWAP/train_datasets/datasets/fgvc-aircraft-2013b/data/train_sorted_image', transform=test_transform)
|
root, train=True, transform=train_transform, download=True)
|
||||||
|
test_data = dset.MNIST(
|
||||||
elif name.startswith("imagenet-1k"):
|
root, train=False, transform=test_transform, download=True)
|
||||||
train_data = dset.ImageFolder(osp.join(root, "train"), train_transform)
|
assert len(train_data) == 60000 and len(test_data) == 10000
|
||||||
test_data = dset.ImageFolder(osp.join(root, "val"), test_transform)
|
elif name == 'svhn':
|
||||||
assert (
|
train_data = dset.SVHN(root, split='train',
|
||||||
len(train_data) == 1281167 and len(test_data) == 50000
|
transform=train_transform, download=True)
|
||||||
), "invalid number of images : {:} & {:} vs {:} & {:}".format(
|
test_data = dset.SVHN(root, split='test',
|
||||||
len(train_data), len(test_data), 1281167, 50000
|
transform=test_transform, download=True)
|
||||||
)
|
assert len(train_data) == 73257 and len(test_data) == 26032
|
||||||
elif name == "ImageNet16":
|
elif name == 'aircraft':
|
||||||
train_data = ImageNet16(root, True, train_transform)
|
train_data = FGVCAircraft(root, class_type='manufacturer', split='trainval',
|
||||||
test_data = ImageNet16(root, False, test_transform)
|
transform=train_transform, download=False)
|
||||||
assert len(train_data) == 1281167 and len(test_data) == 50000
|
test_data = FGVCAircraft(root, class_type='manufacturer', split='test',
|
||||||
elif name == "ImageNet16-120":
|
transform=test_transform, download=False)
|
||||||
train_data = ImageNet16(root, True, train_transform, 120)
|
assert len(train_data) == 6667 and len(test_data) == 3333
|
||||||
test_data = ImageNet16(root, False, test_transform, 120)
|
elif name == 'oxford':
|
||||||
assert len(train_data) == 151700 and len(test_data) == 6000
|
train_data = PetDataset(root, train=True, num_cl=37,
|
||||||
elif name == "ImageNet16-150":
|
val_split=0.15, transforms=train_transform)
|
||||||
train_data = ImageNet16(root, True, train_transform, 150)
|
test_data = PetDataset(root, train=False, num_cl=37,
|
||||||
test_data = ImageNet16(root, False, test_transform, 150)
|
val_split=0.15, transforms=test_transform)
|
||||||
assert len(train_data) == 190272 and len(test_data) == 7500
|
|
||||||
elif name == "ImageNet16-200":
|
|
||||||
train_data = ImageNet16(root, True, train_transform, 200)
|
|
||||||
test_data = ImageNet16(root, False, test_transform, 200)
|
|
||||||
assert len(train_data) == 254775 and len(test_data) == 10000
|
|
||||||
else:
|
else:
|
||||||
raise TypeError("Unknow dataset : {:}".format(name))
|
raise TypeError("Unknow dataset : {:}".format(name))
|
||||||
|
|
||||||
class_num = Dataset2Class[name]
|
class_num = Dataset2Class[name] if use_num_cls is None else len(
|
||||||
|
use_num_cls)
|
||||||
return train_data, test_data, xshape, class_num
|
return train_data, test_data, xshape, class_num
|
||||||
|
|
||||||
|
|
||||||
def get_nas_search_loaders(
|
def get_nas_search_loaders(train_data, valid_data, dataset, config_root, batch_size, workers, num_cls=None):
|
||||||
train_data, valid_data, dataset, config_root, batch_size, workers
|
|
||||||
):
|
|
||||||
if isinstance(batch_size, (list, tuple)):
|
if isinstance(batch_size, (list, tuple)):
|
||||||
batch, test_batch = batch_size
|
batch, test_batch = batch_size
|
||||||
else:
|
else:
|
||||||
batch, test_batch = batch_size, batch_size
|
batch, test_batch = batch_size, batch_size
|
||||||
if dataset == "cifar10":
|
if dataset == 'cifar10':
|
||||||
# split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
# split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
||||||
cifar_split = load_config("{:}/cifar-split.txt".format(config_root), None, None)
|
cifar_split = load_config(
|
||||||
train_split, valid_split = (
|
'{:}/cifar-split.txt'.format(config_root), None, None)
|
||||||
cifar_split.train,
|
# search over the proposed training and validation set
|
||||||
cifar_split.valid,
|
train_split, valid_split = cifar_split.train, cifar_split.valid
|
||||||
) # search over the proposed training and validation set
|
|
||||||
# logger.log('Load split file from {:}'.format(split_Fpath)) # they are two disjoint groups in the original CIFAR-10 training set
|
# logger.log('Load split file from {:}'.format(split_Fpath)) # they are two disjoint groups in the original CIFAR-10 training set
|
||||||
# To split data
|
# To split data
|
||||||
xvalid_data = deepcopy(train_data)
|
xvalid_data = deepcopy(train_data)
|
||||||
if hasattr(xvalid_data, "transforms"): # to avoid a print issue
|
if hasattr(xvalid_data, 'transforms'): # to avoid a print issue
|
||||||
xvalid_data.transforms = valid_data.transform
|
xvalid_data.transforms = valid_data.transform
|
||||||
xvalid_data.transform = deepcopy(valid_data.transform)
|
xvalid_data.transform = deepcopy(valid_data.transform)
|
||||||
search_data = SearchDataset(dataset, train_data, train_split, valid_split)
|
search_data = SearchDataset(
|
||||||
|
dataset, train_data, train_split, valid_split)
|
||||||
# data loader
|
# data loader
|
||||||
search_loader = torch.utils.data.DataLoader(
|
search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True, num_workers=workers,
|
||||||
search_data,
|
pin_memory=True)
|
||||||
batch_size=batch,
|
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch,
|
||||||
shuffle=True,
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(
|
||||||
num_workers=workers,
|
train_split),
|
||||||
pin_memory=True,
|
num_workers=workers, pin_memory=True)
|
||||||
)
|
valid_loader = torch.utils.data.DataLoader(xvalid_data, batch_size=test_batch,
|
||||||
train_loader = torch.utils.data.DataLoader(
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(
|
||||||
train_data,
|
valid_split),
|
||||||
batch_size=batch,
|
num_workers=workers, pin_memory=True)
|
||||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split),
|
elif dataset == 'cifar100':
|
||||||
num_workers=workers,
|
|
||||||
pin_memory=True,
|
|
||||||
)
|
|
||||||
valid_loader = torch.utils.data.DataLoader(
|
|
||||||
xvalid_data,
|
|
||||||
batch_size=test_batch,
|
|
||||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split),
|
|
||||||
num_workers=workers,
|
|
||||||
pin_memory=True,
|
|
||||||
)
|
|
||||||
elif dataset == "cifar100":
|
|
||||||
cifar100_test_split = load_config(
|
cifar100_test_split = load_config(
|
||||||
"{:}/cifar100-test-split.txt".format(config_root), None, None
|
'{:}/cifar100-test-split.txt'.format(config_root), None, None)
|
||||||
)
|
|
||||||
search_train_data = train_data
|
search_train_data = train_data
|
||||||
search_valid_data = deepcopy(valid_data)
|
search_valid_data = deepcopy(valid_data)
|
||||||
search_valid_data.transform = train_data.transform
|
search_valid_data.transform = train_data.transform
|
||||||
search_data = SearchDataset(
|
search_data = SearchDataset(dataset, [search_train_data, search_valid_data],
|
||||||
dataset,
|
|
||||||
[search_train_data, search_valid_data],
|
|
||||||
list(range(len(search_train_data))),
|
list(range(len(search_train_data))),
|
||||||
cifar100_test_split.xvalid,
|
cifar100_test_split.xvalid)
|
||||||
)
|
search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True, num_workers=workers,
|
||||||
search_loader = torch.utils.data.DataLoader(
|
pin_memory=True)
|
||||||
search_data,
|
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch, shuffle=True, num_workers=workers,
|
||||||
batch_size=batch,
|
pin_memory=True)
|
||||||
shuffle=True,
|
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=test_batch,
|
||||||
num_workers=workers,
|
|
||||||
pin_memory=True,
|
|
||||||
)
|
|
||||||
train_loader = torch.utils.data.DataLoader(
|
|
||||||
train_data,
|
|
||||||
batch_size=batch,
|
|
||||||
shuffle=True,
|
|
||||||
num_workers=workers,
|
|
||||||
pin_memory=True,
|
|
||||||
)
|
|
||||||
valid_loader = torch.utils.data.DataLoader(
|
|
||||||
valid_data,
|
|
||||||
batch_size=test_batch,
|
|
||||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(
|
||||||
cifar100_test_split.xvalid
|
cifar100_test_split.xvalid), num_workers=workers, pin_memory=True)
|
||||||
),
|
elif dataset in ['mnist', 'svhn', 'aircraft', 'oxford']:
|
||||||
num_workers=workers,
|
if not os.path.exists('{:}/{}-test-split.txt'.format(config_root, dataset)):
|
||||||
pin_memory=True,
|
import json
|
||||||
)
|
label_list = list(range(len(valid_data)))
|
||||||
elif dataset == "ImageNet16-120":
|
random.shuffle(label_list)
|
||||||
imagenet_test_split = load_config(
|
strlist = [str(label_list[i]) for i in range(len(label_list))]
|
||||||
"{:}/imagenet-16-120-test-split.txt".format(config_root), None, None
|
split = {'xvalid': ["int", strlist[:len(valid_data) // 2]],
|
||||||
)
|
'xtest': ["int", strlist[len(valid_data) // 2:]]}
|
||||||
|
with open('{:}/{}-test-split.txt'.format(config_root, dataset), 'w') as f:
|
||||||
|
f.write(json.dumps(split))
|
||||||
|
test_split = load_config(
|
||||||
|
'{:}/{}-test-split.txt'.format(config_root, dataset), None, None)
|
||||||
|
|
||||||
search_train_data = train_data
|
search_train_data = train_data
|
||||||
search_valid_data = deepcopy(valid_data)
|
search_valid_data = deepcopy(valid_data)
|
||||||
search_valid_data.transform = train_data.transform
|
search_valid_data.transform = train_data.transform
|
||||||
search_data = SearchDataset(
|
search_data = SearchDataset(dataset, [search_train_data, search_valid_data],
|
||||||
dataset,
|
list(range(len(search_train_data))), test_split.xvalid)
|
||||||
[search_train_data, search_valid_data],
|
search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True,
|
||||||
list(range(len(search_train_data))),
|
num_workers=workers, pin_memory=True)
|
||||||
imagenet_test_split.xvalid,
|
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch, shuffle=True,
|
||||||
)
|
num_workers=workers, pin_memory=True)
|
||||||
search_loader = torch.utils.data.DataLoader(
|
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=test_batch,
|
||||||
search_data,
|
|
||||||
batch_size=batch,
|
|
||||||
shuffle=True,
|
|
||||||
num_workers=workers,
|
|
||||||
pin_memory=True,
|
|
||||||
)
|
|
||||||
train_loader = torch.utils.data.DataLoader(
|
|
||||||
train_data,
|
|
||||||
batch_size=batch,
|
|
||||||
shuffle=True,
|
|
||||||
num_workers=workers,
|
|
||||||
pin_memory=True,
|
|
||||||
)
|
|
||||||
valid_loader = torch.utils.data.DataLoader(
|
|
||||||
valid_data,
|
|
||||||
batch_size=test_batch,
|
|
||||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(
|
||||||
imagenet_test_split.xvalid
|
test_split.xvalid), num_workers=workers, pin_memory=True)
|
||||||
),
|
|
||||||
num_workers=workers,
|
|
||||||
pin_memory=True,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("invalid dataset : {:}".format(dataset))
|
raise ValueError('invalid dataset : {:}'.format(dataset))
|
||||||
return search_loader, train_loader, valid_loader
|
return search_loader, train_loader, valid_loader
|
||||||
|
|
||||||
|
|
||||||
# if __name__ == '__main__':
|
|
||||||
# train_data, test_data, xshape, class_num = dataset = get_datasets('cifar10', '/data02/dongxuanyi/.torch/cifar.python/', -1)
|
|
||||||
# import pdb; pdb.set_trace()
|
|
||||||
|
Loading…
Reference in New Issue
Block a user