2019-11-15 07:15:07 +01:00
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
2019-01-31 17:23:55 +01:00
import os , sys , torch
import os . path as osp
2019-09-28 10:24:47 +02:00
import numpy as np
2019-01-31 17:23:55 +01:00
import torchvision . datasets as dset
import torchvision . transforms as transforms
2019-09-28 10:24:47 +02:00
from PIL import Image
from . DownsampledImageNet import ImageNet16
2019-01-31 17:23:55 +01:00
2019-03-31 16:49:43 +02:00
2019-01-31 17:23:55 +01:00
Dataset2Class = { ' cifar10 ' : 10 ,
' cifar100 ' : 100 ,
2019-09-28 10:24:47 +02:00
' imagenet-1k-s ' : 1000 ,
2019-01-31 18:03:35 +01:00
' imagenet-1k ' : 1000 ,
2019-09-28 10:24:47 +02:00
' ImageNet16 ' : 1000 ,
' ImageNet16-150 ' : 150 ,
' ImageNet16-120 ' : 120 ,
' ImageNet16-200 ' : 200 }
class CUTOUT ( object ) :
def __init__ ( self , length ) :
self . length = length
def __repr__ ( self ) :
return ( ' {name} (length= {length} ) ' . format ( name = self . __class__ . __name__ , * * self . __dict__ ) )
def __call__ ( self , img ) :
h , w = img . size ( 1 ) , img . size ( 2 )
mask = np . ones ( ( h , w ) , np . float32 )
y = np . random . randint ( h )
x = np . random . randint ( w )
y1 = np . clip ( y - self . length / / 2 , 0 , h )
y2 = np . clip ( y + self . length / / 2 , 0 , h )
x1 = np . clip ( x - self . length / / 2 , 0 , w )
x2 = np . clip ( x + self . length / / 2 , 0 , w )
mask [ y1 : y2 , x1 : x2 ] = 0.
mask = torch . from_numpy ( mask )
mask = mask . expand_as ( img )
img * = mask
return img
imagenet_pca = {
' eigval ' : np . asarray ( [ 0.2175 , 0.0188 , 0.0045 ] ) ,
' eigvec ' : np . asarray ( [
[ - 0.5675 , 0.7192 , 0.4009 ] ,
[ - 0.5808 , - 0.0045 , - 0.8140 ] ,
[ - 0.5836 , - 0.6948 , 0.4203 ] ,
] )
}
class Lighting ( object ) :
def __init__ ( self , alphastd ,
eigval = imagenet_pca [ ' eigval ' ] ,
eigvec = imagenet_pca [ ' eigvec ' ] ) :
self . alphastd = alphastd
assert eigval . shape == ( 3 , )
assert eigvec . shape == ( 3 , 3 )
self . eigval = eigval
self . eigvec = eigvec
def __call__ ( self , img ) :
if self . alphastd == 0. :
return img
rnd = np . random . randn ( 3 ) * self . alphastd
rnd = rnd . astype ( ' float32 ' )
v = rnd
old_dtype = np . asarray ( img ) . dtype
v = v * self . eigval
v = v . reshape ( ( 3 , 1 ) )
inc = np . dot ( self . eigvec , v ) . reshape ( ( 3 , ) )
img = np . add ( img , inc )
if old_dtype == np . uint8 :
img = np . clip ( img , 0 , 255 )
img = Image . fromarray ( img . astype ( old_dtype ) , ' RGB ' )
return img
def __repr__ ( self ) :
return self . __class__ . __name__ + ' () '
2019-01-31 17:23:55 +01:00
def get_datasets ( name , root , cutout ) :
if name == ' cifar10 ' :
mean = [ x / 255 for x in [ 125.3 , 123.0 , 113.9 ] ]
2019-09-28 10:24:47 +02:00
std = [ x / 255 for x in [ 63.0 , 62.1 , 66.7 ] ]
2019-01-31 17:23:55 +01:00
elif name == ' cifar100 ' :
mean = [ x / 255 for x in [ 129.3 , 124.1 , 112.4 ] ]
2019-09-28 10:24:47 +02:00
std = [ x / 255 for x in [ 68.2 , 65.4 , 70.4 ] ]
elif name . startswith ( ' imagenet-1k ' ) :
2019-01-31 18:03:35 +01:00
mean , std = [ 0.485 , 0.456 , 0.406 ] , [ 0.229 , 0.224 , 0.225 ]
2019-09-28 10:24:47 +02:00
elif name . startswith ( ' ImageNet16 ' ) :
mean = [ x / 255 for x in [ 122.68 , 116.66 , 104.01 ] ]
std = [ x / 255 for x in [ 63.22 , 61.26 , 65.09 ] ]
else :
raise TypeError ( " Unknow dataset : {:} " . format ( name ) )
2019-01-31 17:23:55 +01:00
# Data Argumentation
if name == ' cifar10 ' or name == ' cifar100 ' :
2019-09-28 10:24:47 +02:00
lists = [ transforms . RandomHorizontalFlip ( ) , transforms . RandomCrop ( 32 , padding = 4 ) , transforms . ToTensor ( ) , transforms . Normalize ( mean , std ) ]
if cutout > 0 : lists + = [ CUTOUT ( cutout ) ]
train_transform = transforms . Compose ( lists )
test_transform = transforms . Compose ( [ transforms . ToTensor ( ) , transforms . Normalize ( mean , std ) ] )
xshape = ( 1 , 3 , 32 , 32 )
elif name . startswith ( ' ImageNet16 ' ) :
lists = [ transforms . RandomHorizontalFlip ( ) , transforms . RandomCrop ( 16 , padding = 2 ) , transforms . ToTensor ( ) , transforms . Normalize ( mean , std ) ]
if cutout > 0 : lists + = [ CUTOUT ( cutout ) ]
2019-01-31 17:23:55 +01:00
train_transform = transforms . Compose ( lists )
test_transform = transforms . Compose ( [ transforms . ToTensor ( ) , transforms . Normalize ( mean , std ) ] )
2019-09-28 10:24:47 +02:00
xshape = ( 1 , 3 , 16 , 16 )
2019-01-31 17:23:55 +01:00
elif name == ' tiered ' :
lists = [ transforms . RandomHorizontalFlip ( ) , transforms . RandomCrop ( 80 , padding = 4 ) , transforms . ToTensor ( ) , transforms . Normalize ( mean , std ) ]
2019-09-28 10:24:47 +02:00
if cutout > 0 : lists + = [ CUTOUT ( cutout ) ]
2019-01-31 17:23:55 +01:00
train_transform = transforms . Compose ( lists )
test_transform = transforms . Compose ( [ transforms . CenterCrop ( 80 ) , transforms . ToTensor ( ) , transforms . Normalize ( mean , std ) ] )
2019-09-28 10:24:47 +02:00
xshape = ( 1 , 3 , 32 , 32 )
elif name . startswith ( ' imagenet-1k ' ) :
2019-01-31 17:23:55 +01:00
normalize = transforms . Normalize ( mean = [ 0.485 , 0.456 , 0.406 ] , std = [ 0.229 , 0.224 , 0.225 ] )
2019-09-28 10:24:47 +02:00
if name == ' imagenet-1k ' :
xlists = [ transforms . RandomResizedCrop ( 224 ) ]
xlists . append (
transforms . ColorJitter (
2019-01-31 17:23:55 +01:00
brightness = 0.4 ,
contrast = 0.4 ,
saturation = 0.4 ,
2019-09-28 10:24:47 +02:00
hue = 0.2 ) )
xlists . append ( Lighting ( 0.1 ) )
elif name == ' imagenet-1k-s ' :
xlists = [ transforms . RandomResizedCrop ( 224 , scale = ( 0.2 , 1.0 ) ) ]
else : raise ValueError ( ' invalid name : {:} ' . format ( name ) )
xlists . append ( transforms . RandomHorizontalFlip ( p = 0.5 ) )
xlists . append ( transforms . ToTensor ( ) )
xlists . append ( normalize )
train_transform = transforms . Compose ( xlists )
test_transform = transforms . Compose ( [ transforms . Resize ( 256 ) , transforms . CenterCrop ( 224 ) , transforms . ToTensor ( ) , normalize ] )
xshape = ( 1 , 3 , 224 , 224 )
else :
raise TypeError ( " Unknow dataset : {:} " . format ( name ) )
2019-01-31 18:03:35 +01:00
2019-01-31 17:23:55 +01:00
if name == ' cifar10 ' :
2019-04-08 05:04:08 +02:00
train_data = dset . CIFAR10 ( root , train = True , transform = train_transform , download = True )
test_data = dset . CIFAR10 ( root , train = False , transform = test_transform , download = True )
2019-09-28 10:24:47 +02:00
assert len ( train_data ) == 50000 and len ( test_data ) == 10000
2019-01-31 17:23:55 +01:00
elif name == ' cifar100 ' :
2019-03-31 16:49:43 +02:00
train_data = dset . CIFAR100 ( root , train = True , transform = train_transform , download = True )
test_data = dset . CIFAR100 ( root , train = False , transform = test_transform , download = True )
2019-09-28 10:24:47 +02:00
assert len ( train_data ) == 50000 and len ( test_data ) == 10000
elif name . startswith ( ' imagenet-1k ' ) :
2019-01-31 17:23:55 +01:00
train_data = dset . ImageFolder ( osp . join ( root , ' train ' ) , train_transform )
2019-04-08 05:04:08 +02:00
test_data = dset . ImageFolder ( osp . join ( root , ' val ' ) , test_transform )
2019-09-28 10:24:47 +02:00
assert len ( train_data ) == 1281167 and len ( test_data ) == 50000 , ' invalid number of images : {:} & {:} vs {:} & {:} ' . format ( len ( train_data ) , len ( test_data ) , 1281167 , 50000 )
elif name == ' ImageNet16 ' :
train_data = ImageNet16 ( root , True , train_transform )
test_data = ImageNet16 ( root , False , test_transform )
assert len ( train_data ) == 1281167 and len ( test_data ) == 50000
elif name == ' ImageNet16-120 ' :
train_data = ImageNet16 ( root , True , train_transform , 120 )
test_data = ImageNet16 ( root , False , test_transform , 120 )
assert len ( train_data ) == 151700 and len ( test_data ) == 6000
elif name == ' ImageNet16-150 ' :
train_data = ImageNet16 ( root , True , train_transform , 150 )
test_data = ImageNet16 ( root , False , test_transform , 150 )
assert len ( train_data ) == 190272 and len ( test_data ) == 7500
elif name == ' ImageNet16-200 ' :
train_data = ImageNet16 ( root , True , train_transform , 200 )
test_data = ImageNet16 ( root , False , test_transform , 200 )
assert len ( train_data ) == 254775 and len ( test_data ) == 10000
2019-01-31 17:23:55 +01:00
else : raise TypeError ( " Unknow dataset : {:} " . format ( name ) )
class_num = Dataset2Class [ name ]
2019-09-28 10:24:47 +02:00
return train_data , test_data , xshape , class_num
#if __name__ == '__main__':
# train_data, test_data, xshape, class_num = dataset = get_datasets('cifar10', '/data02/dongxuanyi/.torch/cifar.python/', -1)
# import pdb; pdb.set_trace()