# coding=utf-8 import numpy as np import torch class MetaBatchSampler(object): def __init__(self, labels, classes_per_it, num_samples, iterations): ''' Initialize MetaBatchSampler Args: - labels: an iterable containing all the labels for the current dataset samples indexes will be infered from this iterable. - classes_per_it: number of random classes for each iteration - num_samples: number of samples for each iteration for each class (support + query) - iterations: number of iterations (episodes) per epoch ''' super(MetaBatchSampler, self).__init__() self.labels = labels.copy() self.classes_per_it = classes_per_it self.sample_per_class = num_samples self.iterations = iterations self.classes, self.counts = np.unique(self.labels, return_counts=True) assert len(self.classes) == np.max(self.classes) + 1 and np.min(self.classes) == 0 assert classes_per_it < len(self.classes), '{:} vs. {:}'.format(classes_per_it, len(self.classes)) self.classes = torch.LongTensor(self.classes) # create a matrix, indexes, of dim: classes X max(elements per class) # fill it with nans # for every class c, fill the relative row with the indices samples belonging to c # in numel_per_class we store the number of samples for each class/row self.indexes = { x.item() : [] for x in self.classes } indexes = { x.item() : [] for x in self.classes } for idx, label in enumerate(self.labels): indexes[ label.item() ].append( idx ) for key, value in indexes.items(): self.indexes[ key ] = torch.LongTensor( value ) def __iter__(self): # yield a batch of indexes spc = self.sample_per_class cpi = self.classes_per_it for it in range(self.iterations): batch_size = spc * cpi batch = torch.LongTensor(batch_size) assert cpi < len(self.classes), '{:} vs. {:}'.format(cpi, len(self.classes)) c_idxs = torch.randperm(len(self.classes))[:cpi] for i, cls in enumerate(self.classes[c_idxs]): s = slice(i * spc, (i + 1) * spc) num = self.indexes[ cls.item() ].nelement() assert spc < num, '{:} vs. {:}'.format(spc, num) sample_idxs = torch.randperm( num )[:spc] batch[s] = self.indexes[ cls.item() ][sample_idxs] batch = batch[torch.randperm(len(batch))] yield batch def __len__(self): # returns the number of iterations (episodes) per epoch return self.iterations