mps / cpu support
This commit is contained in:
		| @@ -25,14 +25,14 @@ from cotracker.models.core.embeddings import ( | |||||||
| torch.manual_seed(0) | torch.manual_seed(0) | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0)): | def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0), device='cuda'): | ||||||
|     if grid_size == 1: |     if grid_size == 1: | ||||||
|         return torch.tensor([interp_shape[1] / 2, interp_shape[0] / 2])[ |         return torch.tensor([interp_shape[1] / 2, interp_shape[0] / 2])[ | ||||||
|             None, None |             None, None | ||||||
|         ].cuda() |         ].to(device) | ||||||
|  |  | ||||||
|     grid_y, grid_x = meshgrid2d( |     grid_y, grid_x = meshgrid2d( | ||||||
|         1, grid_size, grid_size, stack=False, norm=False, device="cuda" |         1, grid_size, grid_size, stack=False, norm=False, device=device | ||||||
|     ) |     ) | ||||||
|     step = interp_shape[1] // 64 |     step = interp_shape[1] // 64 | ||||||
|     if grid_center[0] != 0 or grid_center[1] != 0: |     if grid_center[0] != 0 or grid_center[1] != 0: | ||||||
| @@ -47,7 +47,7 @@ def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0)): | |||||||
|  |  | ||||||
|     grid_y = grid_y + grid_center[0] |     grid_y = grid_y + grid_center[0] | ||||||
|     grid_x = grid_x + grid_center[1] |     grid_x = grid_x + grid_center[1] | ||||||
|     xy = torch.stack([grid_x, grid_y], dim=-1).cuda() |     xy = torch.stack([grid_x, grid_y], dim=-1).to(device) | ||||||
|     return xy |     return xy | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
| @@ -17,7 +17,7 @@ from cotracker.models.build_cotracker import ( | |||||||
|  |  | ||||||
| class CoTrackerPredictor(torch.nn.Module): | class CoTrackerPredictor(torch.nn.Module): | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, checkpoint="cotracker/checkpoints/cotracker_stride_4_wind_8.pth" |         self, checkpoint="cotracker/checkpoints/cotracker_stride_4_wind_8.pth", device=None | ||||||
|     ): |     ): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|         self.interp_shape = (384, 512) |         self.interp_shape = (384, 512) | ||||||
| @@ -25,7 +25,8 @@ class CoTrackerPredictor(torch.nn.Module): | |||||||
|         model = build_cotracker(checkpoint) |         model = build_cotracker(checkpoint) | ||||||
|  |  | ||||||
|         self.model = model |         self.model = model | ||||||
|         self.model.to("cuda") |         self.device = device or 'cuda' | ||||||
|  |         self.model.to(self.device) | ||||||
|         self.model.eval() |         self.model.eval() | ||||||
|  |  | ||||||
|     @torch.no_grad() |     @torch.no_grad() | ||||||
| @@ -72,7 +73,7 @@ class CoTrackerPredictor(torch.nn.Module): | |||||||
|         grid_width = W // grid_step |         grid_width = W // grid_step | ||||||
|         grid_height = H // grid_step |         grid_height = H // grid_step | ||||||
|         tracks = visibilities = None |         tracks = visibilities = None | ||||||
|         grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to("cuda") |         grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(self.device) | ||||||
|         grid_pts[0, :, 0] = grid_query_frame |         grid_pts[0, :, 0] = grid_query_frame | ||||||
|         for offset in tqdm(range(grid_step * grid_step)): |         for offset in tqdm(range(grid_step * grid_step)): | ||||||
|             ox = offset % grid_step |             ox = offset % grid_step | ||||||
| @@ -107,10 +108,10 @@ class CoTrackerPredictor(torch.nn.Module): | |||||||
|         assert B == 1 |         assert B == 1 | ||||||
|  |  | ||||||
|         video = video.reshape(B * T, C, H, W) |         video = video.reshape(B * T, C, H, W) | ||||||
|         video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear").cuda() |         video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear").to(self.device) | ||||||
|         video = video.reshape( |         video = video.reshape( | ||||||
|             B, T, 3, self.interp_shape[0], self.interp_shape[1] |             B, T, 3, self.interp_shape[0], self.interp_shape[1] | ||||||
|         ).cuda() |         ).to(self.device) | ||||||
|  |  | ||||||
|         if queries is not None: |         if queries is not None: | ||||||
|             queries = queries.clone() |             queries = queries.clone() | ||||||
| @@ -119,7 +120,7 @@ class CoTrackerPredictor(torch.nn.Module): | |||||||
|             queries[:, :, 1] *= self.interp_shape[1] / W |             queries[:, :, 1] *= self.interp_shape[1] / W | ||||||
|             queries[:, :, 2] *= self.interp_shape[0] / H |             queries[:, :, 2] *= self.interp_shape[0] / H | ||||||
|         elif grid_size > 0: |         elif grid_size > 0: | ||||||
|             grid_pts = get_points_on_a_grid(grid_size, self.interp_shape) |             grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=self.device) | ||||||
|             if segm_mask is not None: |             if segm_mask is not None: | ||||||
|                 segm_mask = F.interpolate( |                 segm_mask = F.interpolate( | ||||||
|                     segm_mask, tuple(self.interp_shape), mode="nearest" |                     segm_mask, tuple(self.interp_shape), mode="nearest" | ||||||
| @@ -136,7 +137,7 @@ class CoTrackerPredictor(torch.nn.Module): | |||||||
|             ) |             ) | ||||||
|  |  | ||||||
|         if add_support_grid: |         if add_support_grid: | ||||||
|             grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape) |             grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape, device=self.device) | ||||||
|             grid_pts = torch.cat( |             grid_pts = torch.cat( | ||||||
|                 [torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2 |                 [torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2 | ||||||
|             ) |             ) | ||||||
|   | |||||||
| @@ -63,6 +63,7 @@ class Visualizer: | |||||||
|         self, |         self, | ||||||
|         video: torch.Tensor,  # (B,T,C,H,W) |         video: torch.Tensor,  # (B,T,C,H,W) | ||||||
|         tracks: torch.Tensor,  # (B,T,N,2) |         tracks: torch.Tensor,  # (B,T,N,2) | ||||||
|  |         visibility: torch.Tensor,  # (B, T, N, 1) bool | ||||||
|         gt_tracks: torch.Tensor = None,  # (B,T,N,2) |         gt_tracks: torch.Tensor = None,  # (B,T,N,2) | ||||||
|         segm_mask: torch.Tensor = None,  # (B,1,H,W) |         segm_mask: torch.Tensor = None,  # (B,1,H,W) | ||||||
|         filename: str = "video", |         filename: str = "video", | ||||||
| @@ -94,6 +95,7 @@ class Visualizer: | |||||||
|         res_video = self.draw_tracks_on_video( |         res_video = self.draw_tracks_on_video( | ||||||
|             video=video, |             video=video, | ||||||
|             tracks=tracks, |             tracks=tracks, | ||||||
|  |             visibility=visibility, | ||||||
|             segm_mask=segm_mask, |             segm_mask=segm_mask, | ||||||
|             gt_tracks=gt_tracks, |             gt_tracks=gt_tracks, | ||||||
|             query_frame=query_frame, |             query_frame=query_frame, | ||||||
| @@ -127,6 +129,7 @@ class Visualizer: | |||||||
|         self, |         self, | ||||||
|         video: torch.Tensor, |         video: torch.Tensor, | ||||||
|         tracks: torch.Tensor, |         tracks: torch.Tensor, | ||||||
|  |         visibility: torch.Tensor, | ||||||
|         segm_mask: torch.Tensor = None, |         segm_mask: torch.Tensor = None, | ||||||
|         gt_tracks=None, |         gt_tracks=None, | ||||||
|         query_frame: int = 0, |         query_frame: int = 0, | ||||||
| @@ -228,11 +231,13 @@ class Visualizer: | |||||||
|                     if not compensate_for_camera_motion or ( |                     if not compensate_for_camera_motion or ( | ||||||
|                         compensate_for_camera_motion and segm_mask[i] > 0 |                         compensate_for_camera_motion and segm_mask[i] > 0 | ||||||
|                     ): |                     ): | ||||||
|  |  | ||||||
|                         cv2.circle( |                         cv2.circle( | ||||||
|                             res_video[t], |                             res_video[t], | ||||||
|                             coord, |                             coord, | ||||||
|                             int(self.linewidth * 2), |                             int(self.linewidth * 2), | ||||||
|                             vector_colors[t, i].tolist(), |                             vector_colors[t, i].tolist(), | ||||||
|  |                             thickness=-1 if visibility[0, t, i] else 2 | ||||||
|                             -1, |                             -1, | ||||||
|                         ) |                         ) | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										7
									
								
								demo.py
									
									
									
									
									
								
							
							
						
						
									
										7
									
								
								demo.py
									
									
									
									
									
								
							| @@ -32,6 +32,11 @@ if __name__ == "__main__": | |||||||
|         default="./checkpoints/cotracker_stride_4_wind_8.pth", |         default="./checkpoints/cotracker_stride_4_wind_8.pth", | ||||||
|         help="cotracker model", |         help="cotracker model", | ||||||
|     ) |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--device", | ||||||
|  |         default="cuda", | ||||||
|  |         help="Device to use for inference", | ||||||
|  |     ) | ||||||
|     parser.add_argument("--grid_size", type=int, default=0, help="Regular grid size") |     parser.add_argument("--grid_size", type=int, default=0, help="Regular grid size") | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--grid_query_frame", |         "--grid_query_frame", | ||||||
| @@ -54,7 +59,7 @@ if __name__ == "__main__": | |||||||
|     segm_mask = np.array(Image.open(os.path.join(args.mask_path))) |     segm_mask = np.array(Image.open(os.path.join(args.mask_path))) | ||||||
|     segm_mask = torch.from_numpy(segm_mask)[None, None] |     segm_mask = torch.from_numpy(segm_mask)[None, None] | ||||||
|  |  | ||||||
|     model = CoTrackerPredictor(checkpoint=args.checkpoint) |     model = CoTrackerPredictor(checkpoint=args.checkpoint, device=args.device) | ||||||
|  |  | ||||||
|     pred_tracks, pred_visibility = model( |     pred_tracks, pred_visibility = model( | ||||||
|         video, |         video, | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user