65 lines
2.1 KiB
Python
65 lines
2.1 KiB
Python
import random, tarfile
|
|
import numpy, six
|
|
from six.moves import cPickle as pickle
|
|
from PIL import Image, ImageOps
|
|
|
|
|
|
def train_cifar_augmentation(image):
|
|
# flip
|
|
if random.random() < 0.5: image1 = image.transpose(Image.FLIP_LEFT_RIGHT)
|
|
else: image1 = image
|
|
# random crop
|
|
image2 = ImageOps.expand(image1, border=4, fill=0)
|
|
i = random.randint(0, 40 - 32)
|
|
j = random.randint(0, 40 - 32)
|
|
image3 = image2.crop((j,i,j+32,i+32))
|
|
# to numpy
|
|
image3 = numpy.array(image3) / 255.0
|
|
mean = numpy.array([x / 255 for x in [125.3, 123.0, 113.9]]).reshape(1, 1, 3)
|
|
std = numpy.array([x / 255 for x in [63.0, 62.1, 66.7]]).reshape(1, 1, 3)
|
|
return (image3 - mean) / std
|
|
|
|
|
|
def valid_cifar_augmentation(image):
|
|
image3 = numpy.array(image) / 255.0
|
|
mean = numpy.array([x / 255 for x in [125.3, 123.0, 113.9]]).reshape(1, 1, 3)
|
|
std = numpy.array([x / 255 for x in [63.0, 62.1, 66.7]]).reshape(1, 1, 3)
|
|
return (image3 - mean) / std
|
|
|
|
|
|
def reader_creator(filename, sub_name, is_train, cycle=False):
|
|
def read_batch(batch):
|
|
data = batch[six.b('data')]
|
|
labels = batch.get(
|
|
six.b('labels'), batch.get(six.b('fine_labels'), None))
|
|
assert labels is not None
|
|
for sample, label in six.moves.zip(data, labels):
|
|
sample = sample.reshape(3, 32, 32)
|
|
sample = sample.transpose((1, 2, 0))
|
|
image = Image.fromarray(sample)
|
|
if is_train:
|
|
ximage = train_cifar_augmentation(image)
|
|
else:
|
|
ximage = valid_cifar_augmentation(image)
|
|
ximage = ximage.transpose((2, 0, 1))
|
|
yield ximage.astype(numpy.float32), int(label)
|
|
|
|
def reader():
|
|
with tarfile.open(filename, mode='r') as f:
|
|
names = (each_item.name for each_item in f
|
|
if sub_name in each_item.name)
|
|
|
|
while True:
|
|
for name in names:
|
|
if six.PY2:
|
|
batch = pickle.load(f.extractfile(name))
|
|
else:
|
|
batch = pickle.load(
|
|
f.extractfile(name), encoding='bytes')
|
|
for item in read_batch(batch):
|
|
yield item
|
|
if not cycle:
|
|
break
|
|
|
|
return reader
|