add some comments
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
BIN
cotracker/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
cotracker/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
cotracker/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
cotracker/__pycache__/predictor.cpython-38.pyc
Normal file
BIN
cotracker/__pycache__/predictor.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/__pycache__/predictor.cpython-39.pyc
Normal file
BIN
cotracker/__pycache__/predictor.cpython-39.pyc
Normal file
Binary file not shown.
@@ -1,5 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
@@ -1,166 +1,166 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import json
|
||||
import dataclasses
|
||||
import numpy as np
|
||||
from dataclasses import Field, MISSING
|
||||
from typing import IO, TypeVar, Type, get_args, get_origin, Union, Any, Tuple
|
||||
|
||||
_X = TypeVar("_X")
|
||||
|
||||
|
||||
def load_dataclass(f: IO, cls: Type[_X], binary: bool = False) -> _X:
|
||||
"""
|
||||
Loads to a @dataclass or collection hierarchy including dataclasses
|
||||
from a json recursively.
|
||||
Call it like load_dataclass(f, typing.List[FrameAnnotationAnnotation]).
|
||||
raises KeyError if json has keys not mapping to the dataclass fields.
|
||||
|
||||
Args:
|
||||
f: Either a path to a file, or a file opened for writing.
|
||||
cls: The class of the loaded dataclass.
|
||||
binary: Set to True if `f` is a file handle, else False.
|
||||
"""
|
||||
if binary:
|
||||
asdict = json.loads(f.read().decode("utf8"))
|
||||
else:
|
||||
asdict = json.load(f)
|
||||
|
||||
# in the list case, run a faster "vectorized" version
|
||||
cls = get_args(cls)[0]
|
||||
res = list(_dataclass_list_from_dict_list(asdict, cls))
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def _resolve_optional(type_: Any) -> Tuple[bool, Any]:
|
||||
"""Check whether `type_` is equivalent to `typing.Optional[T]` for some T."""
|
||||
if get_origin(type_) is Union:
|
||||
args = get_args(type_)
|
||||
if len(args) == 2 and args[1] == type(None): # noqa E721
|
||||
return True, args[0]
|
||||
if type_ is Any:
|
||||
return True, Any
|
||||
|
||||
return False, type_
|
||||
|
||||
|
||||
def _unwrap_type(tp):
|
||||
# strips Optional wrapper, if any
|
||||
if get_origin(tp) is Union:
|
||||
args = get_args(tp)
|
||||
if len(args) == 2 and any(a is type(None) for a in args): # noqa: E721
|
||||
# this is typing.Optional
|
||||
return args[0] if args[1] is type(None) else args[1] # noqa: E721
|
||||
return tp
|
||||
|
||||
|
||||
def _get_dataclass_field_default(field: Field) -> Any:
|
||||
if field.default_factory is not MISSING:
|
||||
# pyre-fixme[29]: `Union[dataclasses._MISSING_TYPE,
|
||||
# dataclasses._DefaultFactory[typing.Any]]` is not a function.
|
||||
return field.default_factory()
|
||||
elif field.default is not MISSING:
|
||||
return field.default
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def _dataclass_list_from_dict_list(dlist, typeannot):
|
||||
"""
|
||||
Vectorised version of `_dataclass_from_dict`.
|
||||
The output should be equivalent to
|
||||
`[_dataclass_from_dict(d, typeannot) for d in dlist]`.
|
||||
|
||||
Args:
|
||||
dlist: list of objects to convert.
|
||||
typeannot: type of each of those objects.
|
||||
Returns:
|
||||
iterator or list over converted objects of the same length as `dlist`.
|
||||
|
||||
Raises:
|
||||
ValueError: it assumes the objects have None's in consistent places across
|
||||
objects, otherwise it would ignore some values. This generally holds for
|
||||
auto-generated annotations, but otherwise use `_dataclass_from_dict`.
|
||||
"""
|
||||
|
||||
cls = get_origin(typeannot) or typeannot
|
||||
|
||||
if typeannot is Any:
|
||||
return dlist
|
||||
if all(obj is None for obj in dlist): # 1st recursion base: all None nodes
|
||||
return dlist
|
||||
if any(obj is None for obj in dlist):
|
||||
# filter out Nones and recurse on the resulting list
|
||||
idx_notnone = [(i, obj) for i, obj in enumerate(dlist) if obj is not None]
|
||||
idx, notnone = zip(*idx_notnone)
|
||||
converted = _dataclass_list_from_dict_list(notnone, typeannot)
|
||||
res = [None] * len(dlist)
|
||||
for i, obj in zip(idx, converted):
|
||||
res[i] = obj
|
||||
return res
|
||||
|
||||
is_optional, contained_type = _resolve_optional(typeannot)
|
||||
if is_optional:
|
||||
return _dataclass_list_from_dict_list(dlist, contained_type)
|
||||
|
||||
# otherwise, we dispatch by the type of the provided annotation to convert to
|
||||
if issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
|
||||
# For namedtuple, call the function recursively on the lists of corresponding keys
|
||||
types = cls.__annotations__.values()
|
||||
dlist_T = zip(*dlist)
|
||||
res_T = [
|
||||
_dataclass_list_from_dict_list(key_list, tp) for key_list, tp in zip(dlist_T, types)
|
||||
]
|
||||
return [cls(*converted_as_tuple) for converted_as_tuple in zip(*res_T)]
|
||||
elif issubclass(cls, (list, tuple)):
|
||||
# For list/tuple, call the function recursively on the lists of corresponding positions
|
||||
types = get_args(typeannot)
|
||||
if len(types) == 1: # probably List; replicate for all items
|
||||
types = types * len(dlist[0])
|
||||
dlist_T = zip(*dlist)
|
||||
res_T = (
|
||||
_dataclass_list_from_dict_list(pos_list, tp) for pos_list, tp in zip(dlist_T, types)
|
||||
)
|
||||
if issubclass(cls, tuple):
|
||||
return list(zip(*res_T))
|
||||
else:
|
||||
return [cls(converted_as_tuple) for converted_as_tuple in zip(*res_T)]
|
||||
elif issubclass(cls, dict):
|
||||
# For the dictionary, call the function recursively on concatenated keys and vertices
|
||||
key_t, val_t = get_args(typeannot)
|
||||
all_keys_res = _dataclass_list_from_dict_list(
|
||||
[k for obj in dlist for k in obj.keys()], key_t
|
||||
)
|
||||
all_vals_res = _dataclass_list_from_dict_list(
|
||||
[k for obj in dlist for k in obj.values()], val_t
|
||||
)
|
||||
indices = np.cumsum([len(obj) for obj in dlist])
|
||||
assert indices[-1] == len(all_keys_res)
|
||||
|
||||
keys = np.split(list(all_keys_res), indices[:-1])
|
||||
all_vals_res_iter = iter(all_vals_res)
|
||||
return [cls(zip(k, all_vals_res_iter)) for k in keys]
|
||||
elif not dataclasses.is_dataclass(typeannot):
|
||||
return dlist
|
||||
|
||||
# dataclass node: 2nd recursion base; call the function recursively on the lists
|
||||
# of the corresponding fields
|
||||
assert dataclasses.is_dataclass(cls)
|
||||
fieldtypes = {
|
||||
f.name: (_unwrap_type(f.type), _get_dataclass_field_default(f))
|
||||
for f in dataclasses.fields(typeannot)
|
||||
}
|
||||
|
||||
# NOTE the default object is shared here
|
||||
key_lists = (
|
||||
_dataclass_list_from_dict_list([obj.get(k, default) for obj in dlist], type_)
|
||||
for k, (type_, default) in fieldtypes.items()
|
||||
)
|
||||
transposed = zip(*key_lists)
|
||||
return [cls(*vals_as_tuple) for vals_as_tuple in transposed]
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import json
|
||||
import dataclasses
|
||||
import numpy as np
|
||||
from dataclasses import Field, MISSING
|
||||
from typing import IO, TypeVar, Type, get_args, get_origin, Union, Any, Tuple
|
||||
|
||||
_X = TypeVar("_X")
|
||||
|
||||
|
||||
def load_dataclass(f: IO, cls: Type[_X], binary: bool = False) -> _X:
|
||||
"""
|
||||
Loads to a @dataclass or collection hierarchy including dataclasses
|
||||
from a json recursively.
|
||||
Call it like load_dataclass(f, typing.List[FrameAnnotationAnnotation]).
|
||||
raises KeyError if json has keys not mapping to the dataclass fields.
|
||||
|
||||
Args:
|
||||
f: Either a path to a file, or a file opened for writing.
|
||||
cls: The class of the loaded dataclass.
|
||||
binary: Set to True if `f` is a file handle, else False.
|
||||
"""
|
||||
if binary:
|
||||
asdict = json.loads(f.read().decode("utf8"))
|
||||
else:
|
||||
asdict = json.load(f)
|
||||
|
||||
# in the list case, run a faster "vectorized" version
|
||||
cls = get_args(cls)[0]
|
||||
res = list(_dataclass_list_from_dict_list(asdict, cls))
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def _resolve_optional(type_: Any) -> Tuple[bool, Any]:
|
||||
"""Check whether `type_` is equivalent to `typing.Optional[T]` for some T."""
|
||||
if get_origin(type_) is Union:
|
||||
args = get_args(type_)
|
||||
if len(args) == 2 and args[1] == type(None): # noqa E721
|
||||
return True, args[0]
|
||||
if type_ is Any:
|
||||
return True, Any
|
||||
|
||||
return False, type_
|
||||
|
||||
|
||||
def _unwrap_type(tp):
|
||||
# strips Optional wrapper, if any
|
||||
if get_origin(tp) is Union:
|
||||
args = get_args(tp)
|
||||
if len(args) == 2 and any(a is type(None) for a in args): # noqa: E721
|
||||
# this is typing.Optional
|
||||
return args[0] if args[1] is type(None) else args[1] # noqa: E721
|
||||
return tp
|
||||
|
||||
|
||||
def _get_dataclass_field_default(field: Field) -> Any:
|
||||
if field.default_factory is not MISSING:
|
||||
# pyre-fixme[29]: `Union[dataclasses._MISSING_TYPE,
|
||||
# dataclasses._DefaultFactory[typing.Any]]` is not a function.
|
||||
return field.default_factory()
|
||||
elif field.default is not MISSING:
|
||||
return field.default
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def _dataclass_list_from_dict_list(dlist, typeannot):
|
||||
"""
|
||||
Vectorised version of `_dataclass_from_dict`.
|
||||
The output should be equivalent to
|
||||
`[_dataclass_from_dict(d, typeannot) for d in dlist]`.
|
||||
|
||||
Args:
|
||||
dlist: list of objects to convert.
|
||||
typeannot: type of each of those objects.
|
||||
Returns:
|
||||
iterator or list over converted objects of the same length as `dlist`.
|
||||
|
||||
Raises:
|
||||
ValueError: it assumes the objects have None's in consistent places across
|
||||
objects, otherwise it would ignore some values. This generally holds for
|
||||
auto-generated annotations, but otherwise use `_dataclass_from_dict`.
|
||||
"""
|
||||
|
||||
cls = get_origin(typeannot) or typeannot
|
||||
|
||||
if typeannot is Any:
|
||||
return dlist
|
||||
if all(obj is None for obj in dlist): # 1st recursion base: all None nodes
|
||||
return dlist
|
||||
if any(obj is None for obj in dlist):
|
||||
# filter out Nones and recurse on the resulting list
|
||||
idx_notnone = [(i, obj) for i, obj in enumerate(dlist) if obj is not None]
|
||||
idx, notnone = zip(*idx_notnone)
|
||||
converted = _dataclass_list_from_dict_list(notnone, typeannot)
|
||||
res = [None] * len(dlist)
|
||||
for i, obj in zip(idx, converted):
|
||||
res[i] = obj
|
||||
return res
|
||||
|
||||
is_optional, contained_type = _resolve_optional(typeannot)
|
||||
if is_optional:
|
||||
return _dataclass_list_from_dict_list(dlist, contained_type)
|
||||
|
||||
# otherwise, we dispatch by the type of the provided annotation to convert to
|
||||
if issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
|
||||
# For namedtuple, call the function recursively on the lists of corresponding keys
|
||||
types = cls.__annotations__.values()
|
||||
dlist_T = zip(*dlist)
|
||||
res_T = [
|
||||
_dataclass_list_from_dict_list(key_list, tp) for key_list, tp in zip(dlist_T, types)
|
||||
]
|
||||
return [cls(*converted_as_tuple) for converted_as_tuple in zip(*res_T)]
|
||||
elif issubclass(cls, (list, tuple)):
|
||||
# For list/tuple, call the function recursively on the lists of corresponding positions
|
||||
types = get_args(typeannot)
|
||||
if len(types) == 1: # probably List; replicate for all items
|
||||
types = types * len(dlist[0])
|
||||
dlist_T = zip(*dlist)
|
||||
res_T = (
|
||||
_dataclass_list_from_dict_list(pos_list, tp) for pos_list, tp in zip(dlist_T, types)
|
||||
)
|
||||
if issubclass(cls, tuple):
|
||||
return list(zip(*res_T))
|
||||
else:
|
||||
return [cls(converted_as_tuple) for converted_as_tuple in zip(*res_T)]
|
||||
elif issubclass(cls, dict):
|
||||
# For the dictionary, call the function recursively on concatenated keys and vertices
|
||||
key_t, val_t = get_args(typeannot)
|
||||
all_keys_res = _dataclass_list_from_dict_list(
|
||||
[k for obj in dlist for k in obj.keys()], key_t
|
||||
)
|
||||
all_vals_res = _dataclass_list_from_dict_list(
|
||||
[k for obj in dlist for k in obj.values()], val_t
|
||||
)
|
||||
indices = np.cumsum([len(obj) for obj in dlist])
|
||||
assert indices[-1] == len(all_keys_res)
|
||||
|
||||
keys = np.split(list(all_keys_res), indices[:-1])
|
||||
all_vals_res_iter = iter(all_vals_res)
|
||||
return [cls(zip(k, all_vals_res_iter)) for k in keys]
|
||||
elif not dataclasses.is_dataclass(typeannot):
|
||||
return dlist
|
||||
|
||||
# dataclass node: 2nd recursion base; call the function recursively on the lists
|
||||
# of the corresponding fields
|
||||
assert dataclasses.is_dataclass(cls)
|
||||
fieldtypes = {
|
||||
f.name: (_unwrap_type(f.type), _get_dataclass_field_default(f))
|
||||
for f in dataclasses.fields(typeannot)
|
||||
}
|
||||
|
||||
# NOTE the default object is shared here
|
||||
key_lists = (
|
||||
_dataclass_list_from_dict_list([obj.get(k, default) for obj in dlist], type_)
|
||||
for k, (type_, default) in fieldtypes.items()
|
||||
)
|
||||
transposed = zip(*key_lists)
|
||||
return [cls(*vals_as_tuple) for vals_as_tuple in transposed]
|
||||
|
@@ -1,161 +1,161 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import os
|
||||
import gzip
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.utils.data as data
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Any, Dict, Tuple
|
||||
|
||||
from cotracker.datasets.utils import CoTrackerData
|
||||
from cotracker.datasets.dataclass_utils import load_dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageAnnotation:
|
||||
# path to jpg file, relative w.r.t. dataset_root
|
||||
path: str
|
||||
# H x W
|
||||
size: Tuple[int, int]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DynamicReplicaFrameAnnotation:
|
||||
"""A dataclass used to load annotations from json."""
|
||||
|
||||
# can be used to join with `SequenceAnnotation`
|
||||
sequence_name: str
|
||||
# 0-based, continuous frame number within sequence
|
||||
frame_number: int
|
||||
# timestamp in seconds from the video start
|
||||
frame_timestamp: float
|
||||
|
||||
image: ImageAnnotation
|
||||
meta: Optional[Dict[str, Any]] = None
|
||||
|
||||
camera_name: Optional[str] = None
|
||||
trajectories: Optional[str] = None
|
||||
|
||||
|
||||
class DynamicReplicaDataset(data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
root,
|
||||
split="valid",
|
||||
traj_per_sample=256,
|
||||
crop_size=None,
|
||||
sample_len=-1,
|
||||
only_first_n_samples=-1,
|
||||
rgbd_input=False,
|
||||
):
|
||||
super(DynamicReplicaDataset, self).__init__()
|
||||
self.root = root
|
||||
self.sample_len = sample_len
|
||||
self.split = split
|
||||
self.traj_per_sample = traj_per_sample
|
||||
self.rgbd_input = rgbd_input
|
||||
self.crop_size = crop_size
|
||||
frame_annotations_file = f"frame_annotations_{split}.jgz"
|
||||
self.sample_list = []
|
||||
with gzip.open(
|
||||
os.path.join(root, split, frame_annotations_file), "rt", encoding="utf8"
|
||||
) as zipfile:
|
||||
frame_annots_list = load_dataclass(zipfile, List[DynamicReplicaFrameAnnotation])
|
||||
seq_annot = defaultdict(list)
|
||||
for frame_annot in frame_annots_list:
|
||||
if frame_annot.camera_name == "left":
|
||||
seq_annot[frame_annot.sequence_name].append(frame_annot)
|
||||
|
||||
for seq_name in seq_annot.keys():
|
||||
seq_len = len(seq_annot[seq_name])
|
||||
|
||||
step = self.sample_len if self.sample_len > 0 else seq_len
|
||||
counter = 0
|
||||
|
||||
for ref_idx in range(0, seq_len, step):
|
||||
sample = seq_annot[seq_name][ref_idx : ref_idx + step]
|
||||
self.sample_list.append(sample)
|
||||
counter += 1
|
||||
if only_first_n_samples > 0 and counter >= only_first_n_samples:
|
||||
break
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sample_list)
|
||||
|
||||
def crop(self, rgbs, trajs):
|
||||
T, N, _ = trajs.shape
|
||||
|
||||
S = len(rgbs)
|
||||
H, W = rgbs[0].shape[:2]
|
||||
assert S == T
|
||||
|
||||
H_new = H
|
||||
W_new = W
|
||||
|
||||
# simple random crop
|
||||
y0 = 0 if self.crop_size[0] >= H_new else (H_new - self.crop_size[0]) // 2
|
||||
x0 = 0 if self.crop_size[1] >= W_new else (W_new - self.crop_size[1]) // 2
|
||||
rgbs = [rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] for rgb in rgbs]
|
||||
|
||||
trajs[:, :, 0] -= x0
|
||||
trajs[:, :, 1] -= y0
|
||||
|
||||
return rgbs, trajs
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.sample_list[index]
|
||||
T = len(sample)
|
||||
rgbs, visibilities, traj_2d = [], [], []
|
||||
|
||||
H, W = sample[0].image.size
|
||||
image_size = (H, W)
|
||||
|
||||
for i in range(T):
|
||||
traj_path = os.path.join(self.root, self.split, sample[i].trajectories["path"])
|
||||
traj = torch.load(traj_path)
|
||||
|
||||
visibilities.append(traj["verts_inds_vis"].numpy())
|
||||
|
||||
rgbs.append(traj["img"].numpy())
|
||||
traj_2d.append(traj["traj_2d"].numpy()[..., :2])
|
||||
|
||||
traj_2d = np.stack(traj_2d)
|
||||
visibility = np.stack(visibilities)
|
||||
T, N, D = traj_2d.shape
|
||||
# subsample trajectories for augmentations
|
||||
visible_inds_sampled = torch.randperm(N)[: self.traj_per_sample]
|
||||
|
||||
traj_2d = traj_2d[:, visible_inds_sampled]
|
||||
visibility = visibility[:, visible_inds_sampled]
|
||||
|
||||
if self.crop_size is not None:
|
||||
rgbs, traj_2d = self.crop(rgbs, traj_2d)
|
||||
H, W, _ = rgbs[0].shape
|
||||
image_size = self.crop_size
|
||||
|
||||
visibility[traj_2d[:, :, 0] > image_size[1] - 1] = False
|
||||
visibility[traj_2d[:, :, 0] < 0] = False
|
||||
visibility[traj_2d[:, :, 1] > image_size[0] - 1] = False
|
||||
visibility[traj_2d[:, :, 1] < 0] = False
|
||||
|
||||
# filter out points that're visible for less than 10 frames
|
||||
visible_inds_resampled = visibility.sum(0) > 10
|
||||
traj_2d = torch.from_numpy(traj_2d[:, visible_inds_resampled])
|
||||
visibility = torch.from_numpy(visibility[:, visible_inds_resampled])
|
||||
|
||||
rgbs = np.stack(rgbs, 0)
|
||||
video = torch.from_numpy(rgbs).reshape(T, H, W, 3).permute(0, 3, 1, 2).float()
|
||||
return CoTrackerData(
|
||||
video=video,
|
||||
trajectory=traj_2d,
|
||||
visibility=visibility,
|
||||
valid=torch.ones(T, N),
|
||||
seq_name=sample[0].sequence_name,
|
||||
)
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import os
|
||||
import gzip
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.utils.data as data
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Any, Dict, Tuple
|
||||
|
||||
from cotracker.datasets.utils import CoTrackerData
|
||||
from cotracker.datasets.dataclass_utils import load_dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageAnnotation:
|
||||
# path to jpg file, relative w.r.t. dataset_root
|
||||
path: str
|
||||
# H x W
|
||||
size: Tuple[int, int]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DynamicReplicaFrameAnnotation:
|
||||
"""A dataclass used to load annotations from json."""
|
||||
|
||||
# can be used to join with `SequenceAnnotation`
|
||||
sequence_name: str
|
||||
# 0-based, continuous frame number within sequence
|
||||
frame_number: int
|
||||
# timestamp in seconds from the video start
|
||||
frame_timestamp: float
|
||||
|
||||
image: ImageAnnotation
|
||||
meta: Optional[Dict[str, Any]] = None
|
||||
|
||||
camera_name: Optional[str] = None
|
||||
trajectories: Optional[str] = None
|
||||
|
||||
|
||||
class DynamicReplicaDataset(data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
root,
|
||||
split="valid",
|
||||
traj_per_sample=256,
|
||||
crop_size=None,
|
||||
sample_len=-1,
|
||||
only_first_n_samples=-1,
|
||||
rgbd_input=False,
|
||||
):
|
||||
super(DynamicReplicaDataset, self).__init__()
|
||||
self.root = root
|
||||
self.sample_len = sample_len
|
||||
self.split = split
|
||||
self.traj_per_sample = traj_per_sample
|
||||
self.rgbd_input = rgbd_input
|
||||
self.crop_size = crop_size
|
||||
frame_annotations_file = f"frame_annotations_{split}.jgz"
|
||||
self.sample_list = []
|
||||
with gzip.open(
|
||||
os.path.join(root, split, frame_annotations_file), "rt", encoding="utf8"
|
||||
) as zipfile:
|
||||
frame_annots_list = load_dataclass(zipfile, List[DynamicReplicaFrameAnnotation])
|
||||
seq_annot = defaultdict(list)
|
||||
for frame_annot in frame_annots_list:
|
||||
if frame_annot.camera_name == "left":
|
||||
seq_annot[frame_annot.sequence_name].append(frame_annot)
|
||||
|
||||
for seq_name in seq_annot.keys():
|
||||
seq_len = len(seq_annot[seq_name])
|
||||
|
||||
step = self.sample_len if self.sample_len > 0 else seq_len
|
||||
counter = 0
|
||||
|
||||
for ref_idx in range(0, seq_len, step):
|
||||
sample = seq_annot[seq_name][ref_idx : ref_idx + step]
|
||||
self.sample_list.append(sample)
|
||||
counter += 1
|
||||
if only_first_n_samples > 0 and counter >= only_first_n_samples:
|
||||
break
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sample_list)
|
||||
|
||||
def crop(self, rgbs, trajs):
|
||||
T, N, _ = trajs.shape
|
||||
|
||||
S = len(rgbs)
|
||||
H, W = rgbs[0].shape[:2]
|
||||
assert S == T
|
||||
|
||||
H_new = H
|
||||
W_new = W
|
||||
|
||||
# simple random crop
|
||||
y0 = 0 if self.crop_size[0] >= H_new else (H_new - self.crop_size[0]) // 2
|
||||
x0 = 0 if self.crop_size[1] >= W_new else (W_new - self.crop_size[1]) // 2
|
||||
rgbs = [rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] for rgb in rgbs]
|
||||
|
||||
trajs[:, :, 0] -= x0
|
||||
trajs[:, :, 1] -= y0
|
||||
|
||||
return rgbs, trajs
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.sample_list[index]
|
||||
T = len(sample)
|
||||
rgbs, visibilities, traj_2d = [], [], []
|
||||
|
||||
H, W = sample[0].image.size
|
||||
image_size = (H, W)
|
||||
|
||||
for i in range(T):
|
||||
traj_path = os.path.join(self.root, self.split, sample[i].trajectories["path"])
|
||||
traj = torch.load(traj_path)
|
||||
|
||||
visibilities.append(traj["verts_inds_vis"].numpy())
|
||||
|
||||
rgbs.append(traj["img"].numpy())
|
||||
traj_2d.append(traj["traj_2d"].numpy()[..., :2])
|
||||
|
||||
traj_2d = np.stack(traj_2d)
|
||||
visibility = np.stack(visibilities)
|
||||
T, N, D = traj_2d.shape
|
||||
# subsample trajectories for augmentations
|
||||
visible_inds_sampled = torch.randperm(N)[: self.traj_per_sample]
|
||||
|
||||
traj_2d = traj_2d[:, visible_inds_sampled]
|
||||
visibility = visibility[:, visible_inds_sampled]
|
||||
|
||||
if self.crop_size is not None:
|
||||
rgbs, traj_2d = self.crop(rgbs, traj_2d)
|
||||
H, W, _ = rgbs[0].shape
|
||||
image_size = self.crop_size
|
||||
|
||||
visibility[traj_2d[:, :, 0] > image_size[1] - 1] = False
|
||||
visibility[traj_2d[:, :, 0] < 0] = False
|
||||
visibility[traj_2d[:, :, 1] > image_size[0] - 1] = False
|
||||
visibility[traj_2d[:, :, 1] < 0] = False
|
||||
|
||||
# filter out points that're visible for less than 10 frames
|
||||
visible_inds_resampled = visibility.sum(0) > 10
|
||||
traj_2d = torch.from_numpy(traj_2d[:, visible_inds_resampled])
|
||||
visibility = torch.from_numpy(visibility[:, visible_inds_resampled])
|
||||
|
||||
rgbs = np.stack(rgbs, 0)
|
||||
video = torch.from_numpy(rgbs).reshape(T, H, W, 3).permute(0, 3, 1, 2).float()
|
||||
return CoTrackerData(
|
||||
video=video,
|
||||
trajectory=traj_2d,
|
||||
visibility=visibility,
|
||||
valid=torch.ones(T, N),
|
||||
seq_name=sample[0].sequence_name,
|
||||
)
|
||||
|
@@ -1,441 +1,441 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import torch
|
||||
import cv2
|
||||
|
||||
import imageio
|
||||
import numpy as np
|
||||
|
||||
from cotracker.datasets.utils import CoTrackerData
|
||||
from torchvision.transforms import ColorJitter, GaussianBlur
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class CoTrackerDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_root,
|
||||
crop_size=(384, 512),
|
||||
seq_len=24,
|
||||
traj_per_sample=768,
|
||||
sample_vis_1st_frame=False,
|
||||
use_augs=False,
|
||||
):
|
||||
super(CoTrackerDataset, self).__init__()
|
||||
np.random.seed(0)
|
||||
torch.manual_seed(0)
|
||||
self.data_root = data_root
|
||||
self.seq_len = seq_len
|
||||
self.traj_per_sample = traj_per_sample
|
||||
self.sample_vis_1st_frame = sample_vis_1st_frame
|
||||
self.use_augs = use_augs
|
||||
self.crop_size = crop_size
|
||||
|
||||
# photometric augmentation
|
||||
self.photo_aug = ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.25 / 3.14)
|
||||
self.blur_aug = GaussianBlur(11, sigma=(0.1, 2.0))
|
||||
|
||||
self.blur_aug_prob = 0.25
|
||||
self.color_aug_prob = 0.25
|
||||
|
||||
# occlusion augmentation
|
||||
self.eraser_aug_prob = 0.5
|
||||
self.eraser_bounds = [2, 100]
|
||||
self.eraser_max = 10
|
||||
|
||||
# occlusion augmentation
|
||||
self.replace_aug_prob = 0.5
|
||||
self.replace_bounds = [2, 100]
|
||||
self.replace_max = 10
|
||||
|
||||
# spatial augmentations
|
||||
self.pad_bounds = [0, 100]
|
||||
self.crop_size = crop_size
|
||||
self.resize_lim = [0.25, 2.0] # sample resizes from here
|
||||
self.resize_delta = 0.2
|
||||
self.max_crop_offset = 50
|
||||
|
||||
self.do_flip = True
|
||||
self.h_flip_prob = 0.5
|
||||
self.v_flip_prob = 0.5
|
||||
|
||||
def getitem_helper(self, index):
|
||||
return NotImplementedError
|
||||
|
||||
def __getitem__(self, index):
|
||||
gotit = False
|
||||
|
||||
sample, gotit = self.getitem_helper(index)
|
||||
if not gotit:
|
||||
print("warning: sampling failed")
|
||||
# fake sample, so we can still collate
|
||||
sample = CoTrackerData(
|
||||
video=torch.zeros((self.seq_len, 3, self.crop_size[0], self.crop_size[1])),
|
||||
trajectory=torch.zeros((self.seq_len, self.traj_per_sample, 2)),
|
||||
visibility=torch.zeros((self.seq_len, self.traj_per_sample)),
|
||||
valid=torch.zeros((self.seq_len, self.traj_per_sample)),
|
||||
)
|
||||
|
||||
return sample, gotit
|
||||
|
||||
def add_photometric_augs(self, rgbs, trajs, visibles, eraser=True, replace=True):
|
||||
T, N, _ = trajs.shape
|
||||
|
||||
S = len(rgbs)
|
||||
H, W = rgbs[0].shape[:2]
|
||||
assert S == T
|
||||
|
||||
if eraser:
|
||||
############ eraser transform (per image after the first) ############
|
||||
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
|
||||
for i in range(1, S):
|
||||
if np.random.rand() < self.eraser_aug_prob:
|
||||
for _ in range(
|
||||
np.random.randint(1, self.eraser_max + 1)
|
||||
): # number of times to occlude
|
||||
xc = np.random.randint(0, W)
|
||||
yc = np.random.randint(0, H)
|
||||
dx = np.random.randint(self.eraser_bounds[0], self.eraser_bounds[1])
|
||||
dy = np.random.randint(self.eraser_bounds[0], self.eraser_bounds[1])
|
||||
x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)
|
||||
x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)
|
||||
y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)
|
||||
y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32)
|
||||
|
||||
mean_color = np.mean(rgbs[i][y0:y1, x0:x1, :].reshape(-1, 3), axis=0)
|
||||
rgbs[i][y0:y1, x0:x1, :] = mean_color
|
||||
|
||||
occ_inds = np.logical_and(
|
||||
np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1),
|
||||
np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1),
|
||||
)
|
||||
visibles[i, occ_inds] = 0
|
||||
rgbs = [rgb.astype(np.uint8) for rgb in rgbs]
|
||||
|
||||
if replace:
|
||||
rgbs_alt = [
|
||||
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs
|
||||
]
|
||||
rgbs_alt = [
|
||||
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs_alt
|
||||
]
|
||||
|
||||
############ replace transform (per image after the first) ############
|
||||
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
|
||||
rgbs_alt = [rgb.astype(np.float32) for rgb in rgbs_alt]
|
||||
for i in range(1, S):
|
||||
if np.random.rand() < self.replace_aug_prob:
|
||||
for _ in range(
|
||||
np.random.randint(1, self.replace_max + 1)
|
||||
): # number of times to occlude
|
||||
xc = np.random.randint(0, W)
|
||||
yc = np.random.randint(0, H)
|
||||
dx = np.random.randint(self.replace_bounds[0], self.replace_bounds[1])
|
||||
dy = np.random.randint(self.replace_bounds[0], self.replace_bounds[1])
|
||||
x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)
|
||||
x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)
|
||||
y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)
|
||||
y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32)
|
||||
|
||||
wid = x1 - x0
|
||||
hei = y1 - y0
|
||||
y00 = np.random.randint(0, H - hei)
|
||||
x00 = np.random.randint(0, W - wid)
|
||||
fr = np.random.randint(0, S)
|
||||
rep = rgbs_alt[fr][y00 : y00 + hei, x00 : x00 + wid, :]
|
||||
rgbs[i][y0:y1, x0:x1, :] = rep
|
||||
|
||||
occ_inds = np.logical_and(
|
||||
np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1),
|
||||
np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1),
|
||||
)
|
||||
visibles[i, occ_inds] = 0
|
||||
rgbs = [rgb.astype(np.uint8) for rgb in rgbs]
|
||||
|
||||
############ photometric augmentation ############
|
||||
if np.random.rand() < self.color_aug_prob:
|
||||
# random per-frame amount of aug
|
||||
rgbs = [np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs]
|
||||
|
||||
if np.random.rand() < self.blur_aug_prob:
|
||||
# random per-frame amount of blur
|
||||
rgbs = [np.array(self.blur_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs]
|
||||
|
||||
return rgbs, trajs, visibles
|
||||
|
||||
def add_spatial_augs(self, rgbs, trajs, visibles):
|
||||
T, N, __ = trajs.shape
|
||||
|
||||
S = len(rgbs)
|
||||
H, W = rgbs[0].shape[:2]
|
||||
assert S == T
|
||||
|
||||
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
|
||||
|
||||
############ spatial transform ############
|
||||
|
||||
# padding
|
||||
pad_x0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
|
||||
pad_x1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
|
||||
pad_y0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
|
||||
pad_y1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
|
||||
|
||||
rgbs = [np.pad(rgb, ((pad_y0, pad_y1), (pad_x0, pad_x1), (0, 0))) for rgb in rgbs]
|
||||
trajs[:, :, 0] += pad_x0
|
||||
trajs[:, :, 1] += pad_y0
|
||||
H, W = rgbs[0].shape[:2]
|
||||
|
||||
# scaling + stretching
|
||||
scale = np.random.uniform(self.resize_lim[0], self.resize_lim[1])
|
||||
scale_x = scale
|
||||
scale_y = scale
|
||||
H_new = H
|
||||
W_new = W
|
||||
|
||||
scale_delta_x = 0.0
|
||||
scale_delta_y = 0.0
|
||||
|
||||
rgbs_scaled = []
|
||||
for s in range(S):
|
||||
if s == 1:
|
||||
scale_delta_x = np.random.uniform(-self.resize_delta, self.resize_delta)
|
||||
scale_delta_y = np.random.uniform(-self.resize_delta, self.resize_delta)
|
||||
elif s > 1:
|
||||
scale_delta_x = (
|
||||
scale_delta_x * 0.8
|
||||
+ np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2
|
||||
)
|
||||
scale_delta_y = (
|
||||
scale_delta_y * 0.8
|
||||
+ np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2
|
||||
)
|
||||
scale_x = scale_x + scale_delta_x
|
||||
scale_y = scale_y + scale_delta_y
|
||||
|
||||
# bring h/w closer
|
||||
scale_xy = (scale_x + scale_y) * 0.5
|
||||
scale_x = scale_x * 0.5 + scale_xy * 0.5
|
||||
scale_y = scale_y * 0.5 + scale_xy * 0.5
|
||||
|
||||
# don't get too crazy
|
||||
scale_x = np.clip(scale_x, 0.2, 2.0)
|
||||
scale_y = np.clip(scale_y, 0.2, 2.0)
|
||||
|
||||
H_new = int(H * scale_y)
|
||||
W_new = int(W * scale_x)
|
||||
|
||||
# make it at least slightly bigger than the crop area,
|
||||
# so that the random cropping can add diversity
|
||||
H_new = np.clip(H_new, self.crop_size[0] + 10, None)
|
||||
W_new = np.clip(W_new, self.crop_size[1] + 10, None)
|
||||
# recompute scale in case we clipped
|
||||
scale_x = (W_new - 1) / float(W - 1)
|
||||
scale_y = (H_new - 1) / float(H - 1)
|
||||
rgbs_scaled.append(cv2.resize(rgbs[s], (W_new, H_new), interpolation=cv2.INTER_LINEAR))
|
||||
trajs[s, :, 0] *= scale_x
|
||||
trajs[s, :, 1] *= scale_y
|
||||
rgbs = rgbs_scaled
|
||||
|
||||
ok_inds = visibles[0, :] > 0
|
||||
vis_trajs = trajs[:, ok_inds] # S,?,2
|
||||
|
||||
if vis_trajs.shape[1] > 0:
|
||||
mid_x = np.mean(vis_trajs[0, :, 0])
|
||||
mid_y = np.mean(vis_trajs[0, :, 1])
|
||||
else:
|
||||
mid_y = self.crop_size[0]
|
||||
mid_x = self.crop_size[1]
|
||||
|
||||
x0 = int(mid_x - self.crop_size[1] // 2)
|
||||
y0 = int(mid_y - self.crop_size[0] // 2)
|
||||
|
||||
offset_x = 0
|
||||
offset_y = 0
|
||||
|
||||
for s in range(S):
|
||||
# on each frame, shift a bit more
|
||||
if s == 1:
|
||||
offset_x = np.random.randint(-self.max_crop_offset, self.max_crop_offset)
|
||||
offset_y = np.random.randint(-self.max_crop_offset, self.max_crop_offset)
|
||||
elif s > 1:
|
||||
offset_x = int(
|
||||
offset_x * 0.8
|
||||
+ np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1) * 0.2
|
||||
)
|
||||
offset_y = int(
|
||||
offset_y * 0.8
|
||||
+ np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1) * 0.2
|
||||
)
|
||||
x0 = x0 + offset_x
|
||||
y0 = y0 + offset_y
|
||||
|
||||
H_new, W_new = rgbs[s].shape[:2]
|
||||
if H_new == self.crop_size[0]:
|
||||
y0 = 0
|
||||
else:
|
||||
y0 = min(max(0, y0), H_new - self.crop_size[0] - 1)
|
||||
|
||||
if W_new == self.crop_size[1]:
|
||||
x0 = 0
|
||||
else:
|
||||
x0 = min(max(0, x0), W_new - self.crop_size[1] - 1)
|
||||
|
||||
rgbs[s] = rgbs[s][y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
|
||||
trajs[s, :, 0] -= x0
|
||||
trajs[s, :, 1] -= y0
|
||||
|
||||
H_new = self.crop_size[0]
|
||||
W_new = self.crop_size[1]
|
||||
|
||||
# flip
|
||||
h_flipped = False
|
||||
v_flipped = False
|
||||
if self.do_flip:
|
||||
# h flip
|
||||
if np.random.rand() < self.h_flip_prob:
|
||||
h_flipped = True
|
||||
rgbs = [rgb[:, ::-1] for rgb in rgbs]
|
||||
# v flip
|
||||
if np.random.rand() < self.v_flip_prob:
|
||||
v_flipped = True
|
||||
rgbs = [rgb[::-1] for rgb in rgbs]
|
||||
if h_flipped:
|
||||
trajs[:, :, 0] = W_new - trajs[:, :, 0]
|
||||
if v_flipped:
|
||||
trajs[:, :, 1] = H_new - trajs[:, :, 1]
|
||||
|
||||
return rgbs, trajs
|
||||
|
||||
def crop(self, rgbs, trajs):
|
||||
T, N, _ = trajs.shape
|
||||
|
||||
S = len(rgbs)
|
||||
H, W = rgbs[0].shape[:2]
|
||||
assert S == T
|
||||
|
||||
############ spatial transform ############
|
||||
|
||||
H_new = H
|
||||
W_new = W
|
||||
|
||||
# simple random crop
|
||||
y0 = 0 if self.crop_size[0] >= H_new else np.random.randint(0, H_new - self.crop_size[0])
|
||||
x0 = 0 if self.crop_size[1] >= W_new else np.random.randint(0, W_new - self.crop_size[1])
|
||||
rgbs = [rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] for rgb in rgbs]
|
||||
|
||||
trajs[:, :, 0] -= x0
|
||||
trajs[:, :, 1] -= y0
|
||||
|
||||
return rgbs, trajs
|
||||
|
||||
|
||||
class KubricMovifDataset(CoTrackerDataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_root,
|
||||
crop_size=(384, 512),
|
||||
seq_len=24,
|
||||
traj_per_sample=768,
|
||||
sample_vis_1st_frame=False,
|
||||
use_augs=False,
|
||||
):
|
||||
super(KubricMovifDataset, self).__init__(
|
||||
data_root=data_root,
|
||||
crop_size=crop_size,
|
||||
seq_len=seq_len,
|
||||
traj_per_sample=traj_per_sample,
|
||||
sample_vis_1st_frame=sample_vis_1st_frame,
|
||||
use_augs=use_augs,
|
||||
)
|
||||
|
||||
self.pad_bounds = [0, 25]
|
||||
self.resize_lim = [0.75, 1.25] # sample resizes from here
|
||||
self.resize_delta = 0.05
|
||||
self.max_crop_offset = 15
|
||||
self.seq_names = [
|
||||
fname
|
||||
for fname in os.listdir(data_root)
|
||||
if os.path.isdir(os.path.join(data_root, fname))
|
||||
]
|
||||
print("found %d unique videos in %s" % (len(self.seq_names), self.data_root))
|
||||
|
||||
def getitem_helper(self, index):
|
||||
gotit = True
|
||||
seq_name = self.seq_names[index]
|
||||
|
||||
npy_path = os.path.join(self.data_root, seq_name, seq_name + ".npy")
|
||||
rgb_path = os.path.join(self.data_root, seq_name, "frames")
|
||||
|
||||
img_paths = sorted(os.listdir(rgb_path))
|
||||
rgbs = []
|
||||
for i, img_path in enumerate(img_paths):
|
||||
rgbs.append(imageio.v2.imread(os.path.join(rgb_path, img_path)))
|
||||
|
||||
rgbs = np.stack(rgbs)
|
||||
annot_dict = np.load(npy_path, allow_pickle=True).item()
|
||||
traj_2d = annot_dict["coords"]
|
||||
visibility = annot_dict["visibility"]
|
||||
|
||||
# random crop
|
||||
assert self.seq_len <= len(rgbs)
|
||||
if self.seq_len < len(rgbs):
|
||||
start_ind = np.random.choice(len(rgbs) - self.seq_len, 1)[0]
|
||||
|
||||
rgbs = rgbs[start_ind : start_ind + self.seq_len]
|
||||
traj_2d = traj_2d[:, start_ind : start_ind + self.seq_len]
|
||||
visibility = visibility[:, start_ind : start_ind + self.seq_len]
|
||||
|
||||
traj_2d = np.transpose(traj_2d, (1, 0, 2))
|
||||
visibility = np.transpose(np.logical_not(visibility), (1, 0))
|
||||
if self.use_augs:
|
||||
rgbs, traj_2d, visibility = self.add_photometric_augs(rgbs, traj_2d, visibility)
|
||||
rgbs, traj_2d = self.add_spatial_augs(rgbs, traj_2d, visibility)
|
||||
else:
|
||||
rgbs, traj_2d = self.crop(rgbs, traj_2d)
|
||||
|
||||
visibility[traj_2d[:, :, 0] > self.crop_size[1] - 1] = False
|
||||
visibility[traj_2d[:, :, 0] < 0] = False
|
||||
visibility[traj_2d[:, :, 1] > self.crop_size[0] - 1] = False
|
||||
visibility[traj_2d[:, :, 1] < 0] = False
|
||||
|
||||
visibility = torch.from_numpy(visibility)
|
||||
traj_2d = torch.from_numpy(traj_2d)
|
||||
|
||||
visibile_pts_first_frame_inds = (visibility[0]).nonzero(as_tuple=False)[:, 0]
|
||||
|
||||
if self.sample_vis_1st_frame:
|
||||
visibile_pts_inds = visibile_pts_first_frame_inds
|
||||
else:
|
||||
visibile_pts_mid_frame_inds = (visibility[self.seq_len // 2]).nonzero(as_tuple=False)[
|
||||
:, 0
|
||||
]
|
||||
visibile_pts_inds = torch.cat(
|
||||
(visibile_pts_first_frame_inds, visibile_pts_mid_frame_inds), dim=0
|
||||
)
|
||||
point_inds = torch.randperm(len(visibile_pts_inds))[: self.traj_per_sample]
|
||||
if len(point_inds) < self.traj_per_sample:
|
||||
gotit = False
|
||||
|
||||
visible_inds_sampled = visibile_pts_inds[point_inds]
|
||||
|
||||
trajs = traj_2d[:, visible_inds_sampled].float()
|
||||
visibles = visibility[:, visible_inds_sampled]
|
||||
valids = torch.ones((self.seq_len, self.traj_per_sample))
|
||||
|
||||
rgbs = torch.from_numpy(np.stack(rgbs)).permute(0, 3, 1, 2).float()
|
||||
sample = CoTrackerData(
|
||||
video=rgbs,
|
||||
trajectory=trajs,
|
||||
visibility=visibles,
|
||||
valid=valids,
|
||||
seq_name=seq_name,
|
||||
)
|
||||
return sample, gotit
|
||||
|
||||
def __len__(self):
|
||||
return len(self.seq_names)
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import torch
|
||||
import cv2
|
||||
|
||||
import imageio
|
||||
import numpy as np
|
||||
|
||||
from cotracker.datasets.utils import CoTrackerData
|
||||
from torchvision.transforms import ColorJitter, GaussianBlur
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class CoTrackerDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_root,
|
||||
crop_size=(384, 512),
|
||||
seq_len=24,
|
||||
traj_per_sample=768,
|
||||
sample_vis_1st_frame=False,
|
||||
use_augs=False,
|
||||
):
|
||||
super(CoTrackerDataset, self).__init__()
|
||||
np.random.seed(0)
|
||||
torch.manual_seed(0)
|
||||
self.data_root = data_root
|
||||
self.seq_len = seq_len
|
||||
self.traj_per_sample = traj_per_sample
|
||||
self.sample_vis_1st_frame = sample_vis_1st_frame
|
||||
self.use_augs = use_augs
|
||||
self.crop_size = crop_size
|
||||
|
||||
# photometric augmentation
|
||||
self.photo_aug = ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.25 / 3.14)
|
||||
self.blur_aug = GaussianBlur(11, sigma=(0.1, 2.0))
|
||||
|
||||
self.blur_aug_prob = 0.25
|
||||
self.color_aug_prob = 0.25
|
||||
|
||||
# occlusion augmentation
|
||||
self.eraser_aug_prob = 0.5
|
||||
self.eraser_bounds = [2, 100]
|
||||
self.eraser_max = 10
|
||||
|
||||
# occlusion augmentation
|
||||
self.replace_aug_prob = 0.5
|
||||
self.replace_bounds = [2, 100]
|
||||
self.replace_max = 10
|
||||
|
||||
# spatial augmentations
|
||||
self.pad_bounds = [0, 100]
|
||||
self.crop_size = crop_size
|
||||
self.resize_lim = [0.25, 2.0] # sample resizes from here
|
||||
self.resize_delta = 0.2
|
||||
self.max_crop_offset = 50
|
||||
|
||||
self.do_flip = True
|
||||
self.h_flip_prob = 0.5
|
||||
self.v_flip_prob = 0.5
|
||||
|
||||
def getitem_helper(self, index):
|
||||
return NotImplementedError
|
||||
|
||||
def __getitem__(self, index):
|
||||
gotit = False
|
||||
|
||||
sample, gotit = self.getitem_helper(index)
|
||||
if not gotit:
|
||||
print("warning: sampling failed")
|
||||
# fake sample, so we can still collate
|
||||
sample = CoTrackerData(
|
||||
video=torch.zeros((self.seq_len, 3, self.crop_size[0], self.crop_size[1])),
|
||||
trajectory=torch.zeros((self.seq_len, self.traj_per_sample, 2)),
|
||||
visibility=torch.zeros((self.seq_len, self.traj_per_sample)),
|
||||
valid=torch.zeros((self.seq_len, self.traj_per_sample)),
|
||||
)
|
||||
|
||||
return sample, gotit
|
||||
|
||||
def add_photometric_augs(self, rgbs, trajs, visibles, eraser=True, replace=True):
|
||||
T, N, _ = trajs.shape
|
||||
|
||||
S = len(rgbs)
|
||||
H, W = rgbs[0].shape[:2]
|
||||
assert S == T
|
||||
|
||||
if eraser:
|
||||
############ eraser transform (per image after the first) ############
|
||||
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
|
||||
for i in range(1, S):
|
||||
if np.random.rand() < self.eraser_aug_prob:
|
||||
for _ in range(
|
||||
np.random.randint(1, self.eraser_max + 1)
|
||||
): # number of times to occlude
|
||||
xc = np.random.randint(0, W)
|
||||
yc = np.random.randint(0, H)
|
||||
dx = np.random.randint(self.eraser_bounds[0], self.eraser_bounds[1])
|
||||
dy = np.random.randint(self.eraser_bounds[0], self.eraser_bounds[1])
|
||||
x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)
|
||||
x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)
|
||||
y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)
|
||||
y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32)
|
||||
|
||||
mean_color = np.mean(rgbs[i][y0:y1, x0:x1, :].reshape(-1, 3), axis=0)
|
||||
rgbs[i][y0:y1, x0:x1, :] = mean_color
|
||||
|
||||
occ_inds = np.logical_and(
|
||||
np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1),
|
||||
np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1),
|
||||
)
|
||||
visibles[i, occ_inds] = 0
|
||||
rgbs = [rgb.astype(np.uint8) for rgb in rgbs]
|
||||
|
||||
if replace:
|
||||
rgbs_alt = [
|
||||
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs
|
||||
]
|
||||
rgbs_alt = [
|
||||
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs_alt
|
||||
]
|
||||
|
||||
############ replace transform (per image after the first) ############
|
||||
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
|
||||
rgbs_alt = [rgb.astype(np.float32) for rgb in rgbs_alt]
|
||||
for i in range(1, S):
|
||||
if np.random.rand() < self.replace_aug_prob:
|
||||
for _ in range(
|
||||
np.random.randint(1, self.replace_max + 1)
|
||||
): # number of times to occlude
|
||||
xc = np.random.randint(0, W)
|
||||
yc = np.random.randint(0, H)
|
||||
dx = np.random.randint(self.replace_bounds[0], self.replace_bounds[1])
|
||||
dy = np.random.randint(self.replace_bounds[0], self.replace_bounds[1])
|
||||
x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)
|
||||
x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)
|
||||
y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)
|
||||
y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32)
|
||||
|
||||
wid = x1 - x0
|
||||
hei = y1 - y0
|
||||
y00 = np.random.randint(0, H - hei)
|
||||
x00 = np.random.randint(0, W - wid)
|
||||
fr = np.random.randint(0, S)
|
||||
rep = rgbs_alt[fr][y00 : y00 + hei, x00 : x00 + wid, :]
|
||||
rgbs[i][y0:y1, x0:x1, :] = rep
|
||||
|
||||
occ_inds = np.logical_and(
|
||||
np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1),
|
||||
np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1),
|
||||
)
|
||||
visibles[i, occ_inds] = 0
|
||||
rgbs = [rgb.astype(np.uint8) for rgb in rgbs]
|
||||
|
||||
############ photometric augmentation ############
|
||||
if np.random.rand() < self.color_aug_prob:
|
||||
# random per-frame amount of aug
|
||||
rgbs = [np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs]
|
||||
|
||||
if np.random.rand() < self.blur_aug_prob:
|
||||
# random per-frame amount of blur
|
||||
rgbs = [np.array(self.blur_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs]
|
||||
|
||||
return rgbs, trajs, visibles
|
||||
|
||||
def add_spatial_augs(self, rgbs, trajs, visibles):
|
||||
T, N, __ = trajs.shape
|
||||
|
||||
S = len(rgbs)
|
||||
H, W = rgbs[0].shape[:2]
|
||||
assert S == T
|
||||
|
||||
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
|
||||
|
||||
############ spatial transform ############
|
||||
|
||||
# padding
|
||||
pad_x0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
|
||||
pad_x1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
|
||||
pad_y0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
|
||||
pad_y1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
|
||||
|
||||
rgbs = [np.pad(rgb, ((pad_y0, pad_y1), (pad_x0, pad_x1), (0, 0))) for rgb in rgbs]
|
||||
trajs[:, :, 0] += pad_x0
|
||||
trajs[:, :, 1] += pad_y0
|
||||
H, W = rgbs[0].shape[:2]
|
||||
|
||||
# scaling + stretching
|
||||
scale = np.random.uniform(self.resize_lim[0], self.resize_lim[1])
|
||||
scale_x = scale
|
||||
scale_y = scale
|
||||
H_new = H
|
||||
W_new = W
|
||||
|
||||
scale_delta_x = 0.0
|
||||
scale_delta_y = 0.0
|
||||
|
||||
rgbs_scaled = []
|
||||
for s in range(S):
|
||||
if s == 1:
|
||||
scale_delta_x = np.random.uniform(-self.resize_delta, self.resize_delta)
|
||||
scale_delta_y = np.random.uniform(-self.resize_delta, self.resize_delta)
|
||||
elif s > 1:
|
||||
scale_delta_x = (
|
||||
scale_delta_x * 0.8
|
||||
+ np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2
|
||||
)
|
||||
scale_delta_y = (
|
||||
scale_delta_y * 0.8
|
||||
+ np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2
|
||||
)
|
||||
scale_x = scale_x + scale_delta_x
|
||||
scale_y = scale_y + scale_delta_y
|
||||
|
||||
# bring h/w closer
|
||||
scale_xy = (scale_x + scale_y) * 0.5
|
||||
scale_x = scale_x * 0.5 + scale_xy * 0.5
|
||||
scale_y = scale_y * 0.5 + scale_xy * 0.5
|
||||
|
||||
# don't get too crazy
|
||||
scale_x = np.clip(scale_x, 0.2, 2.0)
|
||||
scale_y = np.clip(scale_y, 0.2, 2.0)
|
||||
|
||||
H_new = int(H * scale_y)
|
||||
W_new = int(W * scale_x)
|
||||
|
||||
# make it at least slightly bigger than the crop area,
|
||||
# so that the random cropping can add diversity
|
||||
H_new = np.clip(H_new, self.crop_size[0] + 10, None)
|
||||
W_new = np.clip(W_new, self.crop_size[1] + 10, None)
|
||||
# recompute scale in case we clipped
|
||||
scale_x = (W_new - 1) / float(W - 1)
|
||||
scale_y = (H_new - 1) / float(H - 1)
|
||||
rgbs_scaled.append(cv2.resize(rgbs[s], (W_new, H_new), interpolation=cv2.INTER_LINEAR))
|
||||
trajs[s, :, 0] *= scale_x
|
||||
trajs[s, :, 1] *= scale_y
|
||||
rgbs = rgbs_scaled
|
||||
|
||||
ok_inds = visibles[0, :] > 0
|
||||
vis_trajs = trajs[:, ok_inds] # S,?,2
|
||||
|
||||
if vis_trajs.shape[1] > 0:
|
||||
mid_x = np.mean(vis_trajs[0, :, 0])
|
||||
mid_y = np.mean(vis_trajs[0, :, 1])
|
||||
else:
|
||||
mid_y = self.crop_size[0]
|
||||
mid_x = self.crop_size[1]
|
||||
|
||||
x0 = int(mid_x - self.crop_size[1] // 2)
|
||||
y0 = int(mid_y - self.crop_size[0] // 2)
|
||||
|
||||
offset_x = 0
|
||||
offset_y = 0
|
||||
|
||||
for s in range(S):
|
||||
# on each frame, shift a bit more
|
||||
if s == 1:
|
||||
offset_x = np.random.randint(-self.max_crop_offset, self.max_crop_offset)
|
||||
offset_y = np.random.randint(-self.max_crop_offset, self.max_crop_offset)
|
||||
elif s > 1:
|
||||
offset_x = int(
|
||||
offset_x * 0.8
|
||||
+ np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1) * 0.2
|
||||
)
|
||||
offset_y = int(
|
||||
offset_y * 0.8
|
||||
+ np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1) * 0.2
|
||||
)
|
||||
x0 = x0 + offset_x
|
||||
y0 = y0 + offset_y
|
||||
|
||||
H_new, W_new = rgbs[s].shape[:2]
|
||||
if H_new == self.crop_size[0]:
|
||||
y0 = 0
|
||||
else:
|
||||
y0 = min(max(0, y0), H_new - self.crop_size[0] - 1)
|
||||
|
||||
if W_new == self.crop_size[1]:
|
||||
x0 = 0
|
||||
else:
|
||||
x0 = min(max(0, x0), W_new - self.crop_size[1] - 1)
|
||||
|
||||
rgbs[s] = rgbs[s][y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
|
||||
trajs[s, :, 0] -= x0
|
||||
trajs[s, :, 1] -= y0
|
||||
|
||||
H_new = self.crop_size[0]
|
||||
W_new = self.crop_size[1]
|
||||
|
||||
# flip
|
||||
h_flipped = False
|
||||
v_flipped = False
|
||||
if self.do_flip:
|
||||
# h flip
|
||||
if np.random.rand() < self.h_flip_prob:
|
||||
h_flipped = True
|
||||
rgbs = [rgb[:, ::-1] for rgb in rgbs]
|
||||
# v flip
|
||||
if np.random.rand() < self.v_flip_prob:
|
||||
v_flipped = True
|
||||
rgbs = [rgb[::-1] for rgb in rgbs]
|
||||
if h_flipped:
|
||||
trajs[:, :, 0] = W_new - trajs[:, :, 0]
|
||||
if v_flipped:
|
||||
trajs[:, :, 1] = H_new - trajs[:, :, 1]
|
||||
|
||||
return rgbs, trajs
|
||||
|
||||
def crop(self, rgbs, trajs):
|
||||
T, N, _ = trajs.shape
|
||||
|
||||
S = len(rgbs)
|
||||
H, W = rgbs[0].shape[:2]
|
||||
assert S == T
|
||||
|
||||
############ spatial transform ############
|
||||
|
||||
H_new = H
|
||||
W_new = W
|
||||
|
||||
# simple random crop
|
||||
y0 = 0 if self.crop_size[0] >= H_new else np.random.randint(0, H_new - self.crop_size[0])
|
||||
x0 = 0 if self.crop_size[1] >= W_new else np.random.randint(0, W_new - self.crop_size[1])
|
||||
rgbs = [rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] for rgb in rgbs]
|
||||
|
||||
trajs[:, :, 0] -= x0
|
||||
trajs[:, :, 1] -= y0
|
||||
|
||||
return rgbs, trajs
|
||||
|
||||
|
||||
class KubricMovifDataset(CoTrackerDataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_root,
|
||||
crop_size=(384, 512),
|
||||
seq_len=24,
|
||||
traj_per_sample=768,
|
||||
sample_vis_1st_frame=False,
|
||||
use_augs=False,
|
||||
):
|
||||
super(KubricMovifDataset, self).__init__(
|
||||
data_root=data_root,
|
||||
crop_size=crop_size,
|
||||
seq_len=seq_len,
|
||||
traj_per_sample=traj_per_sample,
|
||||
sample_vis_1st_frame=sample_vis_1st_frame,
|
||||
use_augs=use_augs,
|
||||
)
|
||||
|
||||
self.pad_bounds = [0, 25]
|
||||
self.resize_lim = [0.75, 1.25] # sample resizes from here
|
||||
self.resize_delta = 0.05
|
||||
self.max_crop_offset = 15
|
||||
self.seq_names = [
|
||||
fname
|
||||
for fname in os.listdir(data_root)
|
||||
if os.path.isdir(os.path.join(data_root, fname))
|
||||
]
|
||||
print("found %d unique videos in %s" % (len(self.seq_names), self.data_root))
|
||||
|
||||
def getitem_helper(self, index):
|
||||
gotit = True
|
||||
seq_name = self.seq_names[index]
|
||||
|
||||
npy_path = os.path.join(self.data_root, seq_name, seq_name + ".npy")
|
||||
rgb_path = os.path.join(self.data_root, seq_name, "frames")
|
||||
|
||||
img_paths = sorted(os.listdir(rgb_path))
|
||||
rgbs = []
|
||||
for i, img_path in enumerate(img_paths):
|
||||
rgbs.append(imageio.v2.imread(os.path.join(rgb_path, img_path)))
|
||||
|
||||
rgbs = np.stack(rgbs)
|
||||
annot_dict = np.load(npy_path, allow_pickle=True).item()
|
||||
traj_2d = annot_dict["coords"]
|
||||
visibility = annot_dict["visibility"]
|
||||
|
||||
# random crop
|
||||
assert self.seq_len <= len(rgbs)
|
||||
if self.seq_len < len(rgbs):
|
||||
start_ind = np.random.choice(len(rgbs) - self.seq_len, 1)[0]
|
||||
|
||||
rgbs = rgbs[start_ind : start_ind + self.seq_len]
|
||||
traj_2d = traj_2d[:, start_ind : start_ind + self.seq_len]
|
||||
visibility = visibility[:, start_ind : start_ind + self.seq_len]
|
||||
|
||||
traj_2d = np.transpose(traj_2d, (1, 0, 2))
|
||||
visibility = np.transpose(np.logical_not(visibility), (1, 0))
|
||||
if self.use_augs:
|
||||
rgbs, traj_2d, visibility = self.add_photometric_augs(rgbs, traj_2d, visibility)
|
||||
rgbs, traj_2d = self.add_spatial_augs(rgbs, traj_2d, visibility)
|
||||
else:
|
||||
rgbs, traj_2d = self.crop(rgbs, traj_2d)
|
||||
|
||||
visibility[traj_2d[:, :, 0] > self.crop_size[1] - 1] = False
|
||||
visibility[traj_2d[:, :, 0] < 0] = False
|
||||
visibility[traj_2d[:, :, 1] > self.crop_size[0] - 1] = False
|
||||
visibility[traj_2d[:, :, 1] < 0] = False
|
||||
|
||||
visibility = torch.from_numpy(visibility)
|
||||
traj_2d = torch.from_numpy(traj_2d)
|
||||
|
||||
visibile_pts_first_frame_inds = (visibility[0]).nonzero(as_tuple=False)[:, 0]
|
||||
|
||||
if self.sample_vis_1st_frame:
|
||||
visibile_pts_inds = visibile_pts_first_frame_inds
|
||||
else:
|
||||
visibile_pts_mid_frame_inds = (visibility[self.seq_len // 2]).nonzero(as_tuple=False)[
|
||||
:, 0
|
||||
]
|
||||
visibile_pts_inds = torch.cat(
|
||||
(visibile_pts_first_frame_inds, visibile_pts_mid_frame_inds), dim=0
|
||||
)
|
||||
point_inds = torch.randperm(len(visibile_pts_inds))[: self.traj_per_sample]
|
||||
if len(point_inds) < self.traj_per_sample:
|
||||
gotit = False
|
||||
|
||||
visible_inds_sampled = visibile_pts_inds[point_inds]
|
||||
|
||||
trajs = traj_2d[:, visible_inds_sampled].float()
|
||||
visibles = visibility[:, visible_inds_sampled]
|
||||
valids = torch.ones((self.seq_len, self.traj_per_sample))
|
||||
|
||||
rgbs = torch.from_numpy(np.stack(rgbs)).permute(0, 3, 1, 2).float()
|
||||
sample = CoTrackerData(
|
||||
video=rgbs,
|
||||
trajectory=trajs,
|
||||
visibility=visibles,
|
||||
valid=valids,
|
||||
seq_name=seq_name,
|
||||
)
|
||||
return sample, gotit
|
||||
|
||||
def __len__(self):
|
||||
return len(self.seq_names)
|
||||
|
@@ -1,209 +1,209 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import io
|
||||
import glob
|
||||
import torch
|
||||
import pickle
|
||||
import numpy as np
|
||||
import mediapy as media
|
||||
|
||||
from PIL import Image
|
||||
from typing import Mapping, Tuple, Union
|
||||
|
||||
from cotracker.datasets.utils import CoTrackerData
|
||||
|
||||
DatasetElement = Mapping[str, Mapping[str, Union[np.ndarray, str]]]
|
||||
|
||||
|
||||
def resize_video(video: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
|
||||
"""Resize a video to output_size."""
|
||||
# If you have a GPU, consider replacing this with a GPU-enabled resize op,
|
||||
# such as a jitted jax.image.resize. It will make things faster.
|
||||
return media.resize_video(video, output_size)
|
||||
|
||||
|
||||
def sample_queries_first(
|
||||
target_occluded: np.ndarray,
|
||||
target_points: np.ndarray,
|
||||
frames: np.ndarray,
|
||||
) -> Mapping[str, np.ndarray]:
|
||||
"""Package a set of frames and tracks for use in TAPNet evaluations.
|
||||
Given a set of frames and tracks with no query points, use the first
|
||||
visible point in each track as the query.
|
||||
Args:
|
||||
target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
|
||||
where True indicates occluded.
|
||||
target_points: Position, of shape [n_tracks, n_frames, 2], where each point
|
||||
is [x,y] scaled between 0 and 1.
|
||||
frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
|
||||
-1 and 1.
|
||||
Returns:
|
||||
A dict with the keys:
|
||||
video: Video tensor of shape [1, n_frames, height, width, 3]
|
||||
query_points: Query points of shape [1, n_queries, 3] where
|
||||
each point is [t, y, x] scaled to the range [-1, 1]
|
||||
target_points: Target points of shape [1, n_queries, n_frames, 2] where
|
||||
each point is [x, y] scaled to the range [-1, 1]
|
||||
"""
|
||||
valid = np.sum(~target_occluded, axis=1) > 0
|
||||
target_points = target_points[valid, :]
|
||||
target_occluded = target_occluded[valid, :]
|
||||
|
||||
query_points = []
|
||||
for i in range(target_points.shape[0]):
|
||||
index = np.where(target_occluded[i] == 0)[0][0]
|
||||
x, y = target_points[i, index, 0], target_points[i, index, 1]
|
||||
query_points.append(np.array([index, y, x])) # [t, y, x]
|
||||
query_points = np.stack(query_points, axis=0)
|
||||
|
||||
return {
|
||||
"video": frames[np.newaxis, ...],
|
||||
"query_points": query_points[np.newaxis, ...],
|
||||
"target_points": target_points[np.newaxis, ...],
|
||||
"occluded": target_occluded[np.newaxis, ...],
|
||||
}
|
||||
|
||||
|
||||
def sample_queries_strided(
|
||||
target_occluded: np.ndarray,
|
||||
target_points: np.ndarray,
|
||||
frames: np.ndarray,
|
||||
query_stride: int = 5,
|
||||
) -> Mapping[str, np.ndarray]:
|
||||
"""Package a set of frames and tracks for use in TAPNet evaluations.
|
||||
|
||||
Given a set of frames and tracks with no query points, sample queries
|
||||
strided every query_stride frames, ignoring points that are not visible
|
||||
at the selected frames.
|
||||
|
||||
Args:
|
||||
target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
|
||||
where True indicates occluded.
|
||||
target_points: Position, of shape [n_tracks, n_frames, 2], where each point
|
||||
is [x,y] scaled between 0 and 1.
|
||||
frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
|
||||
-1 and 1.
|
||||
query_stride: When sampling query points, search for un-occluded points
|
||||
every query_stride frames and convert each one into a query.
|
||||
|
||||
Returns:
|
||||
A dict with the keys:
|
||||
video: Video tensor of shape [1, n_frames, height, width, 3]. The video
|
||||
has floats scaled to the range [-1, 1].
|
||||
query_points: Query points of shape [1, n_queries, 3] where
|
||||
each point is [t, y, x] scaled to the range [-1, 1].
|
||||
target_points: Target points of shape [1, n_queries, n_frames, 2] where
|
||||
each point is [x, y] scaled to the range [-1, 1].
|
||||
trackgroup: Index of the original track that each query point was
|
||||
sampled from. This is useful for visualization.
|
||||
"""
|
||||
tracks = []
|
||||
occs = []
|
||||
queries = []
|
||||
trackgroups = []
|
||||
total = 0
|
||||
trackgroup = np.arange(target_occluded.shape[0])
|
||||
for i in range(0, target_occluded.shape[1], query_stride):
|
||||
mask = target_occluded[:, i] == 0
|
||||
query = np.stack(
|
||||
[
|
||||
i * np.ones(target_occluded.shape[0:1]),
|
||||
target_points[:, i, 1],
|
||||
target_points[:, i, 0],
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
queries.append(query[mask])
|
||||
tracks.append(target_points[mask])
|
||||
occs.append(target_occluded[mask])
|
||||
trackgroups.append(trackgroup[mask])
|
||||
total += np.array(np.sum(target_occluded[:, i] == 0))
|
||||
|
||||
return {
|
||||
"video": frames[np.newaxis, ...],
|
||||
"query_points": np.concatenate(queries, axis=0)[np.newaxis, ...],
|
||||
"target_points": np.concatenate(tracks, axis=0)[np.newaxis, ...],
|
||||
"occluded": np.concatenate(occs, axis=0)[np.newaxis, ...],
|
||||
"trackgroup": np.concatenate(trackgroups, axis=0)[np.newaxis, ...],
|
||||
}
|
||||
|
||||
|
||||
class TapVidDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_root,
|
||||
dataset_type="davis",
|
||||
resize_to_256=True,
|
||||
queried_first=True,
|
||||
):
|
||||
self.dataset_type = dataset_type
|
||||
self.resize_to_256 = resize_to_256
|
||||
self.queried_first = queried_first
|
||||
if self.dataset_type == "kinetics":
|
||||
all_paths = glob.glob(os.path.join(data_root, "*_of_0010.pkl"))
|
||||
points_dataset = []
|
||||
for pickle_path in all_paths:
|
||||
with open(pickle_path, "rb") as f:
|
||||
data = pickle.load(f)
|
||||
points_dataset = points_dataset + data
|
||||
self.points_dataset = points_dataset
|
||||
else:
|
||||
with open(data_root, "rb") as f:
|
||||
self.points_dataset = pickle.load(f)
|
||||
if self.dataset_type == "davis":
|
||||
self.video_names = list(self.points_dataset.keys())
|
||||
print("found %d unique videos in %s" % (len(self.points_dataset), data_root))
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.dataset_type == "davis":
|
||||
video_name = self.video_names[index]
|
||||
else:
|
||||
video_name = index
|
||||
video = self.points_dataset[video_name]
|
||||
frames = video["video"]
|
||||
|
||||
if isinstance(frames[0], bytes):
|
||||
# TAP-Vid is stored and JPEG bytes rather than `np.ndarray`s.
|
||||
def decode(frame):
|
||||
byteio = io.BytesIO(frame)
|
||||
img = Image.open(byteio)
|
||||
return np.array(img)
|
||||
|
||||
frames = np.array([decode(frame) for frame in frames])
|
||||
|
||||
target_points = self.points_dataset[video_name]["points"]
|
||||
if self.resize_to_256:
|
||||
frames = resize_video(frames, [256, 256])
|
||||
target_points *= np.array([255, 255]) # 1 should be mapped to 256-1
|
||||
else:
|
||||
target_points *= np.array([frames.shape[2] - 1, frames.shape[1] - 1])
|
||||
|
||||
target_occ = self.points_dataset[video_name]["occluded"]
|
||||
if self.queried_first:
|
||||
converted = sample_queries_first(target_occ, target_points, frames)
|
||||
else:
|
||||
converted = sample_queries_strided(target_occ, target_points, frames)
|
||||
assert converted["target_points"].shape[1] == converted["query_points"].shape[1]
|
||||
|
||||
trajs = torch.from_numpy(converted["target_points"])[0].permute(1, 0, 2).float() # T, N, D
|
||||
|
||||
rgbs = torch.from_numpy(frames).permute(0, 3, 1, 2).float()
|
||||
visibles = torch.logical_not(torch.from_numpy(converted["occluded"]))[0].permute(
|
||||
1, 0
|
||||
) # T, N
|
||||
query_points = torch.from_numpy(converted["query_points"])[0] # T, N
|
||||
return CoTrackerData(
|
||||
rgbs,
|
||||
trajs,
|
||||
visibles,
|
||||
seq_name=str(video_name),
|
||||
query_points=query_points,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.points_dataset)
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import io
|
||||
import glob
|
||||
import torch
|
||||
import pickle
|
||||
import numpy as np
|
||||
import mediapy as media
|
||||
|
||||
from PIL import Image
|
||||
from typing import Mapping, Tuple, Union
|
||||
|
||||
from cotracker.datasets.utils import CoTrackerData
|
||||
|
||||
DatasetElement = Mapping[str, Mapping[str, Union[np.ndarray, str]]]
|
||||
|
||||
|
||||
def resize_video(video: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
|
||||
"""Resize a video to output_size."""
|
||||
# If you have a GPU, consider replacing this with a GPU-enabled resize op,
|
||||
# such as a jitted jax.image.resize. It will make things faster.
|
||||
return media.resize_video(video, output_size)
|
||||
|
||||
|
||||
def sample_queries_first(
|
||||
target_occluded: np.ndarray,
|
||||
target_points: np.ndarray,
|
||||
frames: np.ndarray,
|
||||
) -> Mapping[str, np.ndarray]:
|
||||
"""Package a set of frames and tracks for use in TAPNet evaluations.
|
||||
Given a set of frames and tracks with no query points, use the first
|
||||
visible point in each track as the query.
|
||||
Args:
|
||||
target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
|
||||
where True indicates occluded.
|
||||
target_points: Position, of shape [n_tracks, n_frames, 2], where each point
|
||||
is [x,y] scaled between 0 and 1.
|
||||
frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
|
||||
-1 and 1.
|
||||
Returns:
|
||||
A dict with the keys:
|
||||
video: Video tensor of shape [1, n_frames, height, width, 3]
|
||||
query_points: Query points of shape [1, n_queries, 3] where
|
||||
each point is [t, y, x] scaled to the range [-1, 1]
|
||||
target_points: Target points of shape [1, n_queries, n_frames, 2] where
|
||||
each point is [x, y] scaled to the range [-1, 1]
|
||||
"""
|
||||
valid = np.sum(~target_occluded, axis=1) > 0
|
||||
target_points = target_points[valid, :]
|
||||
target_occluded = target_occluded[valid, :]
|
||||
|
||||
query_points = []
|
||||
for i in range(target_points.shape[0]):
|
||||
index = np.where(target_occluded[i] == 0)[0][0]
|
||||
x, y = target_points[i, index, 0], target_points[i, index, 1]
|
||||
query_points.append(np.array([index, y, x])) # [t, y, x]
|
||||
query_points = np.stack(query_points, axis=0)
|
||||
|
||||
return {
|
||||
"video": frames[np.newaxis, ...],
|
||||
"query_points": query_points[np.newaxis, ...],
|
||||
"target_points": target_points[np.newaxis, ...],
|
||||
"occluded": target_occluded[np.newaxis, ...],
|
||||
}
|
||||
|
||||
|
||||
def sample_queries_strided(
|
||||
target_occluded: np.ndarray,
|
||||
target_points: np.ndarray,
|
||||
frames: np.ndarray,
|
||||
query_stride: int = 5,
|
||||
) -> Mapping[str, np.ndarray]:
|
||||
"""Package a set of frames and tracks for use in TAPNet evaluations.
|
||||
|
||||
Given a set of frames and tracks with no query points, sample queries
|
||||
strided every query_stride frames, ignoring points that are not visible
|
||||
at the selected frames.
|
||||
|
||||
Args:
|
||||
target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
|
||||
where True indicates occluded.
|
||||
target_points: Position, of shape [n_tracks, n_frames, 2], where each point
|
||||
is [x,y] scaled between 0 and 1.
|
||||
frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
|
||||
-1 and 1.
|
||||
query_stride: When sampling query points, search for un-occluded points
|
||||
every query_stride frames and convert each one into a query.
|
||||
|
||||
Returns:
|
||||
A dict with the keys:
|
||||
video: Video tensor of shape [1, n_frames, height, width, 3]. The video
|
||||
has floats scaled to the range [-1, 1].
|
||||
query_points: Query points of shape [1, n_queries, 3] where
|
||||
each point is [t, y, x] scaled to the range [-1, 1].
|
||||
target_points: Target points of shape [1, n_queries, n_frames, 2] where
|
||||
each point is [x, y] scaled to the range [-1, 1].
|
||||
trackgroup: Index of the original track that each query point was
|
||||
sampled from. This is useful for visualization.
|
||||
"""
|
||||
tracks = []
|
||||
occs = []
|
||||
queries = []
|
||||
trackgroups = []
|
||||
total = 0
|
||||
trackgroup = np.arange(target_occluded.shape[0])
|
||||
for i in range(0, target_occluded.shape[1], query_stride):
|
||||
mask = target_occluded[:, i] == 0
|
||||
query = np.stack(
|
||||
[
|
||||
i * np.ones(target_occluded.shape[0:1]),
|
||||
target_points[:, i, 1],
|
||||
target_points[:, i, 0],
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
queries.append(query[mask])
|
||||
tracks.append(target_points[mask])
|
||||
occs.append(target_occluded[mask])
|
||||
trackgroups.append(trackgroup[mask])
|
||||
total += np.array(np.sum(target_occluded[:, i] == 0))
|
||||
|
||||
return {
|
||||
"video": frames[np.newaxis, ...],
|
||||
"query_points": np.concatenate(queries, axis=0)[np.newaxis, ...],
|
||||
"target_points": np.concatenate(tracks, axis=0)[np.newaxis, ...],
|
||||
"occluded": np.concatenate(occs, axis=0)[np.newaxis, ...],
|
||||
"trackgroup": np.concatenate(trackgroups, axis=0)[np.newaxis, ...],
|
||||
}
|
||||
|
||||
|
||||
class TapVidDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_root,
|
||||
dataset_type="davis",
|
||||
resize_to_256=True,
|
||||
queried_first=True,
|
||||
):
|
||||
self.dataset_type = dataset_type
|
||||
self.resize_to_256 = resize_to_256
|
||||
self.queried_first = queried_first
|
||||
if self.dataset_type == "kinetics":
|
||||
all_paths = glob.glob(os.path.join(data_root, "*_of_0010.pkl"))
|
||||
points_dataset = []
|
||||
for pickle_path in all_paths:
|
||||
with open(pickle_path, "rb") as f:
|
||||
data = pickle.load(f)
|
||||
points_dataset = points_dataset + data
|
||||
self.points_dataset = points_dataset
|
||||
else:
|
||||
with open(data_root, "rb") as f:
|
||||
self.points_dataset = pickle.load(f)
|
||||
if self.dataset_type == "davis":
|
||||
self.video_names = list(self.points_dataset.keys())
|
||||
print("found %d unique videos in %s" % (len(self.points_dataset), data_root))
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.dataset_type == "davis":
|
||||
video_name = self.video_names[index]
|
||||
else:
|
||||
video_name = index
|
||||
video = self.points_dataset[video_name]
|
||||
frames = video["video"]
|
||||
|
||||
if isinstance(frames[0], bytes):
|
||||
# TAP-Vid is stored and JPEG bytes rather than `np.ndarray`s.
|
||||
def decode(frame):
|
||||
byteio = io.BytesIO(frame)
|
||||
img = Image.open(byteio)
|
||||
return np.array(img)
|
||||
|
||||
frames = np.array([decode(frame) for frame in frames])
|
||||
|
||||
target_points = self.points_dataset[video_name]["points"]
|
||||
if self.resize_to_256:
|
||||
frames = resize_video(frames, [256, 256])
|
||||
target_points *= np.array([255, 255]) # 1 should be mapped to 256-1
|
||||
else:
|
||||
target_points *= np.array([frames.shape[2] - 1, frames.shape[1] - 1])
|
||||
|
||||
target_occ = self.points_dataset[video_name]["occluded"]
|
||||
if self.queried_first:
|
||||
converted = sample_queries_first(target_occ, target_points, frames)
|
||||
else:
|
||||
converted = sample_queries_strided(target_occ, target_points, frames)
|
||||
assert converted["target_points"].shape[1] == converted["query_points"].shape[1]
|
||||
|
||||
trajs = torch.from_numpy(converted["target_points"])[0].permute(1, 0, 2).float() # T, N, D
|
||||
|
||||
rgbs = torch.from_numpy(frames).permute(0, 3, 1, 2).float()
|
||||
visibles = torch.logical_not(torch.from_numpy(converted["occluded"]))[0].permute(
|
||||
1, 0
|
||||
) # T, N
|
||||
query_points = torch.from_numpy(converted["query_points"])[0] # T, N
|
||||
return CoTrackerData(
|
||||
rgbs,
|
||||
trajs,
|
||||
visibles,
|
||||
seq_name=str(video_name),
|
||||
query_points=query_points,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.points_dataset)
|
||||
|
@@ -1,106 +1,106 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import torch
|
||||
import dataclasses
|
||||
import torch.nn.functional as F
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class CoTrackerData:
|
||||
"""
|
||||
Dataclass for storing video tracks data.
|
||||
"""
|
||||
|
||||
video: torch.Tensor # B, S, C, H, W
|
||||
trajectory: torch.Tensor # B, S, N, 2
|
||||
visibility: torch.Tensor # B, S, N
|
||||
# optional data
|
||||
valid: Optional[torch.Tensor] = None # B, S, N
|
||||
segmentation: Optional[torch.Tensor] = None # B, S, 1, H, W
|
||||
seq_name: Optional[str] = None
|
||||
query_points: Optional[torch.Tensor] = None # TapVID evaluation format
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
"""
|
||||
Collate function for video tracks data.
|
||||
"""
|
||||
video = torch.stack([b.video for b in batch], dim=0)
|
||||
trajectory = torch.stack([b.trajectory for b in batch], dim=0)
|
||||
visibility = torch.stack([b.visibility for b in batch], dim=0)
|
||||
query_points = segmentation = None
|
||||
if batch[0].query_points is not None:
|
||||
query_points = torch.stack([b.query_points for b in batch], dim=0)
|
||||
if batch[0].segmentation is not None:
|
||||
segmentation = torch.stack([b.segmentation for b in batch], dim=0)
|
||||
seq_name = [b.seq_name for b in batch]
|
||||
|
||||
return CoTrackerData(
|
||||
video=video,
|
||||
trajectory=trajectory,
|
||||
visibility=visibility,
|
||||
segmentation=segmentation,
|
||||
seq_name=seq_name,
|
||||
query_points=query_points,
|
||||
)
|
||||
|
||||
|
||||
def collate_fn_train(batch):
|
||||
"""
|
||||
Collate function for video tracks data during training.
|
||||
"""
|
||||
gotit = [gotit for _, gotit in batch]
|
||||
video = torch.stack([b.video for b, _ in batch], dim=0)
|
||||
trajectory = torch.stack([b.trajectory for b, _ in batch], dim=0)
|
||||
visibility = torch.stack([b.visibility for b, _ in batch], dim=0)
|
||||
valid = torch.stack([b.valid for b, _ in batch], dim=0)
|
||||
seq_name = [b.seq_name for b, _ in batch]
|
||||
return (
|
||||
CoTrackerData(
|
||||
video=video,
|
||||
trajectory=trajectory,
|
||||
visibility=visibility,
|
||||
valid=valid,
|
||||
seq_name=seq_name,
|
||||
),
|
||||
gotit,
|
||||
)
|
||||
|
||||
|
||||
def try_to_cuda(t: Any) -> Any:
|
||||
"""
|
||||
Try to move the input variable `t` to a cuda device.
|
||||
|
||||
Args:
|
||||
t: Input.
|
||||
|
||||
Returns:
|
||||
t_cuda: `t` moved to a cuda device, if supported.
|
||||
"""
|
||||
try:
|
||||
t = t.float().cuda()
|
||||
except AttributeError:
|
||||
pass
|
||||
return t
|
||||
|
||||
|
||||
def dataclass_to_cuda_(obj):
|
||||
"""
|
||||
Move all contents of a dataclass to cuda inplace if supported.
|
||||
|
||||
Args:
|
||||
batch: Input dataclass.
|
||||
|
||||
Returns:
|
||||
batch_cuda: `batch` moved to a cuda device, if supported.
|
||||
"""
|
||||
for f in dataclasses.fields(obj):
|
||||
setattr(obj, f.name, try_to_cuda(getattr(obj, f.name)))
|
||||
return obj
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import torch
|
||||
import dataclasses
|
||||
import torch.nn.functional as F
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class CoTrackerData:
|
||||
"""
|
||||
Dataclass for storing video tracks data.
|
||||
"""
|
||||
|
||||
video: torch.Tensor # B, S, C, H, W
|
||||
trajectory: torch.Tensor # B, S, N, 2
|
||||
visibility: torch.Tensor # B, S, N
|
||||
# optional data
|
||||
valid: Optional[torch.Tensor] = None # B, S, N
|
||||
segmentation: Optional[torch.Tensor] = None # B, S, 1, H, W
|
||||
seq_name: Optional[str] = None
|
||||
query_points: Optional[torch.Tensor] = None # TapVID evaluation format
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
"""
|
||||
Collate function for video tracks data.
|
||||
"""
|
||||
video = torch.stack([b.video for b in batch], dim=0)
|
||||
trajectory = torch.stack([b.trajectory for b in batch], dim=0)
|
||||
visibility = torch.stack([b.visibility for b in batch], dim=0)
|
||||
query_points = segmentation = None
|
||||
if batch[0].query_points is not None:
|
||||
query_points = torch.stack([b.query_points for b in batch], dim=0)
|
||||
if batch[0].segmentation is not None:
|
||||
segmentation = torch.stack([b.segmentation for b in batch], dim=0)
|
||||
seq_name = [b.seq_name for b in batch]
|
||||
|
||||
return CoTrackerData(
|
||||
video=video,
|
||||
trajectory=trajectory,
|
||||
visibility=visibility,
|
||||
segmentation=segmentation,
|
||||
seq_name=seq_name,
|
||||
query_points=query_points,
|
||||
)
|
||||
|
||||
|
||||
def collate_fn_train(batch):
|
||||
"""
|
||||
Collate function for video tracks data during training.
|
||||
"""
|
||||
gotit = [gotit for _, gotit in batch]
|
||||
video = torch.stack([b.video for b, _ in batch], dim=0)
|
||||
trajectory = torch.stack([b.trajectory for b, _ in batch], dim=0)
|
||||
visibility = torch.stack([b.visibility for b, _ in batch], dim=0)
|
||||
valid = torch.stack([b.valid for b, _ in batch], dim=0)
|
||||
seq_name = [b.seq_name for b, _ in batch]
|
||||
return (
|
||||
CoTrackerData(
|
||||
video=video,
|
||||
trajectory=trajectory,
|
||||
visibility=visibility,
|
||||
valid=valid,
|
||||
seq_name=seq_name,
|
||||
),
|
||||
gotit,
|
||||
)
|
||||
|
||||
|
||||
def try_to_cuda(t: Any) -> Any:
|
||||
"""
|
||||
Try to move the input variable `t` to a cuda device.
|
||||
|
||||
Args:
|
||||
t: Input.
|
||||
|
||||
Returns:
|
||||
t_cuda: `t` moved to a cuda device, if supported.
|
||||
"""
|
||||
try:
|
||||
t = t.float().cuda()
|
||||
except AttributeError:
|
||||
pass
|
||||
return t
|
||||
|
||||
|
||||
def dataclass_to_cuda_(obj):
|
||||
"""
|
||||
Move all contents of a dataclass to cuda inplace if supported.
|
||||
|
||||
Args:
|
||||
batch: Input dataclass.
|
||||
|
||||
Returns:
|
||||
batch_cuda: `batch` moved to a cuda device, if supported.
|
||||
"""
|
||||
for f in dataclasses.fields(obj):
|
||||
setattr(obj, f.name, try_to_cuda(getattr(obj, f.name)))
|
||||
return obj
|
||||
|
@@ -1,5 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
@@ -1,6 +1,6 @@
|
||||
defaults:
|
||||
- default_config_eval
|
||||
exp_dir: ./outputs/cotracker
|
||||
dataset_name: dynamic_replica
|
||||
|
||||
defaults:
|
||||
- default_config_eval
|
||||
exp_dir: ./outputs/cotracker
|
||||
dataset_name: dynamic_replica
|
||||
|
||||
|
@@ -1,6 +1,6 @@
|
||||
defaults:
|
||||
- default_config_eval
|
||||
exp_dir: ./outputs/cotracker
|
||||
dataset_name: tapvid_davis_first
|
||||
|
||||
defaults:
|
||||
- default_config_eval
|
||||
exp_dir: ./outputs/cotracker
|
||||
dataset_name: tapvid_davis_first
|
||||
|
||||
|
@@ -1,6 +1,6 @@
|
||||
defaults:
|
||||
- default_config_eval
|
||||
exp_dir: ./outputs/cotracker
|
||||
dataset_name: tapvid_davis_strided
|
||||
|
||||
defaults:
|
||||
- default_config_eval
|
||||
exp_dir: ./outputs/cotracker
|
||||
dataset_name: tapvid_davis_strided
|
||||
|
||||
|
@@ -1,6 +1,6 @@
|
||||
defaults:
|
||||
- default_config_eval
|
||||
exp_dir: ./outputs/cotracker
|
||||
dataset_name: tapvid_kinetics_first
|
||||
|
||||
defaults:
|
||||
- default_config_eval
|
||||
exp_dir: ./outputs/cotracker
|
||||
dataset_name: tapvid_kinetics_first
|
||||
|
||||
|
@@ -1,5 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
@@ -1,138 +1,138 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from typing import Iterable, Mapping, Tuple, Union
|
||||
|
||||
|
||||
def compute_tapvid_metrics(
|
||||
query_points: np.ndarray,
|
||||
gt_occluded: np.ndarray,
|
||||
gt_tracks: np.ndarray,
|
||||
pred_occluded: np.ndarray,
|
||||
pred_tracks: np.ndarray,
|
||||
query_mode: str,
|
||||
) -> Mapping[str, np.ndarray]:
|
||||
"""Computes TAP-Vid metrics (Jaccard, Pts. Within Thresh, Occ. Acc.)
|
||||
See the TAP-Vid paper for details on the metric computation. All inputs are
|
||||
given in raster coordinates. The first three arguments should be the direct
|
||||
outputs of the reader: the 'query_points', 'occluded', and 'target_points'.
|
||||
The paper metrics assume these are scaled relative to 256x256 images.
|
||||
pred_occluded and pred_tracks are your algorithm's predictions.
|
||||
This function takes a batch of inputs, and computes metrics separately for
|
||||
each video. The metrics for the full benchmark are a simple mean of the
|
||||
metrics across the full set of videos. These numbers are between 0 and 1,
|
||||
but the paper multiplies them by 100 to ease reading.
|
||||
Args:
|
||||
query_points: The query points, an in the format [t, y, x]. Its size is
|
||||
[b, n, 3], where b is the batch size and n is the number of queries
|
||||
gt_occluded: A boolean array of shape [b, n, t], where t is the number
|
||||
of frames. True indicates that the point is occluded.
|
||||
gt_tracks: The target points, of shape [b, n, t, 2]. Each point is
|
||||
in the format [x, y]
|
||||
pred_occluded: A boolean array of predicted occlusions, in the same
|
||||
format as gt_occluded.
|
||||
pred_tracks: An array of track predictions from your algorithm, in the
|
||||
same format as gt_tracks.
|
||||
query_mode: Either 'first' or 'strided', depending on how queries are
|
||||
sampled. If 'first', we assume the prior knowledge that all points
|
||||
before the query point are occluded, and these are removed from the
|
||||
evaluation.
|
||||
Returns:
|
||||
A dict with the following keys:
|
||||
occlusion_accuracy: Accuracy at predicting occlusion.
|
||||
pts_within_{x} for x in [1, 2, 4, 8, 16]: Fraction of points
|
||||
predicted to be within the given pixel threshold, ignoring occlusion
|
||||
prediction.
|
||||
jaccard_{x} for x in [1, 2, 4, 8, 16]: Jaccard metric for the given
|
||||
threshold
|
||||
average_pts_within_thresh: average across pts_within_{x}
|
||||
average_jaccard: average across jaccard_{x}
|
||||
"""
|
||||
|
||||
metrics = {}
|
||||
# Fixed bug is described in:
|
||||
# https://github.com/facebookresearch/co-tracker/issues/20
|
||||
eye = np.eye(gt_tracks.shape[2], dtype=np.int32)
|
||||
|
||||
if query_mode == "first":
|
||||
# evaluate frames after the query frame
|
||||
query_frame_to_eval_frames = np.cumsum(eye, axis=1) - eye
|
||||
elif query_mode == "strided":
|
||||
# evaluate all frames except the query frame
|
||||
query_frame_to_eval_frames = 1 - eye
|
||||
else:
|
||||
raise ValueError("Unknown query mode " + query_mode)
|
||||
|
||||
query_frame = query_points[..., 0]
|
||||
query_frame = np.round(query_frame).astype(np.int32)
|
||||
evaluation_points = query_frame_to_eval_frames[query_frame] > 0
|
||||
|
||||
# Occlusion accuracy is simply how often the predicted occlusion equals the
|
||||
# ground truth.
|
||||
occ_acc = np.sum(
|
||||
np.equal(pred_occluded, gt_occluded) & evaluation_points,
|
||||
axis=(1, 2),
|
||||
) / np.sum(evaluation_points)
|
||||
metrics["occlusion_accuracy"] = occ_acc
|
||||
|
||||
# Next, convert the predictions and ground truth positions into pixel
|
||||
# coordinates.
|
||||
visible = np.logical_not(gt_occluded)
|
||||
pred_visible = np.logical_not(pred_occluded)
|
||||
all_frac_within = []
|
||||
all_jaccard = []
|
||||
for thresh in [1, 2, 4, 8, 16]:
|
||||
# True positives are points that are within the threshold and where both
|
||||
# the prediction and the ground truth are listed as visible.
|
||||
within_dist = np.sum(
|
||||
np.square(pred_tracks - gt_tracks),
|
||||
axis=-1,
|
||||
) < np.square(thresh)
|
||||
is_correct = np.logical_and(within_dist, visible)
|
||||
|
||||
# Compute the frac_within_threshold, which is the fraction of points
|
||||
# within the threshold among points that are visible in the ground truth,
|
||||
# ignoring whether they're predicted to be visible.
|
||||
count_correct = np.sum(
|
||||
is_correct & evaluation_points,
|
||||
axis=(1, 2),
|
||||
)
|
||||
count_visible_points = np.sum(visible & evaluation_points, axis=(1, 2))
|
||||
frac_correct = count_correct / count_visible_points
|
||||
metrics["pts_within_" + str(thresh)] = frac_correct
|
||||
all_frac_within.append(frac_correct)
|
||||
|
||||
true_positives = np.sum(
|
||||
is_correct & pred_visible & evaluation_points, axis=(1, 2)
|
||||
)
|
||||
|
||||
# The denominator of the jaccard metric is the true positives plus
|
||||
# false positives plus false negatives. However, note that true positives
|
||||
# plus false negatives is simply the number of points in the ground truth
|
||||
# which is easier to compute than trying to compute all three quantities.
|
||||
# Thus we just add the number of points in the ground truth to the number
|
||||
# of false positives.
|
||||
#
|
||||
# False positives are simply points that are predicted to be visible,
|
||||
# but the ground truth is not visible or too far from the prediction.
|
||||
gt_positives = np.sum(visible & evaluation_points, axis=(1, 2))
|
||||
false_positives = (~visible) & pred_visible
|
||||
false_positives = false_positives | ((~within_dist) & pred_visible)
|
||||
false_positives = np.sum(false_positives & evaluation_points, axis=(1, 2))
|
||||
jaccard = true_positives / (gt_positives + false_positives)
|
||||
metrics["jaccard_" + str(thresh)] = jaccard
|
||||
all_jaccard.append(jaccard)
|
||||
metrics["average_jaccard"] = np.mean(
|
||||
np.stack(all_jaccard, axis=1),
|
||||
axis=1,
|
||||
)
|
||||
metrics["average_pts_within_thresh"] = np.mean(
|
||||
np.stack(all_frac_within, axis=1),
|
||||
axis=1,
|
||||
)
|
||||
return metrics
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from typing import Iterable, Mapping, Tuple, Union
|
||||
|
||||
|
||||
def compute_tapvid_metrics(
|
||||
query_points: np.ndarray,
|
||||
gt_occluded: np.ndarray,
|
||||
gt_tracks: np.ndarray,
|
||||
pred_occluded: np.ndarray,
|
||||
pred_tracks: np.ndarray,
|
||||
query_mode: str,
|
||||
) -> Mapping[str, np.ndarray]:
|
||||
"""Computes TAP-Vid metrics (Jaccard, Pts. Within Thresh, Occ. Acc.)
|
||||
See the TAP-Vid paper for details on the metric computation. All inputs are
|
||||
given in raster coordinates. The first three arguments should be the direct
|
||||
outputs of the reader: the 'query_points', 'occluded', and 'target_points'.
|
||||
The paper metrics assume these are scaled relative to 256x256 images.
|
||||
pred_occluded and pred_tracks are your algorithm's predictions.
|
||||
This function takes a batch of inputs, and computes metrics separately for
|
||||
each video. The metrics for the full benchmark are a simple mean of the
|
||||
metrics across the full set of videos. These numbers are between 0 and 1,
|
||||
but the paper multiplies them by 100 to ease reading.
|
||||
Args:
|
||||
query_points: The query points, an in the format [t, y, x]. Its size is
|
||||
[b, n, 3], where b is the batch size and n is the number of queries
|
||||
gt_occluded: A boolean array of shape [b, n, t], where t is the number
|
||||
of frames. True indicates that the point is occluded.
|
||||
gt_tracks: The target points, of shape [b, n, t, 2]. Each point is
|
||||
in the format [x, y]
|
||||
pred_occluded: A boolean array of predicted occlusions, in the same
|
||||
format as gt_occluded.
|
||||
pred_tracks: An array of track predictions from your algorithm, in the
|
||||
same format as gt_tracks.
|
||||
query_mode: Either 'first' or 'strided', depending on how queries are
|
||||
sampled. If 'first', we assume the prior knowledge that all points
|
||||
before the query point are occluded, and these are removed from the
|
||||
evaluation.
|
||||
Returns:
|
||||
A dict with the following keys:
|
||||
occlusion_accuracy: Accuracy at predicting occlusion.
|
||||
pts_within_{x} for x in [1, 2, 4, 8, 16]: Fraction of points
|
||||
predicted to be within the given pixel threshold, ignoring occlusion
|
||||
prediction.
|
||||
jaccard_{x} for x in [1, 2, 4, 8, 16]: Jaccard metric for the given
|
||||
threshold
|
||||
average_pts_within_thresh: average across pts_within_{x}
|
||||
average_jaccard: average across jaccard_{x}
|
||||
"""
|
||||
|
||||
metrics = {}
|
||||
# Fixed bug is described in:
|
||||
# https://github.com/facebookresearch/co-tracker/issues/20
|
||||
eye = np.eye(gt_tracks.shape[2], dtype=np.int32)
|
||||
|
||||
if query_mode == "first":
|
||||
# evaluate frames after the query frame
|
||||
query_frame_to_eval_frames = np.cumsum(eye, axis=1) - eye
|
||||
elif query_mode == "strided":
|
||||
# evaluate all frames except the query frame
|
||||
query_frame_to_eval_frames = 1 - eye
|
||||
else:
|
||||
raise ValueError("Unknown query mode " + query_mode)
|
||||
|
||||
query_frame = query_points[..., 0]
|
||||
query_frame = np.round(query_frame).astype(np.int32)
|
||||
evaluation_points = query_frame_to_eval_frames[query_frame] > 0
|
||||
|
||||
# Occlusion accuracy is simply how often the predicted occlusion equals the
|
||||
# ground truth.
|
||||
occ_acc = np.sum(
|
||||
np.equal(pred_occluded, gt_occluded) & evaluation_points,
|
||||
axis=(1, 2),
|
||||
) / np.sum(evaluation_points)
|
||||
metrics["occlusion_accuracy"] = occ_acc
|
||||
|
||||
# Next, convert the predictions and ground truth positions into pixel
|
||||
# coordinates.
|
||||
visible = np.logical_not(gt_occluded)
|
||||
pred_visible = np.logical_not(pred_occluded)
|
||||
all_frac_within = []
|
||||
all_jaccard = []
|
||||
for thresh in [1, 2, 4, 8, 16]:
|
||||
# True positives are points that are within the threshold and where both
|
||||
# the prediction and the ground truth are listed as visible.
|
||||
within_dist = np.sum(
|
||||
np.square(pred_tracks - gt_tracks),
|
||||
axis=-1,
|
||||
) < np.square(thresh)
|
||||
is_correct = np.logical_and(within_dist, visible)
|
||||
|
||||
# Compute the frac_within_threshold, which is the fraction of points
|
||||
# within the threshold among points that are visible in the ground truth,
|
||||
# ignoring whether they're predicted to be visible.
|
||||
count_correct = np.sum(
|
||||
is_correct & evaluation_points,
|
||||
axis=(1, 2),
|
||||
)
|
||||
count_visible_points = np.sum(visible & evaluation_points, axis=(1, 2))
|
||||
frac_correct = count_correct / count_visible_points
|
||||
metrics["pts_within_" + str(thresh)] = frac_correct
|
||||
all_frac_within.append(frac_correct)
|
||||
|
||||
true_positives = np.sum(
|
||||
is_correct & pred_visible & evaluation_points, axis=(1, 2)
|
||||
)
|
||||
|
||||
# The denominator of the jaccard metric is the true positives plus
|
||||
# false positives plus false negatives. However, note that true positives
|
||||
# plus false negatives is simply the number of points in the ground truth
|
||||
# which is easier to compute than trying to compute all three quantities.
|
||||
# Thus we just add the number of points in the ground truth to the number
|
||||
# of false positives.
|
||||
#
|
||||
# False positives are simply points that are predicted to be visible,
|
||||
# but the ground truth is not visible or too far from the prediction.
|
||||
gt_positives = np.sum(visible & evaluation_points, axis=(1, 2))
|
||||
false_positives = (~visible) & pred_visible
|
||||
false_positives = false_positives | ((~within_dist) & pred_visible)
|
||||
false_positives = np.sum(false_positives & evaluation_points, axis=(1, 2))
|
||||
jaccard = true_positives / (gt_positives + false_positives)
|
||||
metrics["jaccard_" + str(thresh)] = jaccard
|
||||
all_jaccard.append(jaccard)
|
||||
metrics["average_jaccard"] = np.mean(
|
||||
np.stack(all_jaccard, axis=1),
|
||||
axis=1,
|
||||
)
|
||||
metrics["average_pts_within_thresh"] = np.mean(
|
||||
np.stack(all_frac_within, axis=1),
|
||||
axis=1,
|
||||
)
|
||||
return metrics
|
||||
|
@@ -1,253 +1,253 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from collections import defaultdict
|
||||
import os
|
||||
from typing import Optional
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from cotracker.datasets.utils import dataclass_to_cuda_
|
||||
from cotracker.utils.visualizer import Visualizer
|
||||
from cotracker.models.core.model_utils import reduce_masked_mean
|
||||
from cotracker.evaluation.core.eval_utils import compute_tapvid_metrics
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
class Evaluator:
|
||||
"""
|
||||
A class defining the CoTracker evaluator.
|
||||
"""
|
||||
|
||||
def __init__(self, exp_dir) -> None:
|
||||
# Visualization
|
||||
self.exp_dir = exp_dir
|
||||
os.makedirs(exp_dir, exist_ok=True)
|
||||
self.visualization_filepaths = defaultdict(lambda: defaultdict(list))
|
||||
self.visualize_dir = os.path.join(exp_dir, "visualisations")
|
||||
|
||||
def compute_metrics(self, metrics, sample, pred_trajectory, dataset_name):
|
||||
if isinstance(pred_trajectory, tuple):
|
||||
pred_trajectory, pred_visibility = pred_trajectory
|
||||
else:
|
||||
pred_visibility = None
|
||||
if "tapvid" in dataset_name:
|
||||
B, T, N, D = sample.trajectory.shape
|
||||
traj = sample.trajectory.clone()
|
||||
thr = 0.9
|
||||
|
||||
if pred_visibility is None:
|
||||
logging.warning("visibility is NONE")
|
||||
pred_visibility = torch.zeros_like(sample.visibility)
|
||||
|
||||
if not pred_visibility.dtype == torch.bool:
|
||||
pred_visibility = pred_visibility > thr
|
||||
|
||||
query_points = sample.query_points.clone().cpu().numpy()
|
||||
|
||||
pred_visibility = pred_visibility[:, :, :N]
|
||||
pred_trajectory = pred_trajectory[:, :, :N]
|
||||
|
||||
gt_tracks = traj.permute(0, 2, 1, 3).cpu().numpy()
|
||||
gt_occluded = (
|
||||
torch.logical_not(sample.visibility.clone().permute(0, 2, 1)).cpu().numpy()
|
||||
)
|
||||
|
||||
pred_occluded = (
|
||||
torch.logical_not(pred_visibility.clone().permute(0, 2, 1)).cpu().numpy()
|
||||
)
|
||||
pred_tracks = pred_trajectory.permute(0, 2, 1, 3).cpu().numpy()
|
||||
|
||||
out_metrics = compute_tapvid_metrics(
|
||||
query_points,
|
||||
gt_occluded,
|
||||
gt_tracks,
|
||||
pred_occluded,
|
||||
pred_tracks,
|
||||
query_mode="strided" if "strided" in dataset_name else "first",
|
||||
)
|
||||
|
||||
metrics[sample.seq_name[0]] = out_metrics
|
||||
for metric_name in out_metrics.keys():
|
||||
if "avg" not in metrics:
|
||||
metrics["avg"] = {}
|
||||
metrics["avg"][metric_name] = np.mean(
|
||||
[v[metric_name] for k, v in metrics.items() if k != "avg"]
|
||||
)
|
||||
|
||||
logging.info(f"Metrics: {out_metrics}")
|
||||
logging.info(f"avg: {metrics['avg']}")
|
||||
print("metrics", out_metrics)
|
||||
print("avg", metrics["avg"])
|
||||
elif dataset_name == "dynamic_replica" or dataset_name == "pointodyssey":
|
||||
*_, N, _ = sample.trajectory.shape
|
||||
B, T, N = sample.visibility.shape
|
||||
H, W = sample.video.shape[-2:]
|
||||
device = sample.video.device
|
||||
|
||||
out_metrics = {}
|
||||
|
||||
d_vis_sum = d_occ_sum = d_sum_all = 0.0
|
||||
thrs = [1, 2, 4, 8, 16]
|
||||
sx_ = (W - 1) / 255.0
|
||||
sy_ = (H - 1) / 255.0
|
||||
sc_py = np.array([sx_, sy_]).reshape([1, 1, 2])
|
||||
sc_pt = torch.from_numpy(sc_py).float().to(device)
|
||||
__, first_visible_inds = torch.max(sample.visibility, dim=1)
|
||||
|
||||
frame_ids_tensor = torch.arange(T, device=device)[None, :, None].repeat(B, 1, N)
|
||||
start_tracking_mask = frame_ids_tensor > (first_visible_inds.unsqueeze(1))
|
||||
|
||||
for thr in thrs:
|
||||
d_ = (
|
||||
torch.norm(
|
||||
pred_trajectory[..., :2] / sc_pt - sample.trajectory[..., :2] / sc_pt,
|
||||
dim=-1,
|
||||
)
|
||||
< thr
|
||||
).float() # B,S-1,N
|
||||
d_occ = (
|
||||
reduce_masked_mean(d_, (1 - sample.visibility) * start_tracking_mask).item()
|
||||
* 100.0
|
||||
)
|
||||
d_occ_sum += d_occ
|
||||
out_metrics[f"accuracy_occ_{thr}"] = d_occ
|
||||
|
||||
d_vis = (
|
||||
reduce_masked_mean(d_, sample.visibility * start_tracking_mask).item() * 100.0
|
||||
)
|
||||
d_vis_sum += d_vis
|
||||
out_metrics[f"accuracy_vis_{thr}"] = d_vis
|
||||
|
||||
d_all = reduce_masked_mean(d_, start_tracking_mask).item() * 100.0
|
||||
d_sum_all += d_all
|
||||
out_metrics[f"accuracy_{thr}"] = d_all
|
||||
|
||||
d_occ_avg = d_occ_sum / len(thrs)
|
||||
d_vis_avg = d_vis_sum / len(thrs)
|
||||
d_all_avg = d_sum_all / len(thrs)
|
||||
|
||||
sur_thr = 50
|
||||
dists = torch.norm(
|
||||
pred_trajectory[..., :2] / sc_pt - sample.trajectory[..., :2] / sc_pt,
|
||||
dim=-1,
|
||||
) # B,S,N
|
||||
dist_ok = 1 - (dists > sur_thr).float() * sample.visibility # B,S,N
|
||||
survival = torch.cumprod(dist_ok, dim=1) # B,S,N
|
||||
out_metrics["survival"] = torch.mean(survival).item() * 100.0
|
||||
|
||||
out_metrics["accuracy_occ"] = d_occ_avg
|
||||
out_metrics["accuracy_vis"] = d_vis_avg
|
||||
out_metrics["accuracy"] = d_all_avg
|
||||
|
||||
metrics[sample.seq_name[0]] = out_metrics
|
||||
for metric_name in out_metrics.keys():
|
||||
if "avg" not in metrics:
|
||||
metrics["avg"] = {}
|
||||
metrics["avg"][metric_name] = float(
|
||||
np.mean([v[metric_name] for k, v in metrics.items() if k != "avg"])
|
||||
)
|
||||
|
||||
logging.info(f"Metrics: {out_metrics}")
|
||||
logging.info(f"avg: {metrics['avg']}")
|
||||
print("metrics", out_metrics)
|
||||
print("avg", metrics["avg"])
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate_sequence(
|
||||
self,
|
||||
model,
|
||||
test_dataloader: torch.utils.data.DataLoader,
|
||||
dataset_name: str,
|
||||
train_mode=False,
|
||||
visualize_every: int = 1,
|
||||
writer: Optional[SummaryWriter] = None,
|
||||
step: Optional[int] = 0,
|
||||
):
|
||||
metrics = {}
|
||||
|
||||
vis = Visualizer(
|
||||
save_dir=self.exp_dir,
|
||||
fps=7,
|
||||
)
|
||||
|
||||
for ind, sample in enumerate(tqdm(test_dataloader)):
|
||||
if isinstance(sample, tuple):
|
||||
sample, gotit = sample
|
||||
if not all(gotit):
|
||||
print("batch is None")
|
||||
continue
|
||||
if torch.cuda.is_available():
|
||||
dataclass_to_cuda_(sample)
|
||||
device = torch.device("cuda")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
if (
|
||||
not train_mode
|
||||
and hasattr(model, "sequence_len")
|
||||
and (sample.visibility[:, : model.sequence_len].sum() == 0)
|
||||
):
|
||||
print(f"skipping batch {ind}")
|
||||
continue
|
||||
|
||||
if "tapvid" in dataset_name:
|
||||
queries = sample.query_points.clone().float()
|
||||
|
||||
queries = torch.stack(
|
||||
[
|
||||
queries[:, :, 0],
|
||||
queries[:, :, 2],
|
||||
queries[:, :, 1],
|
||||
],
|
||||
dim=2,
|
||||
).to(device)
|
||||
else:
|
||||
queries = torch.cat(
|
||||
[
|
||||
torch.zeros_like(sample.trajectory[:, 0, :, :1]),
|
||||
sample.trajectory[:, 0],
|
||||
],
|
||||
dim=2,
|
||||
).to(device)
|
||||
|
||||
pred_tracks = model(sample.video, queries)
|
||||
if "strided" in dataset_name:
|
||||
inv_video = sample.video.flip(1).clone()
|
||||
inv_queries = queries.clone()
|
||||
inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1
|
||||
|
||||
pred_trj, pred_vsb = pred_tracks
|
||||
inv_pred_trj, inv_pred_vsb = model(inv_video, inv_queries)
|
||||
|
||||
inv_pred_trj = inv_pred_trj.flip(1)
|
||||
inv_pred_vsb = inv_pred_vsb.flip(1)
|
||||
|
||||
mask = pred_trj == 0
|
||||
|
||||
pred_trj[mask] = inv_pred_trj[mask]
|
||||
pred_vsb[mask[:, :, :, 0]] = inv_pred_vsb[mask[:, :, :, 0]]
|
||||
|
||||
pred_tracks = pred_trj, pred_vsb
|
||||
|
||||
if dataset_name == "badja" or dataset_name == "fastcapture":
|
||||
seq_name = sample.seq_name[0]
|
||||
else:
|
||||
seq_name = str(ind)
|
||||
if ind % visualize_every == 0:
|
||||
vis.visualize(
|
||||
sample.video,
|
||||
pred_tracks[0] if isinstance(pred_tracks, tuple) else pred_tracks,
|
||||
filename=dataset_name + "_" + seq_name,
|
||||
writer=writer,
|
||||
step=step,
|
||||
)
|
||||
|
||||
self.compute_metrics(metrics, sample, pred_tracks, dataset_name)
|
||||
return metrics
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from collections import defaultdict
|
||||
import os
|
||||
from typing import Optional
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from cotracker.datasets.utils import dataclass_to_cuda_
|
||||
from cotracker.utils.visualizer import Visualizer
|
||||
from cotracker.models.core.model_utils import reduce_masked_mean
|
||||
from cotracker.evaluation.core.eval_utils import compute_tapvid_metrics
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
class Evaluator:
|
||||
"""
|
||||
A class defining the CoTracker evaluator.
|
||||
"""
|
||||
|
||||
def __init__(self, exp_dir) -> None:
|
||||
# Visualization
|
||||
self.exp_dir = exp_dir
|
||||
os.makedirs(exp_dir, exist_ok=True)
|
||||
self.visualization_filepaths = defaultdict(lambda: defaultdict(list))
|
||||
self.visualize_dir = os.path.join(exp_dir, "visualisations")
|
||||
|
||||
def compute_metrics(self, metrics, sample, pred_trajectory, dataset_name):
|
||||
if isinstance(pred_trajectory, tuple):
|
||||
pred_trajectory, pred_visibility = pred_trajectory
|
||||
else:
|
||||
pred_visibility = None
|
||||
if "tapvid" in dataset_name:
|
||||
B, T, N, D = sample.trajectory.shape
|
||||
traj = sample.trajectory.clone()
|
||||
thr = 0.9
|
||||
|
||||
if pred_visibility is None:
|
||||
logging.warning("visibility is NONE")
|
||||
pred_visibility = torch.zeros_like(sample.visibility)
|
||||
|
||||
if not pred_visibility.dtype == torch.bool:
|
||||
pred_visibility = pred_visibility > thr
|
||||
|
||||
query_points = sample.query_points.clone().cpu().numpy()
|
||||
|
||||
pred_visibility = pred_visibility[:, :, :N]
|
||||
pred_trajectory = pred_trajectory[:, :, :N]
|
||||
|
||||
gt_tracks = traj.permute(0, 2, 1, 3).cpu().numpy()
|
||||
gt_occluded = (
|
||||
torch.logical_not(sample.visibility.clone().permute(0, 2, 1)).cpu().numpy()
|
||||
)
|
||||
|
||||
pred_occluded = (
|
||||
torch.logical_not(pred_visibility.clone().permute(0, 2, 1)).cpu().numpy()
|
||||
)
|
||||
pred_tracks = pred_trajectory.permute(0, 2, 1, 3).cpu().numpy()
|
||||
|
||||
out_metrics = compute_tapvid_metrics(
|
||||
query_points,
|
||||
gt_occluded,
|
||||
gt_tracks,
|
||||
pred_occluded,
|
||||
pred_tracks,
|
||||
query_mode="strided" if "strided" in dataset_name else "first",
|
||||
)
|
||||
|
||||
metrics[sample.seq_name[0]] = out_metrics
|
||||
for metric_name in out_metrics.keys():
|
||||
if "avg" not in metrics:
|
||||
metrics["avg"] = {}
|
||||
metrics["avg"][metric_name] = np.mean(
|
||||
[v[metric_name] for k, v in metrics.items() if k != "avg"]
|
||||
)
|
||||
|
||||
logging.info(f"Metrics: {out_metrics}")
|
||||
logging.info(f"avg: {metrics['avg']}")
|
||||
print("metrics", out_metrics)
|
||||
print("avg", metrics["avg"])
|
||||
elif dataset_name == "dynamic_replica" or dataset_name == "pointodyssey":
|
||||
*_, N, _ = sample.trajectory.shape
|
||||
B, T, N = sample.visibility.shape
|
||||
H, W = sample.video.shape[-2:]
|
||||
device = sample.video.device
|
||||
|
||||
out_metrics = {}
|
||||
|
||||
d_vis_sum = d_occ_sum = d_sum_all = 0.0
|
||||
thrs = [1, 2, 4, 8, 16]
|
||||
sx_ = (W - 1) / 255.0
|
||||
sy_ = (H - 1) / 255.0
|
||||
sc_py = np.array([sx_, sy_]).reshape([1, 1, 2])
|
||||
sc_pt = torch.from_numpy(sc_py).float().to(device)
|
||||
__, first_visible_inds = torch.max(sample.visibility, dim=1)
|
||||
|
||||
frame_ids_tensor = torch.arange(T, device=device)[None, :, None].repeat(B, 1, N)
|
||||
start_tracking_mask = frame_ids_tensor > (first_visible_inds.unsqueeze(1))
|
||||
|
||||
for thr in thrs:
|
||||
d_ = (
|
||||
torch.norm(
|
||||
pred_trajectory[..., :2] / sc_pt - sample.trajectory[..., :2] / sc_pt,
|
||||
dim=-1,
|
||||
)
|
||||
< thr
|
||||
).float() # B,S-1,N
|
||||
d_occ = (
|
||||
reduce_masked_mean(d_, (1 - sample.visibility) * start_tracking_mask).item()
|
||||
* 100.0
|
||||
)
|
||||
d_occ_sum += d_occ
|
||||
out_metrics[f"accuracy_occ_{thr}"] = d_occ
|
||||
|
||||
d_vis = (
|
||||
reduce_masked_mean(d_, sample.visibility * start_tracking_mask).item() * 100.0
|
||||
)
|
||||
d_vis_sum += d_vis
|
||||
out_metrics[f"accuracy_vis_{thr}"] = d_vis
|
||||
|
||||
d_all = reduce_masked_mean(d_, start_tracking_mask).item() * 100.0
|
||||
d_sum_all += d_all
|
||||
out_metrics[f"accuracy_{thr}"] = d_all
|
||||
|
||||
d_occ_avg = d_occ_sum / len(thrs)
|
||||
d_vis_avg = d_vis_sum / len(thrs)
|
||||
d_all_avg = d_sum_all / len(thrs)
|
||||
|
||||
sur_thr = 50
|
||||
dists = torch.norm(
|
||||
pred_trajectory[..., :2] / sc_pt - sample.trajectory[..., :2] / sc_pt,
|
||||
dim=-1,
|
||||
) # B,S,N
|
||||
dist_ok = 1 - (dists > sur_thr).float() * sample.visibility # B,S,N
|
||||
survival = torch.cumprod(dist_ok, dim=1) # B,S,N
|
||||
out_metrics["survival"] = torch.mean(survival).item() * 100.0
|
||||
|
||||
out_metrics["accuracy_occ"] = d_occ_avg
|
||||
out_metrics["accuracy_vis"] = d_vis_avg
|
||||
out_metrics["accuracy"] = d_all_avg
|
||||
|
||||
metrics[sample.seq_name[0]] = out_metrics
|
||||
for metric_name in out_metrics.keys():
|
||||
if "avg" not in metrics:
|
||||
metrics["avg"] = {}
|
||||
metrics["avg"][metric_name] = float(
|
||||
np.mean([v[metric_name] for k, v in metrics.items() if k != "avg"])
|
||||
)
|
||||
|
||||
logging.info(f"Metrics: {out_metrics}")
|
||||
logging.info(f"avg: {metrics['avg']}")
|
||||
print("metrics", out_metrics)
|
||||
print("avg", metrics["avg"])
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate_sequence(
|
||||
self,
|
||||
model,
|
||||
test_dataloader: torch.utils.data.DataLoader,
|
||||
dataset_name: str,
|
||||
train_mode=False,
|
||||
visualize_every: int = 1,
|
||||
writer: Optional[SummaryWriter] = None,
|
||||
step: Optional[int] = 0,
|
||||
):
|
||||
metrics = {}
|
||||
|
||||
vis = Visualizer(
|
||||
save_dir=self.exp_dir,
|
||||
fps=7,
|
||||
)
|
||||
|
||||
for ind, sample in enumerate(tqdm(test_dataloader)):
|
||||
if isinstance(sample, tuple):
|
||||
sample, gotit = sample
|
||||
if not all(gotit):
|
||||
print("batch is None")
|
||||
continue
|
||||
if torch.cuda.is_available():
|
||||
dataclass_to_cuda_(sample)
|
||||
device = torch.device("cuda")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
if (
|
||||
not train_mode
|
||||
and hasattr(model, "sequence_len")
|
||||
and (sample.visibility[:, : model.sequence_len].sum() == 0)
|
||||
):
|
||||
print(f"skipping batch {ind}")
|
||||
continue
|
||||
|
||||
if "tapvid" in dataset_name:
|
||||
queries = sample.query_points.clone().float()
|
||||
|
||||
queries = torch.stack(
|
||||
[
|
||||
queries[:, :, 0],
|
||||
queries[:, :, 2],
|
||||
queries[:, :, 1],
|
||||
],
|
||||
dim=2,
|
||||
).to(device)
|
||||
else:
|
||||
queries = torch.cat(
|
||||
[
|
||||
torch.zeros_like(sample.trajectory[:, 0, :, :1]),
|
||||
sample.trajectory[:, 0],
|
||||
],
|
||||
dim=2,
|
||||
).to(device)
|
||||
|
||||
pred_tracks = model(sample.video, queries)
|
||||
if "strided" in dataset_name:
|
||||
inv_video = sample.video.flip(1).clone()
|
||||
inv_queries = queries.clone()
|
||||
inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1
|
||||
|
||||
pred_trj, pred_vsb = pred_tracks
|
||||
inv_pred_trj, inv_pred_vsb = model(inv_video, inv_queries)
|
||||
|
||||
inv_pred_trj = inv_pred_trj.flip(1)
|
||||
inv_pred_vsb = inv_pred_vsb.flip(1)
|
||||
|
||||
mask = pred_trj == 0
|
||||
|
||||
pred_trj[mask] = inv_pred_trj[mask]
|
||||
pred_vsb[mask[:, :, :, 0]] = inv_pred_vsb[mask[:, :, :, 0]]
|
||||
|
||||
pred_tracks = pred_trj, pred_vsb
|
||||
|
||||
if dataset_name == "badja" or dataset_name == "fastcapture":
|
||||
seq_name = sample.seq_name[0]
|
||||
else:
|
||||
seq_name = str(ind)
|
||||
if ind % visualize_every == 0:
|
||||
vis.visualize(
|
||||
sample.video,
|
||||
pred_tracks[0] if isinstance(pred_tracks, tuple) else pred_tracks,
|
||||
filename=dataset_name + "_" + seq_name,
|
||||
writer=writer,
|
||||
step=step,
|
||||
)
|
||||
|
||||
self.compute_metrics(metrics, sample, pred_tracks, dataset_name)
|
||||
return metrics
|
||||
|
@@ -1,169 +1,169 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import hydra
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from cotracker.datasets.tap_vid_datasets import TapVidDataset
|
||||
from cotracker.datasets.dr_dataset import DynamicReplicaDataset
|
||||
from cotracker.datasets.utils import collate_fn
|
||||
|
||||
from cotracker.models.evaluation_predictor import EvaluationPredictor
|
||||
|
||||
from cotracker.evaluation.core.evaluator import Evaluator
|
||||
from cotracker.models.build_cotracker import (
|
||||
build_cotracker,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class DefaultConfig:
|
||||
# Directory where all outputs of the experiment will be saved.
|
||||
exp_dir: str = "./outputs"
|
||||
|
||||
# Name of the dataset to be used for the evaluation.
|
||||
dataset_name: str = "tapvid_davis_first"
|
||||
# The root directory of the dataset.
|
||||
dataset_root: str = "./"
|
||||
|
||||
# Path to the pre-trained model checkpoint to be used for the evaluation.
|
||||
# The default value is the path to a specific CoTracker model checkpoint.
|
||||
checkpoint: str = "./checkpoints/cotracker2.pth"
|
||||
|
||||
# EvaluationPredictor parameters
|
||||
# The size (N) of the support grid used in the predictor.
|
||||
# The total number of points is (N*N).
|
||||
grid_size: int = 5
|
||||
# The size (N) of the local support grid.
|
||||
local_grid_size: int = 8
|
||||
# A flag indicating whether to evaluate one ground truth point at a time.
|
||||
single_point: bool = True
|
||||
# The number of iterative updates for each sliding window.
|
||||
n_iters: int = 6
|
||||
|
||||
seed: int = 0
|
||||
gpu_idx: int = 0
|
||||
|
||||
# Override hydra's working directory to current working dir,
|
||||
# also disable storing the .hydra logs:
|
||||
hydra: dict = field(
|
||||
default_factory=lambda: {
|
||||
"run": {"dir": "."},
|
||||
"output_subdir": None,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def run_eval(cfg: DefaultConfig):
|
||||
"""
|
||||
The function evaluates CoTracker on a specified benchmark dataset based on a provided configuration.
|
||||
|
||||
Args:
|
||||
cfg (DefaultConfig): An instance of DefaultConfig class which includes:
|
||||
- exp_dir (str): The directory path for the experiment.
|
||||
- dataset_name (str): The name of the dataset to be used.
|
||||
- dataset_root (str): The root directory of the dataset.
|
||||
- checkpoint (str): The path to the CoTracker model's checkpoint.
|
||||
- single_point (bool): A flag indicating whether to evaluate one ground truth point at a time.
|
||||
- n_iters (int): The number of iterative updates for each sliding window.
|
||||
- seed (int): The seed for setting the random state for reproducibility.
|
||||
- gpu_idx (int): The index of the GPU to be used.
|
||||
"""
|
||||
# Creating the experiment directory if it doesn't exist
|
||||
os.makedirs(cfg.exp_dir, exist_ok=True)
|
||||
|
||||
# Saving the experiment configuration to a .yaml file in the experiment directory
|
||||
cfg_file = os.path.join(cfg.exp_dir, "expconfig.yaml")
|
||||
with open(cfg_file, "w") as f:
|
||||
OmegaConf.save(config=cfg, f=f)
|
||||
|
||||
evaluator = Evaluator(cfg.exp_dir)
|
||||
cotracker_model = build_cotracker(cfg.checkpoint)
|
||||
|
||||
# Creating the EvaluationPredictor object
|
||||
predictor = EvaluationPredictor(
|
||||
cotracker_model,
|
||||
grid_size=cfg.grid_size,
|
||||
local_grid_size=cfg.local_grid_size,
|
||||
single_point=cfg.single_point,
|
||||
n_iters=cfg.n_iters,
|
||||
)
|
||||
if torch.cuda.is_available():
|
||||
predictor.model = predictor.model.cuda()
|
||||
|
||||
# Setting the random seeds
|
||||
torch.manual_seed(cfg.seed)
|
||||
np.random.seed(cfg.seed)
|
||||
|
||||
# Constructing the specified dataset
|
||||
curr_collate_fn = collate_fn
|
||||
if "tapvid" in cfg.dataset_name:
|
||||
dataset_type = cfg.dataset_name.split("_")[1]
|
||||
if dataset_type == "davis":
|
||||
data_root = os.path.join(cfg.dataset_root, "tapvid_davis", "tapvid_davis.pkl")
|
||||
elif dataset_type == "kinetics":
|
||||
data_root = os.path.join(
|
||||
cfg.dataset_root, "/kinetics/kinetics-dataset/k700-2020/tapvid_kinetics"
|
||||
)
|
||||
test_dataset = TapVidDataset(
|
||||
dataset_type=dataset_type,
|
||||
data_root=data_root,
|
||||
queried_first=not "strided" in cfg.dataset_name,
|
||||
)
|
||||
elif cfg.dataset_name == "dynamic_replica":
|
||||
test_dataset = DynamicReplicaDataset(sample_len=300, only_first_n_samples=1)
|
||||
|
||||
# Creating the DataLoader object
|
||||
test_dataloader = torch.utils.data.DataLoader(
|
||||
test_dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=14,
|
||||
collate_fn=curr_collate_fn,
|
||||
)
|
||||
|
||||
# Timing and conducting the evaluation
|
||||
import time
|
||||
|
||||
start = time.time()
|
||||
evaluate_result = evaluator.evaluate_sequence(
|
||||
predictor,
|
||||
test_dataloader,
|
||||
dataset_name=cfg.dataset_name,
|
||||
)
|
||||
end = time.time()
|
||||
print(end - start)
|
||||
|
||||
# Saving the evaluation results to a .json file
|
||||
evaluate_result = evaluate_result["avg"]
|
||||
print("evaluate_result", evaluate_result)
|
||||
result_file = os.path.join(cfg.exp_dir, f"result_eval_.json")
|
||||
evaluate_result["time"] = end - start
|
||||
print(f"Dumping eval results to {result_file}.")
|
||||
with open(result_file, "w") as f:
|
||||
json.dump(evaluate_result, f)
|
||||
|
||||
|
||||
cs = hydra.core.config_store.ConfigStore.instance()
|
||||
cs.store(name="default_config_eval", node=DefaultConfig)
|
||||
|
||||
|
||||
@hydra.main(config_path="./configs/", config_name="default_config_eval")
|
||||
def evaluate(cfg: DefaultConfig) -> None:
|
||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_idx)
|
||||
run_eval(cfg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
evaluate()
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import hydra
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from cotracker.datasets.tap_vid_datasets import TapVidDataset
|
||||
from cotracker.datasets.dr_dataset import DynamicReplicaDataset
|
||||
from cotracker.datasets.utils import collate_fn
|
||||
|
||||
from cotracker.models.evaluation_predictor import EvaluationPredictor
|
||||
|
||||
from cotracker.evaluation.core.evaluator import Evaluator
|
||||
from cotracker.models.build_cotracker import (
|
||||
build_cotracker,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class DefaultConfig:
|
||||
# Directory where all outputs of the experiment will be saved.
|
||||
exp_dir: str = "./outputs"
|
||||
|
||||
# Name of the dataset to be used for the evaluation.
|
||||
dataset_name: str = "tapvid_davis_first"
|
||||
# The root directory of the dataset.
|
||||
dataset_root: str = "./"
|
||||
|
||||
# Path to the pre-trained model checkpoint to be used for the evaluation.
|
||||
# The default value is the path to a specific CoTracker model checkpoint.
|
||||
checkpoint: str = "./checkpoints/cotracker2.pth"
|
||||
|
||||
# EvaluationPredictor parameters
|
||||
# The size (N) of the support grid used in the predictor.
|
||||
# The total number of points is (N*N).
|
||||
grid_size: int = 5
|
||||
# The size (N) of the local support grid.
|
||||
local_grid_size: int = 8
|
||||
# A flag indicating whether to evaluate one ground truth point at a time.
|
||||
single_point: bool = True
|
||||
# The number of iterative updates for each sliding window.
|
||||
n_iters: int = 6
|
||||
|
||||
seed: int = 0
|
||||
gpu_idx: int = 0
|
||||
|
||||
# Override hydra's working directory to current working dir,
|
||||
# also disable storing the .hydra logs:
|
||||
hydra: dict = field(
|
||||
default_factory=lambda: {
|
||||
"run": {"dir": "."},
|
||||
"output_subdir": None,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def run_eval(cfg: DefaultConfig):
|
||||
"""
|
||||
The function evaluates CoTracker on a specified benchmark dataset based on a provided configuration.
|
||||
|
||||
Args:
|
||||
cfg (DefaultConfig): An instance of DefaultConfig class which includes:
|
||||
- exp_dir (str): The directory path for the experiment.
|
||||
- dataset_name (str): The name of the dataset to be used.
|
||||
- dataset_root (str): The root directory of the dataset.
|
||||
- checkpoint (str): The path to the CoTracker model's checkpoint.
|
||||
- single_point (bool): A flag indicating whether to evaluate one ground truth point at a time.
|
||||
- n_iters (int): The number of iterative updates for each sliding window.
|
||||
- seed (int): The seed for setting the random state for reproducibility.
|
||||
- gpu_idx (int): The index of the GPU to be used.
|
||||
"""
|
||||
# Creating the experiment directory if it doesn't exist
|
||||
os.makedirs(cfg.exp_dir, exist_ok=True)
|
||||
|
||||
# Saving the experiment configuration to a .yaml file in the experiment directory
|
||||
cfg_file = os.path.join(cfg.exp_dir, "expconfig.yaml")
|
||||
with open(cfg_file, "w") as f:
|
||||
OmegaConf.save(config=cfg, f=f)
|
||||
|
||||
evaluator = Evaluator(cfg.exp_dir)
|
||||
cotracker_model = build_cotracker(cfg.checkpoint)
|
||||
|
||||
# Creating the EvaluationPredictor object
|
||||
predictor = EvaluationPredictor(
|
||||
cotracker_model,
|
||||
grid_size=cfg.grid_size,
|
||||
local_grid_size=cfg.local_grid_size,
|
||||
single_point=cfg.single_point,
|
||||
n_iters=cfg.n_iters,
|
||||
)
|
||||
if torch.cuda.is_available():
|
||||
predictor.model = predictor.model.cuda()
|
||||
|
||||
# Setting the random seeds
|
||||
torch.manual_seed(cfg.seed)
|
||||
np.random.seed(cfg.seed)
|
||||
|
||||
# Constructing the specified dataset
|
||||
curr_collate_fn = collate_fn
|
||||
if "tapvid" in cfg.dataset_name:
|
||||
dataset_type = cfg.dataset_name.split("_")[1]
|
||||
if dataset_type == "davis":
|
||||
data_root = os.path.join(cfg.dataset_root, "tapvid_davis", "tapvid_davis.pkl")
|
||||
elif dataset_type == "kinetics":
|
||||
data_root = os.path.join(
|
||||
cfg.dataset_root, "/kinetics/kinetics-dataset/k700-2020/tapvid_kinetics"
|
||||
)
|
||||
test_dataset = TapVidDataset(
|
||||
dataset_type=dataset_type,
|
||||
data_root=data_root,
|
||||
queried_first=not "strided" in cfg.dataset_name,
|
||||
)
|
||||
elif cfg.dataset_name == "dynamic_replica":
|
||||
test_dataset = DynamicReplicaDataset(sample_len=300, only_first_n_samples=1)
|
||||
|
||||
# Creating the DataLoader object
|
||||
test_dataloader = torch.utils.data.DataLoader(
|
||||
test_dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=14,
|
||||
collate_fn=curr_collate_fn,
|
||||
)
|
||||
|
||||
# Timing and conducting the evaluation
|
||||
import time
|
||||
|
||||
start = time.time()
|
||||
evaluate_result = evaluator.evaluate_sequence(
|
||||
predictor,
|
||||
test_dataloader,
|
||||
dataset_name=cfg.dataset_name,
|
||||
)
|
||||
end = time.time()
|
||||
print(end - start)
|
||||
|
||||
# Saving the evaluation results to a .json file
|
||||
evaluate_result = evaluate_result["avg"]
|
||||
print("evaluate_result", evaluate_result)
|
||||
result_file = os.path.join(cfg.exp_dir, f"result_eval_.json")
|
||||
evaluate_result["time"] = end - start
|
||||
print(f"Dumping eval results to {result_file}.")
|
||||
with open(result_file, "w") as f:
|
||||
json.dump(evaluate_result, f)
|
||||
|
||||
|
||||
cs = hydra.core.config_store.ConfigStore.instance()
|
||||
cs.store(name="default_config_eval", node=DefaultConfig)
|
||||
|
||||
|
||||
@hydra.main(config_path="./configs/", config_name="default_config_eval")
|
||||
def evaluate(cfg: DefaultConfig) -> None:
|
||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_idx)
|
||||
run_eval(cfg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
evaluate()
|
||||
|
@@ -1,5 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
BIN
cotracker/models/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
cotracker/models/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/models/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
cotracker/models/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
cotracker/models/__pycache__/build_cotracker.cpython-38.pyc
Normal file
BIN
cotracker/models/__pycache__/build_cotracker.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/models/__pycache__/build_cotracker.cpython-39.pyc
Normal file
BIN
cotracker/models/__pycache__/build_cotracker.cpython-39.pyc
Normal file
Binary file not shown.
@@ -1,33 +1,33 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
|
||||
from cotracker.models.core.cotracker.cotracker import CoTracker2
|
||||
|
||||
|
||||
def build_cotracker(
|
||||
checkpoint: str,
|
||||
):
|
||||
if checkpoint is None:
|
||||
return build_cotracker()
|
||||
model_name = checkpoint.split("/")[-1].split(".")[0]
|
||||
if model_name == "cotracker":
|
||||
return build_cotracker(checkpoint=checkpoint)
|
||||
else:
|
||||
raise ValueError(f"Unknown model name {model_name}")
|
||||
|
||||
|
||||
def build_cotracker(checkpoint=None):
|
||||
cotracker = CoTracker2(stride=4, window_len=8, add_space_attn=True)
|
||||
|
||||
if checkpoint is not None:
|
||||
with open(checkpoint, "rb") as f:
|
||||
state_dict = torch.load(f, map_location="cpu")
|
||||
if "model" in state_dict:
|
||||
state_dict = state_dict["model"]
|
||||
cotracker.load_state_dict(state_dict)
|
||||
return cotracker
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
|
||||
from cotracker.models.core.cotracker.cotracker import CoTracker2
|
||||
|
||||
|
||||
def build_cotracker(
|
||||
checkpoint: str,
|
||||
):
|
||||
if checkpoint is None:
|
||||
return build_cotracker()
|
||||
model_name = checkpoint.split("/")[-1].split(".")[0]
|
||||
if model_name == "cotracker":
|
||||
return build_cotracker(checkpoint=checkpoint)
|
||||
else:
|
||||
raise ValueError(f"Unknown model name {model_name}")
|
||||
|
||||
|
||||
def build_cotracker(checkpoint=None):
|
||||
cotracker = CoTracker2(stride=4, window_len=8, add_space_attn=True)
|
||||
|
||||
if checkpoint is not None:
|
||||
with open(checkpoint, "rb") as f:
|
||||
state_dict = torch.load(f, map_location="cpu")
|
||||
if "model" in state_dict:
|
||||
state_dict = state_dict["model"]
|
||||
cotracker.load_state_dict(state_dict)
|
||||
return cotracker
|
||||
|
@@ -1,5 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
BIN
cotracker/models/core/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
cotracker/models/core/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/models/core/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
cotracker/models/core/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
cotracker/models/core/__pycache__/embeddings.cpython-38.pyc
Normal file
BIN
cotracker/models/core/__pycache__/embeddings.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/models/core/__pycache__/embeddings.cpython-39.pyc
Normal file
BIN
cotracker/models/core/__pycache__/embeddings.cpython-39.pyc
Normal file
Binary file not shown.
BIN
cotracker/models/core/__pycache__/model_utils.cpython-38.pyc
Normal file
BIN
cotracker/models/core/__pycache__/model_utils.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/models/core/__pycache__/model_utils.cpython-39.pyc
Normal file
BIN
cotracker/models/core/__pycache__/model_utils.cpython-39.pyc
Normal file
Binary file not shown.
@@ -1,5 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,367 +1,368 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from functools import partial
|
||||
from typing import Callable
|
||||
import collections
|
||||
from torch import Tensor
|
||||
from itertools import repeat
|
||||
|
||||
from cotracker.models.core.model_utils import bilinear_sampler
|
||||
|
||||
|
||||
# From PyTorch internals
|
||||
def _ntuple(n):
|
||||
def parse(x):
|
||||
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
||||
return tuple(x)
|
||||
return tuple(repeat(x, n))
|
||||
|
||||
return parse
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
|
||||
to_2tuple = _ntuple(2)
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=None,
|
||||
bias=True,
|
||||
drop=0.0,
|
||||
use_conv=False,
|
||||
):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
bias = to_2tuple(bias)
|
||||
drop_probs = to_2tuple(drop)
|
||||
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
||||
|
||||
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
|
||||
self.act = act_layer()
|
||||
self.drop1 = nn.Dropout(drop_probs[0])
|
||||
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
|
||||
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
|
||||
self.drop2 = nn.Dropout(drop_probs[1])
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop1(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop2(x)
|
||||
return x
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_planes,
|
||||
planes,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
stride=stride,
|
||||
padding_mode="zeros",
|
||||
)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, padding_mode="zeros")
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
num_groups = planes // 8
|
||||
|
||||
if norm_fn == "group":
|
||||
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
|
||||
elif norm_fn == "batch":
|
||||
self.norm1 = nn.BatchNorm2d(planes)
|
||||
self.norm2 = nn.BatchNorm2d(planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.BatchNorm2d(planes)
|
||||
|
||||
elif norm_fn == "instance":
|
||||
self.norm1 = nn.InstanceNorm2d(planes)
|
||||
self.norm2 = nn.InstanceNorm2d(planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.InstanceNorm2d(planes)
|
||||
|
||||
elif norm_fn == "none":
|
||||
self.norm1 = nn.Sequential()
|
||||
self.norm2 = nn.Sequential()
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.Sequential()
|
||||
|
||||
if stride == 1:
|
||||
self.downsample = None
|
||||
|
||||
else:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = x
|
||||
y = self.relu(self.norm1(self.conv1(y)))
|
||||
y = self.relu(self.norm2(self.conv2(y)))
|
||||
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
|
||||
return self.relu(x + y)
|
||||
|
||||
|
||||
class BasicEncoder(nn.Module):
|
||||
def __init__(self, input_dim=3, output_dim=128, stride=4):
|
||||
super(BasicEncoder, self).__init__()
|
||||
self.stride = stride
|
||||
self.norm_fn = "instance"
|
||||
self.in_planes = output_dim // 2
|
||||
|
||||
self.norm1 = nn.InstanceNorm2d(self.in_planes)
|
||||
self.norm2 = nn.InstanceNorm2d(output_dim * 2)
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
input_dim,
|
||||
self.in_planes,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
padding=3,
|
||||
padding_mode="zeros",
|
||||
)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
self.layer1 = self._make_layer(output_dim // 2, stride=1)
|
||||
self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2)
|
||||
self.layer3 = self._make_layer(output_dim, stride=2)
|
||||
self.layer4 = self._make_layer(output_dim, stride=2)
|
||||
|
||||
self.conv2 = nn.Conv2d(
|
||||
output_dim * 3 + output_dim // 4,
|
||||
output_dim * 2,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
padding_mode="zeros",
|
||||
)
|
||||
self.relu2 = nn.ReLU(inplace=True)
|
||||
self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||
elif isinstance(m, (nn.InstanceNorm2d)):
|
||||
if m.weight is not None:
|
||||
nn.init.constant_(m.weight, 1)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def _make_layer(self, dim, stride=1):
|
||||
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
||||
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
||||
layers = (layer1, layer2)
|
||||
|
||||
self.in_planes = dim
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
_, _, H, W = x.shape
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.relu1(x)
|
||||
|
||||
a = self.layer1(x)
|
||||
b = self.layer2(a)
|
||||
c = self.layer3(b)
|
||||
d = self.layer4(c)
|
||||
|
||||
def _bilinear_intepolate(x):
|
||||
return F.interpolate(
|
||||
x,
|
||||
(H // self.stride, W // self.stride),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
|
||||
a = _bilinear_intepolate(a)
|
||||
b = _bilinear_intepolate(b)
|
||||
c = _bilinear_intepolate(c)
|
||||
d = _bilinear_intepolate(d)
|
||||
|
||||
x = self.conv2(torch.cat([a, b, c, d], dim=1))
|
||||
x = self.norm2(x)
|
||||
x = self.relu2(x)
|
||||
x = self.conv3(x)
|
||||
return x
|
||||
|
||||
|
||||
class CorrBlock:
|
||||
def __init__(
|
||||
self,
|
||||
fmaps,
|
||||
num_levels=4,
|
||||
radius=4,
|
||||
multiple_track_feats=False,
|
||||
padding_mode="zeros",
|
||||
):
|
||||
B, S, C, H, W = fmaps.shape
|
||||
self.S, self.C, self.H, self.W = S, C, H, W
|
||||
self.padding_mode = padding_mode
|
||||
self.num_levels = num_levels
|
||||
self.radius = radius
|
||||
self.fmaps_pyramid = []
|
||||
self.multiple_track_feats = multiple_track_feats
|
||||
|
||||
self.fmaps_pyramid.append(fmaps)
|
||||
for i in range(self.num_levels - 1):
|
||||
fmaps_ = fmaps.reshape(B * S, C, H, W)
|
||||
fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
|
||||
_, _, H, W = fmaps_.shape
|
||||
fmaps = fmaps_.reshape(B, S, C, H, W)
|
||||
self.fmaps_pyramid.append(fmaps)
|
||||
|
||||
def sample(self, coords):
|
||||
r = self.radius
|
||||
B, S, N, D = coords.shape
|
||||
assert D == 2
|
||||
|
||||
H, W = self.H, self.W
|
||||
out_pyramid = []
|
||||
for i in range(self.num_levels):
|
||||
corrs = self.corrs_pyramid[i] # B, S, N, H, W
|
||||
*_, H, W = corrs.shape
|
||||
|
||||
dx = torch.linspace(-r, r, 2 * r + 1)
|
||||
dy = torch.linspace(-r, r, 2 * r + 1)
|
||||
delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(coords.device)
|
||||
|
||||
centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2**i
|
||||
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
|
||||
coords_lvl = centroid_lvl + delta_lvl
|
||||
|
||||
corrs = bilinear_sampler(
|
||||
corrs.reshape(B * S * N, 1, H, W),
|
||||
coords_lvl,
|
||||
padding_mode=self.padding_mode,
|
||||
)
|
||||
corrs = corrs.view(B, S, N, -1)
|
||||
out_pyramid.append(corrs)
|
||||
|
||||
out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
|
||||
out = out.permute(0, 2, 1, 3).contiguous().view(B * N, S, -1).float()
|
||||
return out
|
||||
|
||||
def corr(self, targets):
|
||||
B, S, N, C = targets.shape
|
||||
if self.multiple_track_feats:
|
||||
targets_split = targets.split(C // self.num_levels, dim=-1)
|
||||
B, S, N, C = targets_split[0].shape
|
||||
|
||||
assert C == self.C
|
||||
assert S == self.S
|
||||
|
||||
fmap1 = targets
|
||||
|
||||
self.corrs_pyramid = []
|
||||
for i, fmaps in enumerate(self.fmaps_pyramid):
|
||||
*_, H, W = fmaps.shape
|
||||
fmap2s = fmaps.view(B, S, C, H * W) # B S C H W -> B S C (H W)
|
||||
if self.multiple_track_feats:
|
||||
fmap1 = targets_split[i]
|
||||
corrs = torch.matmul(fmap1, fmap2s)
|
||||
corrs = corrs.view(B, S, N, H, W) # B S N (H W) -> B S N H W
|
||||
corrs = corrs / torch.sqrt(torch.tensor(C).float())
|
||||
self.corrs_pyramid.append(corrs)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, num_heads=8, dim_head=48, qkv_bias=False):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * num_heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = num_heads
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias)
|
||||
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=qkv_bias)
|
||||
self.to_out = nn.Linear(inner_dim, query_dim)
|
||||
|
||||
def forward(self, x, context=None, attn_bias=None):
|
||||
B, N1, C = x.shape
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x).reshape(B, N1, h, C // h).permute(0, 2, 1, 3)
|
||||
context = default(context, x)
|
||||
k, v = self.to_kv(context).chunk(2, dim=-1)
|
||||
|
||||
N2 = context.shape[1]
|
||||
k = k.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
|
||||
v = v.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
|
||||
|
||||
sim = (q @ k.transpose(-2, -1)) * self.scale
|
||||
|
||||
if attn_bias is not None:
|
||||
sim = sim + attn_bias
|
||||
attn = sim.softmax(dim=-1)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N1, C)
|
||||
return self.to_out(x)
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
num_heads,
|
||||
attn_class: Callable[..., nn.Module] = Attention,
|
||||
mlp_ratio=4.0,
|
||||
**block_kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.attn = attn_class(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
|
||||
|
||||
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||
self.mlp = Mlp(
|
||||
in_features=hidden_size,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=approx_gelu,
|
||||
drop=0,
|
||||
)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
attn_bias = mask
|
||||
if mask is not None:
|
||||
mask = (
|
||||
(mask[:, None] * mask[:, :, None])
|
||||
.unsqueeze(1)
|
||||
.expand(-1, self.attn.num_heads, -1, -1)
|
||||
)
|
||||
max_neg_value = -torch.finfo(x.dtype).max
|
||||
attn_bias = (~mask) * max_neg_value
|
||||
x = x + self.attn(self.norm1(x), attn_bias=attn_bias)
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from functools import partial
|
||||
from typing import Callable
|
||||
import collections
|
||||
from torch import Tensor
|
||||
from itertools import repeat
|
||||
|
||||
from cotracker.models.core.model_utils import bilinear_sampler
|
||||
|
||||
|
||||
# From PyTorch internals
|
||||
def _ntuple(n):
|
||||
def parse(x):
|
||||
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
||||
return tuple(x)
|
||||
return tuple(repeat(x, n))
|
||||
|
||||
return parse
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
|
||||
to_2tuple = _ntuple(2)
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=None,
|
||||
bias=True,
|
||||
drop=0.0,
|
||||
use_conv=False,
|
||||
):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
bias = to_2tuple(bias)
|
||||
drop_probs = to_2tuple(drop)
|
||||
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
||||
|
||||
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
|
||||
self.act = act_layer()
|
||||
self.drop1 = nn.Dropout(drop_probs[0])
|
||||
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
|
||||
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
|
||||
self.drop2 = nn.Dropout(drop_probs[1])
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop1(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop2(x)
|
||||
return x
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_planes,
|
||||
planes,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
stride=stride,
|
||||
padding_mode="zeros",
|
||||
)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, padding_mode="zeros")
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
num_groups = planes // 8
|
||||
|
||||
if norm_fn == "group":
|
||||
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
|
||||
elif norm_fn == "batch":
|
||||
self.norm1 = nn.BatchNorm2d(planes)
|
||||
self.norm2 = nn.BatchNorm2d(planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.BatchNorm2d(planes)
|
||||
|
||||
elif norm_fn == "instance":
|
||||
self.norm1 = nn.InstanceNorm2d(planes)
|
||||
self.norm2 = nn.InstanceNorm2d(planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.InstanceNorm2d(planes)
|
||||
|
||||
elif norm_fn == "none":
|
||||
self.norm1 = nn.Sequential()
|
||||
self.norm2 = nn.Sequential()
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.Sequential()
|
||||
|
||||
if stride == 1:
|
||||
self.downsample = None
|
||||
|
||||
else:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = x
|
||||
y = self.relu(self.norm1(self.conv1(y)))
|
||||
y = self.relu(self.norm2(self.conv2(y)))
|
||||
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
|
||||
return self.relu(x + y)
|
||||
|
||||
|
||||
class BasicEncoder(nn.Module):
|
||||
def __init__(self, input_dim=3, output_dim=128, stride=4):
|
||||
super(BasicEncoder, self).__init__()
|
||||
self.stride = stride
|
||||
self.norm_fn = "instance"
|
||||
self.in_planes = output_dim // 2
|
||||
|
||||
self.norm1 = nn.InstanceNorm2d(self.in_planes)
|
||||
self.norm2 = nn.InstanceNorm2d(output_dim * 2)
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
input_dim,
|
||||
self.in_planes,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
padding=3,
|
||||
padding_mode="zeros",
|
||||
)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
self.layer1 = self._make_layer(output_dim // 2, stride=1)
|
||||
self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2)
|
||||
self.layer3 = self._make_layer(output_dim, stride=2)
|
||||
self.layer4 = self._make_layer(output_dim, stride=2)
|
||||
|
||||
self.conv2 = nn.Conv2d(
|
||||
output_dim * 3 + output_dim // 4,
|
||||
output_dim * 2,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
padding_mode="zeros",
|
||||
)
|
||||
self.relu2 = nn.ReLU(inplace=True)
|
||||
self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||
elif isinstance(m, (nn.InstanceNorm2d)):
|
||||
if m.weight is not None:
|
||||
nn.init.constant_(m.weight, 1)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def _make_layer(self, dim, stride=1):
|
||||
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
||||
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
||||
layers = (layer1, layer2)
|
||||
|
||||
self.in_planes = dim
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
_, _, H, W = x.shape
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.relu1(x)
|
||||
|
||||
# 四层残差块
|
||||
a = self.layer1(x)
|
||||
b = self.layer2(a)
|
||||
c = self.layer3(b)
|
||||
d = self.layer4(c)
|
||||
|
||||
def _bilinear_intepolate(x):
|
||||
return F.interpolate(
|
||||
x,
|
||||
(H // self.stride, W // self.stride),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
|
||||
a = _bilinear_intepolate(a)
|
||||
b = _bilinear_intepolate(b)
|
||||
c = _bilinear_intepolate(c)
|
||||
d = _bilinear_intepolate(d)
|
||||
|
||||
x = self.conv2(torch.cat([a, b, c, d], dim=1))
|
||||
x = self.norm2(x)
|
||||
x = self.relu2(x)
|
||||
x = self.conv3(x)
|
||||
return x
|
||||
|
||||
|
||||
class CorrBlock:
|
||||
def __init__(
|
||||
self,
|
||||
fmaps,
|
||||
num_levels=4,
|
||||
radius=4,
|
||||
multiple_track_feats=False,
|
||||
padding_mode="zeros",
|
||||
):
|
||||
B, S, C, H, W = fmaps.shape
|
||||
self.S, self.C, self.H, self.W = S, C, H, W
|
||||
self.padding_mode = padding_mode
|
||||
self.num_levels = num_levels
|
||||
self.radius = radius
|
||||
self.fmaps_pyramid = []
|
||||
self.multiple_track_feats = multiple_track_feats
|
||||
|
||||
self.fmaps_pyramid.append(fmaps)
|
||||
for i in range(self.num_levels - 1):
|
||||
fmaps_ = fmaps.reshape(B * S, C, H, W)
|
||||
fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
|
||||
_, _, H, W = fmaps_.shape
|
||||
fmaps = fmaps_.reshape(B, S, C, H, W)
|
||||
self.fmaps_pyramid.append(fmaps)
|
||||
|
||||
def sample(self, coords):
|
||||
r = self.radius
|
||||
B, S, N, D = coords.shape
|
||||
assert D == 2
|
||||
|
||||
H, W = self.H, self.W
|
||||
out_pyramid = []
|
||||
for i in range(self.num_levels):
|
||||
corrs = self.corrs_pyramid[i] # B, S, N, H, W
|
||||
*_, H, W = corrs.shape
|
||||
|
||||
dx = torch.linspace(-r, r, 2 * r + 1)
|
||||
dy = torch.linspace(-r, r, 2 * r + 1)
|
||||
delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(coords.device)
|
||||
|
||||
centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2**i
|
||||
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
|
||||
coords_lvl = centroid_lvl + delta_lvl
|
||||
|
||||
corrs = bilinear_sampler(
|
||||
corrs.reshape(B * S * N, 1, H, W),
|
||||
coords_lvl,
|
||||
padding_mode=self.padding_mode,
|
||||
)
|
||||
corrs = corrs.view(B, S, N, -1)
|
||||
out_pyramid.append(corrs)
|
||||
|
||||
out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
|
||||
out = out.permute(0, 2, 1, 3).contiguous().view(B * N, S, -1).float()
|
||||
return out
|
||||
|
||||
def corr(self, targets):
|
||||
B, S, N, C = targets.shape
|
||||
if self.multiple_track_feats:
|
||||
targets_split = targets.split(C // self.num_levels, dim=-1)
|
||||
B, S, N, C = targets_split[0].shape
|
||||
|
||||
assert C == self.C
|
||||
assert S == self.S
|
||||
|
||||
fmap1 = targets
|
||||
|
||||
self.corrs_pyramid = []
|
||||
for i, fmaps in enumerate(self.fmaps_pyramid):
|
||||
*_, H, W = fmaps.shape
|
||||
fmap2s = fmaps.view(B, S, C, H * W) # B S C H W -> B S C (H W)
|
||||
if self.multiple_track_feats:
|
||||
fmap1 = targets_split[i]
|
||||
corrs = torch.matmul(fmap1, fmap2s)
|
||||
corrs = corrs.view(B, S, N, H, W) # B S N (H W) -> B S N H W
|
||||
corrs = corrs / torch.sqrt(torch.tensor(C).float())
|
||||
self.corrs_pyramid.append(corrs)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, num_heads=8, dim_head=48, qkv_bias=False):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * num_heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = num_heads
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias)
|
||||
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=qkv_bias)
|
||||
self.to_out = nn.Linear(inner_dim, query_dim)
|
||||
|
||||
def forward(self, x, context=None, attn_bias=None):
|
||||
B, N1, C = x.shape
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x).reshape(B, N1, h, C // h).permute(0, 2, 1, 3)
|
||||
context = default(context, x)
|
||||
k, v = self.to_kv(context).chunk(2, dim=-1)
|
||||
|
||||
N2 = context.shape[1]
|
||||
k = k.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
|
||||
v = v.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
|
||||
|
||||
sim = (q @ k.transpose(-2, -1)) * self.scale
|
||||
|
||||
if attn_bias is not None:
|
||||
sim = sim + attn_bias
|
||||
attn = sim.softmax(dim=-1)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N1, C)
|
||||
return self.to_out(x)
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
num_heads,
|
||||
attn_class: Callable[..., nn.Module] = Attention,
|
||||
mlp_ratio=4.0,
|
||||
**block_kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.attn = attn_class(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
|
||||
|
||||
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||
self.mlp = Mlp(
|
||||
in_features=hidden_size,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=approx_gelu,
|
||||
drop=0,
|
||||
)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
attn_bias = mask
|
||||
if mask is not None:
|
||||
mask = (
|
||||
(mask[:, None] * mask[:, :, None])
|
||||
.unsqueeze(1)
|
||||
.expand(-1, self.attn.num_heads, -1, -1)
|
||||
)
|
||||
max_neg_value = -torch.finfo(x.dtype).max
|
||||
attn_bias = (~mask) * max_neg_value
|
||||
x = x + self.attn(self.norm1(x), attn_bias=attn_bias)
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
|
File diff suppressed because it is too large
Load Diff
@@ -1,61 +1,61 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from cotracker.models.core.model_utils import reduce_masked_mean
|
||||
|
||||
EPS = 1e-6
|
||||
|
||||
|
||||
def balanced_ce_loss(pred, gt, valid=None):
|
||||
total_balanced_loss = 0.0
|
||||
for j in range(len(gt)):
|
||||
B, S, N = gt[j].shape
|
||||
# pred and gt are the same shape
|
||||
for (a, b) in zip(pred[j].size(), gt[j].size()):
|
||||
assert a == b # some shape mismatch!
|
||||
# if valid is not None:
|
||||
for (a, b) in zip(pred[j].size(), valid[j].size()):
|
||||
assert a == b # some shape mismatch!
|
||||
|
||||
pos = (gt[j] > 0.95).float()
|
||||
neg = (gt[j] < 0.05).float()
|
||||
|
||||
label = pos * 2.0 - 1.0
|
||||
a = -label * pred[j]
|
||||
b = F.relu(a)
|
||||
loss = b + torch.log(torch.exp(-b) + torch.exp(a - b))
|
||||
|
||||
pos_loss = reduce_masked_mean(loss, pos * valid[j])
|
||||
neg_loss = reduce_masked_mean(loss, neg * valid[j])
|
||||
|
||||
balanced_loss = pos_loss + neg_loss
|
||||
total_balanced_loss += balanced_loss / float(N)
|
||||
return total_balanced_loss
|
||||
|
||||
|
||||
def sequence_loss(flow_preds, flow_gt, vis, valids, gamma=0.8):
|
||||
"""Loss function defined over sequence of flow predictions"""
|
||||
total_flow_loss = 0.0
|
||||
for j in range(len(flow_gt)):
|
||||
B, S, N, D = flow_gt[j].shape
|
||||
assert D == 2
|
||||
B, S1, N = vis[j].shape
|
||||
B, S2, N = valids[j].shape
|
||||
assert S == S1
|
||||
assert S == S2
|
||||
n_predictions = len(flow_preds[j])
|
||||
flow_loss = 0.0
|
||||
for i in range(n_predictions):
|
||||
i_weight = gamma ** (n_predictions - i - 1)
|
||||
flow_pred = flow_preds[j][i]
|
||||
i_loss = (flow_pred - flow_gt[j]).abs() # B, S, N, 2
|
||||
i_loss = torch.mean(i_loss, dim=3) # B, S, N
|
||||
flow_loss += i_weight * reduce_masked_mean(i_loss, valids[j])
|
||||
flow_loss = flow_loss / n_predictions
|
||||
total_flow_loss += flow_loss / float(N)
|
||||
return total_flow_loss
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from cotracker.models.core.model_utils import reduce_masked_mean
|
||||
|
||||
EPS = 1e-6
|
||||
|
||||
|
||||
def balanced_ce_loss(pred, gt, valid=None):
|
||||
total_balanced_loss = 0.0
|
||||
for j in range(len(gt)):
|
||||
B, S, N = gt[j].shape
|
||||
# pred and gt are the same shape
|
||||
for (a, b) in zip(pred[j].size(), gt[j].size()):
|
||||
assert a == b # some shape mismatch!
|
||||
# if valid is not None:
|
||||
for (a, b) in zip(pred[j].size(), valid[j].size()):
|
||||
assert a == b # some shape mismatch!
|
||||
|
||||
pos = (gt[j] > 0.95).float()
|
||||
neg = (gt[j] < 0.05).float()
|
||||
|
||||
label = pos * 2.0 - 1.0
|
||||
a = -label * pred[j]
|
||||
b = F.relu(a)
|
||||
loss = b + torch.log(torch.exp(-b) + torch.exp(a - b))
|
||||
|
||||
pos_loss = reduce_masked_mean(loss, pos * valid[j])
|
||||
neg_loss = reduce_masked_mean(loss, neg * valid[j])
|
||||
|
||||
balanced_loss = pos_loss + neg_loss
|
||||
total_balanced_loss += balanced_loss / float(N)
|
||||
return total_balanced_loss
|
||||
|
||||
|
||||
def sequence_loss(flow_preds, flow_gt, vis, valids, gamma=0.8):
|
||||
"""Loss function defined over sequence of flow predictions"""
|
||||
total_flow_loss = 0.0
|
||||
for j in range(len(flow_gt)):
|
||||
B, S, N, D = flow_gt[j].shape
|
||||
assert D == 2
|
||||
B, S1, N = vis[j].shape
|
||||
B, S2, N = valids[j].shape
|
||||
assert S == S1
|
||||
assert S == S2
|
||||
n_predictions = len(flow_preds[j])
|
||||
flow_loss = 0.0
|
||||
for i in range(n_predictions):
|
||||
i_weight = gamma ** (n_predictions - i - 1)
|
||||
flow_pred = flow_preds[j][i]
|
||||
i_loss = (flow_pred - flow_gt[j]).abs() # B, S, N, 2
|
||||
i_loss = torch.mean(i_loss, dim=3) # B, S, N
|
||||
flow_loss += i_weight * reduce_masked_mean(i_loss, valids[j])
|
||||
flow_loss = flow_loss / n_predictions
|
||||
total_flow_loss += flow_loss / float(N)
|
||||
return total_flow_loss
|
||||
|
@@ -1,120 +1,120 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Tuple, Union
|
||||
import torch
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed(
|
||||
embed_dim: int, grid_size: Union[int, Tuple[int, int]]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
|
||||
It is a wrapper of get_2d_sincos_pos_embed_from_grid.
|
||||
Args:
|
||||
- embed_dim: The embedding dimension.
|
||||
- grid_size: The grid size.
|
||||
Returns:
|
||||
- pos_embed: The generated 2D positional embedding.
|
||||
"""
|
||||
if isinstance(grid_size, tuple):
|
||||
grid_size_h, grid_size_w = grid_size
|
||||
else:
|
||||
grid_size_h = grid_size_w = grid_size
|
||||
grid_h = torch.arange(grid_size_h, dtype=torch.float)
|
||||
grid_w = torch.arange(grid_size_w, dtype=torch.float)
|
||||
grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
|
||||
grid = torch.stack(grid, dim=0)
|
||||
grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
|
||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||
return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_from_grid(
|
||||
embed_dim: int, grid: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function generates a 2D positional embedding from a given grid using sine and cosine functions.
|
||||
|
||||
Args:
|
||||
- embed_dim: The embedding dimension.
|
||||
- grid: The grid to generate the embedding from.
|
||||
|
||||
Returns:
|
||||
- emb: The generated 2D positional embedding.
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||
|
||||
emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(
|
||||
embed_dim: int, pos: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function generates a 1D positional embedding from a given grid using sine and cosine functions.
|
||||
|
||||
Args:
|
||||
- embed_dim: The embedding dimension.
|
||||
- pos: The position to generate the embedding from.
|
||||
|
||||
Returns:
|
||||
- emb: The generated 1D positional embedding.
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
omega = torch.arange(embed_dim // 2, dtype=torch.double)
|
||||
omega /= embed_dim / 2.0
|
||||
omega = 1.0 / 10000**omega # (D/2,)
|
||||
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
||||
|
||||
emb_sin = torch.sin(out) # (M, D/2)
|
||||
emb_cos = torch.cos(out) # (M, D/2)
|
||||
|
||||
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
||||
return emb[None].float()
|
||||
|
||||
|
||||
def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
|
||||
|
||||
Args:
|
||||
- xy: The coordinates to generate the embedding from.
|
||||
- C: The size of the embedding.
|
||||
- cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
|
||||
|
||||
Returns:
|
||||
- pe: The generated 2D positional embedding.
|
||||
"""
|
||||
B, N, D = xy.shape
|
||||
assert D == 2
|
||||
|
||||
x = xy[:, :, 0:1]
|
||||
y = xy[:, :, 1:2]
|
||||
div_term = (
|
||||
torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)
|
||||
).reshape(1, 1, int(C / 2))
|
||||
|
||||
pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
|
||||
pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
|
||||
|
||||
pe_x[:, :, 0::2] = torch.sin(x * div_term)
|
||||
pe_x[:, :, 1::2] = torch.cos(x * div_term)
|
||||
|
||||
pe_y[:, :, 0::2] = torch.sin(y * div_term)
|
||||
pe_y[:, :, 1::2] = torch.cos(y * div_term)
|
||||
|
||||
pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
|
||||
if cat_coords:
|
||||
pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
|
||||
return pe
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Tuple, Union
|
||||
import torch
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed(
|
||||
embed_dim: int, grid_size: Union[int, Tuple[int, int]]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
|
||||
It is a wrapper of get_2d_sincos_pos_embed_from_grid.
|
||||
Args:
|
||||
- embed_dim: The embedding dimension.
|
||||
- grid_size: The grid size.
|
||||
Returns:
|
||||
- pos_embed: The generated 2D positional embedding.
|
||||
"""
|
||||
if isinstance(grid_size, tuple):
|
||||
grid_size_h, grid_size_w = grid_size
|
||||
else:
|
||||
grid_size_h = grid_size_w = grid_size
|
||||
grid_h = torch.arange(grid_size_h, dtype=torch.float)
|
||||
grid_w = torch.arange(grid_size_w, dtype=torch.float)
|
||||
grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
|
||||
grid = torch.stack(grid, dim=0)
|
||||
grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
|
||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||
return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_from_grid(
|
||||
embed_dim: int, grid: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function generates a 2D positional embedding from a given grid using sine and cosine functions.
|
||||
|
||||
Args:
|
||||
- embed_dim: The embedding dimension.
|
||||
- grid: The grid to generate the embedding from.
|
||||
|
||||
Returns:
|
||||
- emb: The generated 2D positional embedding.
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||
|
||||
emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(
|
||||
embed_dim: int, pos: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function generates a 1D positional embedding from a given grid using sine and cosine functions.
|
||||
|
||||
Args:
|
||||
- embed_dim: The embedding dimension.
|
||||
- pos: The position to generate the embedding from.
|
||||
|
||||
Returns:
|
||||
- emb: The generated 1D positional embedding.
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
omega = torch.arange(embed_dim // 2, dtype=torch.double)
|
||||
omega /= embed_dim / 2.0
|
||||
omega = 1.0 / 10000**omega # (D/2,)
|
||||
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
||||
|
||||
emb_sin = torch.sin(out) # (M, D/2)
|
||||
emb_cos = torch.cos(out) # (M, D/2)
|
||||
|
||||
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
||||
return emb[None].float()
|
||||
|
||||
|
||||
def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
|
||||
|
||||
Args:
|
||||
- xy: The coordinates to generate the embedding from.
|
||||
- C: The size of the embedding.
|
||||
- cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
|
||||
|
||||
Returns:
|
||||
- pe: The generated 2D positional embedding.
|
||||
"""
|
||||
B, N, D = xy.shape
|
||||
assert D == 2
|
||||
|
||||
x = xy[:, :, 0:1]
|
||||
y = xy[:, :, 1:2]
|
||||
div_term = (
|
||||
torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)
|
||||
).reshape(1, 1, int(C / 2))
|
||||
|
||||
pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
|
||||
pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
|
||||
|
||||
pe_x[:, :, 0::2] = torch.sin(x * div_term)
|
||||
pe_x[:, :, 1::2] = torch.cos(x * div_term)
|
||||
|
||||
pe_y[:, :, 0::2] = torch.sin(y * div_term)
|
||||
pe_y[:, :, 1::2] = torch.cos(y * div_term)
|
||||
|
||||
pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
|
||||
if cat_coords:
|
||||
pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
|
||||
return pe
|
||||
|
@@ -1,256 +1,256 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, Tuple
|
||||
|
||||
EPS = 1e-6
|
||||
|
||||
|
||||
def smart_cat(tensor1, tensor2, dim):
|
||||
if tensor1 is None:
|
||||
return tensor2
|
||||
return torch.cat([tensor1, tensor2], dim=dim)
|
||||
|
||||
|
||||
def get_points_on_a_grid(
|
||||
size: int,
|
||||
extent: Tuple[float, ...],
|
||||
center: Optional[Tuple[float, ...]] = None,
|
||||
device: Optional[torch.device] = torch.device("cpu"),
|
||||
):
|
||||
r"""Get a grid of points covering a rectangular region
|
||||
|
||||
`get_points_on_a_grid(size, extent)` generates a :attr:`size` by
|
||||
:attr:`size` grid fo points distributed to cover a rectangular area
|
||||
specified by `extent`.
|
||||
|
||||
The `extent` is a pair of integer :math:`(H,W)` specifying the height
|
||||
and width of the rectangle.
|
||||
|
||||
Optionally, the :attr:`center` can be specified as a pair :math:`(c_y,c_x)`
|
||||
specifying the vertical and horizontal center coordinates. The center
|
||||
defaults to the middle of the extent.
|
||||
|
||||
Points are distributed uniformly within the rectangle leaving a margin
|
||||
:math:`m=W/64` from the border.
|
||||
|
||||
It returns a :math:`(1, \text{size} \times \text{size}, 2)` tensor of
|
||||
points :math:`P_{ij}=(x_i, y_i)` where
|
||||
|
||||
.. math::
|
||||
P_{ij} = \left(
|
||||
c_x + m -\frac{W}{2} + \frac{W - 2m}{\text{size} - 1}\, j,~
|
||||
c_y + m -\frac{H}{2} + \frac{H - 2m}{\text{size} - 1}\, i
|
||||
\right)
|
||||
|
||||
Points are returned in row-major order.
|
||||
|
||||
Args:
|
||||
size (int): grid size.
|
||||
extent (tuple): height and with of the grid extent.
|
||||
center (tuple, optional): grid center.
|
||||
device (str, optional): Defaults to `"cpu"`.
|
||||
|
||||
Returns:
|
||||
Tensor: grid.
|
||||
"""
|
||||
if size == 1:
|
||||
return torch.tensor([extent[1] / 2, extent[0] / 2], device=device)[None, None]
|
||||
|
||||
if center is None:
|
||||
center = [extent[0] / 2, extent[1] / 2]
|
||||
|
||||
margin = extent[1] / 64
|
||||
range_y = (margin - extent[0] / 2 + center[0], extent[0] / 2 + center[0] - margin)
|
||||
range_x = (margin - extent[1] / 2 + center[1], extent[1] / 2 + center[1] - margin)
|
||||
grid_y, grid_x = torch.meshgrid(
|
||||
torch.linspace(*range_y, size, device=device),
|
||||
torch.linspace(*range_x, size, device=device),
|
||||
indexing="ij",
|
||||
)
|
||||
return torch.stack([grid_x, grid_y], dim=-1).reshape(1, -1, 2)
|
||||
|
||||
|
||||
def reduce_masked_mean(input, mask, dim=None, keepdim=False):
|
||||
r"""Masked mean
|
||||
|
||||
`reduce_masked_mean(x, mask)` computes the mean of a tensor :attr:`input`
|
||||
over a mask :attr:`mask`, returning
|
||||
|
||||
.. math::
|
||||
\text{output} =
|
||||
\frac
|
||||
{\sum_{i=1}^N \text{input}_i \cdot \text{mask}_i}
|
||||
{\epsilon + \sum_{i=1}^N \text{mask}_i}
|
||||
|
||||
where :math:`N` is the number of elements in :attr:`input` and
|
||||
:attr:`mask`, and :math:`\epsilon` is a small constant to avoid
|
||||
division by zero.
|
||||
|
||||
`reduced_masked_mean(x, mask, dim)` computes the mean of a tensor
|
||||
:attr:`input` over a mask :attr:`mask` along a dimension :attr:`dim`.
|
||||
Optionally, the dimension can be kept in the output by setting
|
||||
:attr:`keepdim` to `True`. Tensor :attr:`mask` must be broadcastable to
|
||||
the same dimension as :attr:`input`.
|
||||
|
||||
The interface is similar to `torch.mean()`.
|
||||
|
||||
Args:
|
||||
inout (Tensor): input tensor.
|
||||
mask (Tensor): mask.
|
||||
dim (int, optional): Dimension to sum over. Defaults to None.
|
||||
keepdim (bool, optional): Keep the summed dimension. Defaults to False.
|
||||
|
||||
Returns:
|
||||
Tensor: mean tensor.
|
||||
"""
|
||||
|
||||
mask = mask.expand_as(input)
|
||||
|
||||
prod = input * mask
|
||||
|
||||
if dim is None:
|
||||
numer = torch.sum(prod)
|
||||
denom = torch.sum(mask)
|
||||
else:
|
||||
numer = torch.sum(prod, dim=dim, keepdim=keepdim)
|
||||
denom = torch.sum(mask, dim=dim, keepdim=keepdim)
|
||||
|
||||
mean = numer / (EPS + denom)
|
||||
return mean
|
||||
|
||||
|
||||
def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
|
||||
r"""Sample a tensor using bilinear interpolation
|
||||
|
||||
`bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
|
||||
coordinates :attr:`coords` using bilinear interpolation. It is the same
|
||||
as `torch.nn.functional.grid_sample()` but with a different coordinate
|
||||
convention.
|
||||
|
||||
The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
|
||||
:math:`B` is the batch size, :math:`C` is the number of channels,
|
||||
:math:`H` is the height of the image, and :math:`W` is the width of the
|
||||
image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
|
||||
interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
|
||||
|
||||
Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
|
||||
in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
|
||||
that in this case the order of the components is slightly different
|
||||
from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
|
||||
|
||||
If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
|
||||
in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
|
||||
left-most image pixel :math:`W-1` to the center of the right-most
|
||||
pixel.
|
||||
|
||||
If `align_corners` is `False`, the coordinate :math:`x` is assumed to
|
||||
be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
|
||||
the left-most pixel :math:`W` to the right edge of the right-most
|
||||
pixel.
|
||||
|
||||
Similar conventions apply to the :math:`y` for the range
|
||||
:math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
|
||||
:math:`[0,T-1]` and :math:`[0,T]`.
|
||||
|
||||
Args:
|
||||
input (Tensor): batch of input images.
|
||||
coords (Tensor): batch of coordinates.
|
||||
align_corners (bool, optional): Coordinate convention. Defaults to `True`.
|
||||
padding_mode (str, optional): Padding mode. Defaults to `"border"`.
|
||||
|
||||
Returns:
|
||||
Tensor: sampled points.
|
||||
"""
|
||||
|
||||
sizes = input.shape[2:]
|
||||
|
||||
assert len(sizes) in [2, 3]
|
||||
|
||||
if len(sizes) == 3:
|
||||
# t x y -> x y t to match dimensions T H W in grid_sample
|
||||
coords = coords[..., [1, 2, 0]]
|
||||
|
||||
if align_corners:
|
||||
coords = coords * torch.tensor(
|
||||
[2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device
|
||||
)
|
||||
else:
|
||||
coords = coords * torch.tensor([2 / size for size in reversed(sizes)], device=coords.device)
|
||||
|
||||
coords -= 1
|
||||
|
||||
return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
|
||||
|
||||
|
||||
def sample_features4d(input, coords):
|
||||
r"""Sample spatial features
|
||||
|
||||
`sample_features4d(input, coords)` samples the spatial features
|
||||
:attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
|
||||
|
||||
The field is sampled at coordinates :attr:`coords` using bilinear
|
||||
interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
|
||||
3)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
|
||||
same convention as :func:`bilinear_sampler` with `align_corners=True`.
|
||||
|
||||
The output tensor has one feature per point, and has shape :math:`(B,
|
||||
R, C)`.
|
||||
|
||||
Args:
|
||||
input (Tensor): spatial features.
|
||||
coords (Tensor): points.
|
||||
|
||||
Returns:
|
||||
Tensor: sampled features.
|
||||
"""
|
||||
|
||||
B, _, _, _ = input.shape
|
||||
|
||||
# B R 2 -> B R 1 2
|
||||
coords = coords.unsqueeze(2)
|
||||
|
||||
# B C R 1
|
||||
feats = bilinear_sampler(input, coords)
|
||||
|
||||
return feats.permute(0, 2, 1, 3).view(
|
||||
B, -1, feats.shape[1] * feats.shape[3]
|
||||
) # B C R 1 -> B R C
|
||||
|
||||
|
||||
def sample_features5d(input, coords):
|
||||
r"""Sample spatio-temporal features
|
||||
|
||||
`sample_features5d(input, coords)` works in the same way as
|
||||
:func:`sample_features4d` but for spatio-temporal features and points:
|
||||
:attr:`input` is a 5D tensor :math:`(B, T, C, H, W)`, :attr:`coords` is
|
||||
a :math:`(B, R1, R2, 3)` tensor of spatio-temporal point :math:`(t_i,
|
||||
x_i, y_i)`. The output tensor has shape :math:`(B, R1, R2, C)`.
|
||||
|
||||
Args:
|
||||
input (Tensor): spatio-temporal features.
|
||||
coords (Tensor): spatio-temporal points.
|
||||
|
||||
Returns:
|
||||
Tensor: sampled features.
|
||||
"""
|
||||
|
||||
B, T, _, _, _ = input.shape
|
||||
|
||||
# B T C H W -> B C T H W
|
||||
input = input.permute(0, 2, 1, 3, 4)
|
||||
|
||||
# B R1 R2 3 -> B R1 R2 1 3
|
||||
coords = coords.unsqueeze(3)
|
||||
|
||||
# B C R1 R2 1
|
||||
feats = bilinear_sampler(input, coords)
|
||||
|
||||
return feats.permute(0, 2, 3, 1, 4).view(
|
||||
B, feats.shape[2], feats.shape[3], feats.shape[1]
|
||||
) # B C R1 R2 1 -> B R1 R2 C
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, Tuple
|
||||
|
||||
EPS = 1e-6
|
||||
|
||||
|
||||
def smart_cat(tensor1, tensor2, dim):
|
||||
if tensor1 is None:
|
||||
return tensor2
|
||||
return torch.cat([tensor1, tensor2], dim=dim)
|
||||
|
||||
|
||||
def get_points_on_a_grid(
|
||||
size: int,
|
||||
extent: Tuple[float, ...],
|
||||
center: Optional[Tuple[float, ...]] = None,
|
||||
device: Optional[torch.device] = torch.device("cpu"),
|
||||
):
|
||||
r"""Get a grid of points covering a rectangular region
|
||||
|
||||
`get_points_on_a_grid(size, extent)` generates a :attr:`size` by
|
||||
:attr:`size` grid fo points distributed to cover a rectangular area
|
||||
specified by `extent`.
|
||||
|
||||
The `extent` is a pair of integer :math:`(H,W)` specifying the height
|
||||
and width of the rectangle.
|
||||
|
||||
Optionally, the :attr:`center` can be specified as a pair :math:`(c_y,c_x)`
|
||||
specifying the vertical and horizontal center coordinates. The center
|
||||
defaults to the middle of the extent.
|
||||
|
||||
Points are distributed uniformly within the rectangle leaving a margin
|
||||
:math:`m=W/64` from the border.
|
||||
|
||||
It returns a :math:`(1, \text{size} \times \text{size}, 2)` tensor of
|
||||
points :math:`P_{ij}=(x_i, y_i)` where
|
||||
|
||||
.. math::
|
||||
P_{ij} = \left(
|
||||
c_x + m -\frac{W}{2} + \frac{W - 2m}{\text{size} - 1}\, j,~
|
||||
c_y + m -\frac{H}{2} + \frac{H - 2m}{\text{size} - 1}\, i
|
||||
\right)
|
||||
|
||||
Points are returned in row-major order.
|
||||
|
||||
Args:
|
||||
size (int): grid size.
|
||||
extent (tuple): height and with of the grid extent.
|
||||
center (tuple, optional): grid center.
|
||||
device (str, optional): Defaults to `"cpu"`.
|
||||
|
||||
Returns:
|
||||
Tensor: grid.
|
||||
"""
|
||||
if size == 1:
|
||||
return torch.tensor([extent[1] / 2, extent[0] / 2], device=device)[None, None]
|
||||
|
||||
if center is None:
|
||||
center = [extent[0] / 2, extent[1] / 2]
|
||||
|
||||
margin = extent[1] / 64
|
||||
range_y = (margin - extent[0] / 2 + center[0], extent[0] / 2 + center[0] - margin)
|
||||
range_x = (margin - extent[1] / 2 + center[1], extent[1] / 2 + center[1] - margin)
|
||||
grid_y, grid_x = torch.meshgrid(
|
||||
torch.linspace(*range_y, size, device=device),
|
||||
torch.linspace(*range_x, size, device=device),
|
||||
indexing="ij",
|
||||
)
|
||||
return torch.stack([grid_x, grid_y], dim=-1).reshape(1, -1, 2)
|
||||
|
||||
|
||||
def reduce_masked_mean(input, mask, dim=None, keepdim=False):
|
||||
r"""Masked mean
|
||||
|
||||
`reduce_masked_mean(x, mask)` computes the mean of a tensor :attr:`input`
|
||||
over a mask :attr:`mask`, returning
|
||||
|
||||
.. math::
|
||||
\text{output} =
|
||||
\frac
|
||||
{\sum_{i=1}^N \text{input}_i \cdot \text{mask}_i}
|
||||
{\epsilon + \sum_{i=1}^N \text{mask}_i}
|
||||
|
||||
where :math:`N` is the number of elements in :attr:`input` and
|
||||
:attr:`mask`, and :math:`\epsilon` is a small constant to avoid
|
||||
division by zero.
|
||||
|
||||
`reduced_masked_mean(x, mask, dim)` computes the mean of a tensor
|
||||
:attr:`input` over a mask :attr:`mask` along a dimension :attr:`dim`.
|
||||
Optionally, the dimension can be kept in the output by setting
|
||||
:attr:`keepdim` to `True`. Tensor :attr:`mask` must be broadcastable to
|
||||
the same dimension as :attr:`input`.
|
||||
|
||||
The interface is similar to `torch.mean()`.
|
||||
|
||||
Args:
|
||||
inout (Tensor): input tensor.
|
||||
mask (Tensor): mask.
|
||||
dim (int, optional): Dimension to sum over. Defaults to None.
|
||||
keepdim (bool, optional): Keep the summed dimension. Defaults to False.
|
||||
|
||||
Returns:
|
||||
Tensor: mean tensor.
|
||||
"""
|
||||
|
||||
mask = mask.expand_as(input)
|
||||
|
||||
prod = input * mask
|
||||
|
||||
if dim is None:
|
||||
numer = torch.sum(prod)
|
||||
denom = torch.sum(mask)
|
||||
else:
|
||||
numer = torch.sum(prod, dim=dim, keepdim=keepdim)
|
||||
denom = torch.sum(mask, dim=dim, keepdim=keepdim)
|
||||
|
||||
mean = numer / (EPS + denom)
|
||||
return mean
|
||||
|
||||
|
||||
def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
|
||||
r"""Sample a tensor using bilinear interpolation
|
||||
|
||||
`bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
|
||||
coordinates :attr:`coords` using bilinear interpolation. It is the same
|
||||
as `torch.nn.functional.grid_sample()` but with a different coordinate
|
||||
convention.
|
||||
|
||||
The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
|
||||
:math:`B` is the batch size, :math:`C` is the number of channels,
|
||||
:math:`H` is the height of the image, and :math:`W` is the width of the
|
||||
image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
|
||||
interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
|
||||
|
||||
Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
|
||||
in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
|
||||
that in this case the order of the components is slightly different
|
||||
from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
|
||||
|
||||
If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
|
||||
in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
|
||||
left-most image pixel :math:`W-1` to the center of the right-most
|
||||
pixel.
|
||||
|
||||
If `align_corners` is `False`, the coordinate :math:`x` is assumed to
|
||||
be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
|
||||
the left-most pixel :math:`W` to the right edge of the right-most
|
||||
pixel.
|
||||
|
||||
Similar conventions apply to the :math:`y` for the range
|
||||
:math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
|
||||
:math:`[0,T-1]` and :math:`[0,T]`.
|
||||
|
||||
Args:
|
||||
input (Tensor): batch of input images.
|
||||
coords (Tensor): batch of coordinates.
|
||||
align_corners (bool, optional): Coordinate convention. Defaults to `True`.
|
||||
padding_mode (str, optional): Padding mode. Defaults to `"border"`.
|
||||
|
||||
Returns:
|
||||
Tensor: sampled points.
|
||||
"""
|
||||
|
||||
sizes = input.shape[2:]
|
||||
|
||||
assert len(sizes) in [2, 3]
|
||||
|
||||
if len(sizes) == 3:
|
||||
# t x y -> x y t to match dimensions T H W in grid_sample
|
||||
coords = coords[..., [1, 2, 0]]
|
||||
|
||||
if align_corners:
|
||||
coords = coords * torch.tensor(
|
||||
[2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device
|
||||
)
|
||||
else:
|
||||
coords = coords * torch.tensor([2 / size for size in reversed(sizes)], device=coords.device)
|
||||
|
||||
coords -= 1
|
||||
|
||||
return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
|
||||
|
||||
|
||||
def sample_features4d(input, coords):
|
||||
r"""Sample spatial features
|
||||
|
||||
`sample_features4d(input, coords)` samples the spatial features
|
||||
:attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
|
||||
|
||||
The field is sampled at coordinates :attr:`coords` using bilinear
|
||||
interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
|
||||
3)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
|
||||
same convention as :func:`bilinear_sampler` with `align_corners=True`.
|
||||
|
||||
The output tensor has one feature per point, and has shape :math:`(B,
|
||||
R, C)`.
|
||||
|
||||
Args:
|
||||
input (Tensor): spatial features.
|
||||
coords (Tensor): points.
|
||||
|
||||
Returns:
|
||||
Tensor: sampled features.
|
||||
"""
|
||||
|
||||
B, _, _, _ = input.shape
|
||||
|
||||
# B R 2 -> B R 1 2
|
||||
coords = coords.unsqueeze(2)
|
||||
|
||||
# B C R 1
|
||||
feats = bilinear_sampler(input, coords)
|
||||
|
||||
return feats.permute(0, 2, 1, 3).view(
|
||||
B, -1, feats.shape[1] * feats.shape[3]
|
||||
) # B C R 1 -> B R C
|
||||
|
||||
|
||||
def sample_features5d(input, coords):
|
||||
r"""Sample spatio-temporal features
|
||||
|
||||
`sample_features5d(input, coords)` works in the same way as
|
||||
:func:`sample_features4d` but for spatio-temporal features and points:
|
||||
:attr:`input` is a 5D tensor :math:`(B, T, C, H, W)`, :attr:`coords` is
|
||||
a :math:`(B, R1, R2, 3)` tensor of spatio-temporal point :math:`(t_i,
|
||||
x_i, y_i)`. The output tensor has shape :math:`(B, R1, R2, C)`.
|
||||
|
||||
Args:
|
||||
input (Tensor): spatio-temporal features.
|
||||
coords (Tensor): spatio-temporal points.
|
||||
|
||||
Returns:
|
||||
Tensor: sampled features.
|
||||
"""
|
||||
|
||||
B, T, _, _, _ = input.shape
|
||||
|
||||
# B T C H W -> B C T H W
|
||||
input = input.permute(0, 2, 1, 3, 4)
|
||||
|
||||
# B R1 R2 3 -> B R1 R2 1 3
|
||||
coords = coords.unsqueeze(3)
|
||||
|
||||
# B C R1 R2 1
|
||||
feats = bilinear_sampler(input, coords)
|
||||
|
||||
return feats.permute(0, 2, 3, 1, 4).view(
|
||||
B, feats.shape[2], feats.shape[3], feats.shape[1]
|
||||
) # B C R1 R2 1 -> B R1 R2 C
|
||||
|
@@ -1,104 +1,104 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Tuple
|
||||
|
||||
from cotracker.models.core.cotracker.cotracker import CoTracker2
|
||||
from cotracker.models.core.model_utils import get_points_on_a_grid
|
||||
|
||||
|
||||
class EvaluationPredictor(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
cotracker_model: CoTracker2,
|
||||
interp_shape: Tuple[int, int] = (384, 512),
|
||||
grid_size: int = 5,
|
||||
local_grid_size: int = 8,
|
||||
single_point: bool = True,
|
||||
n_iters: int = 6,
|
||||
) -> None:
|
||||
super(EvaluationPredictor, self).__init__()
|
||||
self.grid_size = grid_size
|
||||
self.local_grid_size = local_grid_size
|
||||
self.single_point = single_point
|
||||
self.interp_shape = interp_shape
|
||||
self.n_iters = n_iters
|
||||
|
||||
self.model = cotracker_model
|
||||
self.model.eval()
|
||||
|
||||
def forward(self, video, queries):
|
||||
queries = queries.clone()
|
||||
B, T, C, H, W = video.shape
|
||||
B, N, D = queries.shape
|
||||
|
||||
assert D == 3
|
||||
|
||||
video = video.reshape(B * T, C, H, W)
|
||||
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True)
|
||||
video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
|
||||
|
||||
device = video.device
|
||||
|
||||
queries[:, :, 1] *= (self.interp_shape[1] - 1) / (W - 1)
|
||||
queries[:, :, 2] *= (self.interp_shape[0] - 1) / (H - 1)
|
||||
|
||||
if self.single_point:
|
||||
traj_e = torch.zeros((B, T, N, 2), device=device)
|
||||
vis_e = torch.zeros((B, T, N), device=device)
|
||||
for pind in range((N)):
|
||||
query = queries[:, pind : pind + 1]
|
||||
|
||||
t = query[0, 0, 0].long()
|
||||
|
||||
traj_e_pind, vis_e_pind = self._process_one_point(video, query)
|
||||
traj_e[:, t:, pind : pind + 1] = traj_e_pind[:, :, :1]
|
||||
vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1]
|
||||
else:
|
||||
if self.grid_size > 0:
|
||||
xy = get_points_on_a_grid(self.grid_size, video.shape[3:])
|
||||
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
|
||||
queries = torch.cat([queries, xy], dim=1) #
|
||||
|
||||
traj_e, vis_e, __ = self.model(
|
||||
video=video,
|
||||
queries=queries,
|
||||
iters=self.n_iters,
|
||||
)
|
||||
|
||||
traj_e[:, :, :, 0] *= (W - 1) / float(self.interp_shape[1] - 1)
|
||||
traj_e[:, :, :, 1] *= (H - 1) / float(self.interp_shape[0] - 1)
|
||||
return traj_e, vis_e
|
||||
|
||||
def _process_one_point(self, video, query):
|
||||
t = query[0, 0, 0].long()
|
||||
|
||||
device = query.device
|
||||
if self.local_grid_size > 0:
|
||||
xy_target = get_points_on_a_grid(
|
||||
self.local_grid_size,
|
||||
(50, 50),
|
||||
[query[0, 0, 2].item(), query[0, 0, 1].item()],
|
||||
)
|
||||
|
||||
xy_target = torch.cat([torch.zeros_like(xy_target[:, :, :1]), xy_target], dim=2).to(
|
||||
device
|
||||
) #
|
||||
query = torch.cat([query, xy_target], dim=1) #
|
||||
|
||||
if self.grid_size > 0:
|
||||
xy = get_points_on_a_grid(self.grid_size, video.shape[3:])
|
||||
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
|
||||
query = torch.cat([query, xy], dim=1) #
|
||||
# crop the video to start from the queried frame
|
||||
query[0, 0, 0] = 0
|
||||
traj_e_pind, vis_e_pind, __ = self.model(
|
||||
video=video[:, t:], queries=query, iters=self.n_iters
|
||||
)
|
||||
|
||||
return traj_e_pind, vis_e_pind
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Tuple
|
||||
|
||||
from cotracker.models.core.cotracker.cotracker import CoTracker2
|
||||
from cotracker.models.core.model_utils import get_points_on_a_grid
|
||||
|
||||
|
||||
class EvaluationPredictor(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
cotracker_model: CoTracker2,
|
||||
interp_shape: Tuple[int, int] = (384, 512),
|
||||
grid_size: int = 5,
|
||||
local_grid_size: int = 8,
|
||||
single_point: bool = True,
|
||||
n_iters: int = 6,
|
||||
) -> None:
|
||||
super(EvaluationPredictor, self).__init__()
|
||||
self.grid_size = grid_size
|
||||
self.local_grid_size = local_grid_size
|
||||
self.single_point = single_point
|
||||
self.interp_shape = interp_shape
|
||||
self.n_iters = n_iters
|
||||
|
||||
self.model = cotracker_model
|
||||
self.model.eval()
|
||||
|
||||
def forward(self, video, queries):
|
||||
queries = queries.clone()
|
||||
B, T, C, H, W = video.shape
|
||||
B, N, D = queries.shape
|
||||
|
||||
assert D == 3
|
||||
|
||||
video = video.reshape(B * T, C, H, W)
|
||||
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True)
|
||||
video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
|
||||
|
||||
device = video.device
|
||||
|
||||
queries[:, :, 1] *= (self.interp_shape[1] - 1) / (W - 1)
|
||||
queries[:, :, 2] *= (self.interp_shape[0] - 1) / (H - 1)
|
||||
|
||||
if self.single_point:
|
||||
traj_e = torch.zeros((B, T, N, 2), device=device)
|
||||
vis_e = torch.zeros((B, T, N), device=device)
|
||||
for pind in range((N)):
|
||||
query = queries[:, pind : pind + 1]
|
||||
|
||||
t = query[0, 0, 0].long()
|
||||
|
||||
traj_e_pind, vis_e_pind = self._process_one_point(video, query)
|
||||
traj_e[:, t:, pind : pind + 1] = traj_e_pind[:, :, :1]
|
||||
vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1]
|
||||
else:
|
||||
if self.grid_size > 0:
|
||||
xy = get_points_on_a_grid(self.grid_size, video.shape[3:])
|
||||
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
|
||||
queries = torch.cat([queries, xy], dim=1) #
|
||||
|
||||
traj_e, vis_e, __ = self.model(
|
||||
video=video,
|
||||
queries=queries,
|
||||
iters=self.n_iters,
|
||||
)
|
||||
|
||||
traj_e[:, :, :, 0] *= (W - 1) / float(self.interp_shape[1] - 1)
|
||||
traj_e[:, :, :, 1] *= (H - 1) / float(self.interp_shape[0] - 1)
|
||||
return traj_e, vis_e
|
||||
|
||||
def _process_one_point(self, video, query):
|
||||
t = query[0, 0, 0].long()
|
||||
|
||||
device = query.device
|
||||
if self.local_grid_size > 0:
|
||||
xy_target = get_points_on_a_grid(
|
||||
self.local_grid_size,
|
||||
(50, 50),
|
||||
[query[0, 0, 2].item(), query[0, 0, 1].item()],
|
||||
)
|
||||
|
||||
xy_target = torch.cat([torch.zeros_like(xy_target[:, :, :1]), xy_target], dim=2).to(
|
||||
device
|
||||
) #
|
||||
query = torch.cat([query, xy_target], dim=1) #
|
||||
|
||||
if self.grid_size > 0:
|
||||
xy = get_points_on_a_grid(self.grid_size, video.shape[3:])
|
||||
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
|
||||
query = torch.cat([query, xy], dim=1) #
|
||||
# crop the video to start from the queried frame
|
||||
query[0, 0, 0] = 0
|
||||
traj_e_pind, vis_e_pind, __ = self.model(
|
||||
video=video[:, t:], queries=query, iters=self.n_iters
|
||||
)
|
||||
|
||||
return traj_e_pind, vis_e_pind
|
||||
|
@@ -1,275 +1,279 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from cotracker.models.core.model_utils import smart_cat, get_points_on_a_grid
|
||||
from cotracker.models.build_cotracker import build_cotracker
|
||||
|
||||
|
||||
class CoTrackerPredictor(torch.nn.Module):
|
||||
def __init__(self, checkpoint="./checkpoints/cotracker2.pth"):
|
||||
super().__init__()
|
||||
self.support_grid_size = 6
|
||||
model = build_cotracker(checkpoint)
|
||||
self.interp_shape = model.model_resolution
|
||||
print(self.interp_shape)
|
||||
self.model = model
|
||||
self.model.eval()
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
video, # (B, T, 3, H, W) Batch_size, time, rgb, height, width
|
||||
# input prompt types:
|
||||
# - None. Dense tracks are computed in this case. You can adjust *query_frame* to compute tracks starting from a specific frame.
|
||||
# *backward_tracking=True* will compute tracks in both directions.
|
||||
# - queries. Queried points of shape (B, N, 3) in format (t, x, y) for frame index and pixel coordinates.
|
||||
# - grid_size. Grid of N*N points from the first frame. if segm_mask is provided, then computed only for the mask.
|
||||
# You can adjust *query_frame* and *backward_tracking* for the regular grid in the same way as for dense tracks.
|
||||
queries: torch.Tensor = None,
|
||||
segm_mask: torch.Tensor = None, # Segmentation mask of shape (B, 1, H, W)
|
||||
grid_size: int = 0,
|
||||
grid_query_frame: int = 0, # only for dense and regular grid tracks
|
||||
backward_tracking: bool = False,
|
||||
):
|
||||
if queries is None and grid_size == 0:
|
||||
tracks, visibilities = self._compute_dense_tracks(
|
||||
video,
|
||||
grid_query_frame=grid_query_frame,
|
||||
backward_tracking=backward_tracking,
|
||||
)
|
||||
else:
|
||||
tracks, visibilities = self._compute_sparse_tracks(
|
||||
video,
|
||||
queries,
|
||||
segm_mask,
|
||||
grid_size,
|
||||
add_support_grid=(grid_size == 0 or segm_mask is not None),
|
||||
grid_query_frame=grid_query_frame,
|
||||
backward_tracking=backward_tracking,
|
||||
)
|
||||
|
||||
return tracks, visibilities
|
||||
|
||||
def _compute_dense_tracks(self, video, grid_query_frame, grid_size=80, backward_tracking=False):
|
||||
*_, H, W = video.shape
|
||||
grid_step = W // grid_size
|
||||
grid_width = W // grid_step
|
||||
grid_height = H // grid_step # set the whole video to grid_size number of grids
|
||||
tracks = visibilities = None
|
||||
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(video.device)
|
||||
# (batch_size, grid_number, t,x,y)
|
||||
grid_pts[0, :, 0] = grid_query_frame
|
||||
# iterate every grid
|
||||
for offset in range(grid_step * grid_step):
|
||||
print(f"step {offset} / {grid_step * grid_step}")
|
||||
ox = offset % grid_step
|
||||
oy = offset // grid_step
|
||||
# initialize
|
||||
# for example
|
||||
# grid width = 4, grid height = 4, grid step = 10, ox = 1
|
||||
# torch.arange(grid_width) = [0,1,2,3]
|
||||
# torch.arange(grid_width).repeat(grid_height) = [0,1,2,3,0,1,2,3,0,1,2,3]
|
||||
# torch.arange(grid_width).repeat(grid_height) * grid_step = [0,10,20,30,0,10,20,30,0,10,20,30]
|
||||
# get the location in the image
|
||||
grid_pts[0, :, 1] = torch.arange(grid_width).repeat(grid_height) * grid_step + ox
|
||||
grid_pts[0, :, 2] = (
|
||||
torch.arange(grid_height).repeat_interleave(grid_width) * grid_step + oy
|
||||
)
|
||||
tracks_step, visibilities_step = self._compute_sparse_tracks(
|
||||
video=video,
|
||||
queries=grid_pts,
|
||||
backward_tracking=backward_tracking,
|
||||
)
|
||||
tracks = smart_cat(tracks, tracks_step, dim=2)
|
||||
visibilities = smart_cat(visibilities, visibilities_step, dim=2)
|
||||
|
||||
return tracks, visibilities
|
||||
|
||||
def _compute_sparse_tracks(
|
||||
self,
|
||||
video,
|
||||
queries,
|
||||
segm_mask=None,
|
||||
grid_size=0,
|
||||
add_support_grid=False,
|
||||
grid_query_frame=0,
|
||||
backward_tracking=False,
|
||||
):
|
||||
B, T, C, H, W = video.shape
|
||||
|
||||
video = video.reshape(B * T, C, H, W)
|
||||
# ? what is interpolate?
|
||||
# 将video插值成interp_shape?
|
||||
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True)
|
||||
video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
|
||||
|
||||
if queries is not None:
|
||||
B, N, D = queries.shape # batch_size, number of points, (t,x,y)
|
||||
assert D == 3
|
||||
# query 缩放到( interp_shape - 1 ) / (W - 1)
|
||||
# 插完值之后缩放
|
||||
queries = queries.clone()
|
||||
queries[:, :, 1:] *= queries.new_tensor(
|
||||
[
|
||||
(self.interp_shape[1] - 1) / (W - 1),
|
||||
(self.interp_shape[0] - 1) / (H - 1),
|
||||
]
|
||||
)
|
||||
# 生成grid
|
||||
elif grid_size > 0:
|
||||
grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=video.device)
|
||||
if segm_mask is not None:
|
||||
segm_mask = F.interpolate(segm_mask, tuple(self.interp_shape), mode="nearest")
|
||||
point_mask = segm_mask[0, 0][
|
||||
(grid_pts[0, :, 1]).round().long().cpu(),
|
||||
(grid_pts[0, :, 0]).round().long().cpu(),
|
||||
].bool()
|
||||
grid_pts = grid_pts[:, point_mask]
|
||||
|
||||
queries = torch.cat(
|
||||
[torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts],
|
||||
dim=2,
|
||||
).repeat(B, 1, 1)
|
||||
|
||||
# 添加支持点
|
||||
|
||||
if add_support_grid:
|
||||
grid_pts = get_points_on_a_grid(
|
||||
self.support_grid_size, self.interp_shape, device=video.device
|
||||
)
|
||||
grid_pts = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)
|
||||
grid_pts = grid_pts.repeat(B, 1, 1)
|
||||
queries = torch.cat([queries, grid_pts], dim=1)
|
||||
|
||||
tracks, visibilities, __ = self.model.forward(video=video, queries=queries, iters=6)
|
||||
|
||||
if backward_tracking:
|
||||
tracks, visibilities = self._compute_backward_tracks(
|
||||
video, queries, tracks, visibilities
|
||||
)
|
||||
if add_support_grid:
|
||||
queries[:, -self.support_grid_size**2 :, 0] = T - 1
|
||||
if add_support_grid:
|
||||
tracks = tracks[:, :, : -self.support_grid_size**2]
|
||||
visibilities = visibilities[:, :, : -self.support_grid_size**2]
|
||||
thr = 0.9
|
||||
visibilities = visibilities > thr
|
||||
|
||||
# correct query-point predictions
|
||||
# see https://github.com/facebookresearch/co-tracker/issues/28
|
||||
|
||||
# TODO: batchify
|
||||
for i in range(len(queries)):
|
||||
queries_t = queries[i, : tracks.size(2), 0].to(torch.int64)
|
||||
arange = torch.arange(0, len(queries_t))
|
||||
|
||||
# overwrite the predictions with the query points
|
||||
tracks[i, queries_t, arange] = queries[i, : tracks.size(2), 1:]
|
||||
|
||||
# correct visibilities, the query points should be visible
|
||||
visibilities[i, queries_t, arange] = True
|
||||
|
||||
tracks *= tracks.new_tensor(
|
||||
[(W - 1) / (self.interp_shape[1] - 1), (H - 1) / (self.interp_shape[0] - 1)]
|
||||
)
|
||||
return tracks, visibilities
|
||||
|
||||
def _compute_backward_tracks(self, video, queries, tracks, visibilities):
|
||||
inv_video = video.flip(1).clone()
|
||||
inv_queries = queries.clone()
|
||||
inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1
|
||||
|
||||
inv_tracks, inv_visibilities, __ = self.model(video=inv_video, queries=inv_queries, iters=6)
|
||||
|
||||
inv_tracks = inv_tracks.flip(1)
|
||||
inv_visibilities = inv_visibilities.flip(1)
|
||||
arange = torch.arange(video.shape[1], device=queries.device)[None, :, None]
|
||||
|
||||
mask = (arange < queries[:, None, :, 0]).unsqueeze(-1).repeat(1, 1, 1, 2)
|
||||
|
||||
tracks[mask] = inv_tracks[mask]
|
||||
visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]]
|
||||
return tracks, visibilities
|
||||
|
||||
|
||||
class CoTrackerOnlinePredictor(torch.nn.Module):
|
||||
def __init__(self, checkpoint="./checkpoints/cotracker2.pth"):
|
||||
super().__init__()
|
||||
self.support_grid_size = 6
|
||||
model = build_cotracker(checkpoint)
|
||||
self.interp_shape = model.model_resolution
|
||||
self.step = model.window_len // 2
|
||||
self.model = model
|
||||
self.model.eval()
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
video_chunk,
|
||||
is_first_step: bool = False,
|
||||
queries: torch.Tensor = None,
|
||||
grid_size: int = 10,
|
||||
grid_query_frame: int = 0,
|
||||
add_support_grid=False,
|
||||
):
|
||||
B, T, C, H, W = video_chunk.shape
|
||||
# Initialize online video processing and save queried points
|
||||
# This needs to be done before processing *each new video*
|
||||
if is_first_step:
|
||||
self.model.init_video_online_processing()
|
||||
if queries is not None:
|
||||
B, N, D = queries.shape
|
||||
assert D == 3
|
||||
queries = queries.clone()
|
||||
queries[:, :, 1:] *= queries.new_tensor(
|
||||
[
|
||||
(self.interp_shape[1] - 1) / (W - 1),
|
||||
(self.interp_shape[0] - 1) / (H - 1),
|
||||
]
|
||||
)
|
||||
elif grid_size > 0:
|
||||
grid_pts = get_points_on_a_grid(
|
||||
grid_size, self.interp_shape, device=video_chunk.device
|
||||
)
|
||||
queries = torch.cat(
|
||||
[torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts],
|
||||
dim=2,
|
||||
)
|
||||
if add_support_grid:
|
||||
grid_pts = get_points_on_a_grid(
|
||||
self.support_grid_size, self.interp_shape, device=video_chunk.device
|
||||
)
|
||||
grid_pts = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)
|
||||
queries = torch.cat([queries, grid_pts], dim=1)
|
||||
self.queries = queries
|
||||
return (None, None)
|
||||
|
||||
video_chunk = video_chunk.reshape(B * T, C, H, W)
|
||||
video_chunk = F.interpolate(
|
||||
video_chunk, tuple(self.interp_shape), mode="bilinear", align_corners=True
|
||||
)
|
||||
video_chunk = video_chunk.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
|
||||
|
||||
tracks, visibilities, __ = self.model(
|
||||
video=video_chunk,
|
||||
queries=self.queries,
|
||||
iters=6,
|
||||
is_online=True,
|
||||
)
|
||||
thr = 0.9
|
||||
return (
|
||||
tracks
|
||||
* tracks.new_tensor(
|
||||
[
|
||||
(W - 1) / (self.interp_shape[1] - 1),
|
||||
(H - 1) / (self.interp_shape[0] - 1),
|
||||
]
|
||||
),
|
||||
visibilities > thr,
|
||||
)
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from cotracker.models.core.model_utils import smart_cat, get_points_on_a_grid
|
||||
from cotracker.models.build_cotracker import build_cotracker
|
||||
|
||||
|
||||
class CoTrackerPredictor(torch.nn.Module):
|
||||
def __init__(self, checkpoint="./checkpoints/cotracker2.pth"):
|
||||
super().__init__()
|
||||
self.support_grid_size = 6
|
||||
model = build_cotracker(checkpoint)
|
||||
self.interp_shape = model.model_resolution
|
||||
print(self.interp_shape)
|
||||
self.model = model
|
||||
self.model.eval()
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
video, # (B, T, 3, H, W) Batch_size, time, rgb, height, width
|
||||
# input prompt types:
|
||||
# - None. Dense tracks are computed in this case. You can adjust *query_frame* to compute tracks starting from a specific frame.
|
||||
# *backward_tracking=True* will compute tracks in both directions.
|
||||
# - queries. Queried points of shape (B, N, 3) in format (t, x, y) for frame index and pixel coordinates.
|
||||
# - grid_size. Grid of N*N points from the first frame. if segm_mask is provided, then computed only for the mask.
|
||||
# You can adjust *query_frame* and *backward_tracking* for the regular grid in the same way as for dense tracks.
|
||||
queries: torch.Tensor = None,
|
||||
segm_mask: torch.Tensor = None, # Segmentation mask of shape (B, 1, H, W)
|
||||
grid_size: int = 0,
|
||||
grid_query_frame: int = 0, # only for dense and regular grid tracks
|
||||
backward_tracking: bool = False,
|
||||
):
|
||||
if queries is None and grid_size == 0:
|
||||
tracks, visibilities = self._compute_dense_tracks(
|
||||
video,
|
||||
grid_query_frame=grid_query_frame,
|
||||
backward_tracking=backward_tracking,
|
||||
)
|
||||
else:
|
||||
tracks, visibilities = self._compute_sparse_tracks(
|
||||
video,
|
||||
queries,
|
||||
segm_mask,
|
||||
grid_size,
|
||||
add_support_grid=(grid_size == 0 or segm_mask is not None),
|
||||
grid_query_frame=grid_query_frame,
|
||||
backward_tracking=backward_tracking,
|
||||
)
|
||||
|
||||
return tracks, visibilities
|
||||
|
||||
# gpu dense inference time
|
||||
# raft gpu comparison
|
||||
# vision effects
|
||||
# raft integrated
|
||||
def _compute_dense_tracks(self, video, grid_query_frame, grid_size=80, backward_tracking=False):
|
||||
*_, H, W = video.shape
|
||||
grid_step = W // grid_size
|
||||
grid_width = W // grid_step
|
||||
grid_height = H // grid_step # set the whole video to grid_size number of grids
|
||||
tracks = visibilities = None
|
||||
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(video.device)
|
||||
# (batch_size, grid_number, t,x,y)
|
||||
grid_pts[0, :, 0] = grid_query_frame
|
||||
# iterate every grid
|
||||
for offset in range(grid_step * grid_step):
|
||||
print(f"step {offset} / {grid_step * grid_step}")
|
||||
ox = offset % grid_step
|
||||
oy = offset // grid_step
|
||||
# initialize
|
||||
# for example
|
||||
# grid width = 4, grid height = 4, grid step = 10, ox = 1
|
||||
# torch.arange(grid_width) = [0,1,2,3]
|
||||
# torch.arange(grid_width).repeat(grid_height) = [0,1,2,3,0,1,2,3,0,1,2,3]
|
||||
# torch.arange(grid_width).repeat(grid_height) * grid_step = [0,10,20,30,0,10,20,30,0,10,20,30]
|
||||
# get the location in the image
|
||||
grid_pts[0, :, 1] = torch.arange(grid_width).repeat(grid_height) * grid_step + ox
|
||||
grid_pts[0, :, 2] = (
|
||||
torch.arange(grid_height).repeat_interleave(grid_width) * grid_step + oy
|
||||
)
|
||||
tracks_step, visibilities_step = self._compute_sparse_tracks(
|
||||
video=video,
|
||||
queries=grid_pts,
|
||||
backward_tracking=backward_tracking,
|
||||
)
|
||||
tracks = smart_cat(tracks, tracks_step, dim=2)
|
||||
visibilities = smart_cat(visibilities, visibilities_step, dim=2)
|
||||
|
||||
return tracks, visibilities
|
||||
|
||||
def _compute_sparse_tracks(
|
||||
self,
|
||||
video,
|
||||
queries,
|
||||
segm_mask=None,
|
||||
grid_size=0,
|
||||
add_support_grid=False,
|
||||
grid_query_frame=0,
|
||||
backward_tracking=False,
|
||||
):
|
||||
B, T, C, H, W = video.shape
|
||||
|
||||
video = video.reshape(B * T, C, H, W)
|
||||
# ? what is interpolate?
|
||||
# 将video插值成interp_shape?
|
||||
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True)
|
||||
video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
|
||||
|
||||
if queries is not None:
|
||||
B, N, D = queries.shape # batch_size, number of points, (t,x,y)
|
||||
assert D == 3
|
||||
# query 缩放到( interp_shape - 1 ) / (W - 1)
|
||||
# 插完值之后缩放
|
||||
queries = queries.clone()
|
||||
queries[:, :, 1:] *= queries.new_tensor(
|
||||
[
|
||||
(self.interp_shape[1] - 1) / (W - 1),
|
||||
(self.interp_shape[0] - 1) / (H - 1),
|
||||
]
|
||||
)
|
||||
# 生成grid
|
||||
elif grid_size > 0:
|
||||
grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=video.device)
|
||||
if segm_mask is not None:
|
||||
segm_mask = F.interpolate(segm_mask, tuple(self.interp_shape), mode="nearest")
|
||||
point_mask = segm_mask[0, 0][
|
||||
(grid_pts[0, :, 1]).round().long().cpu(),
|
||||
(grid_pts[0, :, 0]).round().long().cpu(),
|
||||
].bool()
|
||||
grid_pts = grid_pts[:, point_mask]
|
||||
|
||||
queries = torch.cat(
|
||||
[torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts],
|
||||
dim=2,
|
||||
).repeat(B, 1, 1)
|
||||
|
||||
# 添加支持点
|
||||
|
||||
if add_support_grid:
|
||||
grid_pts = get_points_on_a_grid(
|
||||
self.support_grid_size, self.interp_shape, device=video.device
|
||||
)
|
||||
grid_pts = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)
|
||||
grid_pts = grid_pts.repeat(B, 1, 1)
|
||||
queries = torch.cat([queries, grid_pts], dim=1)
|
||||
|
||||
tracks, visibilities, __ = self.model.forward(video=video, queries=queries, iters=6)
|
||||
|
||||
if backward_tracking:
|
||||
tracks, visibilities = self._compute_backward_tracks(
|
||||
video, queries, tracks, visibilities
|
||||
)
|
||||
if add_support_grid:
|
||||
queries[:, -self.support_grid_size**2 :, 0] = T - 1
|
||||
if add_support_grid:
|
||||
tracks = tracks[:, :, : -self.support_grid_size**2]
|
||||
visibilities = visibilities[:, :, : -self.support_grid_size**2]
|
||||
thr = 0.9
|
||||
visibilities = visibilities > thr
|
||||
|
||||
# correct query-point predictions
|
||||
# see https://github.com/facebookresearch/co-tracker/issues/28
|
||||
|
||||
# TODO: batchify
|
||||
for i in range(len(queries)):
|
||||
queries_t = queries[i, : tracks.size(2), 0].to(torch.int64)
|
||||
arange = torch.arange(0, len(queries_t))
|
||||
|
||||
# overwrite the predictions with the query points
|
||||
tracks[i, queries_t, arange] = queries[i, : tracks.size(2), 1:]
|
||||
|
||||
# correct visibilities, the query points should be visible
|
||||
visibilities[i, queries_t, arange] = True
|
||||
|
||||
tracks *= tracks.new_tensor(
|
||||
[(W - 1) / (self.interp_shape[1] - 1), (H - 1) / (self.interp_shape[0] - 1)]
|
||||
)
|
||||
return tracks, visibilities
|
||||
|
||||
def _compute_backward_tracks(self, video, queries, tracks, visibilities):
|
||||
inv_video = video.flip(1).clone()
|
||||
inv_queries = queries.clone()
|
||||
inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1
|
||||
|
||||
inv_tracks, inv_visibilities, __ = self.model(video=inv_video, queries=inv_queries, iters=6)
|
||||
|
||||
inv_tracks = inv_tracks.flip(1)
|
||||
inv_visibilities = inv_visibilities.flip(1)
|
||||
arange = torch.arange(video.shape[1], device=queries.device)[None, :, None]
|
||||
|
||||
mask = (arange < queries[:, None, :, 0]).unsqueeze(-1).repeat(1, 1, 1, 2)
|
||||
|
||||
tracks[mask] = inv_tracks[mask]
|
||||
visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]]
|
||||
return tracks, visibilities
|
||||
|
||||
|
||||
class CoTrackerOnlinePredictor(torch.nn.Module):
|
||||
def __init__(self, checkpoint="./checkpoints/cotracker2.pth"):
|
||||
super().__init__()
|
||||
self.support_grid_size = 6
|
||||
model = build_cotracker(checkpoint)
|
||||
self.interp_shape = model.model_resolution
|
||||
self.step = model.window_len // 2
|
||||
self.model = model
|
||||
self.model.eval()
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
video_chunk,
|
||||
is_first_step: bool = False,
|
||||
queries: torch.Tensor = None,
|
||||
grid_size: int = 10,
|
||||
grid_query_frame: int = 0,
|
||||
add_support_grid=False,
|
||||
):
|
||||
B, T, C, H, W = video_chunk.shape
|
||||
# Initialize online video processing and save queried points
|
||||
# This needs to be done before processing *each new video*
|
||||
if is_first_step:
|
||||
self.model.init_video_online_processing()
|
||||
if queries is not None:
|
||||
B, N, D = queries.shape
|
||||
assert D == 3
|
||||
queries = queries.clone()
|
||||
queries[:, :, 1:] *= queries.new_tensor(
|
||||
[
|
||||
(self.interp_shape[1] - 1) / (W - 1),
|
||||
(self.interp_shape[0] - 1) / (H - 1),
|
||||
]
|
||||
)
|
||||
elif grid_size > 0:
|
||||
grid_pts = get_points_on_a_grid(
|
||||
grid_size, self.interp_shape, device=video_chunk.device
|
||||
)
|
||||
queries = torch.cat(
|
||||
[torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts],
|
||||
dim=2,
|
||||
)
|
||||
if add_support_grid:
|
||||
grid_pts = get_points_on_a_grid(
|
||||
self.support_grid_size, self.interp_shape, device=video_chunk.device
|
||||
)
|
||||
grid_pts = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)
|
||||
queries = torch.cat([queries, grid_pts], dim=1)
|
||||
self.queries = queries
|
||||
return (None, None)
|
||||
|
||||
video_chunk = video_chunk.reshape(B * T, C, H, W)
|
||||
video_chunk = F.interpolate(
|
||||
video_chunk, tuple(self.interp_shape), mode="bilinear", align_corners=True
|
||||
)
|
||||
video_chunk = video_chunk.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
|
||||
|
||||
tracks, visibilities, __ = self.model(
|
||||
video=video_chunk,
|
||||
queries=self.queries,
|
||||
iters=6,
|
||||
is_online=True,
|
||||
)
|
||||
thr = 0.9
|
||||
return (
|
||||
tracks
|
||||
* tracks.new_tensor(
|
||||
[
|
||||
(W - 1) / (self.interp_shape[1] - 1),
|
||||
(H - 1) / (self.interp_shape[0] - 1),
|
||||
]
|
||||
),
|
||||
visibilities > thr,
|
||||
)
|
||||
|
@@ -1,5 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
BIN
cotracker/utils/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
cotracker/utils/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/utils/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
cotracker/utils/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
cotracker/utils/__pycache__/visualizer.cpython-38.pyc
Normal file
BIN
cotracker/utils/__pycache__/visualizer.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/utils/__pycache__/visualizer.cpython-39.pyc
Normal file
BIN
cotracker/utils/__pycache__/visualizer.cpython-39.pyc
Normal file
Binary file not shown.
@@ -1,343 +1,343 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import os
|
||||
import numpy as np
|
||||
import imageio
|
||||
import torch
|
||||
|
||||
from matplotlib import cm
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms as transforms
|
||||
import matplotlib.pyplot as plt
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
|
||||
def read_video_from_path(path):
|
||||
try:
|
||||
reader = imageio.get_reader(path)
|
||||
except Exception as e:
|
||||
print("Error opening video file: ", e)
|
||||
return None
|
||||
frames = []
|
||||
for i, im in enumerate(reader):
|
||||
frames.append(np.array(im))
|
||||
return np.stack(frames)
|
||||
|
||||
|
||||
def draw_circle(rgb, coord, radius, color=(255, 0, 0), visible=True):
|
||||
# Create a draw object
|
||||
draw = ImageDraw.Draw(rgb)
|
||||
# Calculate the bounding box of the circle
|
||||
left_up_point = (coord[0] - radius, coord[1] - radius)
|
||||
right_down_point = (coord[0] + radius, coord[1] + radius)
|
||||
# Draw the circle
|
||||
draw.ellipse(
|
||||
[left_up_point, right_down_point],
|
||||
fill=tuple(color) if visible else None,
|
||||
outline=tuple(color),
|
||||
)
|
||||
return rgb
|
||||
|
||||
|
||||
def draw_line(rgb, coord_y, coord_x, color, linewidth):
|
||||
draw = ImageDraw.Draw(rgb)
|
||||
draw.line(
|
||||
(coord_y[0], coord_y[1], coord_x[0], coord_x[1]),
|
||||
fill=tuple(color),
|
||||
width=linewidth,
|
||||
)
|
||||
return rgb
|
||||
|
||||
|
||||
def add_weighted(rgb, alpha, original, beta, gamma):
|
||||
return (rgb * alpha + original * beta + gamma).astype("uint8")
|
||||
|
||||
|
||||
class Visualizer:
|
||||
def __init__(
|
||||
self,
|
||||
save_dir: str = "./results",
|
||||
grayscale: bool = False,
|
||||
pad_value: int = 0,
|
||||
fps: int = 10,
|
||||
mode: str = "rainbow", # 'cool', 'optical_flow'
|
||||
linewidth: int = 2,
|
||||
show_first_frame: int = 10,
|
||||
tracks_leave_trace: int = 0, # -1 for infinite
|
||||
):
|
||||
self.mode = mode
|
||||
self.save_dir = save_dir
|
||||
if mode == "rainbow":
|
||||
self.color_map = cm.get_cmap("gist_rainbow")
|
||||
elif mode == "cool":
|
||||
self.color_map = cm.get_cmap(mode)
|
||||
self.show_first_frame = show_first_frame
|
||||
self.grayscale = grayscale
|
||||
self.tracks_leave_trace = tracks_leave_trace
|
||||
self.pad_value = pad_value
|
||||
self.linewidth = linewidth
|
||||
self.fps = fps
|
||||
|
||||
def visualize(
|
||||
self,
|
||||
video: torch.Tensor, # (B,T,C,H,W)
|
||||
tracks: torch.Tensor, # (B,T,N,2)
|
||||
visibility: torch.Tensor = None, # (B, T, N, 1) bool
|
||||
gt_tracks: torch.Tensor = None, # (B,T,N,2)
|
||||
segm_mask: torch.Tensor = None, # (B,1,H,W)
|
||||
filename: str = "video",
|
||||
writer=None, # tensorboard Summary Writer, used for visualization during training
|
||||
step: int = 0,
|
||||
query_frame: int = 0,
|
||||
save_video: bool = True,
|
||||
compensate_for_camera_motion: bool = False,
|
||||
):
|
||||
if compensate_for_camera_motion:
|
||||
assert segm_mask is not None
|
||||
if segm_mask is not None:
|
||||
coords = tracks[0, query_frame].round().long()
|
||||
segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long()
|
||||
|
||||
video = F.pad(
|
||||
video,
|
||||
(self.pad_value, self.pad_value, self.pad_value, self.pad_value),
|
||||
"constant",
|
||||
255,
|
||||
)
|
||||
tracks = tracks + self.pad_value
|
||||
|
||||
if self.grayscale:
|
||||
transform = transforms.Grayscale()
|
||||
video = transform(video)
|
||||
video = video.repeat(1, 1, 3, 1, 1)
|
||||
|
||||
res_video = self.draw_tracks_on_video(
|
||||
video=video,
|
||||
tracks=tracks,
|
||||
visibility=visibility,
|
||||
segm_mask=segm_mask,
|
||||
gt_tracks=gt_tracks,
|
||||
query_frame=query_frame,
|
||||
compensate_for_camera_motion=compensate_for_camera_motion,
|
||||
)
|
||||
if save_video:
|
||||
self.save_video(res_video, filename=filename, writer=writer, step=step)
|
||||
return res_video
|
||||
|
||||
def save_video(self, video, filename, writer=None, step=0):
|
||||
if writer is not None:
|
||||
writer.add_video(
|
||||
filename,
|
||||
video.to(torch.uint8),
|
||||
global_step=step,
|
||||
fps=self.fps,
|
||||
)
|
||||
else:
|
||||
os.makedirs(self.save_dir, exist_ok=True)
|
||||
wide_list = list(video.unbind(1))
|
||||
wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list]
|
||||
|
||||
# Prepare the video file path
|
||||
save_path = os.path.join(self.save_dir, f"{filename}.mp4")
|
||||
|
||||
# Create a writer object
|
||||
video_writer = imageio.get_writer(save_path, fps=self.fps)
|
||||
|
||||
# Write frames to the video file
|
||||
for frame in wide_list[2:-1]:
|
||||
video_writer.append_data(frame)
|
||||
|
||||
video_writer.close()
|
||||
|
||||
print(f"Video saved to {save_path}")
|
||||
|
||||
def draw_tracks_on_video(
|
||||
self,
|
||||
video: torch.Tensor,
|
||||
tracks: torch.Tensor,
|
||||
visibility: torch.Tensor = None,
|
||||
segm_mask: torch.Tensor = None,
|
||||
gt_tracks=None,
|
||||
query_frame: int = 0,
|
||||
compensate_for_camera_motion=False,
|
||||
):
|
||||
B, T, C, H, W = video.shape
|
||||
_, _, N, D = tracks.shape
|
||||
|
||||
assert D == 2
|
||||
assert C == 3
|
||||
video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy() # S, H, W, C
|
||||
tracks = tracks[0].long().detach().cpu().numpy() # S, N, 2
|
||||
if gt_tracks is not None:
|
||||
gt_tracks = gt_tracks[0].detach().cpu().numpy()
|
||||
|
||||
res_video = []
|
||||
|
||||
# process input video
|
||||
for rgb in video:
|
||||
res_video.append(rgb.copy())
|
||||
vector_colors = np.zeros((T, N, 3))
|
||||
|
||||
if self.mode == "optical_flow":
|
||||
import flow_vis
|
||||
|
||||
vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None])
|
||||
elif segm_mask is None:
|
||||
if self.mode == "rainbow":
|
||||
y_min, y_max = (
|
||||
tracks[query_frame, :, 1].min(),
|
||||
tracks[query_frame, :, 1].max(),
|
||||
)
|
||||
norm = plt.Normalize(y_min, y_max)
|
||||
for n in range(N):
|
||||
color = self.color_map(norm(tracks[query_frame, n, 1]))
|
||||
color = np.array(color[:3])[None] * 255
|
||||
vector_colors[:, n] = np.repeat(color, T, axis=0)
|
||||
else:
|
||||
# color changes with time
|
||||
for t in range(T):
|
||||
color = np.array(self.color_map(t / T)[:3])[None] * 255
|
||||
vector_colors[t] = np.repeat(color, N, axis=0)
|
||||
else:
|
||||
if self.mode == "rainbow":
|
||||
vector_colors[:, segm_mask <= 0, :] = 255
|
||||
|
||||
y_min, y_max = (
|
||||
tracks[0, segm_mask > 0, 1].min(),
|
||||
tracks[0, segm_mask > 0, 1].max(),
|
||||
)
|
||||
norm = plt.Normalize(y_min, y_max)
|
||||
for n in range(N):
|
||||
if segm_mask[n] > 0:
|
||||
color = self.color_map(norm(tracks[0, n, 1]))
|
||||
color = np.array(color[:3])[None] * 255
|
||||
vector_colors[:, n] = np.repeat(color, T, axis=0)
|
||||
|
||||
else:
|
||||
# color changes with segm class
|
||||
segm_mask = segm_mask.cpu()
|
||||
color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32)
|
||||
color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0
|
||||
color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0
|
||||
vector_colors = np.repeat(color[None], T, axis=0)
|
||||
|
||||
# draw tracks
|
||||
if self.tracks_leave_trace != 0:
|
||||
for t in range(query_frame + 1, T):
|
||||
first_ind = (
|
||||
max(0, t - self.tracks_leave_trace) if self.tracks_leave_trace >= 0 else 0
|
||||
)
|
||||
curr_tracks = tracks[first_ind : t + 1]
|
||||
curr_colors = vector_colors[first_ind : t + 1]
|
||||
if compensate_for_camera_motion:
|
||||
diff = (
|
||||
tracks[first_ind : t + 1, segm_mask <= 0]
|
||||
- tracks[t : t + 1, segm_mask <= 0]
|
||||
).mean(1)[:, None]
|
||||
|
||||
curr_tracks = curr_tracks - diff
|
||||
curr_tracks = curr_tracks[:, segm_mask > 0]
|
||||
curr_colors = curr_colors[:, segm_mask > 0]
|
||||
|
||||
res_video[t] = self._draw_pred_tracks(
|
||||
res_video[t],
|
||||
curr_tracks,
|
||||
curr_colors,
|
||||
)
|
||||
if gt_tracks is not None:
|
||||
res_video[t] = self._draw_gt_tracks(res_video[t], gt_tracks[first_ind : t + 1])
|
||||
|
||||
# draw points
|
||||
for t in range(query_frame, T):
|
||||
img = Image.fromarray(np.uint8(res_video[t]))
|
||||
for i in range(N):
|
||||
coord = (tracks[t, i, 0], tracks[t, i, 1])
|
||||
visibile = True
|
||||
if visibility is not None:
|
||||
visibile = visibility[0, t, i]
|
||||
if coord[0] != 0 and coord[1] != 0:
|
||||
if not compensate_for_camera_motion or (
|
||||
compensate_for_camera_motion and segm_mask[i] > 0
|
||||
):
|
||||
img = draw_circle(
|
||||
img,
|
||||
coord=coord,
|
||||
radius=int(self.linewidth * 2),
|
||||
color=vector_colors[t, i].astype(int),
|
||||
visible=visibile,
|
||||
)
|
||||
res_video[t] = np.array(img)
|
||||
|
||||
# construct the final rgb sequence
|
||||
if self.show_first_frame > 0:
|
||||
res_video = [res_video[0]] * self.show_first_frame + res_video[1:]
|
||||
return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte()
|
||||
|
||||
def _draw_pred_tracks(
|
||||
self,
|
||||
rgb: np.ndarray, # H x W x 3
|
||||
tracks: np.ndarray, # T x 2
|
||||
vector_colors: np.ndarray,
|
||||
alpha: float = 0.5,
|
||||
):
|
||||
T, N, _ = tracks.shape
|
||||
rgb = Image.fromarray(np.uint8(rgb))
|
||||
for s in range(T - 1):
|
||||
vector_color = vector_colors[s]
|
||||
original = rgb.copy()
|
||||
alpha = (s / T) ** 2
|
||||
for i in range(N):
|
||||
coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1]))
|
||||
coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1]))
|
||||
if coord_y[0] != 0 and coord_y[1] != 0:
|
||||
rgb = draw_line(
|
||||
rgb,
|
||||
coord_y,
|
||||
coord_x,
|
||||
vector_color[i].astype(int),
|
||||
self.linewidth,
|
||||
)
|
||||
if self.tracks_leave_trace > 0:
|
||||
rgb = Image.fromarray(
|
||||
np.uint8(add_weighted(np.array(rgb), alpha, np.array(original), 1 - alpha, 0))
|
||||
)
|
||||
rgb = np.array(rgb)
|
||||
return rgb
|
||||
|
||||
def _draw_gt_tracks(
|
||||
self,
|
||||
rgb: np.ndarray, # H x W x 3,
|
||||
gt_tracks: np.ndarray, # T x 2
|
||||
):
|
||||
T, N, _ = gt_tracks.shape
|
||||
color = np.array((211, 0, 0))
|
||||
rgb = Image.fromarray(np.uint8(rgb))
|
||||
for t in range(T):
|
||||
for i in range(N):
|
||||
gt_tracks = gt_tracks[t][i]
|
||||
# draw a red cross
|
||||
if gt_tracks[0] > 0 and gt_tracks[1] > 0:
|
||||
length = self.linewidth * 3
|
||||
coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length)
|
||||
coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length)
|
||||
rgb = draw_line(
|
||||
rgb,
|
||||
coord_y,
|
||||
coord_x,
|
||||
color,
|
||||
self.linewidth,
|
||||
)
|
||||
coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length)
|
||||
coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length)
|
||||
rgb = draw_line(
|
||||
rgb,
|
||||
coord_y,
|
||||
coord_x,
|
||||
color,
|
||||
self.linewidth,
|
||||
)
|
||||
rgb = np.array(rgb)
|
||||
return rgb
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import os
|
||||
import numpy as np
|
||||
import imageio
|
||||
import torch
|
||||
|
||||
from matplotlib import cm
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms as transforms
|
||||
import matplotlib.pyplot as plt
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
|
||||
def read_video_from_path(path):
|
||||
try:
|
||||
reader = imageio.get_reader(path)
|
||||
except Exception as e:
|
||||
print("Error opening video file: ", e)
|
||||
return None
|
||||
frames = []
|
||||
for i, im in enumerate(reader):
|
||||
frames.append(np.array(im))
|
||||
return np.stack(frames)
|
||||
|
||||
|
||||
def draw_circle(rgb, coord, radius, color=(255, 0, 0), visible=True):
|
||||
# Create a draw object
|
||||
draw = ImageDraw.Draw(rgb)
|
||||
# Calculate the bounding box of the circle
|
||||
left_up_point = (coord[0] - radius, coord[1] - radius)
|
||||
right_down_point = (coord[0] + radius, coord[1] + radius)
|
||||
# Draw the circle
|
||||
draw.ellipse(
|
||||
[left_up_point, right_down_point],
|
||||
fill=tuple(color) if visible else None,
|
||||
outline=tuple(color),
|
||||
)
|
||||
return rgb
|
||||
|
||||
|
||||
def draw_line(rgb, coord_y, coord_x, color, linewidth):
|
||||
draw = ImageDraw.Draw(rgb)
|
||||
draw.line(
|
||||
(coord_y[0], coord_y[1], coord_x[0], coord_x[1]),
|
||||
fill=tuple(color),
|
||||
width=linewidth,
|
||||
)
|
||||
return rgb
|
||||
|
||||
|
||||
def add_weighted(rgb, alpha, original, beta, gamma):
|
||||
return (rgb * alpha + original * beta + gamma).astype("uint8")
|
||||
|
||||
|
||||
class Visualizer:
|
||||
def __init__(
|
||||
self,
|
||||
save_dir: str = "./results",
|
||||
grayscale: bool = False,
|
||||
pad_value: int = 0,
|
||||
fps: int = 10,
|
||||
mode: str = "rainbow", # 'cool', 'optical_flow'
|
||||
linewidth: int = 2,
|
||||
show_first_frame: int = 10,
|
||||
tracks_leave_trace: int = 0, # -1 for infinite
|
||||
):
|
||||
self.mode = mode
|
||||
self.save_dir = save_dir
|
||||
if mode == "rainbow":
|
||||
self.color_map = cm.get_cmap("gist_rainbow")
|
||||
elif mode == "cool":
|
||||
self.color_map = cm.get_cmap(mode)
|
||||
self.show_first_frame = show_first_frame
|
||||
self.grayscale = grayscale
|
||||
self.tracks_leave_trace = tracks_leave_trace
|
||||
self.pad_value = pad_value
|
||||
self.linewidth = linewidth
|
||||
self.fps = fps
|
||||
|
||||
def visualize(
|
||||
self,
|
||||
video: torch.Tensor, # (B,T,C,H,W)
|
||||
tracks: torch.Tensor, # (B,T,N,2)
|
||||
visibility: torch.Tensor = None, # (B, T, N, 1) bool
|
||||
gt_tracks: torch.Tensor = None, # (B,T,N,2)
|
||||
segm_mask: torch.Tensor = None, # (B,1,H,W)
|
||||
filename: str = "video",
|
||||
writer=None, # tensorboard Summary Writer, used for visualization during training
|
||||
step: int = 0,
|
||||
query_frame: int = 0,
|
||||
save_video: bool = True,
|
||||
compensate_for_camera_motion: bool = False,
|
||||
):
|
||||
if compensate_for_camera_motion:
|
||||
assert segm_mask is not None
|
||||
if segm_mask is not None:
|
||||
coords = tracks[0, query_frame].round().long()
|
||||
segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long()
|
||||
|
||||
video = F.pad(
|
||||
video,
|
||||
(self.pad_value, self.pad_value, self.pad_value, self.pad_value),
|
||||
"constant",
|
||||
255,
|
||||
)
|
||||
tracks = tracks + self.pad_value
|
||||
|
||||
if self.grayscale:
|
||||
transform = transforms.Grayscale()
|
||||
video = transform(video)
|
||||
video = video.repeat(1, 1, 3, 1, 1)
|
||||
|
||||
res_video = self.draw_tracks_on_video(
|
||||
video=video,
|
||||
tracks=tracks,
|
||||
visibility=visibility,
|
||||
segm_mask=segm_mask,
|
||||
gt_tracks=gt_tracks,
|
||||
query_frame=query_frame,
|
||||
compensate_for_camera_motion=compensate_for_camera_motion,
|
||||
)
|
||||
if save_video:
|
||||
self.save_video(res_video, filename=filename, writer=writer, step=step)
|
||||
return res_video
|
||||
|
||||
def save_video(self, video, filename, writer=None, step=0):
|
||||
if writer is not None:
|
||||
writer.add_video(
|
||||
filename,
|
||||
video.to(torch.uint8),
|
||||
global_step=step,
|
||||
fps=self.fps,
|
||||
)
|
||||
else:
|
||||
os.makedirs(self.save_dir, exist_ok=True)
|
||||
wide_list = list(video.unbind(1))
|
||||
wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list]
|
||||
|
||||
# Prepare the video file path
|
||||
save_path = os.path.join(self.save_dir, f"{filename}.mp4")
|
||||
|
||||
# Create a writer object
|
||||
video_writer = imageio.get_writer(save_path, fps=self.fps)
|
||||
|
||||
# Write frames to the video file
|
||||
for frame in wide_list[2:-1]:
|
||||
video_writer.append_data(frame)
|
||||
|
||||
video_writer.close()
|
||||
|
||||
print(f"Video saved to {save_path}")
|
||||
|
||||
def draw_tracks_on_video(
|
||||
self,
|
||||
video: torch.Tensor,
|
||||
tracks: torch.Tensor,
|
||||
visibility: torch.Tensor = None,
|
||||
segm_mask: torch.Tensor = None,
|
||||
gt_tracks=None,
|
||||
query_frame: int = 0,
|
||||
compensate_for_camera_motion=False,
|
||||
):
|
||||
B, T, C, H, W = video.shape
|
||||
_, _, N, D = tracks.shape
|
||||
|
||||
assert D == 2
|
||||
assert C == 3
|
||||
video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy() # S, H, W, C
|
||||
tracks = tracks[0].long().detach().cpu().numpy() # S, N, 2
|
||||
if gt_tracks is not None:
|
||||
gt_tracks = gt_tracks[0].detach().cpu().numpy()
|
||||
|
||||
res_video = []
|
||||
|
||||
# process input video
|
||||
for rgb in video:
|
||||
res_video.append(rgb.copy())
|
||||
vector_colors = np.zeros((T, N, 3))
|
||||
|
||||
if self.mode == "optical_flow":
|
||||
import flow_vis
|
||||
|
||||
vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None])
|
||||
elif segm_mask is None:
|
||||
if self.mode == "rainbow":
|
||||
y_min, y_max = (
|
||||
tracks[query_frame, :, 1].min(),
|
||||
tracks[query_frame, :, 1].max(),
|
||||
)
|
||||
norm = plt.Normalize(y_min, y_max)
|
||||
for n in range(N):
|
||||
color = self.color_map(norm(tracks[query_frame, n, 1]))
|
||||
color = np.array(color[:3])[None] * 255
|
||||
vector_colors[:, n] = np.repeat(color, T, axis=0)
|
||||
else:
|
||||
# color changes with time
|
||||
for t in range(T):
|
||||
color = np.array(self.color_map(t / T)[:3])[None] * 255
|
||||
vector_colors[t] = np.repeat(color, N, axis=0)
|
||||
else:
|
||||
if self.mode == "rainbow":
|
||||
vector_colors[:, segm_mask <= 0, :] = 255
|
||||
|
||||
y_min, y_max = (
|
||||
tracks[0, segm_mask > 0, 1].min(),
|
||||
tracks[0, segm_mask > 0, 1].max(),
|
||||
)
|
||||
norm = plt.Normalize(y_min, y_max)
|
||||
for n in range(N):
|
||||
if segm_mask[n] > 0:
|
||||
color = self.color_map(norm(tracks[0, n, 1]))
|
||||
color = np.array(color[:3])[None] * 255
|
||||
vector_colors[:, n] = np.repeat(color, T, axis=0)
|
||||
|
||||
else:
|
||||
# color changes with segm class
|
||||
segm_mask = segm_mask.cpu()
|
||||
color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32)
|
||||
color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0
|
||||
color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0
|
||||
vector_colors = np.repeat(color[None], T, axis=0)
|
||||
|
||||
# draw tracks
|
||||
if self.tracks_leave_trace != 0:
|
||||
for t in range(query_frame + 1, T):
|
||||
first_ind = (
|
||||
max(0, t - self.tracks_leave_trace) if self.tracks_leave_trace >= 0 else 0
|
||||
)
|
||||
curr_tracks = tracks[first_ind : t + 1]
|
||||
curr_colors = vector_colors[first_ind : t + 1]
|
||||
if compensate_for_camera_motion:
|
||||
diff = (
|
||||
tracks[first_ind : t + 1, segm_mask <= 0]
|
||||
- tracks[t : t + 1, segm_mask <= 0]
|
||||
).mean(1)[:, None]
|
||||
|
||||
curr_tracks = curr_tracks - diff
|
||||
curr_tracks = curr_tracks[:, segm_mask > 0]
|
||||
curr_colors = curr_colors[:, segm_mask > 0]
|
||||
|
||||
res_video[t] = self._draw_pred_tracks(
|
||||
res_video[t],
|
||||
curr_tracks,
|
||||
curr_colors,
|
||||
)
|
||||
if gt_tracks is not None:
|
||||
res_video[t] = self._draw_gt_tracks(res_video[t], gt_tracks[first_ind : t + 1])
|
||||
|
||||
# draw points
|
||||
for t in range(query_frame, T):
|
||||
img = Image.fromarray(np.uint8(res_video[t]))
|
||||
for i in range(N):
|
||||
coord = (tracks[t, i, 0], tracks[t, i, 1])
|
||||
visibile = True
|
||||
if visibility is not None:
|
||||
visibile = visibility[0, t, i]
|
||||
if coord[0] != 0 and coord[1] != 0:
|
||||
if not compensate_for_camera_motion or (
|
||||
compensate_for_camera_motion and segm_mask[i] > 0
|
||||
):
|
||||
img = draw_circle(
|
||||
img,
|
||||
coord=coord,
|
||||
radius=int(self.linewidth * 2),
|
||||
color=vector_colors[t, i].astype(int),
|
||||
visible=visibile,
|
||||
)
|
||||
res_video[t] = np.array(img)
|
||||
|
||||
# construct the final rgb sequence
|
||||
if self.show_first_frame > 0:
|
||||
res_video = [res_video[0]] * self.show_first_frame + res_video[1:]
|
||||
return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte()
|
||||
|
||||
def _draw_pred_tracks(
|
||||
self,
|
||||
rgb: np.ndarray, # H x W x 3
|
||||
tracks: np.ndarray, # T x 2
|
||||
vector_colors: np.ndarray,
|
||||
alpha: float = 0.5,
|
||||
):
|
||||
T, N, _ = tracks.shape
|
||||
rgb = Image.fromarray(np.uint8(rgb))
|
||||
for s in range(T - 1):
|
||||
vector_color = vector_colors[s]
|
||||
original = rgb.copy()
|
||||
alpha = (s / T) ** 2
|
||||
for i in range(N):
|
||||
coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1]))
|
||||
coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1]))
|
||||
if coord_y[0] != 0 and coord_y[1] != 0:
|
||||
rgb = draw_line(
|
||||
rgb,
|
||||
coord_y,
|
||||
coord_x,
|
||||
vector_color[i].astype(int),
|
||||
self.linewidth,
|
||||
)
|
||||
if self.tracks_leave_trace > 0:
|
||||
rgb = Image.fromarray(
|
||||
np.uint8(add_weighted(np.array(rgb), alpha, np.array(original), 1 - alpha, 0))
|
||||
)
|
||||
rgb = np.array(rgb)
|
||||
return rgb
|
||||
|
||||
def _draw_gt_tracks(
|
||||
self,
|
||||
rgb: np.ndarray, # H x W x 3,
|
||||
gt_tracks: np.ndarray, # T x 2
|
||||
):
|
||||
T, N, _ = gt_tracks.shape
|
||||
color = np.array((211, 0, 0))
|
||||
rgb = Image.fromarray(np.uint8(rgb))
|
||||
for t in range(T):
|
||||
for i in range(N):
|
||||
gt_tracks = gt_tracks[t][i]
|
||||
# draw a red cross
|
||||
if gt_tracks[0] > 0 and gt_tracks[1] > 0:
|
||||
length = self.linewidth * 3
|
||||
coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length)
|
||||
coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length)
|
||||
rgb = draw_line(
|
||||
rgb,
|
||||
coord_y,
|
||||
coord_x,
|
||||
color,
|
||||
self.linewidth,
|
||||
)
|
||||
coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length)
|
||||
coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length)
|
||||
rgb = draw_line(
|
||||
rgb,
|
||||
coord_y,
|
||||
coord_x,
|
||||
color,
|
||||
self.linewidth,
|
||||
)
|
||||
rgb = np.array(rgb)
|
||||
return rgb
|
||||
|
@@ -1,8 +1,8 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
__version__ = "2.0.0"
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
__version__ = "2.0.0"
|
||||
|
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user