fix multi-batch inference
This commit is contained in:
		| @@ -23,11 +23,11 @@ class CoTrackerPredictor(torch.nn.Module): | ||||
|     @torch.no_grad() | ||||
|     def forward( | ||||
|         self, | ||||
|         video,  # (1, T, 3, H, W) | ||||
|         video,  # (B, T, 3, H, W) | ||||
|         # input prompt types: | ||||
|         # - None. Dense tracks are computed in this case. You can adjust *query_frame* to compute tracks starting from a specific frame. | ||||
|         # *backward_tracking=True* will compute tracks in both directions. | ||||
|         # - queries. Queried points of shape (1, N, 3) in format (t, x, y) for frame index and pixel coordinates. | ||||
|         # - queries. Queried points of shape (B, N, 3) in format (t, x, y) for frame index and pixel coordinates. | ||||
|         # - grid_size. Grid of N*N points from the first frame. if segm_mask is provided, then computed only for the mask. | ||||
|         # You can adjust *query_frame* and *backward_tracking* for the regular grid in the same way as for dense tracks. | ||||
|         queries: torch.Tensor = None, | ||||
| @@ -120,13 +120,14 @@ class CoTrackerPredictor(torch.nn.Module): | ||||
|             queries = torch.cat( | ||||
|                 [torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts], | ||||
|                 dim=2, | ||||
|             ) | ||||
|             ).repeat(B, 1, 1) | ||||
|  | ||||
|         if add_support_grid: | ||||
|             grid_pts = get_points_on_a_grid( | ||||
|                 self.support_grid_size, self.interp_shape, device=video.device | ||||
|             ) | ||||
|             grid_pts = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2) | ||||
|             grid_pts = grid_pts.repeat(B, 1, 1) | ||||
|             queries = torch.cat([queries, grid_pts], dim=1) | ||||
|  | ||||
|         tracks, visibilities, __ = self.model.forward(video=video, queries=queries, iters=6) | ||||
| @@ -173,7 +174,7 @@ class CoTrackerPredictor(torch.nn.Module): | ||||
|         inv_visibilities = inv_visibilities.flip(1) | ||||
|         arange = torch.arange(video.shape[1], device=queries.device)[None, :, None] | ||||
|  | ||||
|         mask = (arange < queries[None, :, :, 0]).unsqueeze(-1).repeat(1, 1, 1, 2) | ||||
|         mask = (arange < queries[:, None, :, 0]).unsqueeze(-1).repeat(1, 1, 1, 2) | ||||
|  | ||||
|         tracks[mask] = inv_tracks[mask] | ||||
|         visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]] | ||||
|   | ||||
		Reference in New Issue
	
	Block a user