66 lines
2.4 KiB
Python
66 lines
2.4 KiB
Python
# coding=utf-8
|
|
import numpy as np
|
|
import torch
|
|
|
|
|
|
class MetaBatchSampler(object):
|
|
|
|
def __init__(self, labels, classes_per_it, num_samples, iterations):
|
|
'''
|
|
Initialize MetaBatchSampler
|
|
Args:
|
|
- labels: an iterable containing all the labels for the current dataset
|
|
samples indexes will be infered from this iterable.
|
|
- classes_per_it: number of random classes for each iteration
|
|
- num_samples: number of samples for each iteration for each class (support + query)
|
|
- iterations: number of iterations (episodes) per epoch
|
|
'''
|
|
super(MetaBatchSampler, self).__init__()
|
|
self.labels = labels.copy()
|
|
self.classes_per_it = classes_per_it
|
|
self.sample_per_class = num_samples
|
|
self.iterations = iterations
|
|
|
|
self.classes, self.counts = np.unique(self.labels, return_counts=True)
|
|
assert len(self.classes) == np.max(self.classes) + 1 and np.min(self.classes) == 0
|
|
assert classes_per_it < len(self.classes), '{:} vs. {:}'.format(classes_per_it, len(self.classes))
|
|
self.classes = torch.LongTensor(self.classes)
|
|
|
|
# create a matrix, indexes, of dim: classes X max(elements per class)
|
|
# fill it with nans
|
|
# for every class c, fill the relative row with the indices samples belonging to c
|
|
# in numel_per_class we store the number of samples for each class/row
|
|
self.indexes = { x.item() : [] for x in self.classes }
|
|
indexes = { x.item() : [] for x in self.classes }
|
|
|
|
for idx, label in enumerate(self.labels):
|
|
indexes[ label.item() ].append( idx )
|
|
for key, value in indexes.items():
|
|
self.indexes[ key ] = torch.LongTensor( value )
|
|
|
|
|
|
def __iter__(self):
|
|
# yield a batch of indexes
|
|
spc = self.sample_per_class
|
|
cpi = self.classes_per_it
|
|
|
|
for it in range(self.iterations):
|
|
batch_size = spc * cpi
|
|
batch = torch.LongTensor(batch_size)
|
|
assert cpi < len(self.classes), '{:} vs. {:}'.format(cpi, len(self.classes))
|
|
c_idxs = torch.randperm(len(self.classes))[:cpi]
|
|
|
|
for i, cls in enumerate(self.classes[c_idxs]):
|
|
s = slice(i * spc, (i + 1) * spc)
|
|
num = self.indexes[ cls.item() ].nelement()
|
|
assert spc < num, '{:} vs. {:}'.format(spc, num)
|
|
sample_idxs = torch.randperm( num )[:spc]
|
|
batch[s] = self.indexes[ cls.item() ][sample_idxs]
|
|
|
|
batch = batch[torch.randperm(len(batch))]
|
|
yield batch
|
|
|
|
def __len__(self):
|
|
# returns the number of iterations (episodes) per epoch
|
|
return self.iterations
|