# Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # from os import path as osp from copy import deepcopy as copy from tqdm import tqdm import warnings, time, random, numpy as np from pts_utils import generate_label_map from xvision import denormalize_points from xvision import identity2affine, solve2theta, affine2image from .dataset_utils import pil_loader from .landmark_utils import PointMeta2V from .augmentation_utils import CutOut import torch import torch.utils.data as data class LandmarkDataset(data.Dataset): def __init__(self, transform, sigma, downsample, heatmap_type, shape, use_gray, mean_file, data_indicator, cache_images=None): self.transform = transform self.sigma = sigma self.downsample = downsample self.heatmap_type = heatmap_type self.dataset_name = data_indicator self.shape = shape # [H,W] self.use_gray = use_gray assert transform is not None, 'transform : {:}'.format(transform) self.mean_file = mean_file if mean_file is None: self.mean_data = None warnings.warn('LandmarkDataset initialized with mean_data = None') else: assert osp.isfile(mean_file), '{:} is not a file.'.format(mean_file) self.mean_data = torch.load(mean_file) self.reset() self.cutout = None self.cache_images = cache_images print ('The general dataset initialization done : {:}'.format(self)) warnings.simplefilter( 'once' ) def __repr__(self): return ('{name}(point-num={NUM_PTS}, shape={shape}, sigma={sigma}, heatmap_type={heatmap_type}, length={length}, cutout={cutout}, dataset={dataset_name}, mean={mean_file})'.format(name=self.__class__.__name__, **self.__dict__)) def set_cutout(self, length): if length is not None and length >= 1: self.cutout = CutOut( int(length) ) else: self.cutout = None def reset(self, num_pts=-1, boxid='default', only_pts=False): self.NUM_PTS = num_pts if only_pts: return self.length = 0 self.datas = [] self.labels = [] self.NormDistances = [] self.BOXID = boxid if self.mean_data is None: self.mean_face = None else: self.mean_face = torch.Tensor(self.mean_data[boxid].copy().T) assert (self.mean_face >= -1).all() and (self.mean_face <= 1).all(), 'mean-{:}-face : {:}'.format(boxid, self.mean_face) #assert self.dataset_name is not None, 'The dataset name is None' def __len__(self): assert len(self.datas) == self.length, 'The length is not correct : {}'.format(self.length) return self.length def append(self, data, label, distance): assert osp.isfile(data), 'The image path is not a file : {:}'.format(data) self.datas.append( data ) ; self.labels.append( label ) self.NormDistances.append( distance ) self.length = self.length + 1 def load_list(self, file_lists, num_pts, boxindicator, normalizeL, reset): if reset: self.reset(num_pts, boxindicator) else : assert self.NUM_PTS == num_pts and self.BOXID == boxindicator, 'The number of point is inconsistance : {:} vs {:}'.format(self.NUM_PTS, num_pts) if isinstance(file_lists, str): file_lists = [file_lists] samples = [] for idx, file_path in enumerate(file_lists): print (':::: load list {:}/{:} : {:}'.format(idx, len(file_lists), file_path)) xdata = torch.load(file_path) if isinstance(xdata, list) : data = xdata # image or video dataset list elif isinstance(xdata, dict): data = xdata['datas'] # multi-view dataset list else: raise ValueError('Invalid Type Error : {:}'.format( type(xdata) )) samples = samples + data # samples is a dict, where the key is the image-path and the value is the annotation # each annotation is a dict, contains 'points' (3,num_pts), and various box print ('GeneralDataset-V2 : {:} samples'.format(len(samples))) #for index, annotation in enumerate(samples): for index in tqdm( range( len(samples) ) ): annotation = samples[index] image_path = annotation['current_frame'] points, box = annotation['points'], annotation['box-{:}'.format(boxindicator)] label = PointMeta2V(self.NUM_PTS, points, box, image_path, self.dataset_name) if normalizeL is None: normDistance = None else : normDistance = annotation['normalizeL-{:}'.format(normalizeL)] self.append(image_path, label, normDistance) assert len(self.datas) == self.length, 'The length and the data is not right {} vs {}'.format(self.length, len(self.datas)) assert len(self.labels) == self.length, 'The length and the labels is not right {} vs {}'.format(self.length, len(self.labels)) assert len(self.NormDistances) == self.length, 'The length and the NormDistances is not right {} vs {}'.format(self.length, len(self.NormDistance)) print ('Load data done for LandmarkDataset, which has {:} images.'.format(self.length)) def __getitem__(self, index): assert index >= 0 and index < self.length, 'Invalid index : {:}'.format(index) if self.cache_images is not None and self.datas[index] in self.cache_images: image = self.cache_images[ self.datas[index] ].clone() else: image = pil_loader(self.datas[index], self.use_gray) target = self.labels[index].copy() return self._process_(image, target, index) def _process_(self, image, target, index): # transform the image and points image, target, theta = self.transform(image, target) (C, H, W), (height, width) = image.size(), self.shape # obtain the visiable indicator vector if target.is_none(): nopoints = True else : nopoints = False if index == -1: __path = None else : __path = self.datas[index] if isinstance(theta, list) or isinstance(theta, tuple): affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = [], [], [], [], [], [] for _theta in theta: _affineImage, _heatmaps, _mask, _norm_trans_points, _theta, _transpose_theta \ = self.__process_affine(image, target, _theta, nopoints, 'P[{:}]@{:}'.format(index, __path)) affineImage.append(_affineImage) heatmaps.append(_heatmaps) mask.append(_mask) norm_trans_points.append(_norm_trans_points) THETA.append(_theta) transpose_theta.append(_transpose_theta) affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = \ torch.stack(affineImage), torch.stack(heatmaps), torch.stack(mask), torch.stack(norm_trans_points), torch.stack(THETA), torch.stack(transpose_theta) else: affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = self.__process_affine(image, target, theta, nopoints, 'S[{:}]@{:}'.format(index, __path)) torch_index = torch.IntTensor([index]) torch_nopoints = torch.ByteTensor( [ nopoints ] ) torch_shape = torch.IntTensor([H,W]) return affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta, torch_index, torch_nopoints, torch_shape def __process_affine(self, image, target, theta, nopoints, aux_info=None): image, target, theta = image.clone(), target.copy(), theta.clone() (C, H, W), (height, width) = image.size(), self.shape if nopoints: # do not have label norm_trans_points = torch.zeros((3, self.NUM_PTS)) heatmaps = torch.zeros((self.NUM_PTS+1, height//self.downsample, width//self.downsample)) mask = torch.ones((self.NUM_PTS+1, 1, 1), dtype=torch.uint8) transpose_theta = identity2affine(False) else: norm_trans_points = apply_affine2point(target.get_points(), theta, (H,W)) norm_trans_points = apply_boundary(norm_trans_points) real_trans_points = norm_trans_points.clone() real_trans_points[:2, :] = denormalize_points(self.shape, real_trans_points[:2,:]) heatmaps, mask = generate_label_map(real_trans_points.numpy(), height//self.downsample, width//self.downsample, self.sigma, self.downsample, nopoints, self.heatmap_type) # H*W*C heatmaps = torch.from_numpy(heatmaps.transpose((2, 0, 1))).type(torch.FloatTensor) mask = torch.from_numpy(mask.transpose((2, 0, 1))).type(torch.ByteTensor) if self.mean_face is None: #warnings.warn('In LandmarkDataset use identity2affine for transpose_theta because self.mean_face is None.') transpose_theta = identity2affine(False) else: if torch.sum(norm_trans_points[2,:] == 1) < 3: warnings.warn('In LandmarkDataset after transformation, no visiable point, using identity instead. Aux: {:}'.format(aux_info)) transpose_theta = identity2affine(False) else: transpose_theta = solve2theta(norm_trans_points, self.mean_face.clone()) affineImage = affine2image(image, theta, self.shape) if self.cutout is not None: affineImage = self.cutout( affineImage ) return affineImage, heatmaps, mask, norm_trans_points, theta, transpose_theta