302 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			302 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Facebook, Inc. and its affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the license found in the
 | |
| # LICENSE file in the root directory of this source tree.
 | |
| #
 | |
| from os import path as osp
 | |
| from copy import deepcopy as copy
 | |
| from tqdm import tqdm
 | |
| import warnings, time, random, numpy as np
 | |
| 
 | |
| from pts_utils import generate_label_map
 | |
| from xvision import denormalize_points
 | |
| from xvision import identity2affine, solve2theta, affine2image
 | |
| from .dataset_utils import pil_loader
 | |
| from .landmark_utils import PointMeta2V
 | |
| from .augmentation_utils import CutOut
 | |
| import torch
 | |
| import torch.utils.data as data
 | |
| 
 | |
| 
 | |
| class LandmarkDataset(data.Dataset):
 | |
|     def __init__(
 | |
|         self,
 | |
|         transform,
 | |
|         sigma,
 | |
|         downsample,
 | |
|         heatmap_type,
 | |
|         shape,
 | |
|         use_gray,
 | |
|         mean_file,
 | |
|         data_indicator,
 | |
|         cache_images=None,
 | |
|     ):
 | |
| 
 | |
|         self.transform = transform
 | |
|         self.sigma = sigma
 | |
|         self.downsample = downsample
 | |
|         self.heatmap_type = heatmap_type
 | |
|         self.dataset_name = data_indicator
 | |
|         self.shape = shape  # [H,W]
 | |
|         self.use_gray = use_gray
 | |
|         assert transform is not None, "transform : {:}".format(transform)
 | |
|         self.mean_file = mean_file
 | |
|         if mean_file is None:
 | |
|             self.mean_data = None
 | |
|             warnings.warn("LandmarkDataset initialized with mean_data = None")
 | |
|         else:
 | |
|             assert osp.isfile(mean_file), "{:} is not a file.".format(mean_file)
 | |
|             self.mean_data = torch.load(mean_file)
 | |
|         self.reset()
 | |
|         self.cutout = None
 | |
|         self.cache_images = cache_images
 | |
|         print("The general dataset initialization done : {:}".format(self))
 | |
|         warnings.simplefilter("once")
 | |
| 
 | |
|     def __repr__(self):
 | |
|         return "{name}(point-num={NUM_PTS}, shape={shape}, sigma={sigma}, heatmap_type={heatmap_type}, length={length}, cutout={cutout}, dataset={dataset_name}, mean={mean_file})".format(
 | |
|             name=self.__class__.__name__, **self.__dict__
 | |
|         )
 | |
| 
 | |
|     def set_cutout(self, length):
 | |
|         if length is not None and length >= 1:
 | |
|             self.cutout = CutOut(int(length))
 | |
|         else:
 | |
|             self.cutout = None
 | |
| 
 | |
|     def reset(self, num_pts=-1, boxid="default", only_pts=False):
 | |
|         self.NUM_PTS = num_pts
 | |
|         if only_pts:
 | |
|             return
 | |
|         self.length = 0
 | |
|         self.datas = []
 | |
|         self.labels = []
 | |
|         self.NormDistances = []
 | |
|         self.BOXID = boxid
 | |
|         if self.mean_data is None:
 | |
|             self.mean_face = None
 | |
|         else:
 | |
|             self.mean_face = torch.Tensor(self.mean_data[boxid].copy().T)
 | |
|             assert (self.mean_face >= -1).all() and (
 | |
|                 self.mean_face <= 1
 | |
|             ).all(), "mean-{:}-face : {:}".format(boxid, self.mean_face)
 | |
|         # assert self.dataset_name is not None, 'The dataset name is None'
 | |
| 
 | |
|     def __len__(self):
 | |
|         assert len(self.datas) == self.length, "The length is not correct : {}".format(
 | |
|             self.length
 | |
|         )
 | |
|         return self.length
 | |
| 
 | |
|     def append(self, data, label, distance):
 | |
|         assert osp.isfile(data), "The image path is not a file : {:}".format(data)
 | |
|         self.datas.append(data)
 | |
|         self.labels.append(label)
 | |
|         self.NormDistances.append(distance)
 | |
|         self.length = self.length + 1
 | |
| 
 | |
|     def load_list(self, file_lists, num_pts, boxindicator, normalizeL, reset):
 | |
|         if reset:
 | |
|             self.reset(num_pts, boxindicator)
 | |
|         else:
 | |
|             assert (
 | |
|                 self.NUM_PTS == num_pts and self.BOXID == boxindicator
 | |
|             ), "The number of point is inconsistance : {:} vs {:}".format(
 | |
|                 self.NUM_PTS, num_pts
 | |
|             )
 | |
|         if isinstance(file_lists, str):
 | |
|             file_lists = [file_lists]
 | |
|         samples = []
 | |
|         for idx, file_path in enumerate(file_lists):
 | |
|             print(
 | |
|                 ":::: load list {:}/{:} : {:}".format(idx, len(file_lists), file_path)
 | |
|             )
 | |
|             xdata = torch.load(file_path)
 | |
|             if isinstance(xdata, list):
 | |
|                 data = xdata  # image or video dataset list
 | |
|             elif isinstance(xdata, dict):
 | |
|                 data = xdata["datas"]  # multi-view dataset list
 | |
|             else:
 | |
|                 raise ValueError("Invalid Type Error : {:}".format(type(xdata)))
 | |
|             samples = samples + data
 | |
|         # samples is a dict, where the key is the image-path and the value is the annotation
 | |
|         # each annotation is a dict, contains 'points' (3,num_pts), and various box
 | |
|         print("GeneralDataset-V2 : {:} samples".format(len(samples)))
 | |
| 
 | |
|         # for index, annotation in enumerate(samples):
 | |
|         for index in tqdm(range(len(samples))):
 | |
|             annotation = samples[index]
 | |
|             image_path = annotation["current_frame"]
 | |
|             points, box = (
 | |
|                 annotation["points"],
 | |
|                 annotation["box-{:}".format(boxindicator)],
 | |
|             )
 | |
|             label = PointMeta2V(
 | |
|                 self.NUM_PTS, points, box, image_path, self.dataset_name
 | |
|             )
 | |
|             if normalizeL is None:
 | |
|                 normDistance = None
 | |
|             else:
 | |
|                 normDistance = annotation["normalizeL-{:}".format(normalizeL)]
 | |
|             self.append(image_path, label, normDistance)
 | |
| 
 | |
|         assert (
 | |
|             len(self.datas) == self.length
 | |
|         ), "The length and the data is not right {} vs {}".format(
 | |
|             self.length, len(self.datas)
 | |
|         )
 | |
|         assert (
 | |
|             len(self.labels) == self.length
 | |
|         ), "The length and the labels is not right {} vs {}".format(
 | |
|             self.length, len(self.labels)
 | |
|         )
 | |
|         assert (
 | |
|             len(self.NormDistances) == self.length
 | |
|         ), "The length and the NormDistances is not right {} vs {}".format(
 | |
|             self.length, len(self.NormDistance)
 | |
|         )
 | |
|         print(
 | |
|             "Load data done for LandmarkDataset, which has {:} images.".format(
 | |
|                 self.length
 | |
|             )
 | |
|         )
 | |
| 
 | |
|     def __getitem__(self, index):
 | |
|         assert index >= 0 and index < self.length, "Invalid index : {:}".format(index)
 | |
|         if self.cache_images is not None and self.datas[index] in self.cache_images:
 | |
|             image = self.cache_images[self.datas[index]].clone()
 | |
|         else:
 | |
|             image = pil_loader(self.datas[index], self.use_gray)
 | |
|         target = self.labels[index].copy()
 | |
|         return self._process_(image, target, index)
 | |
| 
 | |
|     def _process_(self, image, target, index):
 | |
| 
 | |
|         # transform the image and points
 | |
|         image, target, theta = self.transform(image, target)
 | |
|         (C, H, W), (height, width) = image.size(), self.shape
 | |
| 
 | |
|         # obtain the visiable indicator vector
 | |
|         if target.is_none():
 | |
|             nopoints = True
 | |
|         else:
 | |
|             nopoints = False
 | |
|         if index == -1:
 | |
|             __path = None
 | |
|         else:
 | |
|             __path = self.datas[index]
 | |
|         if isinstance(theta, list) or isinstance(theta, tuple):
 | |
|             affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = (
 | |
|                 [],
 | |
|                 [],
 | |
|                 [],
 | |
|                 [],
 | |
|                 [],
 | |
|                 [],
 | |
|             )
 | |
|             for _theta in theta:
 | |
|                 (
 | |
|                     _affineImage,
 | |
|                     _heatmaps,
 | |
|                     _mask,
 | |
|                     _norm_trans_points,
 | |
|                     _theta,
 | |
|                     _transpose_theta,
 | |
|                 ) = self.__process_affine(
 | |
|                     image, target, _theta, nopoints, "P[{:}]@{:}".format(index, __path)
 | |
|                 )
 | |
|                 affineImage.append(_affineImage)
 | |
|                 heatmaps.append(_heatmaps)
 | |
|                 mask.append(_mask)
 | |
|                 norm_trans_points.append(_norm_trans_points)
 | |
|                 THETA.append(_theta)
 | |
|                 transpose_theta.append(_transpose_theta)
 | |
|             affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = (
 | |
|                 torch.stack(affineImage),
 | |
|                 torch.stack(heatmaps),
 | |
|                 torch.stack(mask),
 | |
|                 torch.stack(norm_trans_points),
 | |
|                 torch.stack(THETA),
 | |
|                 torch.stack(transpose_theta),
 | |
|             )
 | |
|         else:
 | |
|             (
 | |
|                 affineImage,
 | |
|                 heatmaps,
 | |
|                 mask,
 | |
|                 norm_trans_points,
 | |
|                 THETA,
 | |
|                 transpose_theta,
 | |
|             ) = self.__process_affine(
 | |
|                 image, target, theta, nopoints, "S[{:}]@{:}".format(index, __path)
 | |
|             )
 | |
| 
 | |
|         torch_index = torch.IntTensor([index])
 | |
|         torch_nopoints = torch.ByteTensor([nopoints])
 | |
|         torch_shape = torch.IntTensor([H, W])
 | |
| 
 | |
|         return (
 | |
|             affineImage,
 | |
|             heatmaps,
 | |
|             mask,
 | |
|             norm_trans_points,
 | |
|             THETA,
 | |
|             transpose_theta,
 | |
|             torch_index,
 | |
|             torch_nopoints,
 | |
|             torch_shape,
 | |
|         )
 | |
| 
 | |
|     def __process_affine(self, image, target, theta, nopoints, aux_info=None):
 | |
|         image, target, theta = image.clone(), target.copy(), theta.clone()
 | |
|         (C, H, W), (height, width) = image.size(), self.shape
 | |
|         if nopoints:  # do not have label
 | |
|             norm_trans_points = torch.zeros((3, self.NUM_PTS))
 | |
|             heatmaps = torch.zeros(
 | |
|                 (self.NUM_PTS + 1, height // self.downsample, width // self.downsample)
 | |
|             )
 | |
|             mask = torch.ones((self.NUM_PTS + 1, 1, 1), dtype=torch.uint8)
 | |
|             transpose_theta = identity2affine(False)
 | |
|         else:
 | |
|             norm_trans_points = apply_affine2point(target.get_points(), theta, (H, W))
 | |
|             norm_trans_points = apply_boundary(norm_trans_points)
 | |
|             real_trans_points = norm_trans_points.clone()
 | |
|             real_trans_points[:2, :] = denormalize_points(
 | |
|                 self.shape, real_trans_points[:2, :]
 | |
|             )
 | |
|             heatmaps, mask = generate_label_map(
 | |
|                 real_trans_points.numpy(),
 | |
|                 height // self.downsample,
 | |
|                 width // self.downsample,
 | |
|                 self.sigma,
 | |
|                 self.downsample,
 | |
|                 nopoints,
 | |
|                 self.heatmap_type,
 | |
|             )  # H*W*C
 | |
|             heatmaps = torch.from_numpy(heatmaps.transpose((2, 0, 1))).type(
 | |
|                 torch.FloatTensor
 | |
|             )
 | |
|             mask = torch.from_numpy(mask.transpose((2, 0, 1))).type(torch.ByteTensor)
 | |
|             if self.mean_face is None:
 | |
|                 # warnings.warn('In LandmarkDataset use identity2affine for transpose_theta because self.mean_face is None.')
 | |
|                 transpose_theta = identity2affine(False)
 | |
|             else:
 | |
|                 if torch.sum(norm_trans_points[2, :] == 1) < 3:
 | |
|                     warnings.warn(
 | |
|                         "In LandmarkDataset after transformation, no visiable point, using identity instead. Aux: {:}".format(
 | |
|                             aux_info
 | |
|                         )
 | |
|                     )
 | |
|                     transpose_theta = identity2affine(False)
 | |
|                 else:
 | |
|                     transpose_theta = solve2theta(
 | |
|                         norm_trans_points, self.mean_face.clone()
 | |
|                     )
 | |
| 
 | |
|         affineImage = affine2image(image, theta, self.shape)
 | |
|         if self.cutout is not None:
 | |
|             affineImage = self.cutout(affineImage)
 | |
| 
 | |
|         return affineImage, heatmaps, mask, norm_trans_points, theta, transpose_theta
 |