autodl-projects/lib/datasets/get_dataset_with_transform.py
2019-04-08 11:04:08 +08:00

75 lines
3.1 KiB
Python

import os, sys, torch
import os.path as osp
import torchvision.datasets as dset
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
from utils import Cutout
from .TieredImageNet import TieredImageNet
Dataset2Class = {'cifar10' : 10,
'cifar100': 100,
'tiered' : -1,
'imagenet-1k' : 1000,
'imagenet-100': 100}
def get_datasets(name, root, cutout):
# Mean + Std
if name == 'cifar10':
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
std = [x / 255 for x in [63.0, 62.1, 66.7]]
elif name == 'cifar100':
mean = [x / 255 for x in [129.3, 124.1, 112.4]]
std = [x / 255 for x in [68.2, 65.4, 70.4]]
elif name == 'tiered':
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
elif name == 'imagenet-1k' or name == 'imagenet-100':
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
else: raise TypeError("Unknow dataset : {:}".format(name))
# Data Argumentation
if name == 'cifar10' or name == 'cifar100':
lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(),
transforms.Normalize(mean, std)]
if cutout > 0 : lists += [Cutout(cutout)]
train_transform = transforms.Compose(lists)
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
elif name == 'tiered':
lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(80, padding=4), transforms.ToTensor(), transforms.Normalize(mean, std)]
if cutout > 0 : lists += [Cutout(cutout)]
train_transform = transforms.Compose(lists)
test_transform = transforms.Compose([transforms.CenterCrop(80), transforms.ToTensor(), transforms.Normalize(mean, std)])
elif name == 'imagenet-1k' or name == 'imagenet-100':
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(
brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.2),
transforms.ToTensor(),
normalize,
])
test_transform = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize])
else: raise TypeError("Unknow dataset : {:}".format(name))
if name == 'cifar10':
train_data = dset.CIFAR10 (root, train=True , transform=train_transform, download=True)
test_data = dset.CIFAR10 (root, train=False, transform=test_transform , download=True)
elif name == 'cifar100':
train_data = dset.CIFAR100(root, train=True , transform=train_transform, download=True)
test_data = dset.CIFAR100(root, train=False, transform=test_transform , download=True)
elif name == 'imagenet-1k' or name == 'imagenet-100':
train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform)
test_data = dset.ImageFolder(osp.join(root, 'val'), test_transform)
else: raise TypeError("Unknow dataset : {:}".format(name))
class_num = Dataset2Class[name]
return train_data, test_data, class_num