Compare commits
19 Commits
9c9a97d158
...
main
Author | SHA1 | Date | |
---|---|---|---|
be0891967b | |||
40e628ac73 | |||
f208a962b9 | |||
15cdb3027c | |||
6e7bcd2d26 | |||
|
36d1566750 | ||
|
9ed8669a50 | ||
|
eeda4d3c98 | ||
|
9ed05317b7 | ||
|
19767a9d65 | ||
|
e29e938311 | ||
|
0f9d32869a | ||
|
9460eefecc | ||
|
9921cf0895 | ||
|
941c24fd40 | ||
|
fac27989b3 | ||
|
f084a93f28 | ||
|
3716e36249 | ||
|
721fcc237b |
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
__pycache__/
|
||||
.vscode/
|
||||
cotracker/__pycache__/
|
21
README.md
21
README.md
@@ -4,7 +4,7 @@
|
||||
|
||||
[Nikita Karaev](https://nikitakaraevv.github.io/), [Ignacio Rocco](https://www.irocco.info/), [Benjamin Graham](https://ai.facebook.com/people/benjamin-graham/), [Natalia Neverova](https://nneverova.github.io/), [Andrea Vedaldi](https://www.robots.ox.ac.uk/~vedaldi/), [Christian Rupprecht](https://chrirupp.github.io/)
|
||||
|
||||
[[`Paper`](https://arxiv.org/abs/2307.07635)] [[`Project`](https://co-tracker.github.io/)] [[`BibTeX`](#citing-cotracker)]
|
||||
### [Project Page](https://co-tracker.github.io/) | [Paper](https://arxiv.org/abs/2307.07635) | [X Thread](https://twitter.com/n_karaev/status/1742638906355470772) | [BibTeX](#citing-cotracker)
|
||||
|
||||
<a target="_blank" href="https://colab.research.google.com/github/facebookresearch/co-tracker/blob/main/notebooks/demo.ipynb">
|
||||
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
||||
@@ -26,6 +26,7 @@ CoTracker can track:
|
||||
Try these tracking modes for yourself with our [Colab demo](https://colab.research.google.com/github/facebookresearch/co-tracker/blob/master/notebooks/demo.ipynb) or in the [Hugging Face Space 🤗](https://huggingface.co/spaces/facebook/cotracker).
|
||||
|
||||
**Updates:**
|
||||
- [June 14, 2024] 📣 We have released the code for [VGGSfM](https://github.com/facebookresearch/vggsfm), a model for recovering camera poses and 3D structure from any image sequences based on point tracking! VGGSfM is the first fully differentiable SfM framework that unlocks scalability and outperforms conventional SfM methods on standard benchmarks.
|
||||
|
||||
- [December 27, 2023] 📣 CoTracker2 is now available! It can now track many more (up to **265*265**!) points jointly and it has a cleaner and more memory-efficient implementation. It also supports online processing. See the [updated paper](https://arxiv.org/abs/2307.07635) for more details. The old version remains available [here](https://github.com/facebookresearch/co-tracker/tree/8d364031971f6b3efec945dd15c468a183e58212).
|
||||
|
||||
@@ -39,7 +40,7 @@ The easiest way to use CoTracker is to load a pretrained model from `torch.hub`:
|
||||
```python
|
||||
import torch
|
||||
# Download the video
|
||||
url = 'https://github.com/facebookresearch/co-tracker/blob/main/assets/apple.mp4'
|
||||
url = 'https://github.com/facebookresearch/co-tracker/raw/main/assets/apple.mp4'
|
||||
|
||||
import imageio.v3 as iio
|
||||
frames = iio.imread(url, plugin="FFMPEG") # plugin="pyav"
|
||||
@@ -119,7 +120,7 @@ We strongly recommend installing both PyTorch and TorchVision with CUDA support,
|
||||
git clone https://github.com/facebookresearch/co-tracker
|
||||
cd co-tracker
|
||||
pip install -e .
|
||||
pip install matplotlib flow_vis tqdm tensorboard
|
||||
pip install matplotlib flow_vis tqdm tensorboard imageio[ffmpeg]
|
||||
```
|
||||
|
||||
You can manually download the CoTracker2 checkpoint from the links below and place it in the `checkpoints` folder as follows:
|
||||
@@ -132,6 +133,11 @@ cd ..
|
||||
```
|
||||
For old checkpoints, see [this section](#previous-version).
|
||||
|
||||
After installation, this is how you could run the model on `./assets/apple.mp4` (results will be saved to `./saved_videos/apple.mp4`):
|
||||
```bash
|
||||
python demo.py --checkpoint checkpoints/cotracker2.pth
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
|
||||
To reproduce the results presented in the paper, download the following datasets:
|
||||
@@ -203,6 +209,15 @@ make -C docs html
|
||||
|
||||
|
||||
## Previous version
|
||||
You can use CoTracker v1 directly via pytorch hub:
|
||||
```python
|
||||
import torch
|
||||
import einops
|
||||
import timm
|
||||
import tqdm
|
||||
|
||||
cotracker = torch.hub.load("facebookresearch/co-tracker:v1.0", "cotracker_w8")
|
||||
```
|
||||
The old version of the code is available [here](https://github.com/facebookresearch/co-tracker/tree/8d364031971f6b3efec945dd15c468a183e58212).
|
||||
You can also download the corresponding checkpoints:
|
||||
```bash
|
||||
|
BIN
assets/F1_shorts.mp4
Normal file
BIN
assets/F1_shorts.mp4
Normal file
Binary file not shown.
BIN
cotracker/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
cotracker/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
cotracker/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
cotracker/__pycache__/predictor.cpython-38.pyc
Normal file
BIN
cotracker/__pycache__/predictor.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/__pycache__/predictor.cpython-39.pyc
Normal file
BIN
cotracker/__pycache__/predictor.cpython-39.pyc
Normal file
Binary file not shown.
BIN
cotracker/models/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
cotracker/models/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/models/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
cotracker/models/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
cotracker/models/__pycache__/build_cotracker.cpython-38.pyc
Normal file
BIN
cotracker/models/__pycache__/build_cotracker.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/models/__pycache__/build_cotracker.cpython-39.pyc
Normal file
BIN
cotracker/models/__pycache__/build_cotracker.cpython-39.pyc
Normal file
Binary file not shown.
BIN
cotracker/models/core/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
cotracker/models/core/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/models/core/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
cotracker/models/core/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
cotracker/models/core/__pycache__/embeddings.cpython-38.pyc
Normal file
BIN
cotracker/models/core/__pycache__/embeddings.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/models/core/__pycache__/embeddings.cpython-39.pyc
Normal file
BIN
cotracker/models/core/__pycache__/embeddings.cpython-39.pyc
Normal file
Binary file not shown.
BIN
cotracker/models/core/__pycache__/model_utils.cpython-38.pyc
Normal file
BIN
cotracker/models/core/__pycache__/model_utils.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/models/core/__pycache__/model_utils.cpython-39.pyc
Normal file
BIN
cotracker/models/core/__pycache__/model_utils.cpython-39.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -191,6 +191,7 @@ class BasicEncoder(nn.Module):
|
||||
x = self.norm1(x)
|
||||
x = self.relu1(x)
|
||||
|
||||
# 四层残差块
|
||||
a = self.layer1(x)
|
||||
b = self.layer2(a)
|
||||
c = self.layer3(b)
|
||||
|
@@ -41,6 +41,7 @@ class CoTracker2(nn.Module):
|
||||
self.hidden_dim = 256
|
||||
self.latent_dim = 128
|
||||
self.add_space_attn = add_space_attn
|
||||
|
||||
self.fnet = BasicEncoder(output_dim=self.latent_dim)
|
||||
self.num_virtual_tracks = num_virtual_tracks
|
||||
self.model_resolution = model_resolution
|
||||
@@ -107,6 +108,7 @@ class CoTracker2(nn.Module):
|
||||
B, S_init, N, __ = track_mask.shape
|
||||
B, S, *_ = fmaps.shape
|
||||
|
||||
# 填充使得track_mask 的帧数与特征图的帧数一致。
|
||||
track_mask = F.pad(track_mask, (0, 0, 0, 0, 0, S - S_init), "constant")
|
||||
track_mask_vis = (
|
||||
torch.cat([track_mask, vis], dim=-1).permute(0, 2, 1, 3).reshape(B * N, S, 2)
|
||||
@@ -171,6 +173,7 @@ class CoTracker2(nn.Module):
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
# 双线性采样
|
||||
sample_track_feats = sample_features5d(fmaps, sample_coords)
|
||||
return sample_track_feats
|
||||
|
||||
@@ -227,22 +230,24 @@ class CoTracker2(nn.Module):
|
||||
|
||||
# The first channel is the frame number
|
||||
# The rest are the coordinates of points we want to track
|
||||
queried_frames = queries[:, :, 0].long()
|
||||
queried_frames = queries[:, :, 0].long() # 获取帧数字
|
||||
|
||||
queried_coords = queries[..., 1:]
|
||||
queried_coords = queried_coords / self.stride
|
||||
queried_coords = queried_coords / self.stride # 缩放
|
||||
|
||||
# We store our predictions here
|
||||
coords_predicted = torch.zeros((B, T, N, 2), device=device)
|
||||
coords_predicted = torch.zeros((B, T, N, 2), device=device) # 等待处理的预测的点
|
||||
vis_predicted = torch.zeros((B, T, N), device=device)
|
||||
if is_online:
|
||||
# 如果online的话,坐标都制成0, vis都是false
|
||||
# 如果不是在线,就填充一圈0
|
||||
if self.online_coords_predicted is None:
|
||||
# Init online predictions with zeros
|
||||
self.online_coords_predicted = coords_predicted
|
||||
self.online_vis_predicted = vis_predicted
|
||||
else:
|
||||
# Pad online predictions with zeros for the current window
|
||||
pad = min(step, T - step)
|
||||
pad = min(step, T - step) # 确保填充量不会超过剩余的时间帧数
|
||||
coords_predicted = F.pad(
|
||||
self.online_coords_predicted, (0, 0, 0, 0, 0, pad), "constant"
|
||||
)
|
||||
@@ -250,19 +255,24 @@ class CoTracker2(nn.Module):
|
||||
all_coords_predictions, all_vis_predictions = [], []
|
||||
|
||||
# Pad the video so that an integer number of sliding windows fit into it
|
||||
# 填充视频,使得一个整数的滑动窗口能够适应它
|
||||
# TODO: we may drop this requirement because the transformer should not care
|
||||
# TODO: pad the features instead of the video
|
||||
# 下面这行计算需要填充的帧数
|
||||
pad = S - T if is_online else (S - T % S) % S # We don't want to pad if T % S == 0
|
||||
# 填充将最后一个帧复制pad遍
|
||||
video = F.pad(video.reshape(B, 1, T, C * H * W), (0, 0, 0, pad), "replicate").reshape(
|
||||
B, -1, C, H, W
|
||||
)
|
||||
|
||||
# Compute convolutional features for the video or for the current chunk in case of online mode
|
||||
# 计算视频的卷积特征或者是在线计算当前的块
|
||||
fmaps = self.fnet(video.reshape(-1, C, H, W)).reshape(
|
||||
B, -1, self.latent_dim, H // self.stride, W // self.stride
|
||||
)
|
||||
|
||||
# We compute track features
|
||||
# 内部是用双线性采样求feature maps feature
|
||||
track_feat = self.get_track_feat(
|
||||
fmaps,
|
||||
queried_frames - self.online_ind if is_online else queried_frames,
|
||||
@@ -284,14 +294,17 @@ class CoTracker2(nn.Module):
|
||||
# We process only the current video chunk in the online mode
|
||||
indices = [self.online_ind] if is_online else range(0, step * num_windows, step)
|
||||
|
||||
# 查询的坐标调整形状
|
||||
coords_init = queried_coords.reshape(B, 1, N, 2).expand(B, S, N, 2).float()
|
||||
vis_init = torch.ones((B, S, N, 1), device=device).float() * 10
|
||||
for ind in indices:
|
||||
# We copy over coords and vis for tracks that are queried
|
||||
# by the end of the previous window, which is ind + overlap
|
||||
# 处理重叠部分
|
||||
if ind > 0:
|
||||
overlap = S - step
|
||||
copy_over = (queried_frames < ind + overlap)[:, None, :, None] # B 1 N 1
|
||||
# 复制前一个窗口的预测结果
|
||||
coords_prev = torch.nn.functional.pad(
|
||||
coords_predicted[:, ind : ind + overlap] / self.stride,
|
||||
(0, 0, 0, 0, 0, step),
|
||||
@@ -304,16 +317,18 @@ class CoTracker2(nn.Module):
|
||||
) # B S N 1
|
||||
coords_init = torch.where(
|
||||
copy_over.expand_as(coords_init), coords_prev, coords_init
|
||||
)
|
||||
)# True就是coords_prev, False 就是coords_init
|
||||
vis_init = torch.where(copy_over.expand_as(vis_init), vis_prev, vis_init)
|
||||
|
||||
# The attention mask is 1 for the spatio-temporal points within
|
||||
# a track which is updated in the current window
|
||||
# 用于表示在当前窗口内需要更新的时间-空间点
|
||||
attention_mask = (queried_frames < ind + S).reshape(B, 1, N).repeat(1, S, 1) # B S N
|
||||
|
||||
# The track mask is 1 for the spatio-temporal points that actually
|
||||
# need updating: only after begin queried, and not if contained
|
||||
# in a previous window
|
||||
# track_mask表示实际需要更新的
|
||||
track_mask = (
|
||||
queried_frames[:, None, :, None]
|
||||
<= torch.arange(ind, ind + S, device=device)[None, :, None, None]
|
||||
@@ -323,6 +338,7 @@ class CoTracker2(nn.Module):
|
||||
track_mask[:, :overlap, :, :] = False
|
||||
|
||||
# Predict the coordinates and visibility for the current window
|
||||
# 用forward_window 来更新coords和vis
|
||||
coords, vis = self.forward_window(
|
||||
fmaps=fmaps if is_online else fmaps[:, ind : ind + S],
|
||||
coords=coords_init,
|
||||
|
@@ -38,7 +38,6 @@ class EvaluationPredictor(torch.nn.Module):
|
||||
B, N, D = queries.shape
|
||||
|
||||
assert D == 3
|
||||
assert B == 1
|
||||
|
||||
video = video.reshape(B * T, C, H, W)
|
||||
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True)
|
||||
|
@@ -17,17 +17,21 @@ class CoTrackerPredictor(torch.nn.Module):
|
||||
self.support_grid_size = 6
|
||||
model = build_cotracker(checkpoint)
|
||||
self.interp_shape = model.model_resolution
|
||||
print(self.interp_shape)
|
||||
self.model = model
|
||||
self.model.eval()
|
||||
|
||||
self.device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
|
||||
self.model.to(self.device)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
video, # (1, T, 3, H, W)
|
||||
video, # (B, T, 3, H, W) Batch_size, time, rgb, height, width
|
||||
# input prompt types:
|
||||
# - None. Dense tracks are computed in this case. You can adjust *query_frame* to compute tracks starting from a specific frame.
|
||||
# *backward_tracking=True* will compute tracks in both directions.
|
||||
# - queries. Queried points of shape (1, N, 3) in format (t, x, y) for frame index and pixel coordinates.
|
||||
# - queries. Queried points of shape (B, N, 3) in format (t, x, y) for frame index and pixel coordinates.
|
||||
# - grid_size. Grid of N*N points from the first frame. if segm_mask is provided, then computed only for the mask.
|
||||
# You can adjust *query_frame* and *backward_tracking* for the regular grid in the same way as for dense tracks.
|
||||
queries: torch.Tensor = None,
|
||||
@@ -55,18 +59,31 @@ class CoTrackerPredictor(torch.nn.Module):
|
||||
|
||||
return tracks, visibilities
|
||||
|
||||
# gpu dense inference time
|
||||
# raft gpu comparison
|
||||
# vision effects
|
||||
# raft integrated
|
||||
def _compute_dense_tracks(self, video, grid_query_frame, grid_size=80, backward_tracking=False):
|
||||
*_, H, W = video.shape
|
||||
grid_step = W // grid_size
|
||||
grid_width = W // grid_step
|
||||
grid_height = H // grid_step
|
||||
grid_height = H // grid_step # set the whole video to grid_size number of grids
|
||||
tracks = visibilities = None
|
||||
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(video.device)
|
||||
# (batch_size, grid_number, t,x,y)
|
||||
grid_pts[0, :, 0] = grid_query_frame
|
||||
# iterate every grid
|
||||
for offset in range(grid_step * grid_step):
|
||||
print(f"step {offset} / {grid_step * grid_step}")
|
||||
ox = offset % grid_step
|
||||
oy = offset // grid_step
|
||||
# initialize
|
||||
# for example
|
||||
# grid width = 4, grid height = 4, grid step = 10, ox = 1
|
||||
# torch.arange(grid_width) = [0,1,2,3]
|
||||
# torch.arange(grid_width).repeat(grid_height) = [0,1,2,3,0,1,2,3,0,1,2,3]
|
||||
# torch.arange(grid_width).repeat(grid_height) * grid_step = [0,10,20,30,0,10,20,30,0,10,20,30]
|
||||
# get the location in the image
|
||||
grid_pts[0, :, 1] = torch.arange(grid_width).repeat(grid_height) * grid_step + ox
|
||||
grid_pts[0, :, 2] = (
|
||||
torch.arange(grid_height).repeat_interleave(grid_width) * grid_step + oy
|
||||
@@ -92,15 +109,18 @@ class CoTrackerPredictor(torch.nn.Module):
|
||||
backward_tracking=False,
|
||||
):
|
||||
B, T, C, H, W = video.shape
|
||||
assert B == 1
|
||||
|
||||
video = video.reshape(B * T, C, H, W)
|
||||
# ? what is interpolate?
|
||||
# 将video插值成interp_shape?
|
||||
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True)
|
||||
video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
|
||||
|
||||
if queries is not None:
|
||||
B, N, D = queries.shape
|
||||
B, N, D = queries.shape # batch_size, number of points, (t,x,y)
|
||||
assert D == 3
|
||||
# query 缩放到( interp_shape - 1 ) / (W - 1)
|
||||
# 插完值之后缩放
|
||||
queries = queries.clone()
|
||||
queries[:, :, 1:] *= queries.new_tensor(
|
||||
[
|
||||
@@ -108,6 +128,7 @@ class CoTrackerPredictor(torch.nn.Module):
|
||||
(self.interp_shape[0] - 1) / (H - 1),
|
||||
]
|
||||
)
|
||||
# 生成grid
|
||||
elif grid_size > 0:
|
||||
grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=video.device)
|
||||
if segm_mask is not None:
|
||||
@@ -121,13 +142,16 @@ class CoTrackerPredictor(torch.nn.Module):
|
||||
queries = torch.cat(
|
||||
[torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts],
|
||||
dim=2,
|
||||
)
|
||||
).repeat(B, 1, 1)
|
||||
|
||||
# 添加支持点
|
||||
|
||||
if add_support_grid:
|
||||
grid_pts = get_points_on_a_grid(
|
||||
self.support_grid_size, self.interp_shape, device=video.device
|
||||
)
|
||||
grid_pts = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)
|
||||
grid_pts = grid_pts.repeat(B, 1, 1)
|
||||
queries = torch.cat([queries, grid_pts], dim=1)
|
||||
|
||||
tracks, visibilities, __ = self.model.forward(video=video, queries=queries, iters=6)
|
||||
@@ -174,7 +198,7 @@ class CoTrackerPredictor(torch.nn.Module):
|
||||
inv_visibilities = inv_visibilities.flip(1)
|
||||
arange = torch.arange(video.shape[1], device=queries.device)[None, :, None]
|
||||
|
||||
mask = (arange < queries[None, :, :, 0]).unsqueeze(-1).repeat(1, 1, 1, 2)
|
||||
mask = (arange < queries[:, None, :, 0]).unsqueeze(-1).repeat(1, 1, 1, 2)
|
||||
|
||||
tracks[mask] = inv_tracks[mask]
|
||||
visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]]
|
||||
@@ -201,6 +225,7 @@ class CoTrackerOnlinePredictor(torch.nn.Module):
|
||||
grid_query_frame: int = 0,
|
||||
add_support_grid=False,
|
||||
):
|
||||
B, T, C, H, W = video_chunk.shape
|
||||
# Initialize online video processing and save queried points
|
||||
# This needs to be done before processing *each new video*
|
||||
if is_first_step:
|
||||
@@ -231,7 +256,7 @@ class CoTrackerOnlinePredictor(torch.nn.Module):
|
||||
queries = torch.cat([queries, grid_pts], dim=1)
|
||||
self.queries = queries
|
||||
return (None, None)
|
||||
B, T, C, H, W = video_chunk.shape
|
||||
|
||||
video_chunk = video_chunk.reshape(B * T, C, H, W)
|
||||
video_chunk = F.interpolate(
|
||||
video_chunk, tuple(self.interp_shape), mode="bilinear", align_corners=True
|
||||
|
BIN
cotracker/utils/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
cotracker/utils/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/utils/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
cotracker/utils/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
cotracker/utils/__pycache__/visualizer.cpython-38.pyc
Normal file
BIN
cotracker/utils/__pycache__/visualizer.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/utils/__pycache__/visualizer.cpython-39.pyc
Normal file
BIN
cotracker/utils/__pycache__/visualizer.cpython-39.pyc
Normal file
Binary file not shown.
3
demo.py
3
demo.py
@@ -83,11 +83,12 @@ if __name__ == "__main__":
|
||||
print("computed")
|
||||
|
||||
# save a video with predicted tracks
|
||||
seq_name = args.video_path.split("/")[-1]
|
||||
seq_name = os.path.splitext(args.video_path.split("/")[-1])[0]
|
||||
vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3)
|
||||
vis.visualize(
|
||||
video,
|
||||
pred_tracks,
|
||||
pred_visibility,
|
||||
query_frame=0 if args.backward_tracking else args.grid_query_frame,
|
||||
filename=seq_name,
|
||||
)
|
||||
|
82
demo1.py
Normal file
82
demo1.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import os
|
||||
import torch
|
||||
|
||||
from base64 import b64encode
|
||||
from cotracker.utils.visualizer import Visualizer, read_video_from_path
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import time
|
||||
|
||||
device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
start_time = time.time()
|
||||
print(f'Using device: {device}')
|
||||
print(f'start loading video')
|
||||
video = read_video_from_path('./assets/F1_shorts.mp4')
|
||||
print(f'video shape: {video.shape}')
|
||||
# video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float().to(device)
|
||||
video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float()
|
||||
end_time = time.time()
|
||||
print(f'video shape after permute: {video.shape}')
|
||||
print("Load video Time taken: {:.2f} seconds".format(end_time - start_time))
|
||||
|
||||
from cotracker.predictor import CoTrackerPredictor
|
||||
|
||||
|
||||
model = CoTrackerPredictor(
|
||||
checkpoint=os.path.join(
|
||||
'./checkpoints/cotracker2.pth'
|
||||
)
|
||||
)
|
||||
|
||||
# pred_tracks, pred_visibility = model(video, grid_size=30)
|
||||
|
||||
# vis = Visualizer(save_dir='./videos', pad_value=100)
|
||||
# vis.visualize(video=video, tracks=pred_tracks, visibility=pred_visibility, filename='teaser');
|
||||
|
||||
grid_query_frame=20
|
||||
|
||||
import torch.nn.functional as F
|
||||
# video_interp = F.interpolate(video[0], [200, 360], mode="bilinear")[None].to(device)
|
||||
interp_size = (720, 1280)
|
||||
video_interp = F.interpolate(video[0], [interp_size[0], interp_size[1]], mode="bilinear")[None].to(device)
|
||||
print(f'video_interp shape: {video_interp.shape}')
|
||||
|
||||
start_time = time.time()
|
||||
# pred_tracks, pred_visibility = model(video_interp,
|
||||
input_mask='./assets/F1_mask.png'
|
||||
segm_mask = Image.open(input_mask)
|
||||
interp_size = (interp_size[1], interp_size[0])
|
||||
segm_mask = segm_mask.resize(interp_size, Image.BILINEAR)
|
||||
segm_mask = np.array(Image.open(input_mask))
|
||||
segm_mask = torch.tensor(segm_mask).to(device)
|
||||
# pred_tracks, pred_visibility = model(video,
|
||||
pred_tracks, pred_visibility = model(video_interp,
|
||||
grid_query_frame=grid_query_frame, backward_tracking=True,
|
||||
segm_mask=segm_mask )
|
||||
end_time = time.time()
|
||||
|
||||
print("Time taken: {:.2f} seconds".format(end_time - start_time))
|
||||
|
||||
start_time = time.time()
|
||||
print(f'start visualizing')
|
||||
vis = Visualizer(
|
||||
save_dir='./videos',
|
||||
pad_value=20,
|
||||
linewidth=1,
|
||||
mode='optical_flow'
|
||||
)
|
||||
print(f'vis initialized')
|
||||
end_time = time.time()
|
||||
print("Time taken: {:.2f} seconds".format(end_time - start_time))
|
||||
start_time = time.time()
|
||||
print(f'start visualize')
|
||||
vis.visualize(
|
||||
video=video_interp,
|
||||
# video=video,
|
||||
tracks=pred_tracks,
|
||||
visibility=pred_visibility,
|
||||
filename='dense2');
|
||||
print(f'done')
|
||||
end_time = time.time()
|
||||
print("Time taken: {:.2f} seconds".format(end_time - start_time))
|
@@ -1,3 +1,10 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import os
|
||||
import torch
|
||||
import gradio as gr
|
||||
@@ -22,7 +29,12 @@ def cotracker_demo(
|
||||
model = model.cuda()
|
||||
load_video = load_video.cuda()
|
||||
|
||||
model(video_chunk=load_video, is_first_step=True, grid_size=grid_size)
|
||||
model(
|
||||
video_chunk=load_video,
|
||||
is_first_step=True,
|
||||
grid_size=grid_size,
|
||||
grid_query_frame=grid_query_frame,
|
||||
)
|
||||
for ind in range(0, load_video.shape[1] - model.step, model.step):
|
||||
pred_tracks, pred_visibility = model(
|
||||
video_chunk=load_video[:, ind : ind + model.step * 2]
|
||||
|
File diff suppressed because one or more lines are too long
@@ -4,6 +4,7 @@
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import torch
|
||||
import argparse
|
||||
import imageio.v3 as iio
|
||||
@@ -44,6 +45,9 @@ if __name__ == "__main__":
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.isfile(args.video_path):
|
||||
raise ValueError("Video file does not exist")
|
||||
|
||||
if args.checkpoint is not None:
|
||||
model = CoTrackerOnlinePredictor(checkpoint=args.checkpoint)
|
||||
else:
|
||||
@@ -52,25 +56,33 @@ if __name__ == "__main__":
|
||||
|
||||
window_frames = []
|
||||
|
||||
def _process_step(window_frames, is_first_step, grid_size):
|
||||
def _process_step(window_frames, is_first_step, grid_size, grid_query_frame):
|
||||
video_chunk = (
|
||||
torch.tensor(np.stack(window_frames[-model.step * 2 :]), device=DEFAULT_DEVICE)
|
||||
.float()
|
||||
.permute(0, 3, 1, 2)[None]
|
||||
) # (1, T, 3, H, W)
|
||||
return model(video_chunk, is_first_step=is_first_step, grid_size=grid_size)
|
||||
return model(
|
||||
video_chunk,
|
||||
is_first_step=is_first_step,
|
||||
grid_size=grid_size,
|
||||
grid_query_frame=grid_query_frame,
|
||||
)
|
||||
|
||||
# Iterating over video frames, processing one window at a time:
|
||||
is_first_step = True
|
||||
for i, frame in enumerate(
|
||||
iio.imiter(
|
||||
"https://github.com/facebookresearch/co-tracker/blob/main/assets/apple.mp4",
|
||||
args.video_path,
|
||||
plugin="FFMPEG",
|
||||
)
|
||||
):
|
||||
if i % model.step == 0 and i != 0:
|
||||
pred_tracks, pred_visibility = _process_step(
|
||||
window_frames, is_first_step, grid_size=args.grid_size
|
||||
window_frames,
|
||||
is_first_step,
|
||||
grid_size=args.grid_size,
|
||||
grid_query_frame=args.grid_query_frame,
|
||||
)
|
||||
is_first_step = False
|
||||
window_frames.append(frame)
|
||||
@@ -79,12 +91,13 @@ if __name__ == "__main__":
|
||||
window_frames[-(i % model.step) - model.step - 1 :],
|
||||
is_first_step,
|
||||
grid_size=args.grid_size,
|
||||
grid_query_frame=args.grid_query_frame,
|
||||
)
|
||||
|
||||
print("Tracks are computed")
|
||||
|
||||
# save a video with predicted tracks
|
||||
seq_name = args.video_path.split("/")[-1]
|
||||
seq_name = os.path.splitext(args.video_path.split("/")[-1])[0]
|
||||
video = torch.tensor(np.stack(window_frames), device=DEFAULT_DEVICE).permute(0, 3, 1, 2)[None]
|
||||
vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3)
|
||||
vis.visualize(video, pred_tracks, pred_visibility, query_frame=args.grid_query_frame)
|
||||
vis.visualize(video, pred_tracks, pred_visibility, query_frame=args.grid_query_frame, filename=seq_name)
|
||||
|
Reference in New Issue
Block a user