add some comments
This commit is contained in:
BIN
cotracker/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
cotracker/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
cotracker/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
cotracker/__pycache__/predictor.cpython-38.pyc
Normal file
BIN
cotracker/__pycache__/predictor.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/__pycache__/predictor.cpython-39.pyc
Normal file
BIN
cotracker/__pycache__/predictor.cpython-39.pyc
Normal file
Binary file not shown.
BIN
cotracker/models/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
cotracker/models/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/models/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
cotracker/models/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
cotracker/models/__pycache__/build_cotracker.cpython-38.pyc
Normal file
BIN
cotracker/models/__pycache__/build_cotracker.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/models/__pycache__/build_cotracker.cpython-39.pyc
Normal file
BIN
cotracker/models/__pycache__/build_cotracker.cpython-39.pyc
Normal file
Binary file not shown.
BIN
cotracker/models/core/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
cotracker/models/core/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/models/core/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
cotracker/models/core/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
cotracker/models/core/__pycache__/embeddings.cpython-38.pyc
Normal file
BIN
cotracker/models/core/__pycache__/embeddings.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/models/core/__pycache__/embeddings.cpython-39.pyc
Normal file
BIN
cotracker/models/core/__pycache__/embeddings.cpython-39.pyc
Normal file
Binary file not shown.
BIN
cotracker/models/core/__pycache__/model_utils.cpython-38.pyc
Normal file
BIN
cotracker/models/core/__pycache__/model_utils.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/models/core/__pycache__/model_utils.cpython-39.pyc
Normal file
BIN
cotracker/models/core/__pycache__/model_utils.cpython-39.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -191,6 +191,7 @@ class BasicEncoder(nn.Module):
|
|||||||
x = self.norm1(x)
|
x = self.norm1(x)
|
||||||
x = self.relu1(x)
|
x = self.relu1(x)
|
||||||
|
|
||||||
|
# 四层残差块
|
||||||
a = self.layer1(x)
|
a = self.layer1(x)
|
||||||
b = self.layer2(a)
|
b = self.layer2(a)
|
||||||
c = self.layer3(b)
|
c = self.layer3(b)
|
||||||
|
@@ -41,6 +41,7 @@ class CoTracker2(nn.Module):
|
|||||||
self.hidden_dim = 256
|
self.hidden_dim = 256
|
||||||
self.latent_dim = 128
|
self.latent_dim = 128
|
||||||
self.add_space_attn = add_space_attn
|
self.add_space_attn = add_space_attn
|
||||||
|
|
||||||
self.fnet = BasicEncoder(output_dim=self.latent_dim)
|
self.fnet = BasicEncoder(output_dim=self.latent_dim)
|
||||||
self.num_virtual_tracks = num_virtual_tracks
|
self.num_virtual_tracks = num_virtual_tracks
|
||||||
self.model_resolution = model_resolution
|
self.model_resolution = model_resolution
|
||||||
@@ -107,6 +108,7 @@ class CoTracker2(nn.Module):
|
|||||||
B, S_init, N, __ = track_mask.shape
|
B, S_init, N, __ = track_mask.shape
|
||||||
B, S, *_ = fmaps.shape
|
B, S, *_ = fmaps.shape
|
||||||
|
|
||||||
|
# 填充使得track_mask 的帧数与特征图的帧数一致。
|
||||||
track_mask = F.pad(track_mask, (0, 0, 0, 0, 0, S - S_init), "constant")
|
track_mask = F.pad(track_mask, (0, 0, 0, 0, 0, S - S_init), "constant")
|
||||||
track_mask_vis = (
|
track_mask_vis = (
|
||||||
torch.cat([track_mask, vis], dim=-1).permute(0, 2, 1, 3).reshape(B * N, S, 2)
|
torch.cat([track_mask, vis], dim=-1).permute(0, 2, 1, 3).reshape(B * N, S, 2)
|
||||||
@@ -171,6 +173,7 @@ class CoTracker2(nn.Module):
|
|||||||
],
|
],
|
||||||
dim=-1,
|
dim=-1,
|
||||||
)
|
)
|
||||||
|
# 双线性采样
|
||||||
sample_track_feats = sample_features5d(fmaps, sample_coords)
|
sample_track_feats = sample_features5d(fmaps, sample_coords)
|
||||||
return sample_track_feats
|
return sample_track_feats
|
||||||
|
|
||||||
@@ -227,22 +230,24 @@ class CoTracker2(nn.Module):
|
|||||||
|
|
||||||
# The first channel is the frame number
|
# The first channel is the frame number
|
||||||
# The rest are the coordinates of points we want to track
|
# The rest are the coordinates of points we want to track
|
||||||
queried_frames = queries[:, :, 0].long()
|
queried_frames = queries[:, :, 0].long() # 获取帧数字
|
||||||
|
|
||||||
queried_coords = queries[..., 1:]
|
queried_coords = queries[..., 1:]
|
||||||
queried_coords = queried_coords / self.stride
|
queried_coords = queried_coords / self.stride # 缩放
|
||||||
|
|
||||||
# We store our predictions here
|
# We store our predictions here
|
||||||
coords_predicted = torch.zeros((B, T, N, 2), device=device)
|
coords_predicted = torch.zeros((B, T, N, 2), device=device) # 等待处理的预测的点
|
||||||
vis_predicted = torch.zeros((B, T, N), device=device)
|
vis_predicted = torch.zeros((B, T, N), device=device)
|
||||||
if is_online:
|
if is_online:
|
||||||
|
# 如果online的话,坐标都制成0, vis都是false
|
||||||
|
# 如果不是在线,就填充一圈0
|
||||||
if self.online_coords_predicted is None:
|
if self.online_coords_predicted is None:
|
||||||
# Init online predictions with zeros
|
# Init online predictions with zeros
|
||||||
self.online_coords_predicted = coords_predicted
|
self.online_coords_predicted = coords_predicted
|
||||||
self.online_vis_predicted = vis_predicted
|
self.online_vis_predicted = vis_predicted
|
||||||
else:
|
else:
|
||||||
# Pad online predictions with zeros for the current window
|
# Pad online predictions with zeros for the current window
|
||||||
pad = min(step, T - step)
|
pad = min(step, T - step) # 确保填充量不会超过剩余的时间帧数
|
||||||
coords_predicted = F.pad(
|
coords_predicted = F.pad(
|
||||||
self.online_coords_predicted, (0, 0, 0, 0, 0, pad), "constant"
|
self.online_coords_predicted, (0, 0, 0, 0, 0, pad), "constant"
|
||||||
)
|
)
|
||||||
@@ -250,19 +255,24 @@ class CoTracker2(nn.Module):
|
|||||||
all_coords_predictions, all_vis_predictions = [], []
|
all_coords_predictions, all_vis_predictions = [], []
|
||||||
|
|
||||||
# Pad the video so that an integer number of sliding windows fit into it
|
# Pad the video so that an integer number of sliding windows fit into it
|
||||||
|
# 填充视频,使得一个整数的滑动窗口能够适应它
|
||||||
# TODO: we may drop this requirement because the transformer should not care
|
# TODO: we may drop this requirement because the transformer should not care
|
||||||
# TODO: pad the features instead of the video
|
# TODO: pad the features instead of the video
|
||||||
|
# 下面这行计算需要填充的帧数
|
||||||
pad = S - T if is_online else (S - T % S) % S # We don't want to pad if T % S == 0
|
pad = S - T if is_online else (S - T % S) % S # We don't want to pad if T % S == 0
|
||||||
|
# 填充将最后一个帧复制pad遍
|
||||||
video = F.pad(video.reshape(B, 1, T, C * H * W), (0, 0, 0, pad), "replicate").reshape(
|
video = F.pad(video.reshape(B, 1, T, C * H * W), (0, 0, 0, pad), "replicate").reshape(
|
||||||
B, -1, C, H, W
|
B, -1, C, H, W
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compute convolutional features for the video or for the current chunk in case of online mode
|
# Compute convolutional features for the video or for the current chunk in case of online mode
|
||||||
|
# 计算视频的卷积特征或者是在线计算当前的块
|
||||||
fmaps = self.fnet(video.reshape(-1, C, H, W)).reshape(
|
fmaps = self.fnet(video.reshape(-1, C, H, W)).reshape(
|
||||||
B, -1, self.latent_dim, H // self.stride, W // self.stride
|
B, -1, self.latent_dim, H // self.stride, W // self.stride
|
||||||
)
|
)
|
||||||
|
|
||||||
# We compute track features
|
# We compute track features
|
||||||
|
# 内部是用双线性采样求feature maps feature
|
||||||
track_feat = self.get_track_feat(
|
track_feat = self.get_track_feat(
|
||||||
fmaps,
|
fmaps,
|
||||||
queried_frames - self.online_ind if is_online else queried_frames,
|
queried_frames - self.online_ind if is_online else queried_frames,
|
||||||
@@ -284,14 +294,17 @@ class CoTracker2(nn.Module):
|
|||||||
# We process only the current video chunk in the online mode
|
# We process only the current video chunk in the online mode
|
||||||
indices = [self.online_ind] if is_online else range(0, step * num_windows, step)
|
indices = [self.online_ind] if is_online else range(0, step * num_windows, step)
|
||||||
|
|
||||||
|
# 查询的坐标调整形状
|
||||||
coords_init = queried_coords.reshape(B, 1, N, 2).expand(B, S, N, 2).float()
|
coords_init = queried_coords.reshape(B, 1, N, 2).expand(B, S, N, 2).float()
|
||||||
vis_init = torch.ones((B, S, N, 1), device=device).float() * 10
|
vis_init = torch.ones((B, S, N, 1), device=device).float() * 10
|
||||||
for ind in indices:
|
for ind in indices:
|
||||||
# We copy over coords and vis for tracks that are queried
|
# We copy over coords and vis for tracks that are queried
|
||||||
# by the end of the previous window, which is ind + overlap
|
# by the end of the previous window, which is ind + overlap
|
||||||
|
# 处理重叠部分
|
||||||
if ind > 0:
|
if ind > 0:
|
||||||
overlap = S - step
|
overlap = S - step
|
||||||
copy_over = (queried_frames < ind + overlap)[:, None, :, None] # B 1 N 1
|
copy_over = (queried_frames < ind + overlap)[:, None, :, None] # B 1 N 1
|
||||||
|
# 复制前一个窗口的预测结果
|
||||||
coords_prev = torch.nn.functional.pad(
|
coords_prev = torch.nn.functional.pad(
|
||||||
coords_predicted[:, ind : ind + overlap] / self.stride,
|
coords_predicted[:, ind : ind + overlap] / self.stride,
|
||||||
(0, 0, 0, 0, 0, step),
|
(0, 0, 0, 0, 0, step),
|
||||||
@@ -304,16 +317,18 @@ class CoTracker2(nn.Module):
|
|||||||
) # B S N 1
|
) # B S N 1
|
||||||
coords_init = torch.where(
|
coords_init = torch.where(
|
||||||
copy_over.expand_as(coords_init), coords_prev, coords_init
|
copy_over.expand_as(coords_init), coords_prev, coords_init
|
||||||
)
|
)# True就是coords_prev, False 就是coords_init
|
||||||
vis_init = torch.where(copy_over.expand_as(vis_init), vis_prev, vis_init)
|
vis_init = torch.where(copy_over.expand_as(vis_init), vis_prev, vis_init)
|
||||||
|
|
||||||
# The attention mask is 1 for the spatio-temporal points within
|
# The attention mask is 1 for the spatio-temporal points within
|
||||||
# a track which is updated in the current window
|
# a track which is updated in the current window
|
||||||
|
# 用于表示在当前窗口内需要更新的时间-空间点
|
||||||
attention_mask = (queried_frames < ind + S).reshape(B, 1, N).repeat(1, S, 1) # B S N
|
attention_mask = (queried_frames < ind + S).reshape(B, 1, N).repeat(1, S, 1) # B S N
|
||||||
|
|
||||||
# The track mask is 1 for the spatio-temporal points that actually
|
# The track mask is 1 for the spatio-temporal points that actually
|
||||||
# need updating: only after begin queried, and not if contained
|
# need updating: only after begin queried, and not if contained
|
||||||
# in a previous window
|
# in a previous window
|
||||||
|
# track_mask表示实际需要更新的
|
||||||
track_mask = (
|
track_mask = (
|
||||||
queried_frames[:, None, :, None]
|
queried_frames[:, None, :, None]
|
||||||
<= torch.arange(ind, ind + S, device=device)[None, :, None, None]
|
<= torch.arange(ind, ind + S, device=device)[None, :, None, None]
|
||||||
@@ -323,6 +338,7 @@ class CoTracker2(nn.Module):
|
|||||||
track_mask[:, :overlap, :, :] = False
|
track_mask[:, :overlap, :, :] = False
|
||||||
|
|
||||||
# Predict the coordinates and visibility for the current window
|
# Predict the coordinates and visibility for the current window
|
||||||
|
# 用forward_window 来更新coords和vis
|
||||||
coords, vis = self.forward_window(
|
coords, vis = self.forward_window(
|
||||||
fmaps=fmaps if is_online else fmaps[:, ind : ind + S],
|
fmaps=fmaps if is_online else fmaps[:, ind : ind + S],
|
||||||
coords=coords_init,
|
coords=coords_init,
|
||||||
|
@@ -56,6 +56,10 @@ class CoTrackerPredictor(torch.nn.Module):
|
|||||||
|
|
||||||
return tracks, visibilities
|
return tracks, visibilities
|
||||||
|
|
||||||
|
# gpu dense inference time
|
||||||
|
# raft gpu comparison
|
||||||
|
# vision effects
|
||||||
|
# raft integrated
|
||||||
def _compute_dense_tracks(self, video, grid_query_frame, grid_size=80, backward_tracking=False):
|
def _compute_dense_tracks(self, video, grid_query_frame, grid_size=80, backward_tracking=False):
|
||||||
*_, H, W = video.shape
|
*_, H, W = video.shape
|
||||||
grid_step = W // grid_size
|
grid_step = W // grid_size
|
||||||
|
BIN
cotracker/utils/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
cotracker/utils/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/utils/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
cotracker/utils/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
cotracker/utils/__pycache__/visualizer.cpython-38.pyc
Normal file
BIN
cotracker/utils/__pycache__/visualizer.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/utils/__pycache__/visualizer.cpython-39.pyc
Normal file
BIN
cotracker/utils/__pycache__/visualizer.cpython-39.pyc
Normal file
Binary file not shown.
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user