update to oxford and aircraft

This commit is contained in:
mhz
2024-11-26 11:02:56 +01:00
parent 0d830dd2f6
commit a6e411a94b
2 changed files with 53 additions and 4 deletions

View File

@@ -18,6 +18,7 @@
from torchvision.datasets import MNIST, CIFAR10, CIFAR100, SVHN
from torchvision.transforms import Compose, ToTensor, Normalize
from torchvision import transforms
import torchvision.datasets as dset
from torch.utils.data import TensorDataset, DataLoader
import torch
@@ -44,6 +45,14 @@ def get_cifar_dataloaders(train_batch_size, test_batch_size, dataset, num_worker
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
#resize = 256
elif dataset == 'aircraft':
mean = (0.4785, 0.5100, 0.5338)
std = (0.1845, 0.1830, 0.2060)
size, pad = 224, 2
elif dataset == 'oxford':
mean = (0.4811, 0.4492, 0.3957)
std = (0.2260, 0.2231, 0.2249)
size, pad = 32, 0
elif 'random' in dataset:
mean = (0.5, 0.5, 0.5)
std = (1, 1, 1)
@@ -65,6 +74,7 @@ def get_cifar_dataloaders(train_batch_size, test_batch_size, dataset, num_worker
transforms.ToTensor(),
transforms.Normalize(mean,std),
])
root = '/nfs/data3/hanzhang/MeCo/data'
if dataset == 'cifar10':
train_dataset = CIFAR10(datadir, True, train_transform, download=True)
@@ -72,6 +82,40 @@ def get_cifar_dataloaders(train_batch_size, test_batch_size, dataset, num_worker
elif dataset == 'cifar100':
train_dataset = CIFAR100(datadir, True, train_transform, download=True)
test_dataset = CIFAR100(datadir, False, test_transform, download=True)
elif dataset == 'aircraft':
lists = [transforms.RandomCrop(size, padding=pad), transforms.ToTensor(), transforms.Normalize(mean, std)]
# if resize != None :
# print(resize)
# lists += [CUTOUT(resize)]
train_transform = transforms.Compose(lists)
test_transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean, std)])
train_data = dset.ImageFolder(os.path.join(root, 'train_sorted_images'), train_transform)
test_data = dset.ImageFolder(os.path.join(root, 'test_sorted_images'), test_transform)
elif dataset == 'oxford':
lists = [transforms.RandomCrop(size, padding=pad), transforms.ToTensor(), transforms.Normalize(mean, std)]
# if resize != None :
# print(resize)
# lists += [CUTOUT(resize)]
train_transform = transforms.Compose(lists)
test_transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean, std)])
train_data = torch.load(os.path.join(root, 'train85.pth'))
test_data = torch.load(os.path.join(root, 'test15.pth'))
train_tensor_data = [(image, label) for image, label in train_data]
test_tensor_data = [(image, label) for image, label in test_data]
sum_data = train_tensor_data + test_tensor_data
train_images = [image for image, label in train_tensor_data]
train_labels = torch.tensor([label for image, label in train_tensor_data])
test_images = [image for image, label in test_tensor_data]
test_labels = torch.tensor([label for image, label in test_tensor_data])
train_tensors = torch.stack([train_transform(image) for image in train_images])
test_tensors = torch.stack([test_transform(image) for image in test_images])
train_dataset = TensorDataset(train_tensors, train_labels)
test_dataset = TensorDataset(test_tensors, test_labels)
elif dataset == 'svhn':
train_dataset = SVHN(datadir, split='train', transform=train_transform, download=True)
test_dataset = SVHN(datadir, split='test', transform=test_transform, download=True)
@@ -97,8 +141,6 @@ def get_cifar_dataloaders(train_batch_size, test_batch_size, dataset, num_worker
shuffle=False,
num_workers=num_workers,
pin_memory=True)
return train_loader, test_loader