85 lines
3.0 KiB
Python
85 lines
3.0 KiB
Python
from __future__ import print_function
|
|
import numpy as np
|
|
from PIL import Image
|
|
import pickle as pkl
|
|
import os, cv2, csv, glob
|
|
import torch
|
|
import torch.utils.data as data
|
|
|
|
|
|
class TieredImageNet(data.Dataset):
|
|
|
|
def __init__(self, root_dir, split, transform=None):
|
|
self.split = split
|
|
self.root_dir = root_dir
|
|
self.transform = transform
|
|
splits = split.split('-')
|
|
|
|
images, labels, last = [], [], 0
|
|
for split in splits:
|
|
labels_name = '{:}/{:}_labels.pkl'.format(self.root_dir, split)
|
|
images_name = '{:}/{:}_images.npz'.format(self.root_dir, split)
|
|
# decompress images if npz not exits
|
|
if not os.path.exists(images_name):
|
|
png_pkl = images_name[:-4] + '_png.pkl'
|
|
if os.path.exists(png_pkl):
|
|
decompress(images_name, png_pkl)
|
|
else:
|
|
raise ValueError('png_pkl {:} not exits'.format( png_pkl ))
|
|
assert os.path.exists(images_name) and os.path.exists(labels_name), '{:} & {:}'.format(images_name, labels_name)
|
|
print ("Prepare {:} done".format(images_name))
|
|
try:
|
|
with open(labels_name) as f:
|
|
data = pkl.load(f)
|
|
label_specific = data["label_specific"]
|
|
except:
|
|
with open(labels_name, 'rb') as f:
|
|
data = pkl.load(f, encoding='bytes')
|
|
label_specific = data[b'label_specific']
|
|
with np.load(images_name, mmap_mode="r", encoding='latin1') as data:
|
|
image_data = data["images"]
|
|
images.append( image_data )
|
|
label_specific = label_specific + last
|
|
labels.append( label_specific )
|
|
last = np.max(label_specific) + 1
|
|
print ("Load {:} done, with image shape = {:}, label shape = {:}, [{:} ~ {:}]".format(images_name, image_data.shape, label_specific.shape, np.min(label_specific), np.max(label_specific)))
|
|
images, labels = np.concatenate(images), np.concatenate(labels)
|
|
|
|
self.images = images
|
|
self.labels = labels
|
|
self.n_classes = int( np.max(labels) + 1 )
|
|
self.dict_index_label = {}
|
|
for cls in range(self.n_classes):
|
|
idxs = np.where(labels==cls)[0]
|
|
self.dict_index_label[cls] = idxs
|
|
self.length = len(labels)
|
|
print ("There are {:} images, {:} labels [{:} ~ {:}]".format(images.shape, labels.shape, np.min(labels), np.max(labels)))
|
|
|
|
|
|
def __repr__(self):
|
|
return ('{name}(length={length}, classes={n_classes})'.format(name=self.__class__.__name__, **self.__dict__))
|
|
|
|
def __len__(self):
|
|
return self.length
|
|
|
|
def __getitem__(self, index):
|
|
assert index >= 0 and index < self.length, 'invalid index = {:}'.format(index)
|
|
image = self.images[index].copy()
|
|
label = int(self.labels[index])
|
|
image = Image.fromarray(image[:,:,::-1].astype('uint8'), 'RGB')
|
|
if self.transform is not None:
|
|
image = self.transform( image )
|
|
return image, label
|
|
|
|
|
|
|
|
|
|
def decompress(path, output):
|
|
with open(output, 'rb') as f:
|
|
array = pkl.load(f, encoding='bytes')
|
|
images = np.zeros([len(array), 84, 84, 3], dtype=np.uint8)
|
|
for ii, item in enumerate(array):
|
|
im = cv2.imdecode(item, 1)
|
|
images[ii] = im
|
|
np.savez(path, images=images)
|