105 lines
3.6 KiB
Python
105 lines
3.6 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from typing import Tuple
|
|
|
|
from cotracker.models.core.cotracker.cotracker import CoTracker2
|
|
from cotracker.models.core.model_utils import get_points_on_a_grid
|
|
|
|
|
|
class EvaluationPredictor(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
cotracker_model: CoTracker2,
|
|
interp_shape: Tuple[int, int] = (384, 512),
|
|
grid_size: int = 5,
|
|
local_grid_size: int = 8,
|
|
single_point: bool = True,
|
|
n_iters: int = 6,
|
|
) -> None:
|
|
super(EvaluationPredictor, self).__init__()
|
|
self.grid_size = grid_size
|
|
self.local_grid_size = local_grid_size
|
|
self.single_point = single_point
|
|
self.interp_shape = interp_shape
|
|
self.n_iters = n_iters
|
|
|
|
self.model = cotracker_model
|
|
self.model.eval()
|
|
|
|
def forward(self, video, queries):
|
|
queries = queries.clone()
|
|
B, T, C, H, W = video.shape
|
|
B, N, D = queries.shape
|
|
|
|
assert D == 3
|
|
|
|
video = video.reshape(B * T, C, H, W)
|
|
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True)
|
|
video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
|
|
|
|
device = video.device
|
|
|
|
queries[:, :, 1] *= (self.interp_shape[1] - 1) / (W - 1)
|
|
queries[:, :, 2] *= (self.interp_shape[0] - 1) / (H - 1)
|
|
|
|
if self.single_point:
|
|
traj_e = torch.zeros((B, T, N, 2), device=device)
|
|
vis_e = torch.zeros((B, T, N), device=device)
|
|
for pind in range((N)):
|
|
query = queries[:, pind : pind + 1]
|
|
|
|
t = query[0, 0, 0].long()
|
|
|
|
traj_e_pind, vis_e_pind = self._process_one_point(video, query)
|
|
traj_e[:, t:, pind : pind + 1] = traj_e_pind[:, :, :1]
|
|
vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1]
|
|
else:
|
|
if self.grid_size > 0:
|
|
xy = get_points_on_a_grid(self.grid_size, video.shape[3:])
|
|
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
|
|
queries = torch.cat([queries, xy], dim=1) #
|
|
|
|
traj_e, vis_e, __ = self.model(
|
|
video=video,
|
|
queries=queries,
|
|
iters=self.n_iters,
|
|
)
|
|
|
|
traj_e[:, :, :, 0] *= (W - 1) / float(self.interp_shape[1] - 1)
|
|
traj_e[:, :, :, 1] *= (H - 1) / float(self.interp_shape[0] - 1)
|
|
return traj_e, vis_e
|
|
|
|
def _process_one_point(self, video, query):
|
|
t = query[0, 0, 0].long()
|
|
|
|
device = query.device
|
|
if self.local_grid_size > 0:
|
|
xy_target = get_points_on_a_grid(
|
|
self.local_grid_size,
|
|
(50, 50),
|
|
[query[0, 0, 2].item(), query[0, 0, 1].item()],
|
|
)
|
|
|
|
xy_target = torch.cat([torch.zeros_like(xy_target[:, :, :1]), xy_target], dim=2).to(
|
|
device
|
|
) #
|
|
query = torch.cat([query, xy_target], dim=1) #
|
|
|
|
if self.grid_size > 0:
|
|
xy = get_points_on_a_grid(self.grid_size, video.shape[3:])
|
|
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
|
|
query = torch.cat([query, xy], dim=1) #
|
|
# crop the video to start from the queried frame
|
|
query[0, 0, 0] = 0
|
|
traj_e_pind, vis_e_pind, __ = self.model(
|
|
video=video[:, t:], queries=query, iters=self.n_iters
|
|
)
|
|
|
|
return traj_e_pind, vis_e_pind
|