# Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # from os import path as osp from copy import deepcopy as copy from tqdm import tqdm import warnings, time, random, numpy as np from pts_utils import generate_label_map from xvision import denormalize_points from xvision import identity2affine, solve2theta, affine2image from .dataset_utils import pil_loader from .landmark_utils import PointMeta2V from .augmentation_utils import CutOut import torch import torch.utils.data as data class LandmarkDataset(data.Dataset): def __init__( self, transform, sigma, downsample, heatmap_type, shape, use_gray, mean_file, data_indicator, cache_images=None, ): self.transform = transform self.sigma = sigma self.downsample = downsample self.heatmap_type = heatmap_type self.dataset_name = data_indicator self.shape = shape # [H,W] self.use_gray = use_gray assert transform is not None, "transform : {:}".format(transform) self.mean_file = mean_file if mean_file is None: self.mean_data = None warnings.warn("LandmarkDataset initialized with mean_data = None") else: assert osp.isfile(mean_file), "{:} is not a file.".format(mean_file) self.mean_data = torch.load(mean_file) self.reset() self.cutout = None self.cache_images = cache_images print("The general dataset initialization done : {:}".format(self)) warnings.simplefilter("once") def __repr__(self): return "{name}(point-num={NUM_PTS}, shape={shape}, sigma={sigma}, heatmap_type={heatmap_type}, length={length}, cutout={cutout}, dataset={dataset_name}, mean={mean_file})".format( name=self.__class__.__name__, **self.__dict__ ) def set_cutout(self, length): if length is not None and length >= 1: self.cutout = CutOut(int(length)) else: self.cutout = None def reset(self, num_pts=-1, boxid="default", only_pts=False): self.NUM_PTS = num_pts if only_pts: return self.length = 0 self.datas = [] self.labels = [] self.NormDistances = [] self.BOXID = boxid if self.mean_data is None: self.mean_face = None else: self.mean_face = torch.Tensor(self.mean_data[boxid].copy().T) assert (self.mean_face >= -1).all() and ( self.mean_face <= 1 ).all(), "mean-{:}-face : {:}".format(boxid, self.mean_face) # assert self.dataset_name is not None, 'The dataset name is None' def __len__(self): assert len(self.datas) == self.length, "The length is not correct : {}".format( self.length ) return self.length def append(self, data, label, distance): assert osp.isfile(data), "The image path is not a file : {:}".format(data) self.datas.append(data) self.labels.append(label) self.NormDistances.append(distance) self.length = self.length + 1 def load_list(self, file_lists, num_pts, boxindicator, normalizeL, reset): if reset: self.reset(num_pts, boxindicator) else: assert ( self.NUM_PTS == num_pts and self.BOXID == boxindicator ), "The number of point is inconsistance : {:} vs {:}".format( self.NUM_PTS, num_pts ) if isinstance(file_lists, str): file_lists = [file_lists] samples = [] for idx, file_path in enumerate(file_lists): print( ":::: load list {:}/{:} : {:}".format(idx, len(file_lists), file_path) ) xdata = torch.load(file_path) if isinstance(xdata, list): data = xdata # image or video dataset list elif isinstance(xdata, dict): data = xdata["datas"] # multi-view dataset list else: raise ValueError("Invalid Type Error : {:}".format(type(xdata))) samples = samples + data # samples is a dict, where the key is the image-path and the value is the annotation # each annotation is a dict, contains 'points' (3,num_pts), and various box print("GeneralDataset-V2 : {:} samples".format(len(samples))) # for index, annotation in enumerate(samples): for index in tqdm(range(len(samples))): annotation = samples[index] image_path = annotation["current_frame"] points, box = ( annotation["points"], annotation["box-{:}".format(boxindicator)], ) label = PointMeta2V( self.NUM_PTS, points, box, image_path, self.dataset_name ) if normalizeL is None: normDistance = None else: normDistance = annotation["normalizeL-{:}".format(normalizeL)] self.append(image_path, label, normDistance) assert ( len(self.datas) == self.length ), "The length and the data is not right {} vs {}".format( self.length, len(self.datas) ) assert ( len(self.labels) == self.length ), "The length and the labels is not right {} vs {}".format( self.length, len(self.labels) ) assert ( len(self.NormDistances) == self.length ), "The length and the NormDistances is not right {} vs {}".format( self.length, len(self.NormDistance) ) print( "Load data done for LandmarkDataset, which has {:} images.".format( self.length ) ) def __getitem__(self, index): assert index >= 0 and index < self.length, "Invalid index : {:}".format(index) if self.cache_images is not None and self.datas[index] in self.cache_images: image = self.cache_images[self.datas[index]].clone() else: image = pil_loader(self.datas[index], self.use_gray) target = self.labels[index].copy() return self._process_(image, target, index) def _process_(self, image, target, index): # transform the image and points image, target, theta = self.transform(image, target) (C, H, W), (height, width) = image.size(), self.shape # obtain the visiable indicator vector if target.is_none(): nopoints = True else: nopoints = False if index == -1: __path = None else: __path = self.datas[index] if isinstance(theta, list) or isinstance(theta, tuple): affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = ( [], [], [], [], [], [], ) for _theta in theta: ( _affineImage, _heatmaps, _mask, _norm_trans_points, _theta, _transpose_theta, ) = self.__process_affine( image, target, _theta, nopoints, "P[{:}]@{:}".format(index, __path) ) affineImage.append(_affineImage) heatmaps.append(_heatmaps) mask.append(_mask) norm_trans_points.append(_norm_trans_points) THETA.append(_theta) transpose_theta.append(_transpose_theta) affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = ( torch.stack(affineImage), torch.stack(heatmaps), torch.stack(mask), torch.stack(norm_trans_points), torch.stack(THETA), torch.stack(transpose_theta), ) else: ( affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta, ) = self.__process_affine( image, target, theta, nopoints, "S[{:}]@{:}".format(index, __path) ) torch_index = torch.IntTensor([index]) torch_nopoints = torch.ByteTensor([nopoints]) torch_shape = torch.IntTensor([H, W]) return ( affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta, torch_index, torch_nopoints, torch_shape, ) def __process_affine(self, image, target, theta, nopoints, aux_info=None): image, target, theta = image.clone(), target.copy(), theta.clone() (C, H, W), (height, width) = image.size(), self.shape if nopoints: # do not have label norm_trans_points = torch.zeros((3, self.NUM_PTS)) heatmaps = torch.zeros( (self.NUM_PTS + 1, height // self.downsample, width // self.downsample) ) mask = torch.ones((self.NUM_PTS + 1, 1, 1), dtype=torch.uint8) transpose_theta = identity2affine(False) else: norm_trans_points = apply_affine2point(target.get_points(), theta, (H, W)) norm_trans_points = apply_boundary(norm_trans_points) real_trans_points = norm_trans_points.clone() real_trans_points[:2, :] = denormalize_points( self.shape, real_trans_points[:2, :] ) heatmaps, mask = generate_label_map( real_trans_points.numpy(), height // self.downsample, width // self.downsample, self.sigma, self.downsample, nopoints, self.heatmap_type, ) # H*W*C heatmaps = torch.from_numpy(heatmaps.transpose((2, 0, 1))).type( torch.FloatTensor ) mask = torch.from_numpy(mask.transpose((2, 0, 1))).type(torch.ByteTensor) if self.mean_face is None: # warnings.warn('In LandmarkDataset use identity2affine for transpose_theta because self.mean_face is None.') transpose_theta = identity2affine(False) else: if torch.sum(norm_trans_points[2, :] == 1) < 3: warnings.warn( "In LandmarkDataset after transformation, no visiable point, using identity instead. Aux: {:}".format( aux_info ) ) transpose_theta = identity2affine(False) else: transpose_theta = solve2theta( norm_trans_points, self.mean_face.clone() ) affineImage = affine2image(image, theta, self.shape) if self.cutout is not None: affineImage = self.cutout(affineImage) return affineImage, heatmaps, mask, norm_trans_points, theta, transpose_theta