add aircraft root
This commit is contained in:
parent
cd80aa277c
commit
4df61fcbb3
@ -25,6 +25,32 @@ import torch
|
|||||||
from .imagenet16 import *
|
from .imagenet16 import *
|
||||||
|
|
||||||
|
|
||||||
|
class CUTOUT(object):
|
||||||
|
def __init__(self, length):
|
||||||
|
self.length = length
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "{name}(length={length})".format(
|
||||||
|
name=self.__class__.__name__, **self.__dict__
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
h, w = img.size(1), img.size(2)
|
||||||
|
mask = np.ones((h, w), np.float32)
|
||||||
|
y = np.random.randint(h)
|
||||||
|
x = np.random.randint(w)
|
||||||
|
|
||||||
|
y1 = np.clip(y - self.length // 2, 0, h)
|
||||||
|
y2 = np.clip(y + self.length // 2, 0, h)
|
||||||
|
x1 = np.clip(x - self.length // 2, 0, w)
|
||||||
|
x2 = np.clip(x + self.length // 2, 0, w)
|
||||||
|
|
||||||
|
mask[y1:y2, x1:x2] = 0.0
|
||||||
|
mask = torch.from_numpy(mask)
|
||||||
|
mask = mask.expand_as(img)
|
||||||
|
img *= mask
|
||||||
|
return img
|
||||||
|
|
||||||
def get_cifar_dataloaders(train_batch_size, test_batch_size, dataset, num_workers, resize=None, datadir='_dataset'):
|
def get_cifar_dataloaders(train_batch_size, test_batch_size, dataset, num_workers, resize=None, datadir='_dataset'):
|
||||||
# print(dataset)
|
# print(dataset)
|
||||||
if 'ImageNet16' in dataset:
|
if 'ImageNet16' in dataset:
|
||||||
@ -74,7 +100,8 @@ def get_cifar_dataloaders(train_batch_size, test_batch_size, dataset, num_worker
|
|||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize(mean,std),
|
transforms.Normalize(mean,std),
|
||||||
])
|
])
|
||||||
root = '/nfs/data3/hanzhang/MeCo/data'
|
root = '/home/iicd/MeCo/data'
|
||||||
|
aircraft_dataset_root = '/mnt/Study/DataSet/DataSet/fgvc-aircraft-2013b/fgvc-aircraft-2013b/data'
|
||||||
|
|
||||||
if dataset == 'cifar10':
|
if dataset == 'cifar10':
|
||||||
train_dataset = CIFAR10(datadir, True, train_transform, download=True)
|
train_dataset = CIFAR10(datadir, True, train_transform, download=True)
|
||||||
@ -84,18 +111,18 @@ def get_cifar_dataloaders(train_batch_size, test_batch_size, dataset, num_worker
|
|||||||
test_dataset = CIFAR100(datadir, False, test_transform, download=True)
|
test_dataset = CIFAR100(datadir, False, test_transform, download=True)
|
||||||
elif dataset == 'aircraft':
|
elif dataset == 'aircraft':
|
||||||
lists = [transforms.RandomCrop(size, padding=pad), transforms.ToTensor(), transforms.Normalize(mean, std)]
|
lists = [transforms.RandomCrop(size, padding=pad), transforms.ToTensor(), transforms.Normalize(mean, std)]
|
||||||
# if resize != None :
|
if resize != None :
|
||||||
# print(resize)
|
print(resize)
|
||||||
# lists += [CUTOUT(resize)]
|
lists += [CUTOUT(resize)]
|
||||||
train_transform = transforms.Compose(lists)
|
train_transform = transforms.Compose(lists)
|
||||||
test_transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean, std)])
|
test_transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean, std)])
|
||||||
train_data = dset.ImageFolder(os.path.join(root, 'train_sorted_images'), train_transform)
|
train_data = dset.ImageFolder(os.path.join(aircraft_dataset_root, 'train_sorted_images'), train_transform)
|
||||||
test_data = dset.ImageFolder(os.path.join(root, 'test_sorted_images'), test_transform)
|
test_data = dset.ImageFolder(os.path.join(aircraft_dataset_root, 'test_sorted_images'), test_transform)
|
||||||
elif dataset == 'oxford':
|
elif dataset == 'oxford':
|
||||||
lists = [transforms.RandomCrop(size, padding=pad), transforms.ToTensor(), transforms.Normalize(mean, std)]
|
lists = [transforms.RandomCrop(size, padding=pad), transforms.ToTensor(), transforms.Normalize(mean, std)]
|
||||||
# if resize != None :
|
if resize != None :
|
||||||
# print(resize)
|
print(resize)
|
||||||
# lists += [CUTOUT(resize)]
|
lists += [CUTOUT(resize)]
|
||||||
train_transform = transforms.Compose(lists)
|
train_transform = transforms.Compose(lists)
|
||||||
test_transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean, std)])
|
test_transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean, std)])
|
||||||
|
|
||||||
@ -172,4 +199,4 @@ if __name__ == '__main__':
|
|||||||
tr, te = get_cifar_dataloaders(64, 64, 'random', 2, resize=None, datadir='_dataset')
|
tr, te = get_cifar_dataloaders(64, 64, 'random', 2, resize=None, datadir='_dataset')
|
||||||
for x, y in tr:
|
for x, y in tr:
|
||||||
print(x.size(), y.size())
|
print(x.size(), y.size())
|
||||||
break
|
break
|
||||||
|
Loading…
Reference in New Issue
Block a user