65 lines
		
	
	
		
			2.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			65 lines
		
	
	
		
			2.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import random, tarfile
 | |
| import numpy, six
 | |
| from six.moves import cPickle as pickle
 | |
| from PIL import Image, ImageOps
 | |
| 
 | |
| 
 | |
| def train_cifar_augmentation(image):
 | |
|   # flip
 | |
|   if random.random() < 0.5: image1 = image.transpose(Image.FLIP_LEFT_RIGHT)
 | |
|   else:                     image1 = image
 | |
|   # random crop
 | |
|   image2 = ImageOps.expand(image1, border=4, fill=0)
 | |
|   i = random.randint(0, 40 - 32)
 | |
|   j = random.randint(0, 40 - 32)
 | |
|   image3 = image2.crop((j,i,j+32,i+32))
 | |
|   # to numpy
 | |
|   image3 = numpy.array(image3) / 255.0
 | |
|   mean   = numpy.array([x / 255 for x in [125.3, 123.0, 113.9]]).reshape(1, 1, 3)
 | |
|   std    = numpy.array([x / 255 for x in [63.0, 62.1, 66.7]]).reshape(1, 1, 3)
 | |
|   return (image3 - mean) / std
 | |
| 
 | |
| 
 | |
| def valid_cifar_augmentation(image):
 | |
|   image3 = numpy.array(image) / 255.0
 | |
|   mean   = numpy.array([x / 255 for x in [125.3, 123.0, 113.9]]).reshape(1, 1, 3)
 | |
|   std    = numpy.array([x / 255 for x in [63.0, 62.1, 66.7]]).reshape(1, 1, 3)
 | |
|   return (image3 - mean) / std
 | |
| 
 | |
| 
 | |
| def reader_creator(filename, sub_name, is_train, cycle=False):
 | |
|   def read_batch(batch):
 | |
|     data = batch[six.b('data')]
 | |
|     labels = batch.get(
 | |
|       six.b('labels'), batch.get(six.b('fine_labels'), None))
 | |
|     assert labels is not None
 | |
|     for sample, label in six.moves.zip(data, labels):
 | |
|       sample = sample.reshape(3, 32, 32)
 | |
|       sample = sample.transpose((1, 2, 0))
 | |
|       image  = Image.fromarray(sample)
 | |
|       if is_train:
 | |
|         ximage = train_cifar_augmentation(image)
 | |
|       else:
 | |
|         ximage = valid_cifar_augmentation(image)
 | |
|       ximage = ximage.transpose((2, 0, 1))
 | |
|       yield ximage.astype(numpy.float32), int(label)
 | |
| 
 | |
|   def reader():
 | |
|     with tarfile.open(filename, mode='r') as f:
 | |
|       names = (each_item.name for each_item in f
 | |
|            if sub_name in each_item.name)
 | |
| 
 | |
|       while True:
 | |
|         for name in names:
 | |
|           if six.PY2:
 | |
|             batch = pickle.load(f.extractfile(name))
 | |
|           else:
 | |
|             batch = pickle.load(
 | |
|               f.extractfile(name), encoding='bytes')
 | |
|           for item in read_batch(batch):
 | |
|             yield item
 | |
|         if not cycle:
 | |
|           break
 | |
| 
 | |
|   return reader
 |