85 lines
		
	
	
		
			3.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			85 lines
		
	
	
		
			3.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
from __future__ import print_function
 | 
						|
import numpy as np
 | 
						|
from PIL import Image
 | 
						|
import pickle as pkl
 | 
						|
import os, cv2, csv, glob
 | 
						|
import torch
 | 
						|
import torch.utils.data as data
 | 
						|
 | 
						|
 | 
						|
class TieredImageNet(data.Dataset):
 | 
						|
 | 
						|
  def __init__(self, root_dir, split, transform=None):
 | 
						|
    self.split = split
 | 
						|
    self.root_dir = root_dir
 | 
						|
    self.transform = transform
 | 
						|
    splits = split.split('-')
 | 
						|
 | 
						|
    images, labels, last = [], [], 0
 | 
						|
    for split in splits:
 | 
						|
      labels_name = '{:}/{:}_labels.pkl'.format(self.root_dir, split)
 | 
						|
      images_name = '{:}/{:}_images.npz'.format(self.root_dir, split)
 | 
						|
      # decompress images if npz not exits
 | 
						|
      if not os.path.exists(images_name):
 | 
						|
        png_pkl = images_name[:-4] + '_png.pkl'
 | 
						|
        if os.path.exists(png_pkl):
 | 
						|
          decompress(images_name, png_pkl)
 | 
						|
        else:
 | 
						|
          raise ValueError('png_pkl {:} not exits'.format( png_pkl ))
 | 
						|
      assert os.path.exists(images_name) and os.path.exists(labels_name), '{:} & {:}'.format(images_name, labels_name)
 | 
						|
      print ("Prepare {:} done".format(images_name))
 | 
						|
      try:
 | 
						|
        with open(labels_name) as f:
 | 
						|
          data = pkl.load(f)
 | 
						|
          label_specific = data["label_specific"]
 | 
						|
      except:
 | 
						|
        with open(labels_name, 'rb') as f:
 | 
						|
          data = pkl.load(f, encoding='bytes')
 | 
						|
          label_specific = data[b'label_specific']
 | 
						|
      with np.load(images_name, mmap_mode="r", encoding='latin1') as data:
 | 
						|
        image_data = data["images"]
 | 
						|
      images.append( image_data )
 | 
						|
      label_specific = label_specific + last
 | 
						|
      labels.append( label_specific )
 | 
						|
      last = np.max(label_specific) + 1
 | 
						|
      print ("Load {:} done, with image shape = {:}, label shape = {:}, [{:} ~ {:}]".format(images_name, image_data.shape, label_specific.shape, np.min(label_specific), np.max(label_specific)))
 | 
						|
    images, labels = np.concatenate(images), np.concatenate(labels)
 | 
						|
 | 
						|
    self.images = images
 | 
						|
    self.labels = labels
 | 
						|
    self.n_classes = int( np.max(labels) + 1 )
 | 
						|
    self.dict_index_label = {}
 | 
						|
    for cls in range(self.n_classes):
 | 
						|
      idxs = np.where(labels==cls)[0]
 | 
						|
      self.dict_index_label[cls] = idxs
 | 
						|
    self.length = len(labels)
 | 
						|
    print ("There are {:} images, {:} labels [{:} ~ {:}]".format(images.shape, labels.shape, np.min(labels), np.max(labels)))
 | 
						|
  
 | 
						|
 | 
						|
  def __repr__(self):
 | 
						|
    return ('{name}(length={length}, classes={n_classes})'.format(name=self.__class__.__name__, **self.__dict__))
 | 
						|
 | 
						|
  def __len__(self):
 | 
						|
    return self.length
 | 
						|
 | 
						|
  def __getitem__(self, index):
 | 
						|
    assert index >= 0 and index < self.length, 'invalid index = {:}'.format(index)
 | 
						|
    image = self.images[index].copy()
 | 
						|
    label = int(self.labels[index])
 | 
						|
    image = Image.fromarray(image[:,:,::-1].astype('uint8'), 'RGB')
 | 
						|
    if self.transform is not None:
 | 
						|
      image = self.transform( image )
 | 
						|
    return image, label
 | 
						|
 | 
						|
 | 
						|
 | 
						|
 | 
						|
def decompress(path, output):
 | 
						|
  with open(output, 'rb') as f:
 | 
						|
    array = pkl.load(f, encoding='bytes')
 | 
						|
  images = np.zeros([len(array), 84, 84, 3], dtype=np.uint8)
 | 
						|
  for ii, item in enumerate(array):
 | 
						|
    im = cv2.imdecode(item, 1)
 | 
						|
    images[ii] = im
 | 
						|
  np.savez(path, images=images)
 |