add some comments

This commit is contained in:
2024-08-05 23:36:58 +02:00
parent 36d1566750
commit 6e7bcd2d26
55 changed files with 3946 additions and 4095 deletions

View File

@@ -1,5 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -1,5 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

View File

@@ -1,166 +1,166 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import json
import dataclasses
import numpy as np
from dataclasses import Field, MISSING
from typing import IO, TypeVar, Type, get_args, get_origin, Union, Any, Tuple
_X = TypeVar("_X")
def load_dataclass(f: IO, cls: Type[_X], binary: bool = False) -> _X:
"""
Loads to a @dataclass or collection hierarchy including dataclasses
from a json recursively.
Call it like load_dataclass(f, typing.List[FrameAnnotationAnnotation]).
raises KeyError if json has keys not mapping to the dataclass fields.
Args:
f: Either a path to a file, or a file opened for writing.
cls: The class of the loaded dataclass.
binary: Set to True if `f` is a file handle, else False.
"""
if binary:
asdict = json.loads(f.read().decode("utf8"))
else:
asdict = json.load(f)
# in the list case, run a faster "vectorized" version
cls = get_args(cls)[0]
res = list(_dataclass_list_from_dict_list(asdict, cls))
return res
def _resolve_optional(type_: Any) -> Tuple[bool, Any]:
"""Check whether `type_` is equivalent to `typing.Optional[T]` for some T."""
if get_origin(type_) is Union:
args = get_args(type_)
if len(args) == 2 and args[1] == type(None): # noqa E721
return True, args[0]
if type_ is Any:
return True, Any
return False, type_
def _unwrap_type(tp):
# strips Optional wrapper, if any
if get_origin(tp) is Union:
args = get_args(tp)
if len(args) == 2 and any(a is type(None) for a in args): # noqa: E721
# this is typing.Optional
return args[0] if args[1] is type(None) else args[1] # noqa: E721
return tp
def _get_dataclass_field_default(field: Field) -> Any:
if field.default_factory is not MISSING:
# pyre-fixme[29]: `Union[dataclasses._MISSING_TYPE,
# dataclasses._DefaultFactory[typing.Any]]` is not a function.
return field.default_factory()
elif field.default is not MISSING:
return field.default
else:
return None
def _dataclass_list_from_dict_list(dlist, typeannot):
"""
Vectorised version of `_dataclass_from_dict`.
The output should be equivalent to
`[_dataclass_from_dict(d, typeannot) for d in dlist]`.
Args:
dlist: list of objects to convert.
typeannot: type of each of those objects.
Returns:
iterator or list over converted objects of the same length as `dlist`.
Raises:
ValueError: it assumes the objects have None's in consistent places across
objects, otherwise it would ignore some values. This generally holds for
auto-generated annotations, but otherwise use `_dataclass_from_dict`.
"""
cls = get_origin(typeannot) or typeannot
if typeannot is Any:
return dlist
if all(obj is None for obj in dlist): # 1st recursion base: all None nodes
return dlist
if any(obj is None for obj in dlist):
# filter out Nones and recurse on the resulting list
idx_notnone = [(i, obj) for i, obj in enumerate(dlist) if obj is not None]
idx, notnone = zip(*idx_notnone)
converted = _dataclass_list_from_dict_list(notnone, typeannot)
res = [None] * len(dlist)
for i, obj in zip(idx, converted):
res[i] = obj
return res
is_optional, contained_type = _resolve_optional(typeannot)
if is_optional:
return _dataclass_list_from_dict_list(dlist, contained_type)
# otherwise, we dispatch by the type of the provided annotation to convert to
if issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
# For namedtuple, call the function recursively on the lists of corresponding keys
types = cls.__annotations__.values()
dlist_T = zip(*dlist)
res_T = [
_dataclass_list_from_dict_list(key_list, tp) for key_list, tp in zip(dlist_T, types)
]
return [cls(*converted_as_tuple) for converted_as_tuple in zip(*res_T)]
elif issubclass(cls, (list, tuple)):
# For list/tuple, call the function recursively on the lists of corresponding positions
types = get_args(typeannot)
if len(types) == 1: # probably List; replicate for all items
types = types * len(dlist[0])
dlist_T = zip(*dlist)
res_T = (
_dataclass_list_from_dict_list(pos_list, tp) for pos_list, tp in zip(dlist_T, types)
)
if issubclass(cls, tuple):
return list(zip(*res_T))
else:
return [cls(converted_as_tuple) for converted_as_tuple in zip(*res_T)]
elif issubclass(cls, dict):
# For the dictionary, call the function recursively on concatenated keys and vertices
key_t, val_t = get_args(typeannot)
all_keys_res = _dataclass_list_from_dict_list(
[k for obj in dlist for k in obj.keys()], key_t
)
all_vals_res = _dataclass_list_from_dict_list(
[k for obj in dlist for k in obj.values()], val_t
)
indices = np.cumsum([len(obj) for obj in dlist])
assert indices[-1] == len(all_keys_res)
keys = np.split(list(all_keys_res), indices[:-1])
all_vals_res_iter = iter(all_vals_res)
return [cls(zip(k, all_vals_res_iter)) for k in keys]
elif not dataclasses.is_dataclass(typeannot):
return dlist
# dataclass node: 2nd recursion base; call the function recursively on the lists
# of the corresponding fields
assert dataclasses.is_dataclass(cls)
fieldtypes = {
f.name: (_unwrap_type(f.type), _get_dataclass_field_default(f))
for f in dataclasses.fields(typeannot)
}
# NOTE the default object is shared here
key_lists = (
_dataclass_list_from_dict_list([obj.get(k, default) for obj in dlist], type_)
for k, (type_, default) in fieldtypes.items()
)
transposed = zip(*key_lists)
return [cls(*vals_as_tuple) for vals_as_tuple in transposed]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import json
import dataclasses
import numpy as np
from dataclasses import Field, MISSING
from typing import IO, TypeVar, Type, get_args, get_origin, Union, Any, Tuple
_X = TypeVar("_X")
def load_dataclass(f: IO, cls: Type[_X], binary: bool = False) -> _X:
"""
Loads to a @dataclass or collection hierarchy including dataclasses
from a json recursively.
Call it like load_dataclass(f, typing.List[FrameAnnotationAnnotation]).
raises KeyError if json has keys not mapping to the dataclass fields.
Args:
f: Either a path to a file, or a file opened for writing.
cls: The class of the loaded dataclass.
binary: Set to True if `f` is a file handle, else False.
"""
if binary:
asdict = json.loads(f.read().decode("utf8"))
else:
asdict = json.load(f)
# in the list case, run a faster "vectorized" version
cls = get_args(cls)[0]
res = list(_dataclass_list_from_dict_list(asdict, cls))
return res
def _resolve_optional(type_: Any) -> Tuple[bool, Any]:
"""Check whether `type_` is equivalent to `typing.Optional[T]` for some T."""
if get_origin(type_) is Union:
args = get_args(type_)
if len(args) == 2 and args[1] == type(None): # noqa E721
return True, args[0]
if type_ is Any:
return True, Any
return False, type_
def _unwrap_type(tp):
# strips Optional wrapper, if any
if get_origin(tp) is Union:
args = get_args(tp)
if len(args) == 2 and any(a is type(None) for a in args): # noqa: E721
# this is typing.Optional
return args[0] if args[1] is type(None) else args[1] # noqa: E721
return tp
def _get_dataclass_field_default(field: Field) -> Any:
if field.default_factory is not MISSING:
# pyre-fixme[29]: `Union[dataclasses._MISSING_TYPE,
# dataclasses._DefaultFactory[typing.Any]]` is not a function.
return field.default_factory()
elif field.default is not MISSING:
return field.default
else:
return None
def _dataclass_list_from_dict_list(dlist, typeannot):
"""
Vectorised version of `_dataclass_from_dict`.
The output should be equivalent to
`[_dataclass_from_dict(d, typeannot) for d in dlist]`.
Args:
dlist: list of objects to convert.
typeannot: type of each of those objects.
Returns:
iterator or list over converted objects of the same length as `dlist`.
Raises:
ValueError: it assumes the objects have None's in consistent places across
objects, otherwise it would ignore some values. This generally holds for
auto-generated annotations, but otherwise use `_dataclass_from_dict`.
"""
cls = get_origin(typeannot) or typeannot
if typeannot is Any:
return dlist
if all(obj is None for obj in dlist): # 1st recursion base: all None nodes
return dlist
if any(obj is None for obj in dlist):
# filter out Nones and recurse on the resulting list
idx_notnone = [(i, obj) for i, obj in enumerate(dlist) if obj is not None]
idx, notnone = zip(*idx_notnone)
converted = _dataclass_list_from_dict_list(notnone, typeannot)
res = [None] * len(dlist)
for i, obj in zip(idx, converted):
res[i] = obj
return res
is_optional, contained_type = _resolve_optional(typeannot)
if is_optional:
return _dataclass_list_from_dict_list(dlist, contained_type)
# otherwise, we dispatch by the type of the provided annotation to convert to
if issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
# For namedtuple, call the function recursively on the lists of corresponding keys
types = cls.__annotations__.values()
dlist_T = zip(*dlist)
res_T = [
_dataclass_list_from_dict_list(key_list, tp) for key_list, tp in zip(dlist_T, types)
]
return [cls(*converted_as_tuple) for converted_as_tuple in zip(*res_T)]
elif issubclass(cls, (list, tuple)):
# For list/tuple, call the function recursively on the lists of corresponding positions
types = get_args(typeannot)
if len(types) == 1: # probably List; replicate for all items
types = types * len(dlist[0])
dlist_T = zip(*dlist)
res_T = (
_dataclass_list_from_dict_list(pos_list, tp) for pos_list, tp in zip(dlist_T, types)
)
if issubclass(cls, tuple):
return list(zip(*res_T))
else:
return [cls(converted_as_tuple) for converted_as_tuple in zip(*res_T)]
elif issubclass(cls, dict):
# For the dictionary, call the function recursively on concatenated keys and vertices
key_t, val_t = get_args(typeannot)
all_keys_res = _dataclass_list_from_dict_list(
[k for obj in dlist for k in obj.keys()], key_t
)
all_vals_res = _dataclass_list_from_dict_list(
[k for obj in dlist for k in obj.values()], val_t
)
indices = np.cumsum([len(obj) for obj in dlist])
assert indices[-1] == len(all_keys_res)
keys = np.split(list(all_keys_res), indices[:-1])
all_vals_res_iter = iter(all_vals_res)
return [cls(zip(k, all_vals_res_iter)) for k in keys]
elif not dataclasses.is_dataclass(typeannot):
return dlist
# dataclass node: 2nd recursion base; call the function recursively on the lists
# of the corresponding fields
assert dataclasses.is_dataclass(cls)
fieldtypes = {
f.name: (_unwrap_type(f.type), _get_dataclass_field_default(f))
for f in dataclasses.fields(typeannot)
}
# NOTE the default object is shared here
key_lists = (
_dataclass_list_from_dict_list([obj.get(k, default) for obj in dlist], type_)
for k, (type_, default) in fieldtypes.items()
)
transposed = zip(*key_lists)
return [cls(*vals_as_tuple) for vals_as_tuple in transposed]

View File

@@ -1,161 +1,161 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import gzip
import torch
import numpy as np
import torch.utils.data as data
from collections import defaultdict
from dataclasses import dataclass
from typing import List, Optional, Any, Dict, Tuple
from cotracker.datasets.utils import CoTrackerData
from cotracker.datasets.dataclass_utils import load_dataclass
@dataclass
class ImageAnnotation:
# path to jpg file, relative w.r.t. dataset_root
path: str
# H x W
size: Tuple[int, int]
@dataclass
class DynamicReplicaFrameAnnotation:
"""A dataclass used to load annotations from json."""
# can be used to join with `SequenceAnnotation`
sequence_name: str
# 0-based, continuous frame number within sequence
frame_number: int
# timestamp in seconds from the video start
frame_timestamp: float
image: ImageAnnotation
meta: Optional[Dict[str, Any]] = None
camera_name: Optional[str] = None
trajectories: Optional[str] = None
class DynamicReplicaDataset(data.Dataset):
def __init__(
self,
root,
split="valid",
traj_per_sample=256,
crop_size=None,
sample_len=-1,
only_first_n_samples=-1,
rgbd_input=False,
):
super(DynamicReplicaDataset, self).__init__()
self.root = root
self.sample_len = sample_len
self.split = split
self.traj_per_sample = traj_per_sample
self.rgbd_input = rgbd_input
self.crop_size = crop_size
frame_annotations_file = f"frame_annotations_{split}.jgz"
self.sample_list = []
with gzip.open(
os.path.join(root, split, frame_annotations_file), "rt", encoding="utf8"
) as zipfile:
frame_annots_list = load_dataclass(zipfile, List[DynamicReplicaFrameAnnotation])
seq_annot = defaultdict(list)
for frame_annot in frame_annots_list:
if frame_annot.camera_name == "left":
seq_annot[frame_annot.sequence_name].append(frame_annot)
for seq_name in seq_annot.keys():
seq_len = len(seq_annot[seq_name])
step = self.sample_len if self.sample_len > 0 else seq_len
counter = 0
for ref_idx in range(0, seq_len, step):
sample = seq_annot[seq_name][ref_idx : ref_idx + step]
self.sample_list.append(sample)
counter += 1
if only_first_n_samples > 0 and counter >= only_first_n_samples:
break
def __len__(self):
return len(self.sample_list)
def crop(self, rgbs, trajs):
T, N, _ = trajs.shape
S = len(rgbs)
H, W = rgbs[0].shape[:2]
assert S == T
H_new = H
W_new = W
# simple random crop
y0 = 0 if self.crop_size[0] >= H_new else (H_new - self.crop_size[0]) // 2
x0 = 0 if self.crop_size[1] >= W_new else (W_new - self.crop_size[1]) // 2
rgbs = [rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] for rgb in rgbs]
trajs[:, :, 0] -= x0
trajs[:, :, 1] -= y0
return rgbs, trajs
def __getitem__(self, index):
sample = self.sample_list[index]
T = len(sample)
rgbs, visibilities, traj_2d = [], [], []
H, W = sample[0].image.size
image_size = (H, W)
for i in range(T):
traj_path = os.path.join(self.root, self.split, sample[i].trajectories["path"])
traj = torch.load(traj_path)
visibilities.append(traj["verts_inds_vis"].numpy())
rgbs.append(traj["img"].numpy())
traj_2d.append(traj["traj_2d"].numpy()[..., :2])
traj_2d = np.stack(traj_2d)
visibility = np.stack(visibilities)
T, N, D = traj_2d.shape
# subsample trajectories for augmentations
visible_inds_sampled = torch.randperm(N)[: self.traj_per_sample]
traj_2d = traj_2d[:, visible_inds_sampled]
visibility = visibility[:, visible_inds_sampled]
if self.crop_size is not None:
rgbs, traj_2d = self.crop(rgbs, traj_2d)
H, W, _ = rgbs[0].shape
image_size = self.crop_size
visibility[traj_2d[:, :, 0] > image_size[1] - 1] = False
visibility[traj_2d[:, :, 0] < 0] = False
visibility[traj_2d[:, :, 1] > image_size[0] - 1] = False
visibility[traj_2d[:, :, 1] < 0] = False
# filter out points that're visible for less than 10 frames
visible_inds_resampled = visibility.sum(0) > 10
traj_2d = torch.from_numpy(traj_2d[:, visible_inds_resampled])
visibility = torch.from_numpy(visibility[:, visible_inds_resampled])
rgbs = np.stack(rgbs, 0)
video = torch.from_numpy(rgbs).reshape(T, H, W, 3).permute(0, 3, 1, 2).float()
return CoTrackerData(
video=video,
trajectory=traj_2d,
visibility=visibility,
valid=torch.ones(T, N),
seq_name=sample[0].sequence_name,
)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import gzip
import torch
import numpy as np
import torch.utils.data as data
from collections import defaultdict
from dataclasses import dataclass
from typing import List, Optional, Any, Dict, Tuple
from cotracker.datasets.utils import CoTrackerData
from cotracker.datasets.dataclass_utils import load_dataclass
@dataclass
class ImageAnnotation:
# path to jpg file, relative w.r.t. dataset_root
path: str
# H x W
size: Tuple[int, int]
@dataclass
class DynamicReplicaFrameAnnotation:
"""A dataclass used to load annotations from json."""
# can be used to join with `SequenceAnnotation`
sequence_name: str
# 0-based, continuous frame number within sequence
frame_number: int
# timestamp in seconds from the video start
frame_timestamp: float
image: ImageAnnotation
meta: Optional[Dict[str, Any]] = None
camera_name: Optional[str] = None
trajectories: Optional[str] = None
class DynamicReplicaDataset(data.Dataset):
def __init__(
self,
root,
split="valid",
traj_per_sample=256,
crop_size=None,
sample_len=-1,
only_first_n_samples=-1,
rgbd_input=False,
):
super(DynamicReplicaDataset, self).__init__()
self.root = root
self.sample_len = sample_len
self.split = split
self.traj_per_sample = traj_per_sample
self.rgbd_input = rgbd_input
self.crop_size = crop_size
frame_annotations_file = f"frame_annotations_{split}.jgz"
self.sample_list = []
with gzip.open(
os.path.join(root, split, frame_annotations_file), "rt", encoding="utf8"
) as zipfile:
frame_annots_list = load_dataclass(zipfile, List[DynamicReplicaFrameAnnotation])
seq_annot = defaultdict(list)
for frame_annot in frame_annots_list:
if frame_annot.camera_name == "left":
seq_annot[frame_annot.sequence_name].append(frame_annot)
for seq_name in seq_annot.keys():
seq_len = len(seq_annot[seq_name])
step = self.sample_len if self.sample_len > 0 else seq_len
counter = 0
for ref_idx in range(0, seq_len, step):
sample = seq_annot[seq_name][ref_idx : ref_idx + step]
self.sample_list.append(sample)
counter += 1
if only_first_n_samples > 0 and counter >= only_first_n_samples:
break
def __len__(self):
return len(self.sample_list)
def crop(self, rgbs, trajs):
T, N, _ = trajs.shape
S = len(rgbs)
H, W = rgbs[0].shape[:2]
assert S == T
H_new = H
W_new = W
# simple random crop
y0 = 0 if self.crop_size[0] >= H_new else (H_new - self.crop_size[0]) // 2
x0 = 0 if self.crop_size[1] >= W_new else (W_new - self.crop_size[1]) // 2
rgbs = [rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] for rgb in rgbs]
trajs[:, :, 0] -= x0
trajs[:, :, 1] -= y0
return rgbs, trajs
def __getitem__(self, index):
sample = self.sample_list[index]
T = len(sample)
rgbs, visibilities, traj_2d = [], [], []
H, W = sample[0].image.size
image_size = (H, W)
for i in range(T):
traj_path = os.path.join(self.root, self.split, sample[i].trajectories["path"])
traj = torch.load(traj_path)
visibilities.append(traj["verts_inds_vis"].numpy())
rgbs.append(traj["img"].numpy())
traj_2d.append(traj["traj_2d"].numpy()[..., :2])
traj_2d = np.stack(traj_2d)
visibility = np.stack(visibilities)
T, N, D = traj_2d.shape
# subsample trajectories for augmentations
visible_inds_sampled = torch.randperm(N)[: self.traj_per_sample]
traj_2d = traj_2d[:, visible_inds_sampled]
visibility = visibility[:, visible_inds_sampled]
if self.crop_size is not None:
rgbs, traj_2d = self.crop(rgbs, traj_2d)
H, W, _ = rgbs[0].shape
image_size = self.crop_size
visibility[traj_2d[:, :, 0] > image_size[1] - 1] = False
visibility[traj_2d[:, :, 0] < 0] = False
visibility[traj_2d[:, :, 1] > image_size[0] - 1] = False
visibility[traj_2d[:, :, 1] < 0] = False
# filter out points that're visible for less than 10 frames
visible_inds_resampled = visibility.sum(0) > 10
traj_2d = torch.from_numpy(traj_2d[:, visible_inds_resampled])
visibility = torch.from_numpy(visibility[:, visible_inds_resampled])
rgbs = np.stack(rgbs, 0)
video = torch.from_numpy(rgbs).reshape(T, H, W, 3).permute(0, 3, 1, 2).float()
return CoTrackerData(
video=video,
trajectory=traj_2d,
visibility=visibility,
valid=torch.ones(T, N),
seq_name=sample[0].sequence_name,
)

View File

@@ -1,441 +1,441 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import torch
import cv2
import imageio
import numpy as np
from cotracker.datasets.utils import CoTrackerData
from torchvision.transforms import ColorJitter, GaussianBlur
from PIL import Image
class CoTrackerDataset(torch.utils.data.Dataset):
def __init__(
self,
data_root,
crop_size=(384, 512),
seq_len=24,
traj_per_sample=768,
sample_vis_1st_frame=False,
use_augs=False,
):
super(CoTrackerDataset, self).__init__()
np.random.seed(0)
torch.manual_seed(0)
self.data_root = data_root
self.seq_len = seq_len
self.traj_per_sample = traj_per_sample
self.sample_vis_1st_frame = sample_vis_1st_frame
self.use_augs = use_augs
self.crop_size = crop_size
# photometric augmentation
self.photo_aug = ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.25 / 3.14)
self.blur_aug = GaussianBlur(11, sigma=(0.1, 2.0))
self.blur_aug_prob = 0.25
self.color_aug_prob = 0.25
# occlusion augmentation
self.eraser_aug_prob = 0.5
self.eraser_bounds = [2, 100]
self.eraser_max = 10
# occlusion augmentation
self.replace_aug_prob = 0.5
self.replace_bounds = [2, 100]
self.replace_max = 10
# spatial augmentations
self.pad_bounds = [0, 100]
self.crop_size = crop_size
self.resize_lim = [0.25, 2.0] # sample resizes from here
self.resize_delta = 0.2
self.max_crop_offset = 50
self.do_flip = True
self.h_flip_prob = 0.5
self.v_flip_prob = 0.5
def getitem_helper(self, index):
return NotImplementedError
def __getitem__(self, index):
gotit = False
sample, gotit = self.getitem_helper(index)
if not gotit:
print("warning: sampling failed")
# fake sample, so we can still collate
sample = CoTrackerData(
video=torch.zeros((self.seq_len, 3, self.crop_size[0], self.crop_size[1])),
trajectory=torch.zeros((self.seq_len, self.traj_per_sample, 2)),
visibility=torch.zeros((self.seq_len, self.traj_per_sample)),
valid=torch.zeros((self.seq_len, self.traj_per_sample)),
)
return sample, gotit
def add_photometric_augs(self, rgbs, trajs, visibles, eraser=True, replace=True):
T, N, _ = trajs.shape
S = len(rgbs)
H, W = rgbs[0].shape[:2]
assert S == T
if eraser:
############ eraser transform (per image after the first) ############
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
for i in range(1, S):
if np.random.rand() < self.eraser_aug_prob:
for _ in range(
np.random.randint(1, self.eraser_max + 1)
): # number of times to occlude
xc = np.random.randint(0, W)
yc = np.random.randint(0, H)
dx = np.random.randint(self.eraser_bounds[0], self.eraser_bounds[1])
dy = np.random.randint(self.eraser_bounds[0], self.eraser_bounds[1])
x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)
x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)
y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)
y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32)
mean_color = np.mean(rgbs[i][y0:y1, x0:x1, :].reshape(-1, 3), axis=0)
rgbs[i][y0:y1, x0:x1, :] = mean_color
occ_inds = np.logical_and(
np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1),
np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1),
)
visibles[i, occ_inds] = 0
rgbs = [rgb.astype(np.uint8) for rgb in rgbs]
if replace:
rgbs_alt = [
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs
]
rgbs_alt = [
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs_alt
]
############ replace transform (per image after the first) ############
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
rgbs_alt = [rgb.astype(np.float32) for rgb in rgbs_alt]
for i in range(1, S):
if np.random.rand() < self.replace_aug_prob:
for _ in range(
np.random.randint(1, self.replace_max + 1)
): # number of times to occlude
xc = np.random.randint(0, W)
yc = np.random.randint(0, H)
dx = np.random.randint(self.replace_bounds[0], self.replace_bounds[1])
dy = np.random.randint(self.replace_bounds[0], self.replace_bounds[1])
x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)
x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)
y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)
y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32)
wid = x1 - x0
hei = y1 - y0
y00 = np.random.randint(0, H - hei)
x00 = np.random.randint(0, W - wid)
fr = np.random.randint(0, S)
rep = rgbs_alt[fr][y00 : y00 + hei, x00 : x00 + wid, :]
rgbs[i][y0:y1, x0:x1, :] = rep
occ_inds = np.logical_and(
np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1),
np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1),
)
visibles[i, occ_inds] = 0
rgbs = [rgb.astype(np.uint8) for rgb in rgbs]
############ photometric augmentation ############
if np.random.rand() < self.color_aug_prob:
# random per-frame amount of aug
rgbs = [np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs]
if np.random.rand() < self.blur_aug_prob:
# random per-frame amount of blur
rgbs = [np.array(self.blur_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs]
return rgbs, trajs, visibles
def add_spatial_augs(self, rgbs, trajs, visibles):
T, N, __ = trajs.shape
S = len(rgbs)
H, W = rgbs[0].shape[:2]
assert S == T
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
############ spatial transform ############
# padding
pad_x0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
pad_x1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
pad_y0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
pad_y1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
rgbs = [np.pad(rgb, ((pad_y0, pad_y1), (pad_x0, pad_x1), (0, 0))) for rgb in rgbs]
trajs[:, :, 0] += pad_x0
trajs[:, :, 1] += pad_y0
H, W = rgbs[0].shape[:2]
# scaling + stretching
scale = np.random.uniform(self.resize_lim[0], self.resize_lim[1])
scale_x = scale
scale_y = scale
H_new = H
W_new = W
scale_delta_x = 0.0
scale_delta_y = 0.0
rgbs_scaled = []
for s in range(S):
if s == 1:
scale_delta_x = np.random.uniform(-self.resize_delta, self.resize_delta)
scale_delta_y = np.random.uniform(-self.resize_delta, self.resize_delta)
elif s > 1:
scale_delta_x = (
scale_delta_x * 0.8
+ np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2
)
scale_delta_y = (
scale_delta_y * 0.8
+ np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2
)
scale_x = scale_x + scale_delta_x
scale_y = scale_y + scale_delta_y
# bring h/w closer
scale_xy = (scale_x + scale_y) * 0.5
scale_x = scale_x * 0.5 + scale_xy * 0.5
scale_y = scale_y * 0.5 + scale_xy * 0.5
# don't get too crazy
scale_x = np.clip(scale_x, 0.2, 2.0)
scale_y = np.clip(scale_y, 0.2, 2.0)
H_new = int(H * scale_y)
W_new = int(W * scale_x)
# make it at least slightly bigger than the crop area,
# so that the random cropping can add diversity
H_new = np.clip(H_new, self.crop_size[0] + 10, None)
W_new = np.clip(W_new, self.crop_size[1] + 10, None)
# recompute scale in case we clipped
scale_x = (W_new - 1) / float(W - 1)
scale_y = (H_new - 1) / float(H - 1)
rgbs_scaled.append(cv2.resize(rgbs[s], (W_new, H_new), interpolation=cv2.INTER_LINEAR))
trajs[s, :, 0] *= scale_x
trajs[s, :, 1] *= scale_y
rgbs = rgbs_scaled
ok_inds = visibles[0, :] > 0
vis_trajs = trajs[:, ok_inds] # S,?,2
if vis_trajs.shape[1] > 0:
mid_x = np.mean(vis_trajs[0, :, 0])
mid_y = np.mean(vis_trajs[0, :, 1])
else:
mid_y = self.crop_size[0]
mid_x = self.crop_size[1]
x0 = int(mid_x - self.crop_size[1] // 2)
y0 = int(mid_y - self.crop_size[0] // 2)
offset_x = 0
offset_y = 0
for s in range(S):
# on each frame, shift a bit more
if s == 1:
offset_x = np.random.randint(-self.max_crop_offset, self.max_crop_offset)
offset_y = np.random.randint(-self.max_crop_offset, self.max_crop_offset)
elif s > 1:
offset_x = int(
offset_x * 0.8
+ np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1) * 0.2
)
offset_y = int(
offset_y * 0.8
+ np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1) * 0.2
)
x0 = x0 + offset_x
y0 = y0 + offset_y
H_new, W_new = rgbs[s].shape[:2]
if H_new == self.crop_size[0]:
y0 = 0
else:
y0 = min(max(0, y0), H_new - self.crop_size[0] - 1)
if W_new == self.crop_size[1]:
x0 = 0
else:
x0 = min(max(0, x0), W_new - self.crop_size[1] - 1)
rgbs[s] = rgbs[s][y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
trajs[s, :, 0] -= x0
trajs[s, :, 1] -= y0
H_new = self.crop_size[0]
W_new = self.crop_size[1]
# flip
h_flipped = False
v_flipped = False
if self.do_flip:
# h flip
if np.random.rand() < self.h_flip_prob:
h_flipped = True
rgbs = [rgb[:, ::-1] for rgb in rgbs]
# v flip
if np.random.rand() < self.v_flip_prob:
v_flipped = True
rgbs = [rgb[::-1] for rgb in rgbs]
if h_flipped:
trajs[:, :, 0] = W_new - trajs[:, :, 0]
if v_flipped:
trajs[:, :, 1] = H_new - trajs[:, :, 1]
return rgbs, trajs
def crop(self, rgbs, trajs):
T, N, _ = trajs.shape
S = len(rgbs)
H, W = rgbs[0].shape[:2]
assert S == T
############ spatial transform ############
H_new = H
W_new = W
# simple random crop
y0 = 0 if self.crop_size[0] >= H_new else np.random.randint(0, H_new - self.crop_size[0])
x0 = 0 if self.crop_size[1] >= W_new else np.random.randint(0, W_new - self.crop_size[1])
rgbs = [rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] for rgb in rgbs]
trajs[:, :, 0] -= x0
trajs[:, :, 1] -= y0
return rgbs, trajs
class KubricMovifDataset(CoTrackerDataset):
def __init__(
self,
data_root,
crop_size=(384, 512),
seq_len=24,
traj_per_sample=768,
sample_vis_1st_frame=False,
use_augs=False,
):
super(KubricMovifDataset, self).__init__(
data_root=data_root,
crop_size=crop_size,
seq_len=seq_len,
traj_per_sample=traj_per_sample,
sample_vis_1st_frame=sample_vis_1st_frame,
use_augs=use_augs,
)
self.pad_bounds = [0, 25]
self.resize_lim = [0.75, 1.25] # sample resizes from here
self.resize_delta = 0.05
self.max_crop_offset = 15
self.seq_names = [
fname
for fname in os.listdir(data_root)
if os.path.isdir(os.path.join(data_root, fname))
]
print("found %d unique videos in %s" % (len(self.seq_names), self.data_root))
def getitem_helper(self, index):
gotit = True
seq_name = self.seq_names[index]
npy_path = os.path.join(self.data_root, seq_name, seq_name + ".npy")
rgb_path = os.path.join(self.data_root, seq_name, "frames")
img_paths = sorted(os.listdir(rgb_path))
rgbs = []
for i, img_path in enumerate(img_paths):
rgbs.append(imageio.v2.imread(os.path.join(rgb_path, img_path)))
rgbs = np.stack(rgbs)
annot_dict = np.load(npy_path, allow_pickle=True).item()
traj_2d = annot_dict["coords"]
visibility = annot_dict["visibility"]
# random crop
assert self.seq_len <= len(rgbs)
if self.seq_len < len(rgbs):
start_ind = np.random.choice(len(rgbs) - self.seq_len, 1)[0]
rgbs = rgbs[start_ind : start_ind + self.seq_len]
traj_2d = traj_2d[:, start_ind : start_ind + self.seq_len]
visibility = visibility[:, start_ind : start_ind + self.seq_len]
traj_2d = np.transpose(traj_2d, (1, 0, 2))
visibility = np.transpose(np.logical_not(visibility), (1, 0))
if self.use_augs:
rgbs, traj_2d, visibility = self.add_photometric_augs(rgbs, traj_2d, visibility)
rgbs, traj_2d = self.add_spatial_augs(rgbs, traj_2d, visibility)
else:
rgbs, traj_2d = self.crop(rgbs, traj_2d)
visibility[traj_2d[:, :, 0] > self.crop_size[1] - 1] = False
visibility[traj_2d[:, :, 0] < 0] = False
visibility[traj_2d[:, :, 1] > self.crop_size[0] - 1] = False
visibility[traj_2d[:, :, 1] < 0] = False
visibility = torch.from_numpy(visibility)
traj_2d = torch.from_numpy(traj_2d)
visibile_pts_first_frame_inds = (visibility[0]).nonzero(as_tuple=False)[:, 0]
if self.sample_vis_1st_frame:
visibile_pts_inds = visibile_pts_first_frame_inds
else:
visibile_pts_mid_frame_inds = (visibility[self.seq_len // 2]).nonzero(as_tuple=False)[
:, 0
]
visibile_pts_inds = torch.cat(
(visibile_pts_first_frame_inds, visibile_pts_mid_frame_inds), dim=0
)
point_inds = torch.randperm(len(visibile_pts_inds))[: self.traj_per_sample]
if len(point_inds) < self.traj_per_sample:
gotit = False
visible_inds_sampled = visibile_pts_inds[point_inds]
trajs = traj_2d[:, visible_inds_sampled].float()
visibles = visibility[:, visible_inds_sampled]
valids = torch.ones((self.seq_len, self.traj_per_sample))
rgbs = torch.from_numpy(np.stack(rgbs)).permute(0, 3, 1, 2).float()
sample = CoTrackerData(
video=rgbs,
trajectory=trajs,
visibility=visibles,
valid=valids,
seq_name=seq_name,
)
return sample, gotit
def __len__(self):
return len(self.seq_names)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import torch
import cv2
import imageio
import numpy as np
from cotracker.datasets.utils import CoTrackerData
from torchvision.transforms import ColorJitter, GaussianBlur
from PIL import Image
class CoTrackerDataset(torch.utils.data.Dataset):
def __init__(
self,
data_root,
crop_size=(384, 512),
seq_len=24,
traj_per_sample=768,
sample_vis_1st_frame=False,
use_augs=False,
):
super(CoTrackerDataset, self).__init__()
np.random.seed(0)
torch.manual_seed(0)
self.data_root = data_root
self.seq_len = seq_len
self.traj_per_sample = traj_per_sample
self.sample_vis_1st_frame = sample_vis_1st_frame
self.use_augs = use_augs
self.crop_size = crop_size
# photometric augmentation
self.photo_aug = ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.25 / 3.14)
self.blur_aug = GaussianBlur(11, sigma=(0.1, 2.0))
self.blur_aug_prob = 0.25
self.color_aug_prob = 0.25
# occlusion augmentation
self.eraser_aug_prob = 0.5
self.eraser_bounds = [2, 100]
self.eraser_max = 10
# occlusion augmentation
self.replace_aug_prob = 0.5
self.replace_bounds = [2, 100]
self.replace_max = 10
# spatial augmentations
self.pad_bounds = [0, 100]
self.crop_size = crop_size
self.resize_lim = [0.25, 2.0] # sample resizes from here
self.resize_delta = 0.2
self.max_crop_offset = 50
self.do_flip = True
self.h_flip_prob = 0.5
self.v_flip_prob = 0.5
def getitem_helper(self, index):
return NotImplementedError
def __getitem__(self, index):
gotit = False
sample, gotit = self.getitem_helper(index)
if not gotit:
print("warning: sampling failed")
# fake sample, so we can still collate
sample = CoTrackerData(
video=torch.zeros((self.seq_len, 3, self.crop_size[0], self.crop_size[1])),
trajectory=torch.zeros((self.seq_len, self.traj_per_sample, 2)),
visibility=torch.zeros((self.seq_len, self.traj_per_sample)),
valid=torch.zeros((self.seq_len, self.traj_per_sample)),
)
return sample, gotit
def add_photometric_augs(self, rgbs, trajs, visibles, eraser=True, replace=True):
T, N, _ = trajs.shape
S = len(rgbs)
H, W = rgbs[0].shape[:2]
assert S == T
if eraser:
############ eraser transform (per image after the first) ############
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
for i in range(1, S):
if np.random.rand() < self.eraser_aug_prob:
for _ in range(
np.random.randint(1, self.eraser_max + 1)
): # number of times to occlude
xc = np.random.randint(0, W)
yc = np.random.randint(0, H)
dx = np.random.randint(self.eraser_bounds[0], self.eraser_bounds[1])
dy = np.random.randint(self.eraser_bounds[0], self.eraser_bounds[1])
x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)
x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)
y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)
y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32)
mean_color = np.mean(rgbs[i][y0:y1, x0:x1, :].reshape(-1, 3), axis=0)
rgbs[i][y0:y1, x0:x1, :] = mean_color
occ_inds = np.logical_and(
np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1),
np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1),
)
visibles[i, occ_inds] = 0
rgbs = [rgb.astype(np.uint8) for rgb in rgbs]
if replace:
rgbs_alt = [
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs
]
rgbs_alt = [
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs_alt
]
############ replace transform (per image after the first) ############
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
rgbs_alt = [rgb.astype(np.float32) for rgb in rgbs_alt]
for i in range(1, S):
if np.random.rand() < self.replace_aug_prob:
for _ in range(
np.random.randint(1, self.replace_max + 1)
): # number of times to occlude
xc = np.random.randint(0, W)
yc = np.random.randint(0, H)
dx = np.random.randint(self.replace_bounds[0], self.replace_bounds[1])
dy = np.random.randint(self.replace_bounds[0], self.replace_bounds[1])
x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)
x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)
y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)
y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32)
wid = x1 - x0
hei = y1 - y0
y00 = np.random.randint(0, H - hei)
x00 = np.random.randint(0, W - wid)
fr = np.random.randint(0, S)
rep = rgbs_alt[fr][y00 : y00 + hei, x00 : x00 + wid, :]
rgbs[i][y0:y1, x0:x1, :] = rep
occ_inds = np.logical_and(
np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1),
np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1),
)
visibles[i, occ_inds] = 0
rgbs = [rgb.astype(np.uint8) for rgb in rgbs]
############ photometric augmentation ############
if np.random.rand() < self.color_aug_prob:
# random per-frame amount of aug
rgbs = [np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs]
if np.random.rand() < self.blur_aug_prob:
# random per-frame amount of blur
rgbs = [np.array(self.blur_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs]
return rgbs, trajs, visibles
def add_spatial_augs(self, rgbs, trajs, visibles):
T, N, __ = trajs.shape
S = len(rgbs)
H, W = rgbs[0].shape[:2]
assert S == T
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
############ spatial transform ############
# padding
pad_x0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
pad_x1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
pad_y0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
pad_y1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
rgbs = [np.pad(rgb, ((pad_y0, pad_y1), (pad_x0, pad_x1), (0, 0))) for rgb in rgbs]
trajs[:, :, 0] += pad_x0
trajs[:, :, 1] += pad_y0
H, W = rgbs[0].shape[:2]
# scaling + stretching
scale = np.random.uniform(self.resize_lim[0], self.resize_lim[1])
scale_x = scale
scale_y = scale
H_new = H
W_new = W
scale_delta_x = 0.0
scale_delta_y = 0.0
rgbs_scaled = []
for s in range(S):
if s == 1:
scale_delta_x = np.random.uniform(-self.resize_delta, self.resize_delta)
scale_delta_y = np.random.uniform(-self.resize_delta, self.resize_delta)
elif s > 1:
scale_delta_x = (
scale_delta_x * 0.8
+ np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2
)
scale_delta_y = (
scale_delta_y * 0.8
+ np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2
)
scale_x = scale_x + scale_delta_x
scale_y = scale_y + scale_delta_y
# bring h/w closer
scale_xy = (scale_x + scale_y) * 0.5
scale_x = scale_x * 0.5 + scale_xy * 0.5
scale_y = scale_y * 0.5 + scale_xy * 0.5
# don't get too crazy
scale_x = np.clip(scale_x, 0.2, 2.0)
scale_y = np.clip(scale_y, 0.2, 2.0)
H_new = int(H * scale_y)
W_new = int(W * scale_x)
# make it at least slightly bigger than the crop area,
# so that the random cropping can add diversity
H_new = np.clip(H_new, self.crop_size[0] + 10, None)
W_new = np.clip(W_new, self.crop_size[1] + 10, None)
# recompute scale in case we clipped
scale_x = (W_new - 1) / float(W - 1)
scale_y = (H_new - 1) / float(H - 1)
rgbs_scaled.append(cv2.resize(rgbs[s], (W_new, H_new), interpolation=cv2.INTER_LINEAR))
trajs[s, :, 0] *= scale_x
trajs[s, :, 1] *= scale_y
rgbs = rgbs_scaled
ok_inds = visibles[0, :] > 0
vis_trajs = trajs[:, ok_inds] # S,?,2
if vis_trajs.shape[1] > 0:
mid_x = np.mean(vis_trajs[0, :, 0])
mid_y = np.mean(vis_trajs[0, :, 1])
else:
mid_y = self.crop_size[0]
mid_x = self.crop_size[1]
x0 = int(mid_x - self.crop_size[1] // 2)
y0 = int(mid_y - self.crop_size[0] // 2)
offset_x = 0
offset_y = 0
for s in range(S):
# on each frame, shift a bit more
if s == 1:
offset_x = np.random.randint(-self.max_crop_offset, self.max_crop_offset)
offset_y = np.random.randint(-self.max_crop_offset, self.max_crop_offset)
elif s > 1:
offset_x = int(
offset_x * 0.8
+ np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1) * 0.2
)
offset_y = int(
offset_y * 0.8
+ np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1) * 0.2
)
x0 = x0 + offset_x
y0 = y0 + offset_y
H_new, W_new = rgbs[s].shape[:2]
if H_new == self.crop_size[0]:
y0 = 0
else:
y0 = min(max(0, y0), H_new - self.crop_size[0] - 1)
if W_new == self.crop_size[1]:
x0 = 0
else:
x0 = min(max(0, x0), W_new - self.crop_size[1] - 1)
rgbs[s] = rgbs[s][y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
trajs[s, :, 0] -= x0
trajs[s, :, 1] -= y0
H_new = self.crop_size[0]
W_new = self.crop_size[1]
# flip
h_flipped = False
v_flipped = False
if self.do_flip:
# h flip
if np.random.rand() < self.h_flip_prob:
h_flipped = True
rgbs = [rgb[:, ::-1] for rgb in rgbs]
# v flip
if np.random.rand() < self.v_flip_prob:
v_flipped = True
rgbs = [rgb[::-1] for rgb in rgbs]
if h_flipped:
trajs[:, :, 0] = W_new - trajs[:, :, 0]
if v_flipped:
trajs[:, :, 1] = H_new - trajs[:, :, 1]
return rgbs, trajs
def crop(self, rgbs, trajs):
T, N, _ = trajs.shape
S = len(rgbs)
H, W = rgbs[0].shape[:2]
assert S == T
############ spatial transform ############
H_new = H
W_new = W
# simple random crop
y0 = 0 if self.crop_size[0] >= H_new else np.random.randint(0, H_new - self.crop_size[0])
x0 = 0 if self.crop_size[1] >= W_new else np.random.randint(0, W_new - self.crop_size[1])
rgbs = [rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] for rgb in rgbs]
trajs[:, :, 0] -= x0
trajs[:, :, 1] -= y0
return rgbs, trajs
class KubricMovifDataset(CoTrackerDataset):
def __init__(
self,
data_root,
crop_size=(384, 512),
seq_len=24,
traj_per_sample=768,
sample_vis_1st_frame=False,
use_augs=False,
):
super(KubricMovifDataset, self).__init__(
data_root=data_root,
crop_size=crop_size,
seq_len=seq_len,
traj_per_sample=traj_per_sample,
sample_vis_1st_frame=sample_vis_1st_frame,
use_augs=use_augs,
)
self.pad_bounds = [0, 25]
self.resize_lim = [0.75, 1.25] # sample resizes from here
self.resize_delta = 0.05
self.max_crop_offset = 15
self.seq_names = [
fname
for fname in os.listdir(data_root)
if os.path.isdir(os.path.join(data_root, fname))
]
print("found %d unique videos in %s" % (len(self.seq_names), self.data_root))
def getitem_helper(self, index):
gotit = True
seq_name = self.seq_names[index]
npy_path = os.path.join(self.data_root, seq_name, seq_name + ".npy")
rgb_path = os.path.join(self.data_root, seq_name, "frames")
img_paths = sorted(os.listdir(rgb_path))
rgbs = []
for i, img_path in enumerate(img_paths):
rgbs.append(imageio.v2.imread(os.path.join(rgb_path, img_path)))
rgbs = np.stack(rgbs)
annot_dict = np.load(npy_path, allow_pickle=True).item()
traj_2d = annot_dict["coords"]
visibility = annot_dict["visibility"]
# random crop
assert self.seq_len <= len(rgbs)
if self.seq_len < len(rgbs):
start_ind = np.random.choice(len(rgbs) - self.seq_len, 1)[0]
rgbs = rgbs[start_ind : start_ind + self.seq_len]
traj_2d = traj_2d[:, start_ind : start_ind + self.seq_len]
visibility = visibility[:, start_ind : start_ind + self.seq_len]
traj_2d = np.transpose(traj_2d, (1, 0, 2))
visibility = np.transpose(np.logical_not(visibility), (1, 0))
if self.use_augs:
rgbs, traj_2d, visibility = self.add_photometric_augs(rgbs, traj_2d, visibility)
rgbs, traj_2d = self.add_spatial_augs(rgbs, traj_2d, visibility)
else:
rgbs, traj_2d = self.crop(rgbs, traj_2d)
visibility[traj_2d[:, :, 0] > self.crop_size[1] - 1] = False
visibility[traj_2d[:, :, 0] < 0] = False
visibility[traj_2d[:, :, 1] > self.crop_size[0] - 1] = False
visibility[traj_2d[:, :, 1] < 0] = False
visibility = torch.from_numpy(visibility)
traj_2d = torch.from_numpy(traj_2d)
visibile_pts_first_frame_inds = (visibility[0]).nonzero(as_tuple=False)[:, 0]
if self.sample_vis_1st_frame:
visibile_pts_inds = visibile_pts_first_frame_inds
else:
visibile_pts_mid_frame_inds = (visibility[self.seq_len // 2]).nonzero(as_tuple=False)[
:, 0
]
visibile_pts_inds = torch.cat(
(visibile_pts_first_frame_inds, visibile_pts_mid_frame_inds), dim=0
)
point_inds = torch.randperm(len(visibile_pts_inds))[: self.traj_per_sample]
if len(point_inds) < self.traj_per_sample:
gotit = False
visible_inds_sampled = visibile_pts_inds[point_inds]
trajs = traj_2d[:, visible_inds_sampled].float()
visibles = visibility[:, visible_inds_sampled]
valids = torch.ones((self.seq_len, self.traj_per_sample))
rgbs = torch.from_numpy(np.stack(rgbs)).permute(0, 3, 1, 2).float()
sample = CoTrackerData(
video=rgbs,
trajectory=trajs,
visibility=visibles,
valid=valids,
seq_name=seq_name,
)
return sample, gotit
def __len__(self):
return len(self.seq_names)

View File

@@ -1,209 +1,209 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import io
import glob
import torch
import pickle
import numpy as np
import mediapy as media
from PIL import Image
from typing import Mapping, Tuple, Union
from cotracker.datasets.utils import CoTrackerData
DatasetElement = Mapping[str, Mapping[str, Union[np.ndarray, str]]]
def resize_video(video: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
"""Resize a video to output_size."""
# If you have a GPU, consider replacing this with a GPU-enabled resize op,
# such as a jitted jax.image.resize. It will make things faster.
return media.resize_video(video, output_size)
def sample_queries_first(
target_occluded: np.ndarray,
target_points: np.ndarray,
frames: np.ndarray,
) -> Mapping[str, np.ndarray]:
"""Package a set of frames and tracks for use in TAPNet evaluations.
Given a set of frames and tracks with no query points, use the first
visible point in each track as the query.
Args:
target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
where True indicates occluded.
target_points: Position, of shape [n_tracks, n_frames, 2], where each point
is [x,y] scaled between 0 and 1.
frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
-1 and 1.
Returns:
A dict with the keys:
video: Video tensor of shape [1, n_frames, height, width, 3]
query_points: Query points of shape [1, n_queries, 3] where
each point is [t, y, x] scaled to the range [-1, 1]
target_points: Target points of shape [1, n_queries, n_frames, 2] where
each point is [x, y] scaled to the range [-1, 1]
"""
valid = np.sum(~target_occluded, axis=1) > 0
target_points = target_points[valid, :]
target_occluded = target_occluded[valid, :]
query_points = []
for i in range(target_points.shape[0]):
index = np.where(target_occluded[i] == 0)[0][0]
x, y = target_points[i, index, 0], target_points[i, index, 1]
query_points.append(np.array([index, y, x])) # [t, y, x]
query_points = np.stack(query_points, axis=0)
return {
"video": frames[np.newaxis, ...],
"query_points": query_points[np.newaxis, ...],
"target_points": target_points[np.newaxis, ...],
"occluded": target_occluded[np.newaxis, ...],
}
def sample_queries_strided(
target_occluded: np.ndarray,
target_points: np.ndarray,
frames: np.ndarray,
query_stride: int = 5,
) -> Mapping[str, np.ndarray]:
"""Package a set of frames and tracks for use in TAPNet evaluations.
Given a set of frames and tracks with no query points, sample queries
strided every query_stride frames, ignoring points that are not visible
at the selected frames.
Args:
target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
where True indicates occluded.
target_points: Position, of shape [n_tracks, n_frames, 2], where each point
is [x,y] scaled between 0 and 1.
frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
-1 and 1.
query_stride: When sampling query points, search for un-occluded points
every query_stride frames and convert each one into a query.
Returns:
A dict with the keys:
video: Video tensor of shape [1, n_frames, height, width, 3]. The video
has floats scaled to the range [-1, 1].
query_points: Query points of shape [1, n_queries, 3] where
each point is [t, y, x] scaled to the range [-1, 1].
target_points: Target points of shape [1, n_queries, n_frames, 2] where
each point is [x, y] scaled to the range [-1, 1].
trackgroup: Index of the original track that each query point was
sampled from. This is useful for visualization.
"""
tracks = []
occs = []
queries = []
trackgroups = []
total = 0
trackgroup = np.arange(target_occluded.shape[0])
for i in range(0, target_occluded.shape[1], query_stride):
mask = target_occluded[:, i] == 0
query = np.stack(
[
i * np.ones(target_occluded.shape[0:1]),
target_points[:, i, 1],
target_points[:, i, 0],
],
axis=-1,
)
queries.append(query[mask])
tracks.append(target_points[mask])
occs.append(target_occluded[mask])
trackgroups.append(trackgroup[mask])
total += np.array(np.sum(target_occluded[:, i] == 0))
return {
"video": frames[np.newaxis, ...],
"query_points": np.concatenate(queries, axis=0)[np.newaxis, ...],
"target_points": np.concatenate(tracks, axis=0)[np.newaxis, ...],
"occluded": np.concatenate(occs, axis=0)[np.newaxis, ...],
"trackgroup": np.concatenate(trackgroups, axis=0)[np.newaxis, ...],
}
class TapVidDataset(torch.utils.data.Dataset):
def __init__(
self,
data_root,
dataset_type="davis",
resize_to_256=True,
queried_first=True,
):
self.dataset_type = dataset_type
self.resize_to_256 = resize_to_256
self.queried_first = queried_first
if self.dataset_type == "kinetics":
all_paths = glob.glob(os.path.join(data_root, "*_of_0010.pkl"))
points_dataset = []
for pickle_path in all_paths:
with open(pickle_path, "rb") as f:
data = pickle.load(f)
points_dataset = points_dataset + data
self.points_dataset = points_dataset
else:
with open(data_root, "rb") as f:
self.points_dataset = pickle.load(f)
if self.dataset_type == "davis":
self.video_names = list(self.points_dataset.keys())
print("found %d unique videos in %s" % (len(self.points_dataset), data_root))
def __getitem__(self, index):
if self.dataset_type == "davis":
video_name = self.video_names[index]
else:
video_name = index
video = self.points_dataset[video_name]
frames = video["video"]
if isinstance(frames[0], bytes):
# TAP-Vid is stored and JPEG bytes rather than `np.ndarray`s.
def decode(frame):
byteio = io.BytesIO(frame)
img = Image.open(byteio)
return np.array(img)
frames = np.array([decode(frame) for frame in frames])
target_points = self.points_dataset[video_name]["points"]
if self.resize_to_256:
frames = resize_video(frames, [256, 256])
target_points *= np.array([255, 255]) # 1 should be mapped to 256-1
else:
target_points *= np.array([frames.shape[2] - 1, frames.shape[1] - 1])
target_occ = self.points_dataset[video_name]["occluded"]
if self.queried_first:
converted = sample_queries_first(target_occ, target_points, frames)
else:
converted = sample_queries_strided(target_occ, target_points, frames)
assert converted["target_points"].shape[1] == converted["query_points"].shape[1]
trajs = torch.from_numpy(converted["target_points"])[0].permute(1, 0, 2).float() # T, N, D
rgbs = torch.from_numpy(frames).permute(0, 3, 1, 2).float()
visibles = torch.logical_not(torch.from_numpy(converted["occluded"]))[0].permute(
1, 0
) # T, N
query_points = torch.from_numpy(converted["query_points"])[0] # T, N
return CoTrackerData(
rgbs,
trajs,
visibles,
seq_name=str(video_name),
query_points=query_points,
)
def __len__(self):
return len(self.points_dataset)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import io
import glob
import torch
import pickle
import numpy as np
import mediapy as media
from PIL import Image
from typing import Mapping, Tuple, Union
from cotracker.datasets.utils import CoTrackerData
DatasetElement = Mapping[str, Mapping[str, Union[np.ndarray, str]]]
def resize_video(video: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
"""Resize a video to output_size."""
# If you have a GPU, consider replacing this with a GPU-enabled resize op,
# such as a jitted jax.image.resize. It will make things faster.
return media.resize_video(video, output_size)
def sample_queries_first(
target_occluded: np.ndarray,
target_points: np.ndarray,
frames: np.ndarray,
) -> Mapping[str, np.ndarray]:
"""Package a set of frames and tracks for use in TAPNet evaluations.
Given a set of frames and tracks with no query points, use the first
visible point in each track as the query.
Args:
target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
where True indicates occluded.
target_points: Position, of shape [n_tracks, n_frames, 2], where each point
is [x,y] scaled between 0 and 1.
frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
-1 and 1.
Returns:
A dict with the keys:
video: Video tensor of shape [1, n_frames, height, width, 3]
query_points: Query points of shape [1, n_queries, 3] where
each point is [t, y, x] scaled to the range [-1, 1]
target_points: Target points of shape [1, n_queries, n_frames, 2] where
each point is [x, y] scaled to the range [-1, 1]
"""
valid = np.sum(~target_occluded, axis=1) > 0
target_points = target_points[valid, :]
target_occluded = target_occluded[valid, :]
query_points = []
for i in range(target_points.shape[0]):
index = np.where(target_occluded[i] == 0)[0][0]
x, y = target_points[i, index, 0], target_points[i, index, 1]
query_points.append(np.array([index, y, x])) # [t, y, x]
query_points = np.stack(query_points, axis=0)
return {
"video": frames[np.newaxis, ...],
"query_points": query_points[np.newaxis, ...],
"target_points": target_points[np.newaxis, ...],
"occluded": target_occluded[np.newaxis, ...],
}
def sample_queries_strided(
target_occluded: np.ndarray,
target_points: np.ndarray,
frames: np.ndarray,
query_stride: int = 5,
) -> Mapping[str, np.ndarray]:
"""Package a set of frames and tracks for use in TAPNet evaluations.
Given a set of frames and tracks with no query points, sample queries
strided every query_stride frames, ignoring points that are not visible
at the selected frames.
Args:
target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
where True indicates occluded.
target_points: Position, of shape [n_tracks, n_frames, 2], where each point
is [x,y] scaled between 0 and 1.
frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
-1 and 1.
query_stride: When sampling query points, search for un-occluded points
every query_stride frames and convert each one into a query.
Returns:
A dict with the keys:
video: Video tensor of shape [1, n_frames, height, width, 3]. The video
has floats scaled to the range [-1, 1].
query_points: Query points of shape [1, n_queries, 3] where
each point is [t, y, x] scaled to the range [-1, 1].
target_points: Target points of shape [1, n_queries, n_frames, 2] where
each point is [x, y] scaled to the range [-1, 1].
trackgroup: Index of the original track that each query point was
sampled from. This is useful for visualization.
"""
tracks = []
occs = []
queries = []
trackgroups = []
total = 0
trackgroup = np.arange(target_occluded.shape[0])
for i in range(0, target_occluded.shape[1], query_stride):
mask = target_occluded[:, i] == 0
query = np.stack(
[
i * np.ones(target_occluded.shape[0:1]),
target_points[:, i, 1],
target_points[:, i, 0],
],
axis=-1,
)
queries.append(query[mask])
tracks.append(target_points[mask])
occs.append(target_occluded[mask])
trackgroups.append(trackgroup[mask])
total += np.array(np.sum(target_occluded[:, i] == 0))
return {
"video": frames[np.newaxis, ...],
"query_points": np.concatenate(queries, axis=0)[np.newaxis, ...],
"target_points": np.concatenate(tracks, axis=0)[np.newaxis, ...],
"occluded": np.concatenate(occs, axis=0)[np.newaxis, ...],
"trackgroup": np.concatenate(trackgroups, axis=0)[np.newaxis, ...],
}
class TapVidDataset(torch.utils.data.Dataset):
def __init__(
self,
data_root,
dataset_type="davis",
resize_to_256=True,
queried_first=True,
):
self.dataset_type = dataset_type
self.resize_to_256 = resize_to_256
self.queried_first = queried_first
if self.dataset_type == "kinetics":
all_paths = glob.glob(os.path.join(data_root, "*_of_0010.pkl"))
points_dataset = []
for pickle_path in all_paths:
with open(pickle_path, "rb") as f:
data = pickle.load(f)
points_dataset = points_dataset + data
self.points_dataset = points_dataset
else:
with open(data_root, "rb") as f:
self.points_dataset = pickle.load(f)
if self.dataset_type == "davis":
self.video_names = list(self.points_dataset.keys())
print("found %d unique videos in %s" % (len(self.points_dataset), data_root))
def __getitem__(self, index):
if self.dataset_type == "davis":
video_name = self.video_names[index]
else:
video_name = index
video = self.points_dataset[video_name]
frames = video["video"]
if isinstance(frames[0], bytes):
# TAP-Vid is stored and JPEG bytes rather than `np.ndarray`s.
def decode(frame):
byteio = io.BytesIO(frame)
img = Image.open(byteio)
return np.array(img)
frames = np.array([decode(frame) for frame in frames])
target_points = self.points_dataset[video_name]["points"]
if self.resize_to_256:
frames = resize_video(frames, [256, 256])
target_points *= np.array([255, 255]) # 1 should be mapped to 256-1
else:
target_points *= np.array([frames.shape[2] - 1, frames.shape[1] - 1])
target_occ = self.points_dataset[video_name]["occluded"]
if self.queried_first:
converted = sample_queries_first(target_occ, target_points, frames)
else:
converted = sample_queries_strided(target_occ, target_points, frames)
assert converted["target_points"].shape[1] == converted["query_points"].shape[1]
trajs = torch.from_numpy(converted["target_points"])[0].permute(1, 0, 2).float() # T, N, D
rgbs = torch.from_numpy(frames).permute(0, 3, 1, 2).float()
visibles = torch.logical_not(torch.from_numpy(converted["occluded"]))[0].permute(
1, 0
) # T, N
query_points = torch.from_numpy(converted["query_points"])[0] # T, N
return CoTrackerData(
rgbs,
trajs,
visibles,
seq_name=str(video_name),
query_points=query_points,
)
def __len__(self):
return len(self.points_dataset)

View File

@@ -1,106 +1,106 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import dataclasses
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Any, Optional
@dataclass(eq=False)
class CoTrackerData:
"""
Dataclass for storing video tracks data.
"""
video: torch.Tensor # B, S, C, H, W
trajectory: torch.Tensor # B, S, N, 2
visibility: torch.Tensor # B, S, N
# optional data
valid: Optional[torch.Tensor] = None # B, S, N
segmentation: Optional[torch.Tensor] = None # B, S, 1, H, W
seq_name: Optional[str] = None
query_points: Optional[torch.Tensor] = None # TapVID evaluation format
def collate_fn(batch):
"""
Collate function for video tracks data.
"""
video = torch.stack([b.video for b in batch], dim=0)
trajectory = torch.stack([b.trajectory for b in batch], dim=0)
visibility = torch.stack([b.visibility for b in batch], dim=0)
query_points = segmentation = None
if batch[0].query_points is not None:
query_points = torch.stack([b.query_points for b in batch], dim=0)
if batch[0].segmentation is not None:
segmentation = torch.stack([b.segmentation for b in batch], dim=0)
seq_name = [b.seq_name for b in batch]
return CoTrackerData(
video=video,
trajectory=trajectory,
visibility=visibility,
segmentation=segmentation,
seq_name=seq_name,
query_points=query_points,
)
def collate_fn_train(batch):
"""
Collate function for video tracks data during training.
"""
gotit = [gotit for _, gotit in batch]
video = torch.stack([b.video for b, _ in batch], dim=0)
trajectory = torch.stack([b.trajectory for b, _ in batch], dim=0)
visibility = torch.stack([b.visibility for b, _ in batch], dim=0)
valid = torch.stack([b.valid for b, _ in batch], dim=0)
seq_name = [b.seq_name for b, _ in batch]
return (
CoTrackerData(
video=video,
trajectory=trajectory,
visibility=visibility,
valid=valid,
seq_name=seq_name,
),
gotit,
)
def try_to_cuda(t: Any) -> Any:
"""
Try to move the input variable `t` to a cuda device.
Args:
t: Input.
Returns:
t_cuda: `t` moved to a cuda device, if supported.
"""
try:
t = t.float().cuda()
except AttributeError:
pass
return t
def dataclass_to_cuda_(obj):
"""
Move all contents of a dataclass to cuda inplace if supported.
Args:
batch: Input dataclass.
Returns:
batch_cuda: `batch` moved to a cuda device, if supported.
"""
for f in dataclasses.fields(obj):
setattr(obj, f.name, try_to_cuda(getattr(obj, f.name)))
return obj
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import dataclasses
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Any, Optional
@dataclass(eq=False)
class CoTrackerData:
"""
Dataclass for storing video tracks data.
"""
video: torch.Tensor # B, S, C, H, W
trajectory: torch.Tensor # B, S, N, 2
visibility: torch.Tensor # B, S, N
# optional data
valid: Optional[torch.Tensor] = None # B, S, N
segmentation: Optional[torch.Tensor] = None # B, S, 1, H, W
seq_name: Optional[str] = None
query_points: Optional[torch.Tensor] = None # TapVID evaluation format
def collate_fn(batch):
"""
Collate function for video tracks data.
"""
video = torch.stack([b.video for b in batch], dim=0)
trajectory = torch.stack([b.trajectory for b in batch], dim=0)
visibility = torch.stack([b.visibility for b in batch], dim=0)
query_points = segmentation = None
if batch[0].query_points is not None:
query_points = torch.stack([b.query_points for b in batch], dim=0)
if batch[0].segmentation is not None:
segmentation = torch.stack([b.segmentation for b in batch], dim=0)
seq_name = [b.seq_name for b in batch]
return CoTrackerData(
video=video,
trajectory=trajectory,
visibility=visibility,
segmentation=segmentation,
seq_name=seq_name,
query_points=query_points,
)
def collate_fn_train(batch):
"""
Collate function for video tracks data during training.
"""
gotit = [gotit for _, gotit in batch]
video = torch.stack([b.video for b, _ in batch], dim=0)
trajectory = torch.stack([b.trajectory for b, _ in batch], dim=0)
visibility = torch.stack([b.visibility for b, _ in batch], dim=0)
valid = torch.stack([b.valid for b, _ in batch], dim=0)
seq_name = [b.seq_name for b, _ in batch]
return (
CoTrackerData(
video=video,
trajectory=trajectory,
visibility=visibility,
valid=valid,
seq_name=seq_name,
),
gotit,
)
def try_to_cuda(t: Any) -> Any:
"""
Try to move the input variable `t` to a cuda device.
Args:
t: Input.
Returns:
t_cuda: `t` moved to a cuda device, if supported.
"""
try:
t = t.float().cuda()
except AttributeError:
pass
return t
def dataclass_to_cuda_(obj):
"""
Move all contents of a dataclass to cuda inplace if supported.
Args:
batch: Input dataclass.
Returns:
batch_cuda: `batch` moved to a cuda device, if supported.
"""
for f in dataclasses.fields(obj):
setattr(obj, f.name, try_to_cuda(getattr(obj, f.name)))
return obj

View File

@@ -1,5 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

View File

@@ -1,6 +1,6 @@
defaults:
- default_config_eval
exp_dir: ./outputs/cotracker
dataset_name: dynamic_replica
defaults:
- default_config_eval
exp_dir: ./outputs/cotracker
dataset_name: dynamic_replica

View File

@@ -1,6 +1,6 @@
defaults:
- default_config_eval
exp_dir: ./outputs/cotracker
dataset_name: tapvid_davis_first
defaults:
- default_config_eval
exp_dir: ./outputs/cotracker
dataset_name: tapvid_davis_first

View File

@@ -1,6 +1,6 @@
defaults:
- default_config_eval
exp_dir: ./outputs/cotracker
dataset_name: tapvid_davis_strided
defaults:
- default_config_eval
exp_dir: ./outputs/cotracker
dataset_name: tapvid_davis_strided

View File

@@ -1,6 +1,6 @@
defaults:
- default_config_eval
exp_dir: ./outputs/cotracker
dataset_name: tapvid_kinetics_first
defaults:
- default_config_eval
exp_dir: ./outputs/cotracker
dataset_name: tapvid_kinetics_first

View File

@@ -1,5 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

View File

@@ -1,138 +1,138 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
from typing import Iterable, Mapping, Tuple, Union
def compute_tapvid_metrics(
query_points: np.ndarray,
gt_occluded: np.ndarray,
gt_tracks: np.ndarray,
pred_occluded: np.ndarray,
pred_tracks: np.ndarray,
query_mode: str,
) -> Mapping[str, np.ndarray]:
"""Computes TAP-Vid metrics (Jaccard, Pts. Within Thresh, Occ. Acc.)
See the TAP-Vid paper for details on the metric computation. All inputs are
given in raster coordinates. The first three arguments should be the direct
outputs of the reader: the 'query_points', 'occluded', and 'target_points'.
The paper metrics assume these are scaled relative to 256x256 images.
pred_occluded and pred_tracks are your algorithm's predictions.
This function takes a batch of inputs, and computes metrics separately for
each video. The metrics for the full benchmark are a simple mean of the
metrics across the full set of videos. These numbers are between 0 and 1,
but the paper multiplies them by 100 to ease reading.
Args:
query_points: The query points, an in the format [t, y, x]. Its size is
[b, n, 3], where b is the batch size and n is the number of queries
gt_occluded: A boolean array of shape [b, n, t], where t is the number
of frames. True indicates that the point is occluded.
gt_tracks: The target points, of shape [b, n, t, 2]. Each point is
in the format [x, y]
pred_occluded: A boolean array of predicted occlusions, in the same
format as gt_occluded.
pred_tracks: An array of track predictions from your algorithm, in the
same format as gt_tracks.
query_mode: Either 'first' or 'strided', depending on how queries are
sampled. If 'first', we assume the prior knowledge that all points
before the query point are occluded, and these are removed from the
evaluation.
Returns:
A dict with the following keys:
occlusion_accuracy: Accuracy at predicting occlusion.
pts_within_{x} for x in [1, 2, 4, 8, 16]: Fraction of points
predicted to be within the given pixel threshold, ignoring occlusion
prediction.
jaccard_{x} for x in [1, 2, 4, 8, 16]: Jaccard metric for the given
threshold
average_pts_within_thresh: average across pts_within_{x}
average_jaccard: average across jaccard_{x}
"""
metrics = {}
# Fixed bug is described in:
# https://github.com/facebookresearch/co-tracker/issues/20
eye = np.eye(gt_tracks.shape[2], dtype=np.int32)
if query_mode == "first":
# evaluate frames after the query frame
query_frame_to_eval_frames = np.cumsum(eye, axis=1) - eye
elif query_mode == "strided":
# evaluate all frames except the query frame
query_frame_to_eval_frames = 1 - eye
else:
raise ValueError("Unknown query mode " + query_mode)
query_frame = query_points[..., 0]
query_frame = np.round(query_frame).astype(np.int32)
evaluation_points = query_frame_to_eval_frames[query_frame] > 0
# Occlusion accuracy is simply how often the predicted occlusion equals the
# ground truth.
occ_acc = np.sum(
np.equal(pred_occluded, gt_occluded) & evaluation_points,
axis=(1, 2),
) / np.sum(evaluation_points)
metrics["occlusion_accuracy"] = occ_acc
# Next, convert the predictions and ground truth positions into pixel
# coordinates.
visible = np.logical_not(gt_occluded)
pred_visible = np.logical_not(pred_occluded)
all_frac_within = []
all_jaccard = []
for thresh in [1, 2, 4, 8, 16]:
# True positives are points that are within the threshold and where both
# the prediction and the ground truth are listed as visible.
within_dist = np.sum(
np.square(pred_tracks - gt_tracks),
axis=-1,
) < np.square(thresh)
is_correct = np.logical_and(within_dist, visible)
# Compute the frac_within_threshold, which is the fraction of points
# within the threshold among points that are visible in the ground truth,
# ignoring whether they're predicted to be visible.
count_correct = np.sum(
is_correct & evaluation_points,
axis=(1, 2),
)
count_visible_points = np.sum(visible & evaluation_points, axis=(1, 2))
frac_correct = count_correct / count_visible_points
metrics["pts_within_" + str(thresh)] = frac_correct
all_frac_within.append(frac_correct)
true_positives = np.sum(
is_correct & pred_visible & evaluation_points, axis=(1, 2)
)
# The denominator of the jaccard metric is the true positives plus
# false positives plus false negatives. However, note that true positives
# plus false negatives is simply the number of points in the ground truth
# which is easier to compute than trying to compute all three quantities.
# Thus we just add the number of points in the ground truth to the number
# of false positives.
#
# False positives are simply points that are predicted to be visible,
# but the ground truth is not visible or too far from the prediction.
gt_positives = np.sum(visible & evaluation_points, axis=(1, 2))
false_positives = (~visible) & pred_visible
false_positives = false_positives | ((~within_dist) & pred_visible)
false_positives = np.sum(false_positives & evaluation_points, axis=(1, 2))
jaccard = true_positives / (gt_positives + false_positives)
metrics["jaccard_" + str(thresh)] = jaccard
all_jaccard.append(jaccard)
metrics["average_jaccard"] = np.mean(
np.stack(all_jaccard, axis=1),
axis=1,
)
metrics["average_pts_within_thresh"] = np.mean(
np.stack(all_frac_within, axis=1),
axis=1,
)
return metrics
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
from typing import Iterable, Mapping, Tuple, Union
def compute_tapvid_metrics(
query_points: np.ndarray,
gt_occluded: np.ndarray,
gt_tracks: np.ndarray,
pred_occluded: np.ndarray,
pred_tracks: np.ndarray,
query_mode: str,
) -> Mapping[str, np.ndarray]:
"""Computes TAP-Vid metrics (Jaccard, Pts. Within Thresh, Occ. Acc.)
See the TAP-Vid paper for details on the metric computation. All inputs are
given in raster coordinates. The first three arguments should be the direct
outputs of the reader: the 'query_points', 'occluded', and 'target_points'.
The paper metrics assume these are scaled relative to 256x256 images.
pred_occluded and pred_tracks are your algorithm's predictions.
This function takes a batch of inputs, and computes metrics separately for
each video. The metrics for the full benchmark are a simple mean of the
metrics across the full set of videos. These numbers are between 0 and 1,
but the paper multiplies them by 100 to ease reading.
Args:
query_points: The query points, an in the format [t, y, x]. Its size is
[b, n, 3], where b is the batch size and n is the number of queries
gt_occluded: A boolean array of shape [b, n, t], where t is the number
of frames. True indicates that the point is occluded.
gt_tracks: The target points, of shape [b, n, t, 2]. Each point is
in the format [x, y]
pred_occluded: A boolean array of predicted occlusions, in the same
format as gt_occluded.
pred_tracks: An array of track predictions from your algorithm, in the
same format as gt_tracks.
query_mode: Either 'first' or 'strided', depending on how queries are
sampled. If 'first', we assume the prior knowledge that all points
before the query point are occluded, and these are removed from the
evaluation.
Returns:
A dict with the following keys:
occlusion_accuracy: Accuracy at predicting occlusion.
pts_within_{x} for x in [1, 2, 4, 8, 16]: Fraction of points
predicted to be within the given pixel threshold, ignoring occlusion
prediction.
jaccard_{x} for x in [1, 2, 4, 8, 16]: Jaccard metric for the given
threshold
average_pts_within_thresh: average across pts_within_{x}
average_jaccard: average across jaccard_{x}
"""
metrics = {}
# Fixed bug is described in:
# https://github.com/facebookresearch/co-tracker/issues/20
eye = np.eye(gt_tracks.shape[2], dtype=np.int32)
if query_mode == "first":
# evaluate frames after the query frame
query_frame_to_eval_frames = np.cumsum(eye, axis=1) - eye
elif query_mode == "strided":
# evaluate all frames except the query frame
query_frame_to_eval_frames = 1 - eye
else:
raise ValueError("Unknown query mode " + query_mode)
query_frame = query_points[..., 0]
query_frame = np.round(query_frame).astype(np.int32)
evaluation_points = query_frame_to_eval_frames[query_frame] > 0
# Occlusion accuracy is simply how often the predicted occlusion equals the
# ground truth.
occ_acc = np.sum(
np.equal(pred_occluded, gt_occluded) & evaluation_points,
axis=(1, 2),
) / np.sum(evaluation_points)
metrics["occlusion_accuracy"] = occ_acc
# Next, convert the predictions and ground truth positions into pixel
# coordinates.
visible = np.logical_not(gt_occluded)
pred_visible = np.logical_not(pred_occluded)
all_frac_within = []
all_jaccard = []
for thresh in [1, 2, 4, 8, 16]:
# True positives are points that are within the threshold and where both
# the prediction and the ground truth are listed as visible.
within_dist = np.sum(
np.square(pred_tracks - gt_tracks),
axis=-1,
) < np.square(thresh)
is_correct = np.logical_and(within_dist, visible)
# Compute the frac_within_threshold, which is the fraction of points
# within the threshold among points that are visible in the ground truth,
# ignoring whether they're predicted to be visible.
count_correct = np.sum(
is_correct & evaluation_points,
axis=(1, 2),
)
count_visible_points = np.sum(visible & evaluation_points, axis=(1, 2))
frac_correct = count_correct / count_visible_points
metrics["pts_within_" + str(thresh)] = frac_correct
all_frac_within.append(frac_correct)
true_positives = np.sum(
is_correct & pred_visible & evaluation_points, axis=(1, 2)
)
# The denominator of the jaccard metric is the true positives plus
# false positives plus false negatives. However, note that true positives
# plus false negatives is simply the number of points in the ground truth
# which is easier to compute than trying to compute all three quantities.
# Thus we just add the number of points in the ground truth to the number
# of false positives.
#
# False positives are simply points that are predicted to be visible,
# but the ground truth is not visible or too far from the prediction.
gt_positives = np.sum(visible & evaluation_points, axis=(1, 2))
false_positives = (~visible) & pred_visible
false_positives = false_positives | ((~within_dist) & pred_visible)
false_positives = np.sum(false_positives & evaluation_points, axis=(1, 2))
jaccard = true_positives / (gt_positives + false_positives)
metrics["jaccard_" + str(thresh)] = jaccard
all_jaccard.append(jaccard)
metrics["average_jaccard"] = np.mean(
np.stack(all_jaccard, axis=1),
axis=1,
)
metrics["average_pts_within_thresh"] = np.mean(
np.stack(all_frac_within, axis=1),
axis=1,
)
return metrics

View File

@@ -1,253 +1,253 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from collections import defaultdict
import os
from typing import Optional
import torch
from tqdm import tqdm
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from cotracker.datasets.utils import dataclass_to_cuda_
from cotracker.utils.visualizer import Visualizer
from cotracker.models.core.model_utils import reduce_masked_mean
from cotracker.evaluation.core.eval_utils import compute_tapvid_metrics
import logging
class Evaluator:
"""
A class defining the CoTracker evaluator.
"""
def __init__(self, exp_dir) -> None:
# Visualization
self.exp_dir = exp_dir
os.makedirs(exp_dir, exist_ok=True)
self.visualization_filepaths = defaultdict(lambda: defaultdict(list))
self.visualize_dir = os.path.join(exp_dir, "visualisations")
def compute_metrics(self, metrics, sample, pred_trajectory, dataset_name):
if isinstance(pred_trajectory, tuple):
pred_trajectory, pred_visibility = pred_trajectory
else:
pred_visibility = None
if "tapvid" in dataset_name:
B, T, N, D = sample.trajectory.shape
traj = sample.trajectory.clone()
thr = 0.9
if pred_visibility is None:
logging.warning("visibility is NONE")
pred_visibility = torch.zeros_like(sample.visibility)
if not pred_visibility.dtype == torch.bool:
pred_visibility = pred_visibility > thr
query_points = sample.query_points.clone().cpu().numpy()
pred_visibility = pred_visibility[:, :, :N]
pred_trajectory = pred_trajectory[:, :, :N]
gt_tracks = traj.permute(0, 2, 1, 3).cpu().numpy()
gt_occluded = (
torch.logical_not(sample.visibility.clone().permute(0, 2, 1)).cpu().numpy()
)
pred_occluded = (
torch.logical_not(pred_visibility.clone().permute(0, 2, 1)).cpu().numpy()
)
pred_tracks = pred_trajectory.permute(0, 2, 1, 3).cpu().numpy()
out_metrics = compute_tapvid_metrics(
query_points,
gt_occluded,
gt_tracks,
pred_occluded,
pred_tracks,
query_mode="strided" if "strided" in dataset_name else "first",
)
metrics[sample.seq_name[0]] = out_metrics
for metric_name in out_metrics.keys():
if "avg" not in metrics:
metrics["avg"] = {}
metrics["avg"][metric_name] = np.mean(
[v[metric_name] for k, v in metrics.items() if k != "avg"]
)
logging.info(f"Metrics: {out_metrics}")
logging.info(f"avg: {metrics['avg']}")
print("metrics", out_metrics)
print("avg", metrics["avg"])
elif dataset_name == "dynamic_replica" or dataset_name == "pointodyssey":
*_, N, _ = sample.trajectory.shape
B, T, N = sample.visibility.shape
H, W = sample.video.shape[-2:]
device = sample.video.device
out_metrics = {}
d_vis_sum = d_occ_sum = d_sum_all = 0.0
thrs = [1, 2, 4, 8, 16]
sx_ = (W - 1) / 255.0
sy_ = (H - 1) / 255.0
sc_py = np.array([sx_, sy_]).reshape([1, 1, 2])
sc_pt = torch.from_numpy(sc_py).float().to(device)
__, first_visible_inds = torch.max(sample.visibility, dim=1)
frame_ids_tensor = torch.arange(T, device=device)[None, :, None].repeat(B, 1, N)
start_tracking_mask = frame_ids_tensor > (first_visible_inds.unsqueeze(1))
for thr in thrs:
d_ = (
torch.norm(
pred_trajectory[..., :2] / sc_pt - sample.trajectory[..., :2] / sc_pt,
dim=-1,
)
< thr
).float() # B,S-1,N
d_occ = (
reduce_masked_mean(d_, (1 - sample.visibility) * start_tracking_mask).item()
* 100.0
)
d_occ_sum += d_occ
out_metrics[f"accuracy_occ_{thr}"] = d_occ
d_vis = (
reduce_masked_mean(d_, sample.visibility * start_tracking_mask).item() * 100.0
)
d_vis_sum += d_vis
out_metrics[f"accuracy_vis_{thr}"] = d_vis
d_all = reduce_masked_mean(d_, start_tracking_mask).item() * 100.0
d_sum_all += d_all
out_metrics[f"accuracy_{thr}"] = d_all
d_occ_avg = d_occ_sum / len(thrs)
d_vis_avg = d_vis_sum / len(thrs)
d_all_avg = d_sum_all / len(thrs)
sur_thr = 50
dists = torch.norm(
pred_trajectory[..., :2] / sc_pt - sample.trajectory[..., :2] / sc_pt,
dim=-1,
) # B,S,N
dist_ok = 1 - (dists > sur_thr).float() * sample.visibility # B,S,N
survival = torch.cumprod(dist_ok, dim=1) # B,S,N
out_metrics["survival"] = torch.mean(survival).item() * 100.0
out_metrics["accuracy_occ"] = d_occ_avg
out_metrics["accuracy_vis"] = d_vis_avg
out_metrics["accuracy"] = d_all_avg
metrics[sample.seq_name[0]] = out_metrics
for metric_name in out_metrics.keys():
if "avg" not in metrics:
metrics["avg"] = {}
metrics["avg"][metric_name] = float(
np.mean([v[metric_name] for k, v in metrics.items() if k != "avg"])
)
logging.info(f"Metrics: {out_metrics}")
logging.info(f"avg: {metrics['avg']}")
print("metrics", out_metrics)
print("avg", metrics["avg"])
@torch.no_grad()
def evaluate_sequence(
self,
model,
test_dataloader: torch.utils.data.DataLoader,
dataset_name: str,
train_mode=False,
visualize_every: int = 1,
writer: Optional[SummaryWriter] = None,
step: Optional[int] = 0,
):
metrics = {}
vis = Visualizer(
save_dir=self.exp_dir,
fps=7,
)
for ind, sample in enumerate(tqdm(test_dataloader)):
if isinstance(sample, tuple):
sample, gotit = sample
if not all(gotit):
print("batch is None")
continue
if torch.cuda.is_available():
dataclass_to_cuda_(sample)
device = torch.device("cuda")
else:
device = torch.device("cpu")
if (
not train_mode
and hasattr(model, "sequence_len")
and (sample.visibility[:, : model.sequence_len].sum() == 0)
):
print(f"skipping batch {ind}")
continue
if "tapvid" in dataset_name:
queries = sample.query_points.clone().float()
queries = torch.stack(
[
queries[:, :, 0],
queries[:, :, 2],
queries[:, :, 1],
],
dim=2,
).to(device)
else:
queries = torch.cat(
[
torch.zeros_like(sample.trajectory[:, 0, :, :1]),
sample.trajectory[:, 0],
],
dim=2,
).to(device)
pred_tracks = model(sample.video, queries)
if "strided" in dataset_name:
inv_video = sample.video.flip(1).clone()
inv_queries = queries.clone()
inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1
pred_trj, pred_vsb = pred_tracks
inv_pred_trj, inv_pred_vsb = model(inv_video, inv_queries)
inv_pred_trj = inv_pred_trj.flip(1)
inv_pred_vsb = inv_pred_vsb.flip(1)
mask = pred_trj == 0
pred_trj[mask] = inv_pred_trj[mask]
pred_vsb[mask[:, :, :, 0]] = inv_pred_vsb[mask[:, :, :, 0]]
pred_tracks = pred_trj, pred_vsb
if dataset_name == "badja" or dataset_name == "fastcapture":
seq_name = sample.seq_name[0]
else:
seq_name = str(ind)
if ind % visualize_every == 0:
vis.visualize(
sample.video,
pred_tracks[0] if isinstance(pred_tracks, tuple) else pred_tracks,
filename=dataset_name + "_" + seq_name,
writer=writer,
step=step,
)
self.compute_metrics(metrics, sample, pred_tracks, dataset_name)
return metrics
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from collections import defaultdict
import os
from typing import Optional
import torch
from tqdm import tqdm
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from cotracker.datasets.utils import dataclass_to_cuda_
from cotracker.utils.visualizer import Visualizer
from cotracker.models.core.model_utils import reduce_masked_mean
from cotracker.evaluation.core.eval_utils import compute_tapvid_metrics
import logging
class Evaluator:
"""
A class defining the CoTracker evaluator.
"""
def __init__(self, exp_dir) -> None:
# Visualization
self.exp_dir = exp_dir
os.makedirs(exp_dir, exist_ok=True)
self.visualization_filepaths = defaultdict(lambda: defaultdict(list))
self.visualize_dir = os.path.join(exp_dir, "visualisations")
def compute_metrics(self, metrics, sample, pred_trajectory, dataset_name):
if isinstance(pred_trajectory, tuple):
pred_trajectory, pred_visibility = pred_trajectory
else:
pred_visibility = None
if "tapvid" in dataset_name:
B, T, N, D = sample.trajectory.shape
traj = sample.trajectory.clone()
thr = 0.9
if pred_visibility is None:
logging.warning("visibility is NONE")
pred_visibility = torch.zeros_like(sample.visibility)
if not pred_visibility.dtype == torch.bool:
pred_visibility = pred_visibility > thr
query_points = sample.query_points.clone().cpu().numpy()
pred_visibility = pred_visibility[:, :, :N]
pred_trajectory = pred_trajectory[:, :, :N]
gt_tracks = traj.permute(0, 2, 1, 3).cpu().numpy()
gt_occluded = (
torch.logical_not(sample.visibility.clone().permute(0, 2, 1)).cpu().numpy()
)
pred_occluded = (
torch.logical_not(pred_visibility.clone().permute(0, 2, 1)).cpu().numpy()
)
pred_tracks = pred_trajectory.permute(0, 2, 1, 3).cpu().numpy()
out_metrics = compute_tapvid_metrics(
query_points,
gt_occluded,
gt_tracks,
pred_occluded,
pred_tracks,
query_mode="strided" if "strided" in dataset_name else "first",
)
metrics[sample.seq_name[0]] = out_metrics
for metric_name in out_metrics.keys():
if "avg" not in metrics:
metrics["avg"] = {}
metrics["avg"][metric_name] = np.mean(
[v[metric_name] for k, v in metrics.items() if k != "avg"]
)
logging.info(f"Metrics: {out_metrics}")
logging.info(f"avg: {metrics['avg']}")
print("metrics", out_metrics)
print("avg", metrics["avg"])
elif dataset_name == "dynamic_replica" or dataset_name == "pointodyssey":
*_, N, _ = sample.trajectory.shape
B, T, N = sample.visibility.shape
H, W = sample.video.shape[-2:]
device = sample.video.device
out_metrics = {}
d_vis_sum = d_occ_sum = d_sum_all = 0.0
thrs = [1, 2, 4, 8, 16]
sx_ = (W - 1) / 255.0
sy_ = (H - 1) / 255.0
sc_py = np.array([sx_, sy_]).reshape([1, 1, 2])
sc_pt = torch.from_numpy(sc_py).float().to(device)
__, first_visible_inds = torch.max(sample.visibility, dim=1)
frame_ids_tensor = torch.arange(T, device=device)[None, :, None].repeat(B, 1, N)
start_tracking_mask = frame_ids_tensor > (first_visible_inds.unsqueeze(1))
for thr in thrs:
d_ = (
torch.norm(
pred_trajectory[..., :2] / sc_pt - sample.trajectory[..., :2] / sc_pt,
dim=-1,
)
< thr
).float() # B,S-1,N
d_occ = (
reduce_masked_mean(d_, (1 - sample.visibility) * start_tracking_mask).item()
* 100.0
)
d_occ_sum += d_occ
out_metrics[f"accuracy_occ_{thr}"] = d_occ
d_vis = (
reduce_masked_mean(d_, sample.visibility * start_tracking_mask).item() * 100.0
)
d_vis_sum += d_vis
out_metrics[f"accuracy_vis_{thr}"] = d_vis
d_all = reduce_masked_mean(d_, start_tracking_mask).item() * 100.0
d_sum_all += d_all
out_metrics[f"accuracy_{thr}"] = d_all
d_occ_avg = d_occ_sum / len(thrs)
d_vis_avg = d_vis_sum / len(thrs)
d_all_avg = d_sum_all / len(thrs)
sur_thr = 50
dists = torch.norm(
pred_trajectory[..., :2] / sc_pt - sample.trajectory[..., :2] / sc_pt,
dim=-1,
) # B,S,N
dist_ok = 1 - (dists > sur_thr).float() * sample.visibility # B,S,N
survival = torch.cumprod(dist_ok, dim=1) # B,S,N
out_metrics["survival"] = torch.mean(survival).item() * 100.0
out_metrics["accuracy_occ"] = d_occ_avg
out_metrics["accuracy_vis"] = d_vis_avg
out_metrics["accuracy"] = d_all_avg
metrics[sample.seq_name[0]] = out_metrics
for metric_name in out_metrics.keys():
if "avg" not in metrics:
metrics["avg"] = {}
metrics["avg"][metric_name] = float(
np.mean([v[metric_name] for k, v in metrics.items() if k != "avg"])
)
logging.info(f"Metrics: {out_metrics}")
logging.info(f"avg: {metrics['avg']}")
print("metrics", out_metrics)
print("avg", metrics["avg"])
@torch.no_grad()
def evaluate_sequence(
self,
model,
test_dataloader: torch.utils.data.DataLoader,
dataset_name: str,
train_mode=False,
visualize_every: int = 1,
writer: Optional[SummaryWriter] = None,
step: Optional[int] = 0,
):
metrics = {}
vis = Visualizer(
save_dir=self.exp_dir,
fps=7,
)
for ind, sample in enumerate(tqdm(test_dataloader)):
if isinstance(sample, tuple):
sample, gotit = sample
if not all(gotit):
print("batch is None")
continue
if torch.cuda.is_available():
dataclass_to_cuda_(sample)
device = torch.device("cuda")
else:
device = torch.device("cpu")
if (
not train_mode
and hasattr(model, "sequence_len")
and (sample.visibility[:, : model.sequence_len].sum() == 0)
):
print(f"skipping batch {ind}")
continue
if "tapvid" in dataset_name:
queries = sample.query_points.clone().float()
queries = torch.stack(
[
queries[:, :, 0],
queries[:, :, 2],
queries[:, :, 1],
],
dim=2,
).to(device)
else:
queries = torch.cat(
[
torch.zeros_like(sample.trajectory[:, 0, :, :1]),
sample.trajectory[:, 0],
],
dim=2,
).to(device)
pred_tracks = model(sample.video, queries)
if "strided" in dataset_name:
inv_video = sample.video.flip(1).clone()
inv_queries = queries.clone()
inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1
pred_trj, pred_vsb = pred_tracks
inv_pred_trj, inv_pred_vsb = model(inv_video, inv_queries)
inv_pred_trj = inv_pred_trj.flip(1)
inv_pred_vsb = inv_pred_vsb.flip(1)
mask = pred_trj == 0
pred_trj[mask] = inv_pred_trj[mask]
pred_vsb[mask[:, :, :, 0]] = inv_pred_vsb[mask[:, :, :, 0]]
pred_tracks = pred_trj, pred_vsb
if dataset_name == "badja" or dataset_name == "fastcapture":
seq_name = sample.seq_name[0]
else:
seq_name = str(ind)
if ind % visualize_every == 0:
vis.visualize(
sample.video,
pred_tracks[0] if isinstance(pred_tracks, tuple) else pred_tracks,
filename=dataset_name + "_" + seq_name,
writer=writer,
step=step,
)
self.compute_metrics(metrics, sample, pred_tracks, dataset_name)
return metrics

View File

@@ -1,169 +1,169 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import json
import os
from dataclasses import dataclass, field
import hydra
import numpy as np
import torch
from omegaconf import OmegaConf
from cotracker.datasets.tap_vid_datasets import TapVidDataset
from cotracker.datasets.dr_dataset import DynamicReplicaDataset
from cotracker.datasets.utils import collate_fn
from cotracker.models.evaluation_predictor import EvaluationPredictor
from cotracker.evaluation.core.evaluator import Evaluator
from cotracker.models.build_cotracker import (
build_cotracker,
)
@dataclass(eq=False)
class DefaultConfig:
# Directory where all outputs of the experiment will be saved.
exp_dir: str = "./outputs"
# Name of the dataset to be used for the evaluation.
dataset_name: str = "tapvid_davis_first"
# The root directory of the dataset.
dataset_root: str = "./"
# Path to the pre-trained model checkpoint to be used for the evaluation.
# The default value is the path to a specific CoTracker model checkpoint.
checkpoint: str = "./checkpoints/cotracker2.pth"
# EvaluationPredictor parameters
# The size (N) of the support grid used in the predictor.
# The total number of points is (N*N).
grid_size: int = 5
# The size (N) of the local support grid.
local_grid_size: int = 8
# A flag indicating whether to evaluate one ground truth point at a time.
single_point: bool = True
# The number of iterative updates for each sliding window.
n_iters: int = 6
seed: int = 0
gpu_idx: int = 0
# Override hydra's working directory to current working dir,
# also disable storing the .hydra logs:
hydra: dict = field(
default_factory=lambda: {
"run": {"dir": "."},
"output_subdir": None,
}
)
def run_eval(cfg: DefaultConfig):
"""
The function evaluates CoTracker on a specified benchmark dataset based on a provided configuration.
Args:
cfg (DefaultConfig): An instance of DefaultConfig class which includes:
- exp_dir (str): The directory path for the experiment.
- dataset_name (str): The name of the dataset to be used.
- dataset_root (str): The root directory of the dataset.
- checkpoint (str): The path to the CoTracker model's checkpoint.
- single_point (bool): A flag indicating whether to evaluate one ground truth point at a time.
- n_iters (int): The number of iterative updates for each sliding window.
- seed (int): The seed for setting the random state for reproducibility.
- gpu_idx (int): The index of the GPU to be used.
"""
# Creating the experiment directory if it doesn't exist
os.makedirs(cfg.exp_dir, exist_ok=True)
# Saving the experiment configuration to a .yaml file in the experiment directory
cfg_file = os.path.join(cfg.exp_dir, "expconfig.yaml")
with open(cfg_file, "w") as f:
OmegaConf.save(config=cfg, f=f)
evaluator = Evaluator(cfg.exp_dir)
cotracker_model = build_cotracker(cfg.checkpoint)
# Creating the EvaluationPredictor object
predictor = EvaluationPredictor(
cotracker_model,
grid_size=cfg.grid_size,
local_grid_size=cfg.local_grid_size,
single_point=cfg.single_point,
n_iters=cfg.n_iters,
)
if torch.cuda.is_available():
predictor.model = predictor.model.cuda()
# Setting the random seeds
torch.manual_seed(cfg.seed)
np.random.seed(cfg.seed)
# Constructing the specified dataset
curr_collate_fn = collate_fn
if "tapvid" in cfg.dataset_name:
dataset_type = cfg.dataset_name.split("_")[1]
if dataset_type == "davis":
data_root = os.path.join(cfg.dataset_root, "tapvid_davis", "tapvid_davis.pkl")
elif dataset_type == "kinetics":
data_root = os.path.join(
cfg.dataset_root, "/kinetics/kinetics-dataset/k700-2020/tapvid_kinetics"
)
test_dataset = TapVidDataset(
dataset_type=dataset_type,
data_root=data_root,
queried_first=not "strided" in cfg.dataset_name,
)
elif cfg.dataset_name == "dynamic_replica":
test_dataset = DynamicReplicaDataset(sample_len=300, only_first_n_samples=1)
# Creating the DataLoader object
test_dataloader = torch.utils.data.DataLoader(
test_dataset,
batch_size=1,
shuffle=False,
num_workers=14,
collate_fn=curr_collate_fn,
)
# Timing and conducting the evaluation
import time
start = time.time()
evaluate_result = evaluator.evaluate_sequence(
predictor,
test_dataloader,
dataset_name=cfg.dataset_name,
)
end = time.time()
print(end - start)
# Saving the evaluation results to a .json file
evaluate_result = evaluate_result["avg"]
print("evaluate_result", evaluate_result)
result_file = os.path.join(cfg.exp_dir, f"result_eval_.json")
evaluate_result["time"] = end - start
print(f"Dumping eval results to {result_file}.")
with open(result_file, "w") as f:
json.dump(evaluate_result, f)
cs = hydra.core.config_store.ConfigStore.instance()
cs.store(name="default_config_eval", node=DefaultConfig)
@hydra.main(config_path="./configs/", config_name="default_config_eval")
def evaluate(cfg: DefaultConfig) -> None:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_idx)
run_eval(cfg)
if __name__ == "__main__":
evaluate()
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import json
import os
from dataclasses import dataclass, field
import hydra
import numpy as np
import torch
from omegaconf import OmegaConf
from cotracker.datasets.tap_vid_datasets import TapVidDataset
from cotracker.datasets.dr_dataset import DynamicReplicaDataset
from cotracker.datasets.utils import collate_fn
from cotracker.models.evaluation_predictor import EvaluationPredictor
from cotracker.evaluation.core.evaluator import Evaluator
from cotracker.models.build_cotracker import (
build_cotracker,
)
@dataclass(eq=False)
class DefaultConfig:
# Directory where all outputs of the experiment will be saved.
exp_dir: str = "./outputs"
# Name of the dataset to be used for the evaluation.
dataset_name: str = "tapvid_davis_first"
# The root directory of the dataset.
dataset_root: str = "./"
# Path to the pre-trained model checkpoint to be used for the evaluation.
# The default value is the path to a specific CoTracker model checkpoint.
checkpoint: str = "./checkpoints/cotracker2.pth"
# EvaluationPredictor parameters
# The size (N) of the support grid used in the predictor.
# The total number of points is (N*N).
grid_size: int = 5
# The size (N) of the local support grid.
local_grid_size: int = 8
# A flag indicating whether to evaluate one ground truth point at a time.
single_point: bool = True
# The number of iterative updates for each sliding window.
n_iters: int = 6
seed: int = 0
gpu_idx: int = 0
# Override hydra's working directory to current working dir,
# also disable storing the .hydra logs:
hydra: dict = field(
default_factory=lambda: {
"run": {"dir": "."},
"output_subdir": None,
}
)
def run_eval(cfg: DefaultConfig):
"""
The function evaluates CoTracker on a specified benchmark dataset based on a provided configuration.
Args:
cfg (DefaultConfig): An instance of DefaultConfig class which includes:
- exp_dir (str): The directory path for the experiment.
- dataset_name (str): The name of the dataset to be used.
- dataset_root (str): The root directory of the dataset.
- checkpoint (str): The path to the CoTracker model's checkpoint.
- single_point (bool): A flag indicating whether to evaluate one ground truth point at a time.
- n_iters (int): The number of iterative updates for each sliding window.
- seed (int): The seed for setting the random state for reproducibility.
- gpu_idx (int): The index of the GPU to be used.
"""
# Creating the experiment directory if it doesn't exist
os.makedirs(cfg.exp_dir, exist_ok=True)
# Saving the experiment configuration to a .yaml file in the experiment directory
cfg_file = os.path.join(cfg.exp_dir, "expconfig.yaml")
with open(cfg_file, "w") as f:
OmegaConf.save(config=cfg, f=f)
evaluator = Evaluator(cfg.exp_dir)
cotracker_model = build_cotracker(cfg.checkpoint)
# Creating the EvaluationPredictor object
predictor = EvaluationPredictor(
cotracker_model,
grid_size=cfg.grid_size,
local_grid_size=cfg.local_grid_size,
single_point=cfg.single_point,
n_iters=cfg.n_iters,
)
if torch.cuda.is_available():
predictor.model = predictor.model.cuda()
# Setting the random seeds
torch.manual_seed(cfg.seed)
np.random.seed(cfg.seed)
# Constructing the specified dataset
curr_collate_fn = collate_fn
if "tapvid" in cfg.dataset_name:
dataset_type = cfg.dataset_name.split("_")[1]
if dataset_type == "davis":
data_root = os.path.join(cfg.dataset_root, "tapvid_davis", "tapvid_davis.pkl")
elif dataset_type == "kinetics":
data_root = os.path.join(
cfg.dataset_root, "/kinetics/kinetics-dataset/k700-2020/tapvid_kinetics"
)
test_dataset = TapVidDataset(
dataset_type=dataset_type,
data_root=data_root,
queried_first=not "strided" in cfg.dataset_name,
)
elif cfg.dataset_name == "dynamic_replica":
test_dataset = DynamicReplicaDataset(sample_len=300, only_first_n_samples=1)
# Creating the DataLoader object
test_dataloader = torch.utils.data.DataLoader(
test_dataset,
batch_size=1,
shuffle=False,
num_workers=14,
collate_fn=curr_collate_fn,
)
# Timing and conducting the evaluation
import time
start = time.time()
evaluate_result = evaluator.evaluate_sequence(
predictor,
test_dataloader,
dataset_name=cfg.dataset_name,
)
end = time.time()
print(end - start)
# Saving the evaluation results to a .json file
evaluate_result = evaluate_result["avg"]
print("evaluate_result", evaluate_result)
result_file = os.path.join(cfg.exp_dir, f"result_eval_.json")
evaluate_result["time"] = end - start
print(f"Dumping eval results to {result_file}.")
with open(result_file, "w") as f:
json.dump(evaluate_result, f)
cs = hydra.core.config_store.ConfigStore.instance()
cs.store(name="default_config_eval", node=DefaultConfig)
@hydra.main(config_path="./configs/", config_name="default_config_eval")
def evaluate(cfg: DefaultConfig) -> None:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_idx)
run_eval(cfg)
if __name__ == "__main__":
evaluate()

View File

@@ -1,5 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

Binary file not shown.

Binary file not shown.

View File

@@ -1,33 +1,33 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from cotracker.models.core.cotracker.cotracker import CoTracker2
def build_cotracker(
checkpoint: str,
):
if checkpoint is None:
return build_cotracker()
model_name = checkpoint.split("/")[-1].split(".")[0]
if model_name == "cotracker":
return build_cotracker(checkpoint=checkpoint)
else:
raise ValueError(f"Unknown model name {model_name}")
def build_cotracker(checkpoint=None):
cotracker = CoTracker2(stride=4, window_len=8, add_space_attn=True)
if checkpoint is not None:
with open(checkpoint, "rb") as f:
state_dict = torch.load(f, map_location="cpu")
if "model" in state_dict:
state_dict = state_dict["model"]
cotracker.load_state_dict(state_dict)
return cotracker
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from cotracker.models.core.cotracker.cotracker import CoTracker2
def build_cotracker(
checkpoint: str,
):
if checkpoint is None:
return build_cotracker()
model_name = checkpoint.split("/")[-1].split(".")[0]
if model_name == "cotracker":
return build_cotracker(checkpoint=checkpoint)
else:
raise ValueError(f"Unknown model name {model_name}")
def build_cotracker(checkpoint=None):
cotracker = CoTracker2(stride=4, window_len=8, add_space_attn=True)
if checkpoint is not None:
with open(checkpoint, "rb") as f:
state_dict = torch.load(f, map_location="cpu")
if "model" in state_dict:
state_dict = state_dict["model"]
cotracker.load_state_dict(state_dict)
return cotracker

View File

@@ -1,5 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

View File

@@ -1,5 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

View File

@@ -1,367 +1,368 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from typing import Callable
import collections
from torch import Tensor
from itertools import repeat
from cotracker.models.core.model_utils import bilinear_sampler
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return tuple(x)
return tuple(repeat(x, n))
return parse
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
to_2tuple = _ntuple(2)
class Mlp(nn.Module):
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
norm_layer=None,
bias=True,
drop=0.0,
use_conv=False,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop)
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class ResidualBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(
in_planes,
planes,
kernel_size=3,
padding=1,
stride=stride,
padding_mode="zeros",
)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, padding_mode="zeros")
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if not stride == 1:
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(planes)
self.norm2 = nn.BatchNorm2d(planes)
if not stride == 1:
self.norm3 = nn.BatchNorm2d(planes)
elif norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(planes)
self.norm2 = nn.InstanceNorm2d(planes)
if not stride == 1:
self.norm3 = nn.InstanceNorm2d(planes)
elif norm_fn == "none":
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
if not stride == 1:
self.norm3 = nn.Sequential()
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x + y)
class BasicEncoder(nn.Module):
def __init__(self, input_dim=3, output_dim=128, stride=4):
super(BasicEncoder, self).__init__()
self.stride = stride
self.norm_fn = "instance"
self.in_planes = output_dim // 2
self.norm1 = nn.InstanceNorm2d(self.in_planes)
self.norm2 = nn.InstanceNorm2d(output_dim * 2)
self.conv1 = nn.Conv2d(
input_dim,
self.in_planes,
kernel_size=7,
stride=2,
padding=3,
padding_mode="zeros",
)
self.relu1 = nn.ReLU(inplace=True)
self.layer1 = self._make_layer(output_dim // 2, stride=1)
self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2)
self.layer3 = self._make_layer(output_dim, stride=2)
self.layer4 = self._make_layer(output_dim, stride=2)
self.conv2 = nn.Conv2d(
output_dim * 3 + output_dim // 4,
output_dim * 2,
kernel_size=3,
padding=1,
padding_mode="zeros",
)
self.relu2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, (nn.InstanceNorm2d)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
_, _, H, W = x.shape
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
a = self.layer1(x)
b = self.layer2(a)
c = self.layer3(b)
d = self.layer4(c)
def _bilinear_intepolate(x):
return F.interpolate(
x,
(H // self.stride, W // self.stride),
mode="bilinear",
align_corners=True,
)
a = _bilinear_intepolate(a)
b = _bilinear_intepolate(b)
c = _bilinear_intepolate(c)
d = _bilinear_intepolate(d)
x = self.conv2(torch.cat([a, b, c, d], dim=1))
x = self.norm2(x)
x = self.relu2(x)
x = self.conv3(x)
return x
class CorrBlock:
def __init__(
self,
fmaps,
num_levels=4,
radius=4,
multiple_track_feats=False,
padding_mode="zeros",
):
B, S, C, H, W = fmaps.shape
self.S, self.C, self.H, self.W = S, C, H, W
self.padding_mode = padding_mode
self.num_levels = num_levels
self.radius = radius
self.fmaps_pyramid = []
self.multiple_track_feats = multiple_track_feats
self.fmaps_pyramid.append(fmaps)
for i in range(self.num_levels - 1):
fmaps_ = fmaps.reshape(B * S, C, H, W)
fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
_, _, H, W = fmaps_.shape
fmaps = fmaps_.reshape(B, S, C, H, W)
self.fmaps_pyramid.append(fmaps)
def sample(self, coords):
r = self.radius
B, S, N, D = coords.shape
assert D == 2
H, W = self.H, self.W
out_pyramid = []
for i in range(self.num_levels):
corrs = self.corrs_pyramid[i] # B, S, N, H, W
*_, H, W = corrs.shape
dx = torch.linspace(-r, r, 2 * r + 1)
dy = torch.linspace(-r, r, 2 * r + 1)
delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(coords.device)
centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2**i
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
coords_lvl = centroid_lvl + delta_lvl
corrs = bilinear_sampler(
corrs.reshape(B * S * N, 1, H, W),
coords_lvl,
padding_mode=self.padding_mode,
)
corrs = corrs.view(B, S, N, -1)
out_pyramid.append(corrs)
out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
out = out.permute(0, 2, 1, 3).contiguous().view(B * N, S, -1).float()
return out
def corr(self, targets):
B, S, N, C = targets.shape
if self.multiple_track_feats:
targets_split = targets.split(C // self.num_levels, dim=-1)
B, S, N, C = targets_split[0].shape
assert C == self.C
assert S == self.S
fmap1 = targets
self.corrs_pyramid = []
for i, fmaps in enumerate(self.fmaps_pyramid):
*_, H, W = fmaps.shape
fmap2s = fmaps.view(B, S, C, H * W) # B S C H W -> B S C (H W)
if self.multiple_track_feats:
fmap1 = targets_split[i]
corrs = torch.matmul(fmap1, fmap2s)
corrs = corrs.view(B, S, N, H, W) # B S N (H W) -> B S N H W
corrs = corrs / torch.sqrt(torch.tensor(C).float())
self.corrs_pyramid.append(corrs)
class Attention(nn.Module):
def __init__(self, query_dim, context_dim=None, num_heads=8, dim_head=48, qkv_bias=False):
super().__init__()
inner_dim = dim_head * num_heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head**-0.5
self.heads = num_heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=qkv_bias)
self.to_out = nn.Linear(inner_dim, query_dim)
def forward(self, x, context=None, attn_bias=None):
B, N1, C = x.shape
h = self.heads
q = self.to_q(x).reshape(B, N1, h, C // h).permute(0, 2, 1, 3)
context = default(context, x)
k, v = self.to_kv(context).chunk(2, dim=-1)
N2 = context.shape[1]
k = k.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
v = v.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
sim = (q @ k.transpose(-2, -1)) * self.scale
if attn_bias is not None:
sim = sim + attn_bias
attn = sim.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N1, C)
return self.to_out(x)
class AttnBlock(nn.Module):
def __init__(
self,
hidden_size,
num_heads,
attn_class: Callable[..., nn.Module] = Attention,
mlp_ratio=4.0,
**block_kwargs
):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = attn_class(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(
in_features=hidden_size,
hidden_features=mlp_hidden_dim,
act_layer=approx_gelu,
drop=0,
)
def forward(self, x, mask=None):
attn_bias = mask
if mask is not None:
mask = (
(mask[:, None] * mask[:, :, None])
.unsqueeze(1)
.expand(-1, self.attn.num_heads, -1, -1)
)
max_neg_value = -torch.finfo(x.dtype).max
attn_bias = (~mask) * max_neg_value
x = x + self.attn(self.norm1(x), attn_bias=attn_bias)
x = x + self.mlp(self.norm2(x))
return x
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from typing import Callable
import collections
from torch import Tensor
from itertools import repeat
from cotracker.models.core.model_utils import bilinear_sampler
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return tuple(x)
return tuple(repeat(x, n))
return parse
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
to_2tuple = _ntuple(2)
class Mlp(nn.Module):
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
norm_layer=None,
bias=True,
drop=0.0,
use_conv=False,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop)
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class ResidualBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(
in_planes,
planes,
kernel_size=3,
padding=1,
stride=stride,
padding_mode="zeros",
)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, padding_mode="zeros")
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if not stride == 1:
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(planes)
self.norm2 = nn.BatchNorm2d(planes)
if not stride == 1:
self.norm3 = nn.BatchNorm2d(planes)
elif norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(planes)
self.norm2 = nn.InstanceNorm2d(planes)
if not stride == 1:
self.norm3 = nn.InstanceNorm2d(planes)
elif norm_fn == "none":
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
if not stride == 1:
self.norm3 = nn.Sequential()
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x + y)
class BasicEncoder(nn.Module):
def __init__(self, input_dim=3, output_dim=128, stride=4):
super(BasicEncoder, self).__init__()
self.stride = stride
self.norm_fn = "instance"
self.in_planes = output_dim // 2
self.norm1 = nn.InstanceNorm2d(self.in_planes)
self.norm2 = nn.InstanceNorm2d(output_dim * 2)
self.conv1 = nn.Conv2d(
input_dim,
self.in_planes,
kernel_size=7,
stride=2,
padding=3,
padding_mode="zeros",
)
self.relu1 = nn.ReLU(inplace=True)
self.layer1 = self._make_layer(output_dim // 2, stride=1)
self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2)
self.layer3 = self._make_layer(output_dim, stride=2)
self.layer4 = self._make_layer(output_dim, stride=2)
self.conv2 = nn.Conv2d(
output_dim * 3 + output_dim // 4,
output_dim * 2,
kernel_size=3,
padding=1,
padding_mode="zeros",
)
self.relu2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, (nn.InstanceNorm2d)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
_, _, H, W = x.shape
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
# 四层残差块
a = self.layer1(x)
b = self.layer2(a)
c = self.layer3(b)
d = self.layer4(c)
def _bilinear_intepolate(x):
return F.interpolate(
x,
(H // self.stride, W // self.stride),
mode="bilinear",
align_corners=True,
)
a = _bilinear_intepolate(a)
b = _bilinear_intepolate(b)
c = _bilinear_intepolate(c)
d = _bilinear_intepolate(d)
x = self.conv2(torch.cat([a, b, c, d], dim=1))
x = self.norm2(x)
x = self.relu2(x)
x = self.conv3(x)
return x
class CorrBlock:
def __init__(
self,
fmaps,
num_levels=4,
radius=4,
multiple_track_feats=False,
padding_mode="zeros",
):
B, S, C, H, W = fmaps.shape
self.S, self.C, self.H, self.W = S, C, H, W
self.padding_mode = padding_mode
self.num_levels = num_levels
self.radius = radius
self.fmaps_pyramid = []
self.multiple_track_feats = multiple_track_feats
self.fmaps_pyramid.append(fmaps)
for i in range(self.num_levels - 1):
fmaps_ = fmaps.reshape(B * S, C, H, W)
fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
_, _, H, W = fmaps_.shape
fmaps = fmaps_.reshape(B, S, C, H, W)
self.fmaps_pyramid.append(fmaps)
def sample(self, coords):
r = self.radius
B, S, N, D = coords.shape
assert D == 2
H, W = self.H, self.W
out_pyramid = []
for i in range(self.num_levels):
corrs = self.corrs_pyramid[i] # B, S, N, H, W
*_, H, W = corrs.shape
dx = torch.linspace(-r, r, 2 * r + 1)
dy = torch.linspace(-r, r, 2 * r + 1)
delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(coords.device)
centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2**i
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
coords_lvl = centroid_lvl + delta_lvl
corrs = bilinear_sampler(
corrs.reshape(B * S * N, 1, H, W),
coords_lvl,
padding_mode=self.padding_mode,
)
corrs = corrs.view(B, S, N, -1)
out_pyramid.append(corrs)
out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
out = out.permute(0, 2, 1, 3).contiguous().view(B * N, S, -1).float()
return out
def corr(self, targets):
B, S, N, C = targets.shape
if self.multiple_track_feats:
targets_split = targets.split(C // self.num_levels, dim=-1)
B, S, N, C = targets_split[0].shape
assert C == self.C
assert S == self.S
fmap1 = targets
self.corrs_pyramid = []
for i, fmaps in enumerate(self.fmaps_pyramid):
*_, H, W = fmaps.shape
fmap2s = fmaps.view(B, S, C, H * W) # B S C H W -> B S C (H W)
if self.multiple_track_feats:
fmap1 = targets_split[i]
corrs = torch.matmul(fmap1, fmap2s)
corrs = corrs.view(B, S, N, H, W) # B S N (H W) -> B S N H W
corrs = corrs / torch.sqrt(torch.tensor(C).float())
self.corrs_pyramid.append(corrs)
class Attention(nn.Module):
def __init__(self, query_dim, context_dim=None, num_heads=8, dim_head=48, qkv_bias=False):
super().__init__()
inner_dim = dim_head * num_heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head**-0.5
self.heads = num_heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=qkv_bias)
self.to_out = nn.Linear(inner_dim, query_dim)
def forward(self, x, context=None, attn_bias=None):
B, N1, C = x.shape
h = self.heads
q = self.to_q(x).reshape(B, N1, h, C // h).permute(0, 2, 1, 3)
context = default(context, x)
k, v = self.to_kv(context).chunk(2, dim=-1)
N2 = context.shape[1]
k = k.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
v = v.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
sim = (q @ k.transpose(-2, -1)) * self.scale
if attn_bias is not None:
sim = sim + attn_bias
attn = sim.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N1, C)
return self.to_out(x)
class AttnBlock(nn.Module):
def __init__(
self,
hidden_size,
num_heads,
attn_class: Callable[..., nn.Module] = Attention,
mlp_ratio=4.0,
**block_kwargs
):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = attn_class(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(
in_features=hidden_size,
hidden_features=mlp_hidden_dim,
act_layer=approx_gelu,
drop=0,
)
def forward(self, x, mask=None):
attn_bias = mask
if mask is not None:
mask = (
(mask[:, None] * mask[:, :, None])
.unsqueeze(1)
.expand(-1, self.attn.num_heads, -1, -1)
)
max_neg_value = -torch.finfo(x.dtype).max
attn_bias = (~mask) * max_neg_value
x = x + self.attn(self.norm1(x), attn_bias=attn_bias)
x = x + self.mlp(self.norm2(x))
return x

File diff suppressed because it is too large Load Diff

View File

@@ -1,61 +1,61 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
from cotracker.models.core.model_utils import reduce_masked_mean
EPS = 1e-6
def balanced_ce_loss(pred, gt, valid=None):
total_balanced_loss = 0.0
for j in range(len(gt)):
B, S, N = gt[j].shape
# pred and gt are the same shape
for (a, b) in zip(pred[j].size(), gt[j].size()):
assert a == b # some shape mismatch!
# if valid is not None:
for (a, b) in zip(pred[j].size(), valid[j].size()):
assert a == b # some shape mismatch!
pos = (gt[j] > 0.95).float()
neg = (gt[j] < 0.05).float()
label = pos * 2.0 - 1.0
a = -label * pred[j]
b = F.relu(a)
loss = b + torch.log(torch.exp(-b) + torch.exp(a - b))
pos_loss = reduce_masked_mean(loss, pos * valid[j])
neg_loss = reduce_masked_mean(loss, neg * valid[j])
balanced_loss = pos_loss + neg_loss
total_balanced_loss += balanced_loss / float(N)
return total_balanced_loss
def sequence_loss(flow_preds, flow_gt, vis, valids, gamma=0.8):
"""Loss function defined over sequence of flow predictions"""
total_flow_loss = 0.0
for j in range(len(flow_gt)):
B, S, N, D = flow_gt[j].shape
assert D == 2
B, S1, N = vis[j].shape
B, S2, N = valids[j].shape
assert S == S1
assert S == S2
n_predictions = len(flow_preds[j])
flow_loss = 0.0
for i in range(n_predictions):
i_weight = gamma ** (n_predictions - i - 1)
flow_pred = flow_preds[j][i]
i_loss = (flow_pred - flow_gt[j]).abs() # B, S, N, 2
i_loss = torch.mean(i_loss, dim=3) # B, S, N
flow_loss += i_weight * reduce_masked_mean(i_loss, valids[j])
flow_loss = flow_loss / n_predictions
total_flow_loss += flow_loss / float(N)
return total_flow_loss
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
from cotracker.models.core.model_utils import reduce_masked_mean
EPS = 1e-6
def balanced_ce_loss(pred, gt, valid=None):
total_balanced_loss = 0.0
for j in range(len(gt)):
B, S, N = gt[j].shape
# pred and gt are the same shape
for (a, b) in zip(pred[j].size(), gt[j].size()):
assert a == b # some shape mismatch!
# if valid is not None:
for (a, b) in zip(pred[j].size(), valid[j].size()):
assert a == b # some shape mismatch!
pos = (gt[j] > 0.95).float()
neg = (gt[j] < 0.05).float()
label = pos * 2.0 - 1.0
a = -label * pred[j]
b = F.relu(a)
loss = b + torch.log(torch.exp(-b) + torch.exp(a - b))
pos_loss = reduce_masked_mean(loss, pos * valid[j])
neg_loss = reduce_masked_mean(loss, neg * valid[j])
balanced_loss = pos_loss + neg_loss
total_balanced_loss += balanced_loss / float(N)
return total_balanced_loss
def sequence_loss(flow_preds, flow_gt, vis, valids, gamma=0.8):
"""Loss function defined over sequence of flow predictions"""
total_flow_loss = 0.0
for j in range(len(flow_gt)):
B, S, N, D = flow_gt[j].shape
assert D == 2
B, S1, N = vis[j].shape
B, S2, N = valids[j].shape
assert S == S1
assert S == S2
n_predictions = len(flow_preds[j])
flow_loss = 0.0
for i in range(n_predictions):
i_weight = gamma ** (n_predictions - i - 1)
flow_pred = flow_preds[j][i]
i_loss = (flow_pred - flow_gt[j]).abs() # B, S, N, 2
i_loss = torch.mean(i_loss, dim=3) # B, S, N
flow_loss += i_weight * reduce_masked_mean(i_loss, valids[j])
flow_loss = flow_loss / n_predictions
total_flow_loss += flow_loss / float(N)
return total_flow_loss

View File

@@ -1,120 +1,120 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Tuple, Union
import torch
def get_2d_sincos_pos_embed(
embed_dim: int, grid_size: Union[int, Tuple[int, int]]
) -> torch.Tensor:
"""
This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
It is a wrapper of get_2d_sincos_pos_embed_from_grid.
Args:
- embed_dim: The embedding dimension.
- grid_size: The grid size.
Returns:
- pos_embed: The generated 2D positional embedding.
"""
if isinstance(grid_size, tuple):
grid_size_h, grid_size_w = grid_size
else:
grid_size_h = grid_size_w = grid_size
grid_h = torch.arange(grid_size_h, dtype=torch.float)
grid_w = torch.arange(grid_size_w, dtype=torch.float)
grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
grid = torch.stack(grid, dim=0)
grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
def get_2d_sincos_pos_embed_from_grid(
embed_dim: int, grid: torch.Tensor
) -> torch.Tensor:
"""
This function generates a 2D positional embedding from a given grid using sine and cosine functions.
Args:
- embed_dim: The embedding dimension.
- grid: The grid to generate the embedding from.
Returns:
- emb: The generated 2D positional embedding.
"""
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(
embed_dim: int, pos: torch.Tensor
) -> torch.Tensor:
"""
This function generates a 1D positional embedding from a given grid using sine and cosine functions.
Args:
- embed_dim: The embedding dimension.
- pos: The position to generate the embedding from.
Returns:
- emb: The generated 1D positional embedding.
"""
assert embed_dim % 2 == 0
omega = torch.arange(embed_dim // 2, dtype=torch.double)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = torch.sin(out) # (M, D/2)
emb_cos = torch.cos(out) # (M, D/2)
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
return emb[None].float()
def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
"""
This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
Args:
- xy: The coordinates to generate the embedding from.
- C: The size of the embedding.
- cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
Returns:
- pe: The generated 2D positional embedding.
"""
B, N, D = xy.shape
assert D == 2
x = xy[:, :, 0:1]
y = xy[:, :, 1:2]
div_term = (
torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)
).reshape(1, 1, int(C / 2))
pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
pe_x[:, :, 0::2] = torch.sin(x * div_term)
pe_x[:, :, 1::2] = torch.cos(x * div_term)
pe_y[:, :, 0::2] = torch.sin(y * div_term)
pe_y[:, :, 1::2] = torch.cos(y * div_term)
pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
if cat_coords:
pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
return pe
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Tuple, Union
import torch
def get_2d_sincos_pos_embed(
embed_dim: int, grid_size: Union[int, Tuple[int, int]]
) -> torch.Tensor:
"""
This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
It is a wrapper of get_2d_sincos_pos_embed_from_grid.
Args:
- embed_dim: The embedding dimension.
- grid_size: The grid size.
Returns:
- pos_embed: The generated 2D positional embedding.
"""
if isinstance(grid_size, tuple):
grid_size_h, grid_size_w = grid_size
else:
grid_size_h = grid_size_w = grid_size
grid_h = torch.arange(grid_size_h, dtype=torch.float)
grid_w = torch.arange(grid_size_w, dtype=torch.float)
grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
grid = torch.stack(grid, dim=0)
grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
def get_2d_sincos_pos_embed_from_grid(
embed_dim: int, grid: torch.Tensor
) -> torch.Tensor:
"""
This function generates a 2D positional embedding from a given grid using sine and cosine functions.
Args:
- embed_dim: The embedding dimension.
- grid: The grid to generate the embedding from.
Returns:
- emb: The generated 2D positional embedding.
"""
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(
embed_dim: int, pos: torch.Tensor
) -> torch.Tensor:
"""
This function generates a 1D positional embedding from a given grid using sine and cosine functions.
Args:
- embed_dim: The embedding dimension.
- pos: The position to generate the embedding from.
Returns:
- emb: The generated 1D positional embedding.
"""
assert embed_dim % 2 == 0
omega = torch.arange(embed_dim // 2, dtype=torch.double)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = torch.sin(out) # (M, D/2)
emb_cos = torch.cos(out) # (M, D/2)
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
return emb[None].float()
def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
"""
This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
Args:
- xy: The coordinates to generate the embedding from.
- C: The size of the embedding.
- cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
Returns:
- pe: The generated 2D positional embedding.
"""
B, N, D = xy.shape
assert D == 2
x = xy[:, :, 0:1]
y = xy[:, :, 1:2]
div_term = (
torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)
).reshape(1, 1, int(C / 2))
pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
pe_x[:, :, 0::2] = torch.sin(x * div_term)
pe_x[:, :, 1::2] = torch.cos(x * div_term)
pe_y[:, :, 0::2] = torch.sin(y * div_term)
pe_y[:, :, 1::2] = torch.cos(y * div_term)
pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
if cat_coords:
pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
return pe

View File

@@ -1,256 +1,256 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
from typing import Optional, Tuple
EPS = 1e-6
def smart_cat(tensor1, tensor2, dim):
if tensor1 is None:
return tensor2
return torch.cat([tensor1, tensor2], dim=dim)
def get_points_on_a_grid(
size: int,
extent: Tuple[float, ...],
center: Optional[Tuple[float, ...]] = None,
device: Optional[torch.device] = torch.device("cpu"),
):
r"""Get a grid of points covering a rectangular region
`get_points_on_a_grid(size, extent)` generates a :attr:`size` by
:attr:`size` grid fo points distributed to cover a rectangular area
specified by `extent`.
The `extent` is a pair of integer :math:`(H,W)` specifying the height
and width of the rectangle.
Optionally, the :attr:`center` can be specified as a pair :math:`(c_y,c_x)`
specifying the vertical and horizontal center coordinates. The center
defaults to the middle of the extent.
Points are distributed uniformly within the rectangle leaving a margin
:math:`m=W/64` from the border.
It returns a :math:`(1, \text{size} \times \text{size}, 2)` tensor of
points :math:`P_{ij}=(x_i, y_i)` where
.. math::
P_{ij} = \left(
c_x + m -\frac{W}{2} + \frac{W - 2m}{\text{size} - 1}\, j,~
c_y + m -\frac{H}{2} + \frac{H - 2m}{\text{size} - 1}\, i
\right)
Points are returned in row-major order.
Args:
size (int): grid size.
extent (tuple): height and with of the grid extent.
center (tuple, optional): grid center.
device (str, optional): Defaults to `"cpu"`.
Returns:
Tensor: grid.
"""
if size == 1:
return torch.tensor([extent[1] / 2, extent[0] / 2], device=device)[None, None]
if center is None:
center = [extent[0] / 2, extent[1] / 2]
margin = extent[1] / 64
range_y = (margin - extent[0] / 2 + center[0], extent[0] / 2 + center[0] - margin)
range_x = (margin - extent[1] / 2 + center[1], extent[1] / 2 + center[1] - margin)
grid_y, grid_x = torch.meshgrid(
torch.linspace(*range_y, size, device=device),
torch.linspace(*range_x, size, device=device),
indexing="ij",
)
return torch.stack([grid_x, grid_y], dim=-1).reshape(1, -1, 2)
def reduce_masked_mean(input, mask, dim=None, keepdim=False):
r"""Masked mean
`reduce_masked_mean(x, mask)` computes the mean of a tensor :attr:`input`
over a mask :attr:`mask`, returning
.. math::
\text{output} =
\frac
{\sum_{i=1}^N \text{input}_i \cdot \text{mask}_i}
{\epsilon + \sum_{i=1}^N \text{mask}_i}
where :math:`N` is the number of elements in :attr:`input` and
:attr:`mask`, and :math:`\epsilon` is a small constant to avoid
division by zero.
`reduced_masked_mean(x, mask, dim)` computes the mean of a tensor
:attr:`input` over a mask :attr:`mask` along a dimension :attr:`dim`.
Optionally, the dimension can be kept in the output by setting
:attr:`keepdim` to `True`. Tensor :attr:`mask` must be broadcastable to
the same dimension as :attr:`input`.
The interface is similar to `torch.mean()`.
Args:
inout (Tensor): input tensor.
mask (Tensor): mask.
dim (int, optional): Dimension to sum over. Defaults to None.
keepdim (bool, optional): Keep the summed dimension. Defaults to False.
Returns:
Tensor: mean tensor.
"""
mask = mask.expand_as(input)
prod = input * mask
if dim is None:
numer = torch.sum(prod)
denom = torch.sum(mask)
else:
numer = torch.sum(prod, dim=dim, keepdim=keepdim)
denom = torch.sum(mask, dim=dim, keepdim=keepdim)
mean = numer / (EPS + denom)
return mean
def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
r"""Sample a tensor using bilinear interpolation
`bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
coordinates :attr:`coords` using bilinear interpolation. It is the same
as `torch.nn.functional.grid_sample()` but with a different coordinate
convention.
The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
:math:`B` is the batch size, :math:`C` is the number of channels,
:math:`H` is the height of the image, and :math:`W` is the width of the
image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
that in this case the order of the components is slightly different
from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
left-most image pixel :math:`W-1` to the center of the right-most
pixel.
If `align_corners` is `False`, the coordinate :math:`x` is assumed to
be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
the left-most pixel :math:`W` to the right edge of the right-most
pixel.
Similar conventions apply to the :math:`y` for the range
:math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
:math:`[0,T-1]` and :math:`[0,T]`.
Args:
input (Tensor): batch of input images.
coords (Tensor): batch of coordinates.
align_corners (bool, optional): Coordinate convention. Defaults to `True`.
padding_mode (str, optional): Padding mode. Defaults to `"border"`.
Returns:
Tensor: sampled points.
"""
sizes = input.shape[2:]
assert len(sizes) in [2, 3]
if len(sizes) == 3:
# t x y -> x y t to match dimensions T H W in grid_sample
coords = coords[..., [1, 2, 0]]
if align_corners:
coords = coords * torch.tensor(
[2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device
)
else:
coords = coords * torch.tensor([2 / size for size in reversed(sizes)], device=coords.device)
coords -= 1
return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
def sample_features4d(input, coords):
r"""Sample spatial features
`sample_features4d(input, coords)` samples the spatial features
:attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
The field is sampled at coordinates :attr:`coords` using bilinear
interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
3)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
same convention as :func:`bilinear_sampler` with `align_corners=True`.
The output tensor has one feature per point, and has shape :math:`(B,
R, C)`.
Args:
input (Tensor): spatial features.
coords (Tensor): points.
Returns:
Tensor: sampled features.
"""
B, _, _, _ = input.shape
# B R 2 -> B R 1 2
coords = coords.unsqueeze(2)
# B C R 1
feats = bilinear_sampler(input, coords)
return feats.permute(0, 2, 1, 3).view(
B, -1, feats.shape[1] * feats.shape[3]
) # B C R 1 -> B R C
def sample_features5d(input, coords):
r"""Sample spatio-temporal features
`sample_features5d(input, coords)` works in the same way as
:func:`sample_features4d` but for spatio-temporal features and points:
:attr:`input` is a 5D tensor :math:`(B, T, C, H, W)`, :attr:`coords` is
a :math:`(B, R1, R2, 3)` tensor of spatio-temporal point :math:`(t_i,
x_i, y_i)`. The output tensor has shape :math:`(B, R1, R2, C)`.
Args:
input (Tensor): spatio-temporal features.
coords (Tensor): spatio-temporal points.
Returns:
Tensor: sampled features.
"""
B, T, _, _, _ = input.shape
# B T C H W -> B C T H W
input = input.permute(0, 2, 1, 3, 4)
# B R1 R2 3 -> B R1 R2 1 3
coords = coords.unsqueeze(3)
# B C R1 R2 1
feats = bilinear_sampler(input, coords)
return feats.permute(0, 2, 3, 1, 4).view(
B, feats.shape[2], feats.shape[3], feats.shape[1]
) # B C R1 R2 1 -> B R1 R2 C
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
from typing import Optional, Tuple
EPS = 1e-6
def smart_cat(tensor1, tensor2, dim):
if tensor1 is None:
return tensor2
return torch.cat([tensor1, tensor2], dim=dim)
def get_points_on_a_grid(
size: int,
extent: Tuple[float, ...],
center: Optional[Tuple[float, ...]] = None,
device: Optional[torch.device] = torch.device("cpu"),
):
r"""Get a grid of points covering a rectangular region
`get_points_on_a_grid(size, extent)` generates a :attr:`size` by
:attr:`size` grid fo points distributed to cover a rectangular area
specified by `extent`.
The `extent` is a pair of integer :math:`(H,W)` specifying the height
and width of the rectangle.
Optionally, the :attr:`center` can be specified as a pair :math:`(c_y,c_x)`
specifying the vertical and horizontal center coordinates. The center
defaults to the middle of the extent.
Points are distributed uniformly within the rectangle leaving a margin
:math:`m=W/64` from the border.
It returns a :math:`(1, \text{size} \times \text{size}, 2)` tensor of
points :math:`P_{ij}=(x_i, y_i)` where
.. math::
P_{ij} = \left(
c_x + m -\frac{W}{2} + \frac{W - 2m}{\text{size} - 1}\, j,~
c_y + m -\frac{H}{2} + \frac{H - 2m}{\text{size} - 1}\, i
\right)
Points are returned in row-major order.
Args:
size (int): grid size.
extent (tuple): height and with of the grid extent.
center (tuple, optional): grid center.
device (str, optional): Defaults to `"cpu"`.
Returns:
Tensor: grid.
"""
if size == 1:
return torch.tensor([extent[1] / 2, extent[0] / 2], device=device)[None, None]
if center is None:
center = [extent[0] / 2, extent[1] / 2]
margin = extent[1] / 64
range_y = (margin - extent[0] / 2 + center[0], extent[0] / 2 + center[0] - margin)
range_x = (margin - extent[1] / 2 + center[1], extent[1] / 2 + center[1] - margin)
grid_y, grid_x = torch.meshgrid(
torch.linspace(*range_y, size, device=device),
torch.linspace(*range_x, size, device=device),
indexing="ij",
)
return torch.stack([grid_x, grid_y], dim=-1).reshape(1, -1, 2)
def reduce_masked_mean(input, mask, dim=None, keepdim=False):
r"""Masked mean
`reduce_masked_mean(x, mask)` computes the mean of a tensor :attr:`input`
over a mask :attr:`mask`, returning
.. math::
\text{output} =
\frac
{\sum_{i=1}^N \text{input}_i \cdot \text{mask}_i}
{\epsilon + \sum_{i=1}^N \text{mask}_i}
where :math:`N` is the number of elements in :attr:`input` and
:attr:`mask`, and :math:`\epsilon` is a small constant to avoid
division by zero.
`reduced_masked_mean(x, mask, dim)` computes the mean of a tensor
:attr:`input` over a mask :attr:`mask` along a dimension :attr:`dim`.
Optionally, the dimension can be kept in the output by setting
:attr:`keepdim` to `True`. Tensor :attr:`mask` must be broadcastable to
the same dimension as :attr:`input`.
The interface is similar to `torch.mean()`.
Args:
inout (Tensor): input tensor.
mask (Tensor): mask.
dim (int, optional): Dimension to sum over. Defaults to None.
keepdim (bool, optional): Keep the summed dimension. Defaults to False.
Returns:
Tensor: mean tensor.
"""
mask = mask.expand_as(input)
prod = input * mask
if dim is None:
numer = torch.sum(prod)
denom = torch.sum(mask)
else:
numer = torch.sum(prod, dim=dim, keepdim=keepdim)
denom = torch.sum(mask, dim=dim, keepdim=keepdim)
mean = numer / (EPS + denom)
return mean
def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
r"""Sample a tensor using bilinear interpolation
`bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
coordinates :attr:`coords` using bilinear interpolation. It is the same
as `torch.nn.functional.grid_sample()` but with a different coordinate
convention.
The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
:math:`B` is the batch size, :math:`C` is the number of channels,
:math:`H` is the height of the image, and :math:`W` is the width of the
image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
that in this case the order of the components is slightly different
from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
left-most image pixel :math:`W-1` to the center of the right-most
pixel.
If `align_corners` is `False`, the coordinate :math:`x` is assumed to
be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
the left-most pixel :math:`W` to the right edge of the right-most
pixel.
Similar conventions apply to the :math:`y` for the range
:math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
:math:`[0,T-1]` and :math:`[0,T]`.
Args:
input (Tensor): batch of input images.
coords (Tensor): batch of coordinates.
align_corners (bool, optional): Coordinate convention. Defaults to `True`.
padding_mode (str, optional): Padding mode. Defaults to `"border"`.
Returns:
Tensor: sampled points.
"""
sizes = input.shape[2:]
assert len(sizes) in [2, 3]
if len(sizes) == 3:
# t x y -> x y t to match dimensions T H W in grid_sample
coords = coords[..., [1, 2, 0]]
if align_corners:
coords = coords * torch.tensor(
[2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device
)
else:
coords = coords * torch.tensor([2 / size for size in reversed(sizes)], device=coords.device)
coords -= 1
return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
def sample_features4d(input, coords):
r"""Sample spatial features
`sample_features4d(input, coords)` samples the spatial features
:attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
The field is sampled at coordinates :attr:`coords` using bilinear
interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
3)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
same convention as :func:`bilinear_sampler` with `align_corners=True`.
The output tensor has one feature per point, and has shape :math:`(B,
R, C)`.
Args:
input (Tensor): spatial features.
coords (Tensor): points.
Returns:
Tensor: sampled features.
"""
B, _, _, _ = input.shape
# B R 2 -> B R 1 2
coords = coords.unsqueeze(2)
# B C R 1
feats = bilinear_sampler(input, coords)
return feats.permute(0, 2, 1, 3).view(
B, -1, feats.shape[1] * feats.shape[3]
) # B C R 1 -> B R C
def sample_features5d(input, coords):
r"""Sample spatio-temporal features
`sample_features5d(input, coords)` works in the same way as
:func:`sample_features4d` but for spatio-temporal features and points:
:attr:`input` is a 5D tensor :math:`(B, T, C, H, W)`, :attr:`coords` is
a :math:`(B, R1, R2, 3)` tensor of spatio-temporal point :math:`(t_i,
x_i, y_i)`. The output tensor has shape :math:`(B, R1, R2, C)`.
Args:
input (Tensor): spatio-temporal features.
coords (Tensor): spatio-temporal points.
Returns:
Tensor: sampled features.
"""
B, T, _, _, _ = input.shape
# B T C H W -> B C T H W
input = input.permute(0, 2, 1, 3, 4)
# B R1 R2 3 -> B R1 R2 1 3
coords = coords.unsqueeze(3)
# B C R1 R2 1
feats = bilinear_sampler(input, coords)
return feats.permute(0, 2, 3, 1, 4).view(
B, feats.shape[2], feats.shape[3], feats.shape[1]
) # B C R1 R2 1 -> B R1 R2 C

View File

@@ -1,104 +1,104 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
from typing import Tuple
from cotracker.models.core.cotracker.cotracker import CoTracker2
from cotracker.models.core.model_utils import get_points_on_a_grid
class EvaluationPredictor(torch.nn.Module):
def __init__(
self,
cotracker_model: CoTracker2,
interp_shape: Tuple[int, int] = (384, 512),
grid_size: int = 5,
local_grid_size: int = 8,
single_point: bool = True,
n_iters: int = 6,
) -> None:
super(EvaluationPredictor, self).__init__()
self.grid_size = grid_size
self.local_grid_size = local_grid_size
self.single_point = single_point
self.interp_shape = interp_shape
self.n_iters = n_iters
self.model = cotracker_model
self.model.eval()
def forward(self, video, queries):
queries = queries.clone()
B, T, C, H, W = video.shape
B, N, D = queries.shape
assert D == 3
video = video.reshape(B * T, C, H, W)
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True)
video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
device = video.device
queries[:, :, 1] *= (self.interp_shape[1] - 1) / (W - 1)
queries[:, :, 2] *= (self.interp_shape[0] - 1) / (H - 1)
if self.single_point:
traj_e = torch.zeros((B, T, N, 2), device=device)
vis_e = torch.zeros((B, T, N), device=device)
for pind in range((N)):
query = queries[:, pind : pind + 1]
t = query[0, 0, 0].long()
traj_e_pind, vis_e_pind = self._process_one_point(video, query)
traj_e[:, t:, pind : pind + 1] = traj_e_pind[:, :, :1]
vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1]
else:
if self.grid_size > 0:
xy = get_points_on_a_grid(self.grid_size, video.shape[3:])
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
queries = torch.cat([queries, xy], dim=1) #
traj_e, vis_e, __ = self.model(
video=video,
queries=queries,
iters=self.n_iters,
)
traj_e[:, :, :, 0] *= (W - 1) / float(self.interp_shape[1] - 1)
traj_e[:, :, :, 1] *= (H - 1) / float(self.interp_shape[0] - 1)
return traj_e, vis_e
def _process_one_point(self, video, query):
t = query[0, 0, 0].long()
device = query.device
if self.local_grid_size > 0:
xy_target = get_points_on_a_grid(
self.local_grid_size,
(50, 50),
[query[0, 0, 2].item(), query[0, 0, 1].item()],
)
xy_target = torch.cat([torch.zeros_like(xy_target[:, :, :1]), xy_target], dim=2).to(
device
) #
query = torch.cat([query, xy_target], dim=1) #
if self.grid_size > 0:
xy = get_points_on_a_grid(self.grid_size, video.shape[3:])
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
query = torch.cat([query, xy], dim=1) #
# crop the video to start from the queried frame
query[0, 0, 0] = 0
traj_e_pind, vis_e_pind, __ = self.model(
video=video[:, t:], queries=query, iters=self.n_iters
)
return traj_e_pind, vis_e_pind
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
from typing import Tuple
from cotracker.models.core.cotracker.cotracker import CoTracker2
from cotracker.models.core.model_utils import get_points_on_a_grid
class EvaluationPredictor(torch.nn.Module):
def __init__(
self,
cotracker_model: CoTracker2,
interp_shape: Tuple[int, int] = (384, 512),
grid_size: int = 5,
local_grid_size: int = 8,
single_point: bool = True,
n_iters: int = 6,
) -> None:
super(EvaluationPredictor, self).__init__()
self.grid_size = grid_size
self.local_grid_size = local_grid_size
self.single_point = single_point
self.interp_shape = interp_shape
self.n_iters = n_iters
self.model = cotracker_model
self.model.eval()
def forward(self, video, queries):
queries = queries.clone()
B, T, C, H, W = video.shape
B, N, D = queries.shape
assert D == 3
video = video.reshape(B * T, C, H, W)
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True)
video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
device = video.device
queries[:, :, 1] *= (self.interp_shape[1] - 1) / (W - 1)
queries[:, :, 2] *= (self.interp_shape[0] - 1) / (H - 1)
if self.single_point:
traj_e = torch.zeros((B, T, N, 2), device=device)
vis_e = torch.zeros((B, T, N), device=device)
for pind in range((N)):
query = queries[:, pind : pind + 1]
t = query[0, 0, 0].long()
traj_e_pind, vis_e_pind = self._process_one_point(video, query)
traj_e[:, t:, pind : pind + 1] = traj_e_pind[:, :, :1]
vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1]
else:
if self.grid_size > 0:
xy = get_points_on_a_grid(self.grid_size, video.shape[3:])
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
queries = torch.cat([queries, xy], dim=1) #
traj_e, vis_e, __ = self.model(
video=video,
queries=queries,
iters=self.n_iters,
)
traj_e[:, :, :, 0] *= (W - 1) / float(self.interp_shape[1] - 1)
traj_e[:, :, :, 1] *= (H - 1) / float(self.interp_shape[0] - 1)
return traj_e, vis_e
def _process_one_point(self, video, query):
t = query[0, 0, 0].long()
device = query.device
if self.local_grid_size > 0:
xy_target = get_points_on_a_grid(
self.local_grid_size,
(50, 50),
[query[0, 0, 2].item(), query[0, 0, 1].item()],
)
xy_target = torch.cat([torch.zeros_like(xy_target[:, :, :1]), xy_target], dim=2).to(
device
) #
query = torch.cat([query, xy_target], dim=1) #
if self.grid_size > 0:
xy = get_points_on_a_grid(self.grid_size, video.shape[3:])
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
query = torch.cat([query, xy], dim=1) #
# crop the video to start from the queried frame
query[0, 0, 0] = 0
traj_e_pind, vis_e_pind, __ = self.model(
video=video[:, t:], queries=query, iters=self.n_iters
)
return traj_e_pind, vis_e_pind

View File

@@ -1,275 +1,279 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
from cotracker.models.core.model_utils import smart_cat, get_points_on_a_grid
from cotracker.models.build_cotracker import build_cotracker
class CoTrackerPredictor(torch.nn.Module):
def __init__(self, checkpoint="./checkpoints/cotracker2.pth"):
super().__init__()
self.support_grid_size = 6
model = build_cotracker(checkpoint)
self.interp_shape = model.model_resolution
print(self.interp_shape)
self.model = model
self.model.eval()
@torch.no_grad()
def forward(
self,
video, # (B, T, 3, H, W) Batch_size, time, rgb, height, width
# input prompt types:
# - None. Dense tracks are computed in this case. You can adjust *query_frame* to compute tracks starting from a specific frame.
# *backward_tracking=True* will compute tracks in both directions.
# - queries. Queried points of shape (B, N, 3) in format (t, x, y) for frame index and pixel coordinates.
# - grid_size. Grid of N*N points from the first frame. if segm_mask is provided, then computed only for the mask.
# You can adjust *query_frame* and *backward_tracking* for the regular grid in the same way as for dense tracks.
queries: torch.Tensor = None,
segm_mask: torch.Tensor = None, # Segmentation mask of shape (B, 1, H, W)
grid_size: int = 0,
grid_query_frame: int = 0, # only for dense and regular grid tracks
backward_tracking: bool = False,
):
if queries is None and grid_size == 0:
tracks, visibilities = self._compute_dense_tracks(
video,
grid_query_frame=grid_query_frame,
backward_tracking=backward_tracking,
)
else:
tracks, visibilities = self._compute_sparse_tracks(
video,
queries,
segm_mask,
grid_size,
add_support_grid=(grid_size == 0 or segm_mask is not None),
grid_query_frame=grid_query_frame,
backward_tracking=backward_tracking,
)
return tracks, visibilities
def _compute_dense_tracks(self, video, grid_query_frame, grid_size=80, backward_tracking=False):
*_, H, W = video.shape
grid_step = W // grid_size
grid_width = W // grid_step
grid_height = H // grid_step # set the whole video to grid_size number of grids
tracks = visibilities = None
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(video.device)
# (batch_size, grid_number, t,x,y)
grid_pts[0, :, 0] = grid_query_frame
# iterate every grid
for offset in range(grid_step * grid_step):
print(f"step {offset} / {grid_step * grid_step}")
ox = offset % grid_step
oy = offset // grid_step
# initialize
# for example
# grid width = 4, grid height = 4, grid step = 10, ox = 1
# torch.arange(grid_width) = [0,1,2,3]
# torch.arange(grid_width).repeat(grid_height) = [0,1,2,3,0,1,2,3,0,1,2,3]
# torch.arange(grid_width).repeat(grid_height) * grid_step = [0,10,20,30,0,10,20,30,0,10,20,30]
# get the location in the image
grid_pts[0, :, 1] = torch.arange(grid_width).repeat(grid_height) * grid_step + ox
grid_pts[0, :, 2] = (
torch.arange(grid_height).repeat_interleave(grid_width) * grid_step + oy
)
tracks_step, visibilities_step = self._compute_sparse_tracks(
video=video,
queries=grid_pts,
backward_tracking=backward_tracking,
)
tracks = smart_cat(tracks, tracks_step, dim=2)
visibilities = smart_cat(visibilities, visibilities_step, dim=2)
return tracks, visibilities
def _compute_sparse_tracks(
self,
video,
queries,
segm_mask=None,
grid_size=0,
add_support_grid=False,
grid_query_frame=0,
backward_tracking=False,
):
B, T, C, H, W = video.shape
video = video.reshape(B * T, C, H, W)
# ? what is interpolate?
# 将video插值成interp_shape?
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True)
video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
if queries is not None:
B, N, D = queries.shape # batch_size, number of points, (t,x,y)
assert D == 3
# query 缩放到( interp_shape - 1 ) / (W - 1)
# 插完值之后缩放
queries = queries.clone()
queries[:, :, 1:] *= queries.new_tensor(
[
(self.interp_shape[1] - 1) / (W - 1),
(self.interp_shape[0] - 1) / (H - 1),
]
)
# 生成grid
elif grid_size > 0:
grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=video.device)
if segm_mask is not None:
segm_mask = F.interpolate(segm_mask, tuple(self.interp_shape), mode="nearest")
point_mask = segm_mask[0, 0][
(grid_pts[0, :, 1]).round().long().cpu(),
(grid_pts[0, :, 0]).round().long().cpu(),
].bool()
grid_pts = grid_pts[:, point_mask]
queries = torch.cat(
[torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts],
dim=2,
).repeat(B, 1, 1)
# 添加支持点
if add_support_grid:
grid_pts = get_points_on_a_grid(
self.support_grid_size, self.interp_shape, device=video.device
)
grid_pts = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)
grid_pts = grid_pts.repeat(B, 1, 1)
queries = torch.cat([queries, grid_pts], dim=1)
tracks, visibilities, __ = self.model.forward(video=video, queries=queries, iters=6)
if backward_tracking:
tracks, visibilities = self._compute_backward_tracks(
video, queries, tracks, visibilities
)
if add_support_grid:
queries[:, -self.support_grid_size**2 :, 0] = T - 1
if add_support_grid:
tracks = tracks[:, :, : -self.support_grid_size**2]
visibilities = visibilities[:, :, : -self.support_grid_size**2]
thr = 0.9
visibilities = visibilities > thr
# correct query-point predictions
# see https://github.com/facebookresearch/co-tracker/issues/28
# TODO: batchify
for i in range(len(queries)):
queries_t = queries[i, : tracks.size(2), 0].to(torch.int64)
arange = torch.arange(0, len(queries_t))
# overwrite the predictions with the query points
tracks[i, queries_t, arange] = queries[i, : tracks.size(2), 1:]
# correct visibilities, the query points should be visible
visibilities[i, queries_t, arange] = True
tracks *= tracks.new_tensor(
[(W - 1) / (self.interp_shape[1] - 1), (H - 1) / (self.interp_shape[0] - 1)]
)
return tracks, visibilities
def _compute_backward_tracks(self, video, queries, tracks, visibilities):
inv_video = video.flip(1).clone()
inv_queries = queries.clone()
inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1
inv_tracks, inv_visibilities, __ = self.model(video=inv_video, queries=inv_queries, iters=6)
inv_tracks = inv_tracks.flip(1)
inv_visibilities = inv_visibilities.flip(1)
arange = torch.arange(video.shape[1], device=queries.device)[None, :, None]
mask = (arange < queries[:, None, :, 0]).unsqueeze(-1).repeat(1, 1, 1, 2)
tracks[mask] = inv_tracks[mask]
visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]]
return tracks, visibilities
class CoTrackerOnlinePredictor(torch.nn.Module):
def __init__(self, checkpoint="./checkpoints/cotracker2.pth"):
super().__init__()
self.support_grid_size = 6
model = build_cotracker(checkpoint)
self.interp_shape = model.model_resolution
self.step = model.window_len // 2
self.model = model
self.model.eval()
@torch.no_grad()
def forward(
self,
video_chunk,
is_first_step: bool = False,
queries: torch.Tensor = None,
grid_size: int = 10,
grid_query_frame: int = 0,
add_support_grid=False,
):
B, T, C, H, W = video_chunk.shape
# Initialize online video processing and save queried points
# This needs to be done before processing *each new video*
if is_first_step:
self.model.init_video_online_processing()
if queries is not None:
B, N, D = queries.shape
assert D == 3
queries = queries.clone()
queries[:, :, 1:] *= queries.new_tensor(
[
(self.interp_shape[1] - 1) / (W - 1),
(self.interp_shape[0] - 1) / (H - 1),
]
)
elif grid_size > 0:
grid_pts = get_points_on_a_grid(
grid_size, self.interp_shape, device=video_chunk.device
)
queries = torch.cat(
[torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts],
dim=2,
)
if add_support_grid:
grid_pts = get_points_on_a_grid(
self.support_grid_size, self.interp_shape, device=video_chunk.device
)
grid_pts = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)
queries = torch.cat([queries, grid_pts], dim=1)
self.queries = queries
return (None, None)
video_chunk = video_chunk.reshape(B * T, C, H, W)
video_chunk = F.interpolate(
video_chunk, tuple(self.interp_shape), mode="bilinear", align_corners=True
)
video_chunk = video_chunk.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
tracks, visibilities, __ = self.model(
video=video_chunk,
queries=self.queries,
iters=6,
is_online=True,
)
thr = 0.9
return (
tracks
* tracks.new_tensor(
[
(W - 1) / (self.interp_shape[1] - 1),
(H - 1) / (self.interp_shape[0] - 1),
]
),
visibilities > thr,
)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
from cotracker.models.core.model_utils import smart_cat, get_points_on_a_grid
from cotracker.models.build_cotracker import build_cotracker
class CoTrackerPredictor(torch.nn.Module):
def __init__(self, checkpoint="./checkpoints/cotracker2.pth"):
super().__init__()
self.support_grid_size = 6
model = build_cotracker(checkpoint)
self.interp_shape = model.model_resolution
print(self.interp_shape)
self.model = model
self.model.eval()
@torch.no_grad()
def forward(
self,
video, # (B, T, 3, H, W) Batch_size, time, rgb, height, width
# input prompt types:
# - None. Dense tracks are computed in this case. You can adjust *query_frame* to compute tracks starting from a specific frame.
# *backward_tracking=True* will compute tracks in both directions.
# - queries. Queried points of shape (B, N, 3) in format (t, x, y) for frame index and pixel coordinates.
# - grid_size. Grid of N*N points from the first frame. if segm_mask is provided, then computed only for the mask.
# You can adjust *query_frame* and *backward_tracking* for the regular grid in the same way as for dense tracks.
queries: torch.Tensor = None,
segm_mask: torch.Tensor = None, # Segmentation mask of shape (B, 1, H, W)
grid_size: int = 0,
grid_query_frame: int = 0, # only for dense and regular grid tracks
backward_tracking: bool = False,
):
if queries is None and grid_size == 0:
tracks, visibilities = self._compute_dense_tracks(
video,
grid_query_frame=grid_query_frame,
backward_tracking=backward_tracking,
)
else:
tracks, visibilities = self._compute_sparse_tracks(
video,
queries,
segm_mask,
grid_size,
add_support_grid=(grid_size == 0 or segm_mask is not None),
grid_query_frame=grid_query_frame,
backward_tracking=backward_tracking,
)
return tracks, visibilities
# gpu dense inference time
# raft gpu comparison
# vision effects
# raft integrated
def _compute_dense_tracks(self, video, grid_query_frame, grid_size=80, backward_tracking=False):
*_, H, W = video.shape
grid_step = W // grid_size
grid_width = W // grid_step
grid_height = H // grid_step # set the whole video to grid_size number of grids
tracks = visibilities = None
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(video.device)
# (batch_size, grid_number, t,x,y)
grid_pts[0, :, 0] = grid_query_frame
# iterate every grid
for offset in range(grid_step * grid_step):
print(f"step {offset} / {grid_step * grid_step}")
ox = offset % grid_step
oy = offset // grid_step
# initialize
# for example
# grid width = 4, grid height = 4, grid step = 10, ox = 1
# torch.arange(grid_width) = [0,1,2,3]
# torch.arange(grid_width).repeat(grid_height) = [0,1,2,3,0,1,2,3,0,1,2,3]
# torch.arange(grid_width).repeat(grid_height) * grid_step = [0,10,20,30,0,10,20,30,0,10,20,30]
# get the location in the image
grid_pts[0, :, 1] = torch.arange(grid_width).repeat(grid_height) * grid_step + ox
grid_pts[0, :, 2] = (
torch.arange(grid_height).repeat_interleave(grid_width) * grid_step + oy
)
tracks_step, visibilities_step = self._compute_sparse_tracks(
video=video,
queries=grid_pts,
backward_tracking=backward_tracking,
)
tracks = smart_cat(tracks, tracks_step, dim=2)
visibilities = smart_cat(visibilities, visibilities_step, dim=2)
return tracks, visibilities
def _compute_sparse_tracks(
self,
video,
queries,
segm_mask=None,
grid_size=0,
add_support_grid=False,
grid_query_frame=0,
backward_tracking=False,
):
B, T, C, H, W = video.shape
video = video.reshape(B * T, C, H, W)
# ? what is interpolate?
# 将video插值成interp_shape?
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True)
video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
if queries is not None:
B, N, D = queries.shape # batch_size, number of points, (t,x,y)
assert D == 3
# query 缩放到( interp_shape - 1 ) / (W - 1)
# 插完值之后缩放
queries = queries.clone()
queries[:, :, 1:] *= queries.new_tensor(
[
(self.interp_shape[1] - 1) / (W - 1),
(self.interp_shape[0] - 1) / (H - 1),
]
)
# 生成grid
elif grid_size > 0:
grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=video.device)
if segm_mask is not None:
segm_mask = F.interpolate(segm_mask, tuple(self.interp_shape), mode="nearest")
point_mask = segm_mask[0, 0][
(grid_pts[0, :, 1]).round().long().cpu(),
(grid_pts[0, :, 0]).round().long().cpu(),
].bool()
grid_pts = grid_pts[:, point_mask]
queries = torch.cat(
[torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts],
dim=2,
).repeat(B, 1, 1)
# 添加支持点
if add_support_grid:
grid_pts = get_points_on_a_grid(
self.support_grid_size, self.interp_shape, device=video.device
)
grid_pts = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)
grid_pts = grid_pts.repeat(B, 1, 1)
queries = torch.cat([queries, grid_pts], dim=1)
tracks, visibilities, __ = self.model.forward(video=video, queries=queries, iters=6)
if backward_tracking:
tracks, visibilities = self._compute_backward_tracks(
video, queries, tracks, visibilities
)
if add_support_grid:
queries[:, -self.support_grid_size**2 :, 0] = T - 1
if add_support_grid:
tracks = tracks[:, :, : -self.support_grid_size**2]
visibilities = visibilities[:, :, : -self.support_grid_size**2]
thr = 0.9
visibilities = visibilities > thr
# correct query-point predictions
# see https://github.com/facebookresearch/co-tracker/issues/28
# TODO: batchify
for i in range(len(queries)):
queries_t = queries[i, : tracks.size(2), 0].to(torch.int64)
arange = torch.arange(0, len(queries_t))
# overwrite the predictions with the query points
tracks[i, queries_t, arange] = queries[i, : tracks.size(2), 1:]
# correct visibilities, the query points should be visible
visibilities[i, queries_t, arange] = True
tracks *= tracks.new_tensor(
[(W - 1) / (self.interp_shape[1] - 1), (H - 1) / (self.interp_shape[0] - 1)]
)
return tracks, visibilities
def _compute_backward_tracks(self, video, queries, tracks, visibilities):
inv_video = video.flip(1).clone()
inv_queries = queries.clone()
inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1
inv_tracks, inv_visibilities, __ = self.model(video=inv_video, queries=inv_queries, iters=6)
inv_tracks = inv_tracks.flip(1)
inv_visibilities = inv_visibilities.flip(1)
arange = torch.arange(video.shape[1], device=queries.device)[None, :, None]
mask = (arange < queries[:, None, :, 0]).unsqueeze(-1).repeat(1, 1, 1, 2)
tracks[mask] = inv_tracks[mask]
visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]]
return tracks, visibilities
class CoTrackerOnlinePredictor(torch.nn.Module):
def __init__(self, checkpoint="./checkpoints/cotracker2.pth"):
super().__init__()
self.support_grid_size = 6
model = build_cotracker(checkpoint)
self.interp_shape = model.model_resolution
self.step = model.window_len // 2
self.model = model
self.model.eval()
@torch.no_grad()
def forward(
self,
video_chunk,
is_first_step: bool = False,
queries: torch.Tensor = None,
grid_size: int = 10,
grid_query_frame: int = 0,
add_support_grid=False,
):
B, T, C, H, W = video_chunk.shape
# Initialize online video processing and save queried points
# This needs to be done before processing *each new video*
if is_first_step:
self.model.init_video_online_processing()
if queries is not None:
B, N, D = queries.shape
assert D == 3
queries = queries.clone()
queries[:, :, 1:] *= queries.new_tensor(
[
(self.interp_shape[1] - 1) / (W - 1),
(self.interp_shape[0] - 1) / (H - 1),
]
)
elif grid_size > 0:
grid_pts = get_points_on_a_grid(
grid_size, self.interp_shape, device=video_chunk.device
)
queries = torch.cat(
[torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts],
dim=2,
)
if add_support_grid:
grid_pts = get_points_on_a_grid(
self.support_grid_size, self.interp_shape, device=video_chunk.device
)
grid_pts = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)
queries = torch.cat([queries, grid_pts], dim=1)
self.queries = queries
return (None, None)
video_chunk = video_chunk.reshape(B * T, C, H, W)
video_chunk = F.interpolate(
video_chunk, tuple(self.interp_shape), mode="bilinear", align_corners=True
)
video_chunk = video_chunk.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
tracks, visibilities, __ = self.model(
video=video_chunk,
queries=self.queries,
iters=6,
is_online=True,
)
thr = 0.9
return (
tracks
* tracks.new_tensor(
[
(W - 1) / (self.interp_shape[1] - 1),
(H - 1) / (self.interp_shape[0] - 1),
]
),
visibilities > thr,
)

View File

@@ -1,5 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -1,343 +1,343 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import numpy as np
import imageio
import torch
from matplotlib import cm
import torch.nn.functional as F
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw
def read_video_from_path(path):
try:
reader = imageio.get_reader(path)
except Exception as e:
print("Error opening video file: ", e)
return None
frames = []
for i, im in enumerate(reader):
frames.append(np.array(im))
return np.stack(frames)
def draw_circle(rgb, coord, radius, color=(255, 0, 0), visible=True):
# Create a draw object
draw = ImageDraw.Draw(rgb)
# Calculate the bounding box of the circle
left_up_point = (coord[0] - radius, coord[1] - radius)
right_down_point = (coord[0] + radius, coord[1] + radius)
# Draw the circle
draw.ellipse(
[left_up_point, right_down_point],
fill=tuple(color) if visible else None,
outline=tuple(color),
)
return rgb
def draw_line(rgb, coord_y, coord_x, color, linewidth):
draw = ImageDraw.Draw(rgb)
draw.line(
(coord_y[0], coord_y[1], coord_x[0], coord_x[1]),
fill=tuple(color),
width=linewidth,
)
return rgb
def add_weighted(rgb, alpha, original, beta, gamma):
return (rgb * alpha + original * beta + gamma).astype("uint8")
class Visualizer:
def __init__(
self,
save_dir: str = "./results",
grayscale: bool = False,
pad_value: int = 0,
fps: int = 10,
mode: str = "rainbow", # 'cool', 'optical_flow'
linewidth: int = 2,
show_first_frame: int = 10,
tracks_leave_trace: int = 0, # -1 for infinite
):
self.mode = mode
self.save_dir = save_dir
if mode == "rainbow":
self.color_map = cm.get_cmap("gist_rainbow")
elif mode == "cool":
self.color_map = cm.get_cmap(mode)
self.show_first_frame = show_first_frame
self.grayscale = grayscale
self.tracks_leave_trace = tracks_leave_trace
self.pad_value = pad_value
self.linewidth = linewidth
self.fps = fps
def visualize(
self,
video: torch.Tensor, # (B,T,C,H,W)
tracks: torch.Tensor, # (B,T,N,2)
visibility: torch.Tensor = None, # (B, T, N, 1) bool
gt_tracks: torch.Tensor = None, # (B,T,N,2)
segm_mask: torch.Tensor = None, # (B,1,H,W)
filename: str = "video",
writer=None, # tensorboard Summary Writer, used for visualization during training
step: int = 0,
query_frame: int = 0,
save_video: bool = True,
compensate_for_camera_motion: bool = False,
):
if compensate_for_camera_motion:
assert segm_mask is not None
if segm_mask is not None:
coords = tracks[0, query_frame].round().long()
segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long()
video = F.pad(
video,
(self.pad_value, self.pad_value, self.pad_value, self.pad_value),
"constant",
255,
)
tracks = tracks + self.pad_value
if self.grayscale:
transform = transforms.Grayscale()
video = transform(video)
video = video.repeat(1, 1, 3, 1, 1)
res_video = self.draw_tracks_on_video(
video=video,
tracks=tracks,
visibility=visibility,
segm_mask=segm_mask,
gt_tracks=gt_tracks,
query_frame=query_frame,
compensate_for_camera_motion=compensate_for_camera_motion,
)
if save_video:
self.save_video(res_video, filename=filename, writer=writer, step=step)
return res_video
def save_video(self, video, filename, writer=None, step=0):
if writer is not None:
writer.add_video(
filename,
video.to(torch.uint8),
global_step=step,
fps=self.fps,
)
else:
os.makedirs(self.save_dir, exist_ok=True)
wide_list = list(video.unbind(1))
wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list]
# Prepare the video file path
save_path = os.path.join(self.save_dir, f"{filename}.mp4")
# Create a writer object
video_writer = imageio.get_writer(save_path, fps=self.fps)
# Write frames to the video file
for frame in wide_list[2:-1]:
video_writer.append_data(frame)
video_writer.close()
print(f"Video saved to {save_path}")
def draw_tracks_on_video(
self,
video: torch.Tensor,
tracks: torch.Tensor,
visibility: torch.Tensor = None,
segm_mask: torch.Tensor = None,
gt_tracks=None,
query_frame: int = 0,
compensate_for_camera_motion=False,
):
B, T, C, H, W = video.shape
_, _, N, D = tracks.shape
assert D == 2
assert C == 3
video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy() # S, H, W, C
tracks = tracks[0].long().detach().cpu().numpy() # S, N, 2
if gt_tracks is not None:
gt_tracks = gt_tracks[0].detach().cpu().numpy()
res_video = []
# process input video
for rgb in video:
res_video.append(rgb.copy())
vector_colors = np.zeros((T, N, 3))
if self.mode == "optical_flow":
import flow_vis
vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None])
elif segm_mask is None:
if self.mode == "rainbow":
y_min, y_max = (
tracks[query_frame, :, 1].min(),
tracks[query_frame, :, 1].max(),
)
norm = plt.Normalize(y_min, y_max)
for n in range(N):
color = self.color_map(norm(tracks[query_frame, n, 1]))
color = np.array(color[:3])[None] * 255
vector_colors[:, n] = np.repeat(color, T, axis=0)
else:
# color changes with time
for t in range(T):
color = np.array(self.color_map(t / T)[:3])[None] * 255
vector_colors[t] = np.repeat(color, N, axis=0)
else:
if self.mode == "rainbow":
vector_colors[:, segm_mask <= 0, :] = 255
y_min, y_max = (
tracks[0, segm_mask > 0, 1].min(),
tracks[0, segm_mask > 0, 1].max(),
)
norm = plt.Normalize(y_min, y_max)
for n in range(N):
if segm_mask[n] > 0:
color = self.color_map(norm(tracks[0, n, 1]))
color = np.array(color[:3])[None] * 255
vector_colors[:, n] = np.repeat(color, T, axis=0)
else:
# color changes with segm class
segm_mask = segm_mask.cpu()
color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32)
color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0
color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0
vector_colors = np.repeat(color[None], T, axis=0)
# draw tracks
if self.tracks_leave_trace != 0:
for t in range(query_frame + 1, T):
first_ind = (
max(0, t - self.tracks_leave_trace) if self.tracks_leave_trace >= 0 else 0
)
curr_tracks = tracks[first_ind : t + 1]
curr_colors = vector_colors[first_ind : t + 1]
if compensate_for_camera_motion:
diff = (
tracks[first_ind : t + 1, segm_mask <= 0]
- tracks[t : t + 1, segm_mask <= 0]
).mean(1)[:, None]
curr_tracks = curr_tracks - diff
curr_tracks = curr_tracks[:, segm_mask > 0]
curr_colors = curr_colors[:, segm_mask > 0]
res_video[t] = self._draw_pred_tracks(
res_video[t],
curr_tracks,
curr_colors,
)
if gt_tracks is not None:
res_video[t] = self._draw_gt_tracks(res_video[t], gt_tracks[first_ind : t + 1])
# draw points
for t in range(query_frame, T):
img = Image.fromarray(np.uint8(res_video[t]))
for i in range(N):
coord = (tracks[t, i, 0], tracks[t, i, 1])
visibile = True
if visibility is not None:
visibile = visibility[0, t, i]
if coord[0] != 0 and coord[1] != 0:
if not compensate_for_camera_motion or (
compensate_for_camera_motion and segm_mask[i] > 0
):
img = draw_circle(
img,
coord=coord,
radius=int(self.linewidth * 2),
color=vector_colors[t, i].astype(int),
visible=visibile,
)
res_video[t] = np.array(img)
# construct the final rgb sequence
if self.show_first_frame > 0:
res_video = [res_video[0]] * self.show_first_frame + res_video[1:]
return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte()
def _draw_pred_tracks(
self,
rgb: np.ndarray, # H x W x 3
tracks: np.ndarray, # T x 2
vector_colors: np.ndarray,
alpha: float = 0.5,
):
T, N, _ = tracks.shape
rgb = Image.fromarray(np.uint8(rgb))
for s in range(T - 1):
vector_color = vector_colors[s]
original = rgb.copy()
alpha = (s / T) ** 2
for i in range(N):
coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1]))
coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1]))
if coord_y[0] != 0 and coord_y[1] != 0:
rgb = draw_line(
rgb,
coord_y,
coord_x,
vector_color[i].astype(int),
self.linewidth,
)
if self.tracks_leave_trace > 0:
rgb = Image.fromarray(
np.uint8(add_weighted(np.array(rgb), alpha, np.array(original), 1 - alpha, 0))
)
rgb = np.array(rgb)
return rgb
def _draw_gt_tracks(
self,
rgb: np.ndarray, # H x W x 3,
gt_tracks: np.ndarray, # T x 2
):
T, N, _ = gt_tracks.shape
color = np.array((211, 0, 0))
rgb = Image.fromarray(np.uint8(rgb))
for t in range(T):
for i in range(N):
gt_tracks = gt_tracks[t][i]
# draw a red cross
if gt_tracks[0] > 0 and gt_tracks[1] > 0:
length = self.linewidth * 3
coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length)
coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length)
rgb = draw_line(
rgb,
coord_y,
coord_x,
color,
self.linewidth,
)
coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length)
coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length)
rgb = draw_line(
rgb,
coord_y,
coord_x,
color,
self.linewidth,
)
rgb = np.array(rgb)
return rgb
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import numpy as np
import imageio
import torch
from matplotlib import cm
import torch.nn.functional as F
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw
def read_video_from_path(path):
try:
reader = imageio.get_reader(path)
except Exception as e:
print("Error opening video file: ", e)
return None
frames = []
for i, im in enumerate(reader):
frames.append(np.array(im))
return np.stack(frames)
def draw_circle(rgb, coord, radius, color=(255, 0, 0), visible=True):
# Create a draw object
draw = ImageDraw.Draw(rgb)
# Calculate the bounding box of the circle
left_up_point = (coord[0] - radius, coord[1] - radius)
right_down_point = (coord[0] + radius, coord[1] + radius)
# Draw the circle
draw.ellipse(
[left_up_point, right_down_point],
fill=tuple(color) if visible else None,
outline=tuple(color),
)
return rgb
def draw_line(rgb, coord_y, coord_x, color, linewidth):
draw = ImageDraw.Draw(rgb)
draw.line(
(coord_y[0], coord_y[1], coord_x[0], coord_x[1]),
fill=tuple(color),
width=linewidth,
)
return rgb
def add_weighted(rgb, alpha, original, beta, gamma):
return (rgb * alpha + original * beta + gamma).astype("uint8")
class Visualizer:
def __init__(
self,
save_dir: str = "./results",
grayscale: bool = False,
pad_value: int = 0,
fps: int = 10,
mode: str = "rainbow", # 'cool', 'optical_flow'
linewidth: int = 2,
show_first_frame: int = 10,
tracks_leave_trace: int = 0, # -1 for infinite
):
self.mode = mode
self.save_dir = save_dir
if mode == "rainbow":
self.color_map = cm.get_cmap("gist_rainbow")
elif mode == "cool":
self.color_map = cm.get_cmap(mode)
self.show_first_frame = show_first_frame
self.grayscale = grayscale
self.tracks_leave_trace = tracks_leave_trace
self.pad_value = pad_value
self.linewidth = linewidth
self.fps = fps
def visualize(
self,
video: torch.Tensor, # (B,T,C,H,W)
tracks: torch.Tensor, # (B,T,N,2)
visibility: torch.Tensor = None, # (B, T, N, 1) bool
gt_tracks: torch.Tensor = None, # (B,T,N,2)
segm_mask: torch.Tensor = None, # (B,1,H,W)
filename: str = "video",
writer=None, # tensorboard Summary Writer, used for visualization during training
step: int = 0,
query_frame: int = 0,
save_video: bool = True,
compensate_for_camera_motion: bool = False,
):
if compensate_for_camera_motion:
assert segm_mask is not None
if segm_mask is not None:
coords = tracks[0, query_frame].round().long()
segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long()
video = F.pad(
video,
(self.pad_value, self.pad_value, self.pad_value, self.pad_value),
"constant",
255,
)
tracks = tracks + self.pad_value
if self.grayscale:
transform = transforms.Grayscale()
video = transform(video)
video = video.repeat(1, 1, 3, 1, 1)
res_video = self.draw_tracks_on_video(
video=video,
tracks=tracks,
visibility=visibility,
segm_mask=segm_mask,
gt_tracks=gt_tracks,
query_frame=query_frame,
compensate_for_camera_motion=compensate_for_camera_motion,
)
if save_video:
self.save_video(res_video, filename=filename, writer=writer, step=step)
return res_video
def save_video(self, video, filename, writer=None, step=0):
if writer is not None:
writer.add_video(
filename,
video.to(torch.uint8),
global_step=step,
fps=self.fps,
)
else:
os.makedirs(self.save_dir, exist_ok=True)
wide_list = list(video.unbind(1))
wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list]
# Prepare the video file path
save_path = os.path.join(self.save_dir, f"{filename}.mp4")
# Create a writer object
video_writer = imageio.get_writer(save_path, fps=self.fps)
# Write frames to the video file
for frame in wide_list[2:-1]:
video_writer.append_data(frame)
video_writer.close()
print(f"Video saved to {save_path}")
def draw_tracks_on_video(
self,
video: torch.Tensor,
tracks: torch.Tensor,
visibility: torch.Tensor = None,
segm_mask: torch.Tensor = None,
gt_tracks=None,
query_frame: int = 0,
compensate_for_camera_motion=False,
):
B, T, C, H, W = video.shape
_, _, N, D = tracks.shape
assert D == 2
assert C == 3
video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy() # S, H, W, C
tracks = tracks[0].long().detach().cpu().numpy() # S, N, 2
if gt_tracks is not None:
gt_tracks = gt_tracks[0].detach().cpu().numpy()
res_video = []
# process input video
for rgb in video:
res_video.append(rgb.copy())
vector_colors = np.zeros((T, N, 3))
if self.mode == "optical_flow":
import flow_vis
vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None])
elif segm_mask is None:
if self.mode == "rainbow":
y_min, y_max = (
tracks[query_frame, :, 1].min(),
tracks[query_frame, :, 1].max(),
)
norm = plt.Normalize(y_min, y_max)
for n in range(N):
color = self.color_map(norm(tracks[query_frame, n, 1]))
color = np.array(color[:3])[None] * 255
vector_colors[:, n] = np.repeat(color, T, axis=0)
else:
# color changes with time
for t in range(T):
color = np.array(self.color_map(t / T)[:3])[None] * 255
vector_colors[t] = np.repeat(color, N, axis=0)
else:
if self.mode == "rainbow":
vector_colors[:, segm_mask <= 0, :] = 255
y_min, y_max = (
tracks[0, segm_mask > 0, 1].min(),
tracks[0, segm_mask > 0, 1].max(),
)
norm = plt.Normalize(y_min, y_max)
for n in range(N):
if segm_mask[n] > 0:
color = self.color_map(norm(tracks[0, n, 1]))
color = np.array(color[:3])[None] * 255
vector_colors[:, n] = np.repeat(color, T, axis=0)
else:
# color changes with segm class
segm_mask = segm_mask.cpu()
color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32)
color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0
color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0
vector_colors = np.repeat(color[None], T, axis=0)
# draw tracks
if self.tracks_leave_trace != 0:
for t in range(query_frame + 1, T):
first_ind = (
max(0, t - self.tracks_leave_trace) if self.tracks_leave_trace >= 0 else 0
)
curr_tracks = tracks[first_ind : t + 1]
curr_colors = vector_colors[first_ind : t + 1]
if compensate_for_camera_motion:
diff = (
tracks[first_ind : t + 1, segm_mask <= 0]
- tracks[t : t + 1, segm_mask <= 0]
).mean(1)[:, None]
curr_tracks = curr_tracks - diff
curr_tracks = curr_tracks[:, segm_mask > 0]
curr_colors = curr_colors[:, segm_mask > 0]
res_video[t] = self._draw_pred_tracks(
res_video[t],
curr_tracks,
curr_colors,
)
if gt_tracks is not None:
res_video[t] = self._draw_gt_tracks(res_video[t], gt_tracks[first_ind : t + 1])
# draw points
for t in range(query_frame, T):
img = Image.fromarray(np.uint8(res_video[t]))
for i in range(N):
coord = (tracks[t, i, 0], tracks[t, i, 1])
visibile = True
if visibility is not None:
visibile = visibility[0, t, i]
if coord[0] != 0 and coord[1] != 0:
if not compensate_for_camera_motion or (
compensate_for_camera_motion and segm_mask[i] > 0
):
img = draw_circle(
img,
coord=coord,
radius=int(self.linewidth * 2),
color=vector_colors[t, i].astype(int),
visible=visibile,
)
res_video[t] = np.array(img)
# construct the final rgb sequence
if self.show_first_frame > 0:
res_video = [res_video[0]] * self.show_first_frame + res_video[1:]
return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte()
def _draw_pred_tracks(
self,
rgb: np.ndarray, # H x W x 3
tracks: np.ndarray, # T x 2
vector_colors: np.ndarray,
alpha: float = 0.5,
):
T, N, _ = tracks.shape
rgb = Image.fromarray(np.uint8(rgb))
for s in range(T - 1):
vector_color = vector_colors[s]
original = rgb.copy()
alpha = (s / T) ** 2
for i in range(N):
coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1]))
coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1]))
if coord_y[0] != 0 and coord_y[1] != 0:
rgb = draw_line(
rgb,
coord_y,
coord_x,
vector_color[i].astype(int),
self.linewidth,
)
if self.tracks_leave_trace > 0:
rgb = Image.fromarray(
np.uint8(add_weighted(np.array(rgb), alpha, np.array(original), 1 - alpha, 0))
)
rgb = np.array(rgb)
return rgb
def _draw_gt_tracks(
self,
rgb: np.ndarray, # H x W x 3,
gt_tracks: np.ndarray, # T x 2
):
T, N, _ = gt_tracks.shape
color = np.array((211, 0, 0))
rgb = Image.fromarray(np.uint8(rgb))
for t in range(T):
for i in range(N):
gt_tracks = gt_tracks[t][i]
# draw a red cross
if gt_tracks[0] > 0 and gt_tracks[1] > 0:
length = self.linewidth * 3
coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length)
coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length)
rgb = draw_line(
rgb,
coord_y,
coord_x,
color,
self.linewidth,
)
coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length)
coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length)
rgb = draw_line(
rgb,
coord_y,
coord_x,
color,
self.linewidth,
)
rgb = np.array(rgb)
return rgb

View File

@@ -1,8 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
__version__ = "2.0.0"
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
__version__ = "2.0.0"

File diff suppressed because one or more lines are too long