Compare commits

..

19 Commits

Author SHA1 Message Date
mhz
be0891967b add time measure codes and update resolution 2024-08-12 22:37:51 +02:00
mhz
40e628ac73 try to generate the tracking f1 video but too many points 2024-08-11 13:42:03 +02:00
mhz
f208a962b9 add a demo code 2024-08-10 23:04:27 +02:00
15cdb3027c add the video 2024-08-05 23:37:45 +02:00
6e7bcd2d26 add some comments 2024-08-05 23:36:58 +02:00
Hanzhang ma
36d1566750 add some comments 2024-07-10 00:05:34 +02:00
Hanzhang ma
9ed8669a50 Merge branch 'main' of https://github.com/facebookresearch/co-tracker 2024-07-10 00:04:14 +02:00
Hanzhang ma
eeda4d3c98 comment the code 2024-07-09 10:54:29 +02:00
Ben Evans
9ed05317b7 Fix URL in README Example (#76) 2024-06-28 21:15:15 +01:00
Nikita Karaev
19767a9d65 Update README.md 2024-06-14 15:06:32 +01:00
Iurii Makarov
e29e938311 readme.md update, demo flexible save path (#83) 2024-05-11 15:34:09 +01:00
Nikita Karaev
0f9d32869a Update README.md 2024-01-22 11:59:03 +00:00
Nikita Karaev
9460eefecc Update README.md 2024-01-09 16:00:07 +00:00
Patrick Pfreundschuh
9921cf0895 fix ignored input video argument (#57) 2024-01-07 15:14:28 +00:00
Nikita Karaev
941c24fd40 add meta copyright 2024-01-05 16:17:50 +00:00
Nikita Karaev
fac27989b3 fixed a small online processing bug 2024-01-05 14:55:54 +00:00
Nikita Karaev
f084a93f28 fix multi-batch inference 2024-01-04 16:53:22 +00:00
Nikita Karaev
3716e36249 fix online demo 2023-12-29 16:12:42 +00:00
Nikita Karaev
721fcc237b remove assert B==1 2023-12-28 17:27:30 +00:00
62 changed files with 4086 additions and 4089 deletions

3
.gitignore vendored Normal file
View File

@@ -0,0 +1,3 @@
__pycache__/
.vscode/
cotracker/__pycache__/

View File

@@ -4,7 +4,7 @@
[Nikita Karaev](https://nikitakaraevv.github.io/), [Ignacio Rocco](https://www.irocco.info/), [Benjamin Graham](https://ai.facebook.com/people/benjamin-graham/), [Natalia Neverova](https://nneverova.github.io/), [Andrea Vedaldi](https://www.robots.ox.ac.uk/~vedaldi/), [Christian Rupprecht](https://chrirupp.github.io/) [Nikita Karaev](https://nikitakaraevv.github.io/), [Ignacio Rocco](https://www.irocco.info/), [Benjamin Graham](https://ai.facebook.com/people/benjamin-graham/), [Natalia Neverova](https://nneverova.github.io/), [Andrea Vedaldi](https://www.robots.ox.ac.uk/~vedaldi/), [Christian Rupprecht](https://chrirupp.github.io/)
[[`Paper`](https://arxiv.org/abs/2307.07635)] [[`Project`](https://co-tracker.github.io/)] [[`BibTeX`](#citing-cotracker)] ### [Project Page](https://co-tracker.github.io/) | [Paper](https://arxiv.org/abs/2307.07635) | [X Thread](https://twitter.com/n_karaev/status/1742638906355470772) | [BibTeX](#citing-cotracker)
<a target="_blank" href="https://colab.research.google.com/github/facebookresearch/co-tracker/blob/main/notebooks/demo.ipynb"> <a target="_blank" href="https://colab.research.google.com/github/facebookresearch/co-tracker/blob/main/notebooks/demo.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/> <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
@@ -26,6 +26,7 @@ CoTracker can track:
Try these tracking modes for yourself with our [Colab demo](https://colab.research.google.com/github/facebookresearch/co-tracker/blob/master/notebooks/demo.ipynb) or in the [Hugging Face Space 🤗](https://huggingface.co/spaces/facebook/cotracker). Try these tracking modes for yourself with our [Colab demo](https://colab.research.google.com/github/facebookresearch/co-tracker/blob/master/notebooks/demo.ipynb) or in the [Hugging Face Space 🤗](https://huggingface.co/spaces/facebook/cotracker).
**Updates:** **Updates:**
- [June 14, 2024] 📣 We have released the code for [VGGSfM](https://github.com/facebookresearch/vggsfm), a model for recovering camera poses and 3D structure from any image sequences based on point tracking! VGGSfM is the first fully differentiable SfM framework that unlocks scalability and outperforms conventional SfM methods on standard benchmarks.
- [December 27, 2023] 📣 CoTracker2 is now available! It can now track many more (up to **265*265**!) points jointly and it has a cleaner and more memory-efficient implementation. It also supports online processing. See the [updated paper](https://arxiv.org/abs/2307.07635) for more details. The old version remains available [here](https://github.com/facebookresearch/co-tracker/tree/8d364031971f6b3efec945dd15c468a183e58212). - [December 27, 2023] 📣 CoTracker2 is now available! It can now track many more (up to **265*265**!) points jointly and it has a cleaner and more memory-efficient implementation. It also supports online processing. See the [updated paper](https://arxiv.org/abs/2307.07635) for more details. The old version remains available [here](https://github.com/facebookresearch/co-tracker/tree/8d364031971f6b3efec945dd15c468a183e58212).
@@ -39,7 +40,7 @@ The easiest way to use CoTracker is to load a pretrained model from `torch.hub`:
```python ```python
import torch import torch
# Download the video # Download the video
url = 'https://github.com/facebookresearch/co-tracker/blob/main/assets/apple.mp4' url = 'https://github.com/facebookresearch/co-tracker/raw/main/assets/apple.mp4'
import imageio.v3 as iio import imageio.v3 as iio
frames = iio.imread(url, plugin="FFMPEG") # plugin="pyav" frames = iio.imread(url, plugin="FFMPEG") # plugin="pyav"
@@ -119,7 +120,7 @@ We strongly recommend installing both PyTorch and TorchVision with CUDA support,
git clone https://github.com/facebookresearch/co-tracker git clone https://github.com/facebookresearch/co-tracker
cd co-tracker cd co-tracker
pip install -e . pip install -e .
pip install matplotlib flow_vis tqdm tensorboard pip install matplotlib flow_vis tqdm tensorboard imageio[ffmpeg]
``` ```
You can manually download the CoTracker2 checkpoint from the links below and place it in the `checkpoints` folder as follows: You can manually download the CoTracker2 checkpoint from the links below and place it in the `checkpoints` folder as follows:
@@ -132,6 +133,11 @@ cd ..
``` ```
For old checkpoints, see [this section](#previous-version). For old checkpoints, see [this section](#previous-version).
After installation, this is how you could run the model on `./assets/apple.mp4` (results will be saved to `./saved_videos/apple.mp4`):
```bash
python demo.py --checkpoint checkpoints/cotracker2.pth
```
## Evaluation ## Evaluation
To reproduce the results presented in the paper, download the following datasets: To reproduce the results presented in the paper, download the following datasets:
@@ -203,6 +209,15 @@ make -C docs html
## Previous version ## Previous version
You can use CoTracker v1 directly via pytorch hub:
```python
import torch
import einops
import timm
import tqdm
cotracker = torch.hub.load("facebookresearch/co-tracker:v1.0", "cotracker_w8")
```
The old version of the code is available [here](https://github.com/facebookresearch/co-tracker/tree/8d364031971f6b3efec945dd15c468a183e58212). The old version of the code is available [here](https://github.com/facebookresearch/co-tracker/tree/8d364031971f6b3efec945dd15c468a183e58212).
You can also download the corresponding checkpoints: You can also download the corresponding checkpoints:
```bash ```bash

BIN
assets/F1_shorts.mp4 Normal file

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -191,6 +191,7 @@ class BasicEncoder(nn.Module):
x = self.norm1(x) x = self.norm1(x)
x = self.relu1(x) x = self.relu1(x)
# 四层残差块
a = self.layer1(x) a = self.layer1(x)
b = self.layer2(a) b = self.layer2(a)
c = self.layer3(b) c = self.layer3(b)

View File

@@ -41,6 +41,7 @@ class CoTracker2(nn.Module):
self.hidden_dim = 256 self.hidden_dim = 256
self.latent_dim = 128 self.latent_dim = 128
self.add_space_attn = add_space_attn self.add_space_attn = add_space_attn
self.fnet = BasicEncoder(output_dim=self.latent_dim) self.fnet = BasicEncoder(output_dim=self.latent_dim)
self.num_virtual_tracks = num_virtual_tracks self.num_virtual_tracks = num_virtual_tracks
self.model_resolution = model_resolution self.model_resolution = model_resolution
@@ -107,6 +108,7 @@ class CoTracker2(nn.Module):
B, S_init, N, __ = track_mask.shape B, S_init, N, __ = track_mask.shape
B, S, *_ = fmaps.shape B, S, *_ = fmaps.shape
# 填充使得track_mask 的帧数与特征图的帧数一致。
track_mask = F.pad(track_mask, (0, 0, 0, 0, 0, S - S_init), "constant") track_mask = F.pad(track_mask, (0, 0, 0, 0, 0, S - S_init), "constant")
track_mask_vis = ( track_mask_vis = (
torch.cat([track_mask, vis], dim=-1).permute(0, 2, 1, 3).reshape(B * N, S, 2) torch.cat([track_mask, vis], dim=-1).permute(0, 2, 1, 3).reshape(B * N, S, 2)
@@ -171,6 +173,7 @@ class CoTracker2(nn.Module):
], ],
dim=-1, dim=-1,
) )
# 双线性采样
sample_track_feats = sample_features5d(fmaps, sample_coords) sample_track_feats = sample_features5d(fmaps, sample_coords)
return sample_track_feats return sample_track_feats
@@ -227,22 +230,24 @@ class CoTracker2(nn.Module):
# The first channel is the frame number # The first channel is the frame number
# The rest are the coordinates of points we want to track # The rest are the coordinates of points we want to track
queried_frames = queries[:, :, 0].long() queried_frames = queries[:, :, 0].long() # 获取帧数字
queried_coords = queries[..., 1:] queried_coords = queries[..., 1:]
queried_coords = queried_coords / self.stride queried_coords = queried_coords / self.stride # 缩放
# We store our predictions here # We store our predictions here
coords_predicted = torch.zeros((B, T, N, 2), device=device) coords_predicted = torch.zeros((B, T, N, 2), device=device) # 等待处理的预测的点
vis_predicted = torch.zeros((B, T, N), device=device) vis_predicted = torch.zeros((B, T, N), device=device)
if is_online: if is_online:
# 如果online的话坐标都制成0, vis都是false
# 如果不是在线就填充一圈0
if self.online_coords_predicted is None: if self.online_coords_predicted is None:
# Init online predictions with zeros # Init online predictions with zeros
self.online_coords_predicted = coords_predicted self.online_coords_predicted = coords_predicted
self.online_vis_predicted = vis_predicted self.online_vis_predicted = vis_predicted
else: else:
# Pad online predictions with zeros for the current window # Pad online predictions with zeros for the current window
pad = min(step, T - step) pad = min(step, T - step) # 确保填充量不会超过剩余的时间帧数
coords_predicted = F.pad( coords_predicted = F.pad(
self.online_coords_predicted, (0, 0, 0, 0, 0, pad), "constant" self.online_coords_predicted, (0, 0, 0, 0, 0, pad), "constant"
) )
@@ -250,19 +255,24 @@ class CoTracker2(nn.Module):
all_coords_predictions, all_vis_predictions = [], [] all_coords_predictions, all_vis_predictions = [], []
# Pad the video so that an integer number of sliding windows fit into it # Pad the video so that an integer number of sliding windows fit into it
# 填充视频,使得一个整数的滑动窗口能够适应它
# TODO: we may drop this requirement because the transformer should not care # TODO: we may drop this requirement because the transformer should not care
# TODO: pad the features instead of the video # TODO: pad the features instead of the video
# 下面这行计算需要填充的帧数
pad = S - T if is_online else (S - T % S) % S # We don't want to pad if T % S == 0 pad = S - T if is_online else (S - T % S) % S # We don't want to pad if T % S == 0
# 填充将最后一个帧复制pad遍
video = F.pad(video.reshape(B, 1, T, C * H * W), (0, 0, 0, pad), "replicate").reshape( video = F.pad(video.reshape(B, 1, T, C * H * W), (0, 0, 0, pad), "replicate").reshape(
B, -1, C, H, W B, -1, C, H, W
) )
# Compute convolutional features for the video or for the current chunk in case of online mode # Compute convolutional features for the video or for the current chunk in case of online mode
# 计算视频的卷积特征或者是在线计算当前的块
fmaps = self.fnet(video.reshape(-1, C, H, W)).reshape( fmaps = self.fnet(video.reshape(-1, C, H, W)).reshape(
B, -1, self.latent_dim, H // self.stride, W // self.stride B, -1, self.latent_dim, H // self.stride, W // self.stride
) )
# We compute track features # We compute track features
# 内部是用双线性采样求feature maps feature
track_feat = self.get_track_feat( track_feat = self.get_track_feat(
fmaps, fmaps,
queried_frames - self.online_ind if is_online else queried_frames, queried_frames - self.online_ind if is_online else queried_frames,
@@ -284,14 +294,17 @@ class CoTracker2(nn.Module):
# We process only the current video chunk in the online mode # We process only the current video chunk in the online mode
indices = [self.online_ind] if is_online else range(0, step * num_windows, step) indices = [self.online_ind] if is_online else range(0, step * num_windows, step)
# 查询的坐标调整形状
coords_init = queried_coords.reshape(B, 1, N, 2).expand(B, S, N, 2).float() coords_init = queried_coords.reshape(B, 1, N, 2).expand(B, S, N, 2).float()
vis_init = torch.ones((B, S, N, 1), device=device).float() * 10 vis_init = torch.ones((B, S, N, 1), device=device).float() * 10
for ind in indices: for ind in indices:
# We copy over coords and vis for tracks that are queried # We copy over coords and vis for tracks that are queried
# by the end of the previous window, which is ind + overlap # by the end of the previous window, which is ind + overlap
# 处理重叠部分
if ind > 0: if ind > 0:
overlap = S - step overlap = S - step
copy_over = (queried_frames < ind + overlap)[:, None, :, None] # B 1 N 1 copy_over = (queried_frames < ind + overlap)[:, None, :, None] # B 1 N 1
# 复制前一个窗口的预测结果
coords_prev = torch.nn.functional.pad( coords_prev = torch.nn.functional.pad(
coords_predicted[:, ind : ind + overlap] / self.stride, coords_predicted[:, ind : ind + overlap] / self.stride,
(0, 0, 0, 0, 0, step), (0, 0, 0, 0, 0, step),
@@ -304,16 +317,18 @@ class CoTracker2(nn.Module):
) # B S N 1 ) # B S N 1
coords_init = torch.where( coords_init = torch.where(
copy_over.expand_as(coords_init), coords_prev, coords_init copy_over.expand_as(coords_init), coords_prev, coords_init
) )# True就是coords_prev, False 就是coords_init
vis_init = torch.where(copy_over.expand_as(vis_init), vis_prev, vis_init) vis_init = torch.where(copy_over.expand_as(vis_init), vis_prev, vis_init)
# The attention mask is 1 for the spatio-temporal points within # The attention mask is 1 for the spatio-temporal points within
# a track which is updated in the current window # a track which is updated in the current window
# 用于表示在当前窗口内需要更新的时间-空间点
attention_mask = (queried_frames < ind + S).reshape(B, 1, N).repeat(1, S, 1) # B S N attention_mask = (queried_frames < ind + S).reshape(B, 1, N).repeat(1, S, 1) # B S N
# The track mask is 1 for the spatio-temporal points that actually # The track mask is 1 for the spatio-temporal points that actually
# need updating: only after begin queried, and not if contained # need updating: only after begin queried, and not if contained
# in a previous window # in a previous window
# track_mask表示实际需要更新的
track_mask = ( track_mask = (
queried_frames[:, None, :, None] queried_frames[:, None, :, None]
<= torch.arange(ind, ind + S, device=device)[None, :, None, None] <= torch.arange(ind, ind + S, device=device)[None, :, None, None]
@@ -323,6 +338,7 @@ class CoTracker2(nn.Module):
track_mask[:, :overlap, :, :] = False track_mask[:, :overlap, :, :] = False
# Predict the coordinates and visibility for the current window # Predict the coordinates and visibility for the current window
# 用forward_window 来更新coords和vis
coords, vis = self.forward_window( coords, vis = self.forward_window(
fmaps=fmaps if is_online else fmaps[:, ind : ind + S], fmaps=fmaps if is_online else fmaps[:, ind : ind + S],
coords=coords_init, coords=coords_init,

View File

@@ -38,7 +38,6 @@ class EvaluationPredictor(torch.nn.Module):
B, N, D = queries.shape B, N, D = queries.shape
assert D == 3 assert D == 3
assert B == 1
video = video.reshape(B * T, C, H, W) video = video.reshape(B * T, C, H, W)
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True) video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True)

View File

@@ -17,17 +17,21 @@ class CoTrackerPredictor(torch.nn.Module):
self.support_grid_size = 6 self.support_grid_size = 6
model = build_cotracker(checkpoint) model = build_cotracker(checkpoint)
self.interp_shape = model.model_resolution self.interp_shape = model.model_resolution
print(self.interp_shape)
self.model = model self.model = model
self.model.eval() self.model.eval()
self.device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
@torch.no_grad() @torch.no_grad()
def forward( def forward(
self, self,
video, # (1, T, 3, H, W) video, # (B, T, 3, H, W) Batch_size, time, rgb, height, width
# input prompt types: # input prompt types:
# - None. Dense tracks are computed in this case. You can adjust *query_frame* to compute tracks starting from a specific frame. # - None. Dense tracks are computed in this case. You can adjust *query_frame* to compute tracks starting from a specific frame.
# *backward_tracking=True* will compute tracks in both directions. # *backward_tracking=True* will compute tracks in both directions.
# - queries. Queried points of shape (1, N, 3) in format (t, x, y) for frame index and pixel coordinates. # - queries. Queried points of shape (B, N, 3) in format (t, x, y) for frame index and pixel coordinates.
# - grid_size. Grid of N*N points from the first frame. if segm_mask is provided, then computed only for the mask. # - grid_size. Grid of N*N points from the first frame. if segm_mask is provided, then computed only for the mask.
# You can adjust *query_frame* and *backward_tracking* for the regular grid in the same way as for dense tracks. # You can adjust *query_frame* and *backward_tracking* for the regular grid in the same way as for dense tracks.
queries: torch.Tensor = None, queries: torch.Tensor = None,
@@ -55,18 +59,31 @@ class CoTrackerPredictor(torch.nn.Module):
return tracks, visibilities return tracks, visibilities
# gpu dense inference time
# raft gpu comparison
# vision effects
# raft integrated
def _compute_dense_tracks(self, video, grid_query_frame, grid_size=80, backward_tracking=False): def _compute_dense_tracks(self, video, grid_query_frame, grid_size=80, backward_tracking=False):
*_, H, W = video.shape *_, H, W = video.shape
grid_step = W // grid_size grid_step = W // grid_size
grid_width = W // grid_step grid_width = W // grid_step
grid_height = H // grid_step grid_height = H // grid_step # set the whole video to grid_size number of grids
tracks = visibilities = None tracks = visibilities = None
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(video.device) grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(video.device)
# (batch_size, grid_number, t,x,y)
grid_pts[0, :, 0] = grid_query_frame grid_pts[0, :, 0] = grid_query_frame
# iterate every grid
for offset in range(grid_step * grid_step): for offset in range(grid_step * grid_step):
print(f"step {offset} / {grid_step * grid_step}") print(f"step {offset} / {grid_step * grid_step}")
ox = offset % grid_step ox = offset % grid_step
oy = offset // grid_step oy = offset // grid_step
# initialize
# for example
# grid width = 4, grid height = 4, grid step = 10, ox = 1
# torch.arange(grid_width) = [0,1,2,3]
# torch.arange(grid_width).repeat(grid_height) = [0,1,2,3,0,1,2,3,0,1,2,3]
# torch.arange(grid_width).repeat(grid_height) * grid_step = [0,10,20,30,0,10,20,30,0,10,20,30]
# get the location in the image
grid_pts[0, :, 1] = torch.arange(grid_width).repeat(grid_height) * grid_step + ox grid_pts[0, :, 1] = torch.arange(grid_width).repeat(grid_height) * grid_step + ox
grid_pts[0, :, 2] = ( grid_pts[0, :, 2] = (
torch.arange(grid_height).repeat_interleave(grid_width) * grid_step + oy torch.arange(grid_height).repeat_interleave(grid_width) * grid_step + oy
@@ -92,15 +109,18 @@ class CoTrackerPredictor(torch.nn.Module):
backward_tracking=False, backward_tracking=False,
): ):
B, T, C, H, W = video.shape B, T, C, H, W = video.shape
assert B == 1
video = video.reshape(B * T, C, H, W) video = video.reshape(B * T, C, H, W)
# ? what is interpolate?
# 将video插值成interp_shape?
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True) video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True)
video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1]) video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
if queries is not None: if queries is not None:
B, N, D = queries.shape B, N, D = queries.shape # batch_size, number of points, (t,x,y)
assert D == 3 assert D == 3
# query 缩放到( interp_shape - 1 ) / (W - 1)
# 插完值之后缩放
queries = queries.clone() queries = queries.clone()
queries[:, :, 1:] *= queries.new_tensor( queries[:, :, 1:] *= queries.new_tensor(
[ [
@@ -108,6 +128,7 @@ class CoTrackerPredictor(torch.nn.Module):
(self.interp_shape[0] - 1) / (H - 1), (self.interp_shape[0] - 1) / (H - 1),
] ]
) )
# 生成grid
elif grid_size > 0: elif grid_size > 0:
grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=video.device) grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=video.device)
if segm_mask is not None: if segm_mask is not None:
@@ -121,13 +142,16 @@ class CoTrackerPredictor(torch.nn.Module):
queries = torch.cat( queries = torch.cat(
[torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts], [torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts],
dim=2, dim=2,
) ).repeat(B, 1, 1)
# 添加支持点
if add_support_grid: if add_support_grid:
grid_pts = get_points_on_a_grid( grid_pts = get_points_on_a_grid(
self.support_grid_size, self.interp_shape, device=video.device self.support_grid_size, self.interp_shape, device=video.device
) )
grid_pts = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2) grid_pts = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)
grid_pts = grid_pts.repeat(B, 1, 1)
queries = torch.cat([queries, grid_pts], dim=1) queries = torch.cat([queries, grid_pts], dim=1)
tracks, visibilities, __ = self.model.forward(video=video, queries=queries, iters=6) tracks, visibilities, __ = self.model.forward(video=video, queries=queries, iters=6)
@@ -174,7 +198,7 @@ class CoTrackerPredictor(torch.nn.Module):
inv_visibilities = inv_visibilities.flip(1) inv_visibilities = inv_visibilities.flip(1)
arange = torch.arange(video.shape[1], device=queries.device)[None, :, None] arange = torch.arange(video.shape[1], device=queries.device)[None, :, None]
mask = (arange < queries[None, :, :, 0]).unsqueeze(-1).repeat(1, 1, 1, 2) mask = (arange < queries[:, None, :, 0]).unsqueeze(-1).repeat(1, 1, 1, 2)
tracks[mask] = inv_tracks[mask] tracks[mask] = inv_tracks[mask]
visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]] visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]]
@@ -201,6 +225,7 @@ class CoTrackerOnlinePredictor(torch.nn.Module):
grid_query_frame: int = 0, grid_query_frame: int = 0,
add_support_grid=False, add_support_grid=False,
): ):
B, T, C, H, W = video_chunk.shape
# Initialize online video processing and save queried points # Initialize online video processing and save queried points
# This needs to be done before processing *each new video* # This needs to be done before processing *each new video*
if is_first_step: if is_first_step:
@@ -231,7 +256,7 @@ class CoTrackerOnlinePredictor(torch.nn.Module):
queries = torch.cat([queries, grid_pts], dim=1) queries = torch.cat([queries, grid_pts], dim=1)
self.queries = queries self.queries = queries
return (None, None) return (None, None)
B, T, C, H, W = video_chunk.shape
video_chunk = video_chunk.reshape(B * T, C, H, W) video_chunk = video_chunk.reshape(B * T, C, H, W)
video_chunk = F.interpolate( video_chunk = F.interpolate(
video_chunk, tuple(self.interp_shape), mode="bilinear", align_corners=True video_chunk, tuple(self.interp_shape), mode="bilinear", align_corners=True

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -83,11 +83,12 @@ if __name__ == "__main__":
print("computed") print("computed")
# save a video with predicted tracks # save a video with predicted tracks
seq_name = args.video_path.split("/")[-1] seq_name = os.path.splitext(args.video_path.split("/")[-1])[0]
vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3) vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3)
vis.visualize( vis.visualize(
video, video,
pred_tracks, pred_tracks,
pred_visibility, pred_visibility,
query_frame=0 if args.backward_tracking else args.grid_query_frame, query_frame=0 if args.backward_tracking else args.grid_query_frame,
filename=seq_name,
) )

82
demo1.py Normal file
View File

@@ -0,0 +1,82 @@
import os
import torch
from base64 import b64encode
from cotracker.utils.visualizer import Visualizer, read_video_from_path
import numpy as np
from PIL import Image
import time
device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
start_time = time.time()
print(f'Using device: {device}')
print(f'start loading video')
video = read_video_from_path('./assets/F1_shorts.mp4')
print(f'video shape: {video.shape}')
# video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float().to(device)
video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float()
end_time = time.time()
print(f'video shape after permute: {video.shape}')
print("Load video Time taken: {:.2f} seconds".format(end_time - start_time))
from cotracker.predictor import CoTrackerPredictor
model = CoTrackerPredictor(
checkpoint=os.path.join(
'./checkpoints/cotracker2.pth'
)
)
# pred_tracks, pred_visibility = model(video, grid_size=30)
# vis = Visualizer(save_dir='./videos', pad_value=100)
# vis.visualize(video=video, tracks=pred_tracks, visibility=pred_visibility, filename='teaser');
grid_query_frame=20
import torch.nn.functional as F
# video_interp = F.interpolate(video[0], [200, 360], mode="bilinear")[None].to(device)
interp_size = (720, 1280)
video_interp = F.interpolate(video[0], [interp_size[0], interp_size[1]], mode="bilinear")[None].to(device)
print(f'video_interp shape: {video_interp.shape}')
start_time = time.time()
# pred_tracks, pred_visibility = model(video_interp,
input_mask='./assets/F1_mask.png'
segm_mask = Image.open(input_mask)
interp_size = (interp_size[1], interp_size[0])
segm_mask = segm_mask.resize(interp_size, Image.BILINEAR)
segm_mask = np.array(Image.open(input_mask))
segm_mask = torch.tensor(segm_mask).to(device)
# pred_tracks, pred_visibility = model(video,
pred_tracks, pred_visibility = model(video_interp,
grid_query_frame=grid_query_frame, backward_tracking=True,
segm_mask=segm_mask )
end_time = time.time()
print("Time taken: {:.2f} seconds".format(end_time - start_time))
start_time = time.time()
print(f'start visualizing')
vis = Visualizer(
save_dir='./videos',
pad_value=20,
linewidth=1,
mode='optical_flow'
)
print(f'vis initialized')
end_time = time.time()
print("Time taken: {:.2f} seconds".format(end_time - start_time))
start_time = time.time()
print(f'start visualize')
vis.visualize(
video=video_interp,
# video=video,
tracks=pred_tracks,
visibility=pred_visibility,
filename='dense2');
print(f'done')
end_time = time.time()
print("Time taken: {:.2f} seconds".format(end_time - start_time))

View File

@@ -1,3 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os import os
import torch import torch
import gradio as gr import gradio as gr
@@ -22,7 +29,12 @@ def cotracker_demo(
model = model.cuda() model = model.cuda()
load_video = load_video.cuda() load_video = load_video.cuda()
model(video_chunk=load_video, is_first_step=True, grid_size=grid_size) model(
video_chunk=load_video,
is_first_step=True,
grid_size=grid_size,
grid_query_frame=grid_query_frame,
)
for ind in range(0, load_video.shape[1] - model.step, model.step): for ind in range(0, load_video.shape[1] - model.step, model.step):
pred_tracks, pred_visibility = model( pred_tracks, pred_visibility = model(
video_chunk=load_video[:, ind : ind + model.step * 2] video_chunk=load_video[:, ind : ind + model.step * 2]

File diff suppressed because one or more lines are too long

View File

@@ -4,6 +4,7 @@
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import os
import torch import torch
import argparse import argparse
import imageio.v3 as iio import imageio.v3 as iio
@@ -44,6 +45,9 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
if not os.path.isfile(args.video_path):
raise ValueError("Video file does not exist")
if args.checkpoint is not None: if args.checkpoint is not None:
model = CoTrackerOnlinePredictor(checkpoint=args.checkpoint) model = CoTrackerOnlinePredictor(checkpoint=args.checkpoint)
else: else:
@@ -52,25 +56,33 @@ if __name__ == "__main__":
window_frames = [] window_frames = []
def _process_step(window_frames, is_first_step, grid_size): def _process_step(window_frames, is_first_step, grid_size, grid_query_frame):
video_chunk = ( video_chunk = (
torch.tensor(np.stack(window_frames[-model.step * 2 :]), device=DEFAULT_DEVICE) torch.tensor(np.stack(window_frames[-model.step * 2 :]), device=DEFAULT_DEVICE)
.float() .float()
.permute(0, 3, 1, 2)[None] .permute(0, 3, 1, 2)[None]
) # (1, T, 3, H, W) ) # (1, T, 3, H, W)
return model(video_chunk, is_first_step=is_first_step, grid_size=grid_size) return model(
video_chunk,
is_first_step=is_first_step,
grid_size=grid_size,
grid_query_frame=grid_query_frame,
)
# Iterating over video frames, processing one window at a time: # Iterating over video frames, processing one window at a time:
is_first_step = True is_first_step = True
for i, frame in enumerate( for i, frame in enumerate(
iio.imiter( iio.imiter(
"https://github.com/facebookresearch/co-tracker/blob/main/assets/apple.mp4", args.video_path,
plugin="FFMPEG", plugin="FFMPEG",
) )
): ):
if i % model.step == 0 and i != 0: if i % model.step == 0 and i != 0:
pred_tracks, pred_visibility = _process_step( pred_tracks, pred_visibility = _process_step(
window_frames, is_first_step, grid_size=args.grid_size window_frames,
is_first_step,
grid_size=args.grid_size,
grid_query_frame=args.grid_query_frame,
) )
is_first_step = False is_first_step = False
window_frames.append(frame) window_frames.append(frame)
@@ -79,12 +91,13 @@ if __name__ == "__main__":
window_frames[-(i % model.step) - model.step - 1 :], window_frames[-(i % model.step) - model.step - 1 :],
is_first_step, is_first_step,
grid_size=args.grid_size, grid_size=args.grid_size,
grid_query_frame=args.grid_query_frame,
) )
print("Tracks are computed") print("Tracks are computed")
# save a video with predicted tracks # save a video with predicted tracks
seq_name = args.video_path.split("/")[-1] seq_name = os.path.splitext(args.video_path.split("/")[-1])[0]
video = torch.tensor(np.stack(window_frames), device=DEFAULT_DEVICE).permute(0, 3, 1, 2)[None] video = torch.tensor(np.stack(window_frames), device=DEFAULT_DEVICE).permute(0, 3, 1, 2)[None]
vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3) vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3)
vis.visualize(video, pred_tracks, pred_visibility, query_frame=args.grid_query_frame) vis.visualize(video, pred_tracks, pred_visibility, query_frame=args.grid_query_frame, filename=seq_name)