220 lines
		
	
	
		
			6.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			220 lines
		
	
	
		
			6.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Facebook, Inc. and its affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the license found in the
 | |
| # LICENSE file in the root directory of this source tree.
 | |
| #
 | |
| import copy, math, torch, numpy as np
 | |
| from xvision import normalize_points
 | |
| from xvision import denormalize_points
 | |
| 
 | |
| 
 | |
| class PointMeta:
 | |
|     # points    : 3 x num_pts (x, y, oculusion)
 | |
|     # image_size: original [width, height]
 | |
|     def __init__(self, num_point, points, box, image_path, dataset_name):
 | |
| 
 | |
|         self.num_point = num_point
 | |
|         if box is not None:
 | |
|             assert (isinstance(box, tuple) or isinstance(box, list)) and len(box) == 4
 | |
|             self.box = torch.Tensor(box)
 | |
|         else:
 | |
|             self.box = None
 | |
|         if points is None:
 | |
|             self.points = points
 | |
|         else:
 | |
|             assert (
 | |
|                 len(points.shape) == 2
 | |
|                 and points.shape[0] == 3
 | |
|                 and points.shape[1] == self.num_point
 | |
|             ), "The shape of point is not right : {}".format(points)
 | |
|             self.points = torch.Tensor(points.copy())
 | |
|         self.image_path = image_path
 | |
|         self.datasets = dataset_name
 | |
| 
 | |
|     def __repr__(self):
 | |
|         if self.box is None:
 | |
|             boxstr = "None"
 | |
|         else:
 | |
|             boxstr = "box=[{:.1f}, {:.1f}, {:.1f}, {:.1f}]".format(*self.box.tolist())
 | |
|         return (
 | |
|             "{name}(points={num_point}, ".format(
 | |
|                 name=self.__class__.__name__, **self.__dict__
 | |
|             )
 | |
|             + boxstr
 | |
|             + ")"
 | |
|         )
 | |
| 
 | |
|     def get_box(self, return_diagonal=False):
 | |
|         if self.box is None:
 | |
|             return None
 | |
|         if not return_diagonal:
 | |
|             return self.box.clone()
 | |
|         else:
 | |
|             W = (self.box[2] - self.box[0]).item()
 | |
|             H = (self.box[3] - self.box[1]).item()
 | |
|             return math.sqrt(H * H + W * W)
 | |
| 
 | |
|     def get_points(self, ignore_indicator=False):
 | |
|         if ignore_indicator:
 | |
|             last = 2
 | |
|         else:
 | |
|             last = 3
 | |
|         if self.points is not None:
 | |
|             return self.points.clone()[:last, :]
 | |
|         else:
 | |
|             return torch.zeros((last, self.num_point))
 | |
| 
 | |
|     def is_none(self):
 | |
|         # assert self.box is not None, 'The box should not be None'
 | |
|         return self.points is None
 | |
|         # if self.box is None: return True
 | |
|         # else               : return self.points is None
 | |
| 
 | |
|     def copy(self):
 | |
|         return copy.deepcopy(self)
 | |
| 
 | |
|     def visiable_pts_num(self):
 | |
|         with torch.no_grad():
 | |
|             ans = self.points[2, :] > 0
 | |
|             ans = torch.sum(ans)
 | |
|             ans = ans.item()
 | |
|         return ans
 | |
| 
 | |
|     def special_fun(self, indicator):
 | |
|         if (
 | |
|             indicator == "68to49"
 | |
|         ):  # For 300W or 300VW, convert the default 68 points to 49 points.
 | |
|             assert self.num_point == 68, "num-point must be 68 vs. {:}".format(
 | |
|                 self.num_point
 | |
|             )
 | |
|             self.num_point = 49
 | |
|             out = torch.ones((68), dtype=torch.uint8)
 | |
|             out[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 60, 64]] = 0
 | |
|             if self.points is not None:
 | |
|                 self.points = self.points.clone()[:, out]
 | |
|         else:
 | |
|             raise ValueError("Invalid indicator : {:}".format(indicator))
 | |
| 
 | |
|     def apply_horizontal_flip(self):
 | |
|         # self.points[0, :] = width - self.points[0, :] - 1
 | |
|         # Mugsy spefic or Synthetic
 | |
|         if self.datasets.startswith("HandsyROT"):
 | |
|             ori = np.array(list(range(0, 42)))
 | |
|             pos = np.array(list(range(21, 42)) + list(range(0, 21)))
 | |
|             self.points[:, pos] = self.points[:, ori]
 | |
|         elif self.datasets.startswith("face68"):
 | |
|             ori = np.array(list(range(0, 68)))
 | |
|             pos = (
 | |
|                 np.array(
 | |
|                     [
 | |
|                         17,
 | |
|                         16,
 | |
|                         15,
 | |
|                         14,
 | |
|                         13,
 | |
|                         12,
 | |
|                         11,
 | |
|                         10,
 | |
|                         9,
 | |
|                         8,
 | |
|                         7,
 | |
|                         6,
 | |
|                         5,
 | |
|                         4,
 | |
|                         3,
 | |
|                         2,
 | |
|                         1,
 | |
|                         27,
 | |
|                         26,
 | |
|                         25,
 | |
|                         24,
 | |
|                         23,
 | |
|                         22,
 | |
|                         21,
 | |
|                         20,
 | |
|                         19,
 | |
|                         18,
 | |
|                         28,
 | |
|                         29,
 | |
|                         30,
 | |
|                         31,
 | |
|                         36,
 | |
|                         35,
 | |
|                         34,
 | |
|                         33,
 | |
|                         32,
 | |
|                         46,
 | |
|                         45,
 | |
|                         44,
 | |
|                         43,
 | |
|                         48,
 | |
|                         47,
 | |
|                         40,
 | |
|                         39,
 | |
|                         38,
 | |
|                         37,
 | |
|                         42,
 | |
|                         41,
 | |
|                         55,
 | |
|                         54,
 | |
|                         53,
 | |
|                         52,
 | |
|                         51,
 | |
|                         50,
 | |
|                         49,
 | |
|                         60,
 | |
|                         59,
 | |
|                         58,
 | |
|                         57,
 | |
|                         56,
 | |
|                         65,
 | |
|                         64,
 | |
|                         63,
 | |
|                         62,
 | |
|                         61,
 | |
|                         68,
 | |
|                         67,
 | |
|                         66,
 | |
|                     ]
 | |
|                 )
 | |
|                 - 1
 | |
|             )
 | |
|             self.points[:, ori] = self.points[:, pos]
 | |
|         else:
 | |
|             raise ValueError("Does not support {:}".format(self.datasets))
 | |
| 
 | |
| 
 | |
| # shape = (H,W)
 | |
| def apply_affine2point(points, theta, shape):
 | |
|     assert points.size(0) == 3, "invalid points shape : {:}".format(points.size())
 | |
|     with torch.no_grad():
 | |
|         ok_points = points[2, :] == 1
 | |
|         assert torch.sum(ok_points).item() > 0, "there is no visiable point"
 | |
|         points[:2, :] = normalize_points(shape, points[:2, :])
 | |
| 
 | |
|         norm_trans_points = ok_points.unsqueeze(0).repeat(3, 1).float()
 | |
| 
 | |
|         trans_points, ___ = torch.gesv(points[:, ok_points], theta)
 | |
| 
 | |
|         norm_trans_points[:, ok_points] = trans_points
 | |
| 
 | |
|     return norm_trans_points
 | |
| 
 | |
| 
 | |
| def apply_boundary(norm_trans_points):
 | |
|     with torch.no_grad():
 | |
|         norm_trans_points = norm_trans_points.clone()
 | |
|         oks = torch.stack(
 | |
|             (
 | |
|                 norm_trans_points[0] > -1,
 | |
|                 norm_trans_points[0] < 1,
 | |
|                 norm_trans_points[1] > -1,
 | |
|                 norm_trans_points[1] < 1,
 | |
|                 norm_trans_points[2] > 0,
 | |
|             )
 | |
|         )
 | |
|         oks = torch.sum(oks, dim=0) == 5
 | |
|         norm_trans_points[2, :] = oks
 | |
|     return norm_trans_points
 |