Compare commits
19 Commits
9c9a97d158
...
main
Author | SHA1 | Date | |
---|---|---|---|
be0891967b | |||
40e628ac73 | |||
f208a962b9 | |||
15cdb3027c | |||
6e7bcd2d26 | |||
|
36d1566750 | ||
|
9ed8669a50 | ||
|
eeda4d3c98 | ||
|
9ed05317b7 | ||
|
19767a9d65 | ||
|
e29e938311 | ||
|
0f9d32869a | ||
|
9460eefecc | ||
|
9921cf0895 | ||
|
941c24fd40 | ||
|
fac27989b3 | ||
|
f084a93f28 | ||
|
3716e36249 | ||
|
721fcc237b |
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
__pycache__/
|
||||
.vscode/
|
||||
cotracker/__pycache__/
|
21
README.md
21
README.md
@@ -4,7 +4,7 @@
|
||||
|
||||
[Nikita Karaev](https://nikitakaraevv.github.io/), [Ignacio Rocco](https://www.irocco.info/), [Benjamin Graham](https://ai.facebook.com/people/benjamin-graham/), [Natalia Neverova](https://nneverova.github.io/), [Andrea Vedaldi](https://www.robots.ox.ac.uk/~vedaldi/), [Christian Rupprecht](https://chrirupp.github.io/)
|
||||
|
||||
[[`Paper`](https://arxiv.org/abs/2307.07635)] [[`Project`](https://co-tracker.github.io/)] [[`BibTeX`](#citing-cotracker)]
|
||||
### [Project Page](https://co-tracker.github.io/) | [Paper](https://arxiv.org/abs/2307.07635) | [X Thread](https://twitter.com/n_karaev/status/1742638906355470772) | [BibTeX](#citing-cotracker)
|
||||
|
||||
<a target="_blank" href="https://colab.research.google.com/github/facebookresearch/co-tracker/blob/main/notebooks/demo.ipynb">
|
||||
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
||||
@@ -26,6 +26,7 @@ CoTracker can track:
|
||||
Try these tracking modes for yourself with our [Colab demo](https://colab.research.google.com/github/facebookresearch/co-tracker/blob/master/notebooks/demo.ipynb) or in the [Hugging Face Space 🤗](https://huggingface.co/spaces/facebook/cotracker).
|
||||
|
||||
**Updates:**
|
||||
- [June 14, 2024] 📣 We have released the code for [VGGSfM](https://github.com/facebookresearch/vggsfm), a model for recovering camera poses and 3D structure from any image sequences based on point tracking! VGGSfM is the first fully differentiable SfM framework that unlocks scalability and outperforms conventional SfM methods on standard benchmarks.
|
||||
|
||||
- [December 27, 2023] 📣 CoTracker2 is now available! It can now track many more (up to **265*265**!) points jointly and it has a cleaner and more memory-efficient implementation. It also supports online processing. See the [updated paper](https://arxiv.org/abs/2307.07635) for more details. The old version remains available [here](https://github.com/facebookresearch/co-tracker/tree/8d364031971f6b3efec945dd15c468a183e58212).
|
||||
|
||||
@@ -39,7 +40,7 @@ The easiest way to use CoTracker is to load a pretrained model from `torch.hub`:
|
||||
```python
|
||||
import torch
|
||||
# Download the video
|
||||
url = 'https://github.com/facebookresearch/co-tracker/blob/main/assets/apple.mp4'
|
||||
url = 'https://github.com/facebookresearch/co-tracker/raw/main/assets/apple.mp4'
|
||||
|
||||
import imageio.v3 as iio
|
||||
frames = iio.imread(url, plugin="FFMPEG") # plugin="pyav"
|
||||
@@ -119,7 +120,7 @@ We strongly recommend installing both PyTorch and TorchVision with CUDA support,
|
||||
git clone https://github.com/facebookresearch/co-tracker
|
||||
cd co-tracker
|
||||
pip install -e .
|
||||
pip install matplotlib flow_vis tqdm tensorboard
|
||||
pip install matplotlib flow_vis tqdm tensorboard imageio[ffmpeg]
|
||||
```
|
||||
|
||||
You can manually download the CoTracker2 checkpoint from the links below and place it in the `checkpoints` folder as follows:
|
||||
@@ -132,6 +133,11 @@ cd ..
|
||||
```
|
||||
For old checkpoints, see [this section](#previous-version).
|
||||
|
||||
After installation, this is how you could run the model on `./assets/apple.mp4` (results will be saved to `./saved_videos/apple.mp4`):
|
||||
```bash
|
||||
python demo.py --checkpoint checkpoints/cotracker2.pth
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
|
||||
To reproduce the results presented in the paper, download the following datasets:
|
||||
@@ -203,6 +209,15 @@ make -C docs html
|
||||
|
||||
|
||||
## Previous version
|
||||
You can use CoTracker v1 directly via pytorch hub:
|
||||
```python
|
||||
import torch
|
||||
import einops
|
||||
import timm
|
||||
import tqdm
|
||||
|
||||
cotracker = torch.hub.load("facebookresearch/co-tracker:v1.0", "cotracker_w8")
|
||||
```
|
||||
The old version of the code is available [here](https://github.com/facebookresearch/co-tracker/tree/8d364031971f6b3efec945dd15c468a183e58212).
|
||||
You can also download the corresponding checkpoints:
|
||||
```bash
|
||||
|
BIN
assets/F1_shorts.mp4
Normal file
BIN
assets/F1_shorts.mp4
Normal file
Binary file not shown.
@@ -1,5 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
BIN
cotracker/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
cotracker/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
cotracker/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
cotracker/__pycache__/predictor.cpython-38.pyc
Normal file
BIN
cotracker/__pycache__/predictor.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/__pycache__/predictor.cpython-39.pyc
Normal file
BIN
cotracker/__pycache__/predictor.cpython-39.pyc
Normal file
Binary file not shown.
@@ -1,5 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
@@ -1,166 +1,166 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import json
|
||||
import dataclasses
|
||||
import numpy as np
|
||||
from dataclasses import Field, MISSING
|
||||
from typing import IO, TypeVar, Type, get_args, get_origin, Union, Any, Tuple
|
||||
|
||||
_X = TypeVar("_X")
|
||||
|
||||
|
||||
def load_dataclass(f: IO, cls: Type[_X], binary: bool = False) -> _X:
|
||||
"""
|
||||
Loads to a @dataclass or collection hierarchy including dataclasses
|
||||
from a json recursively.
|
||||
Call it like load_dataclass(f, typing.List[FrameAnnotationAnnotation]).
|
||||
raises KeyError if json has keys not mapping to the dataclass fields.
|
||||
|
||||
Args:
|
||||
f: Either a path to a file, or a file opened for writing.
|
||||
cls: The class of the loaded dataclass.
|
||||
binary: Set to True if `f` is a file handle, else False.
|
||||
"""
|
||||
if binary:
|
||||
asdict = json.loads(f.read().decode("utf8"))
|
||||
else:
|
||||
asdict = json.load(f)
|
||||
|
||||
# in the list case, run a faster "vectorized" version
|
||||
cls = get_args(cls)[0]
|
||||
res = list(_dataclass_list_from_dict_list(asdict, cls))
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def _resolve_optional(type_: Any) -> Tuple[bool, Any]:
|
||||
"""Check whether `type_` is equivalent to `typing.Optional[T]` for some T."""
|
||||
if get_origin(type_) is Union:
|
||||
args = get_args(type_)
|
||||
if len(args) == 2 and args[1] == type(None): # noqa E721
|
||||
return True, args[0]
|
||||
if type_ is Any:
|
||||
return True, Any
|
||||
|
||||
return False, type_
|
||||
|
||||
|
||||
def _unwrap_type(tp):
|
||||
# strips Optional wrapper, if any
|
||||
if get_origin(tp) is Union:
|
||||
args = get_args(tp)
|
||||
if len(args) == 2 and any(a is type(None) for a in args): # noqa: E721
|
||||
# this is typing.Optional
|
||||
return args[0] if args[1] is type(None) else args[1] # noqa: E721
|
||||
return tp
|
||||
|
||||
|
||||
def _get_dataclass_field_default(field: Field) -> Any:
|
||||
if field.default_factory is not MISSING:
|
||||
# pyre-fixme[29]: `Union[dataclasses._MISSING_TYPE,
|
||||
# dataclasses._DefaultFactory[typing.Any]]` is not a function.
|
||||
return field.default_factory()
|
||||
elif field.default is not MISSING:
|
||||
return field.default
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def _dataclass_list_from_dict_list(dlist, typeannot):
|
||||
"""
|
||||
Vectorised version of `_dataclass_from_dict`.
|
||||
The output should be equivalent to
|
||||
`[_dataclass_from_dict(d, typeannot) for d in dlist]`.
|
||||
|
||||
Args:
|
||||
dlist: list of objects to convert.
|
||||
typeannot: type of each of those objects.
|
||||
Returns:
|
||||
iterator or list over converted objects of the same length as `dlist`.
|
||||
|
||||
Raises:
|
||||
ValueError: it assumes the objects have None's in consistent places across
|
||||
objects, otherwise it would ignore some values. This generally holds for
|
||||
auto-generated annotations, but otherwise use `_dataclass_from_dict`.
|
||||
"""
|
||||
|
||||
cls = get_origin(typeannot) or typeannot
|
||||
|
||||
if typeannot is Any:
|
||||
return dlist
|
||||
if all(obj is None for obj in dlist): # 1st recursion base: all None nodes
|
||||
return dlist
|
||||
if any(obj is None for obj in dlist):
|
||||
# filter out Nones and recurse on the resulting list
|
||||
idx_notnone = [(i, obj) for i, obj in enumerate(dlist) if obj is not None]
|
||||
idx, notnone = zip(*idx_notnone)
|
||||
converted = _dataclass_list_from_dict_list(notnone, typeannot)
|
||||
res = [None] * len(dlist)
|
||||
for i, obj in zip(idx, converted):
|
||||
res[i] = obj
|
||||
return res
|
||||
|
||||
is_optional, contained_type = _resolve_optional(typeannot)
|
||||
if is_optional:
|
||||
return _dataclass_list_from_dict_list(dlist, contained_type)
|
||||
|
||||
# otherwise, we dispatch by the type of the provided annotation to convert to
|
||||
if issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
|
||||
# For namedtuple, call the function recursively on the lists of corresponding keys
|
||||
types = cls.__annotations__.values()
|
||||
dlist_T = zip(*dlist)
|
||||
res_T = [
|
||||
_dataclass_list_from_dict_list(key_list, tp) for key_list, tp in zip(dlist_T, types)
|
||||
]
|
||||
return [cls(*converted_as_tuple) for converted_as_tuple in zip(*res_T)]
|
||||
elif issubclass(cls, (list, tuple)):
|
||||
# For list/tuple, call the function recursively on the lists of corresponding positions
|
||||
types = get_args(typeannot)
|
||||
if len(types) == 1: # probably List; replicate for all items
|
||||
types = types * len(dlist[0])
|
||||
dlist_T = zip(*dlist)
|
||||
res_T = (
|
||||
_dataclass_list_from_dict_list(pos_list, tp) for pos_list, tp in zip(dlist_T, types)
|
||||
)
|
||||
if issubclass(cls, tuple):
|
||||
return list(zip(*res_T))
|
||||
else:
|
||||
return [cls(converted_as_tuple) for converted_as_tuple in zip(*res_T)]
|
||||
elif issubclass(cls, dict):
|
||||
# For the dictionary, call the function recursively on concatenated keys and vertices
|
||||
key_t, val_t = get_args(typeannot)
|
||||
all_keys_res = _dataclass_list_from_dict_list(
|
||||
[k for obj in dlist for k in obj.keys()], key_t
|
||||
)
|
||||
all_vals_res = _dataclass_list_from_dict_list(
|
||||
[k for obj in dlist for k in obj.values()], val_t
|
||||
)
|
||||
indices = np.cumsum([len(obj) for obj in dlist])
|
||||
assert indices[-1] == len(all_keys_res)
|
||||
|
||||
keys = np.split(list(all_keys_res), indices[:-1])
|
||||
all_vals_res_iter = iter(all_vals_res)
|
||||
return [cls(zip(k, all_vals_res_iter)) for k in keys]
|
||||
elif not dataclasses.is_dataclass(typeannot):
|
||||
return dlist
|
||||
|
||||
# dataclass node: 2nd recursion base; call the function recursively on the lists
|
||||
# of the corresponding fields
|
||||
assert dataclasses.is_dataclass(cls)
|
||||
fieldtypes = {
|
||||
f.name: (_unwrap_type(f.type), _get_dataclass_field_default(f))
|
||||
for f in dataclasses.fields(typeannot)
|
||||
}
|
||||
|
||||
# NOTE the default object is shared here
|
||||
key_lists = (
|
||||
_dataclass_list_from_dict_list([obj.get(k, default) for obj in dlist], type_)
|
||||
for k, (type_, default) in fieldtypes.items()
|
||||
)
|
||||
transposed = zip(*key_lists)
|
||||
return [cls(*vals_as_tuple) for vals_as_tuple in transposed]
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import json
|
||||
import dataclasses
|
||||
import numpy as np
|
||||
from dataclasses import Field, MISSING
|
||||
from typing import IO, TypeVar, Type, get_args, get_origin, Union, Any, Tuple
|
||||
|
||||
_X = TypeVar("_X")
|
||||
|
||||
|
||||
def load_dataclass(f: IO, cls: Type[_X], binary: bool = False) -> _X:
|
||||
"""
|
||||
Loads to a @dataclass or collection hierarchy including dataclasses
|
||||
from a json recursively.
|
||||
Call it like load_dataclass(f, typing.List[FrameAnnotationAnnotation]).
|
||||
raises KeyError if json has keys not mapping to the dataclass fields.
|
||||
|
||||
Args:
|
||||
f: Either a path to a file, or a file opened for writing.
|
||||
cls: The class of the loaded dataclass.
|
||||
binary: Set to True if `f` is a file handle, else False.
|
||||
"""
|
||||
if binary:
|
||||
asdict = json.loads(f.read().decode("utf8"))
|
||||
else:
|
||||
asdict = json.load(f)
|
||||
|
||||
# in the list case, run a faster "vectorized" version
|
||||
cls = get_args(cls)[0]
|
||||
res = list(_dataclass_list_from_dict_list(asdict, cls))
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def _resolve_optional(type_: Any) -> Tuple[bool, Any]:
|
||||
"""Check whether `type_` is equivalent to `typing.Optional[T]` for some T."""
|
||||
if get_origin(type_) is Union:
|
||||
args = get_args(type_)
|
||||
if len(args) == 2 and args[1] == type(None): # noqa E721
|
||||
return True, args[0]
|
||||
if type_ is Any:
|
||||
return True, Any
|
||||
|
||||
return False, type_
|
||||
|
||||
|
||||
def _unwrap_type(tp):
|
||||
# strips Optional wrapper, if any
|
||||
if get_origin(tp) is Union:
|
||||
args = get_args(tp)
|
||||
if len(args) == 2 and any(a is type(None) for a in args): # noqa: E721
|
||||
# this is typing.Optional
|
||||
return args[0] if args[1] is type(None) else args[1] # noqa: E721
|
||||
return tp
|
||||
|
||||
|
||||
def _get_dataclass_field_default(field: Field) -> Any:
|
||||
if field.default_factory is not MISSING:
|
||||
# pyre-fixme[29]: `Union[dataclasses._MISSING_TYPE,
|
||||
# dataclasses._DefaultFactory[typing.Any]]` is not a function.
|
||||
return field.default_factory()
|
||||
elif field.default is not MISSING:
|
||||
return field.default
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def _dataclass_list_from_dict_list(dlist, typeannot):
|
||||
"""
|
||||
Vectorised version of `_dataclass_from_dict`.
|
||||
The output should be equivalent to
|
||||
`[_dataclass_from_dict(d, typeannot) for d in dlist]`.
|
||||
|
||||
Args:
|
||||
dlist: list of objects to convert.
|
||||
typeannot: type of each of those objects.
|
||||
Returns:
|
||||
iterator or list over converted objects of the same length as `dlist`.
|
||||
|
||||
Raises:
|
||||
ValueError: it assumes the objects have None's in consistent places across
|
||||
objects, otherwise it would ignore some values. This generally holds for
|
||||
auto-generated annotations, but otherwise use `_dataclass_from_dict`.
|
||||
"""
|
||||
|
||||
cls = get_origin(typeannot) or typeannot
|
||||
|
||||
if typeannot is Any:
|
||||
return dlist
|
||||
if all(obj is None for obj in dlist): # 1st recursion base: all None nodes
|
||||
return dlist
|
||||
if any(obj is None for obj in dlist):
|
||||
# filter out Nones and recurse on the resulting list
|
||||
idx_notnone = [(i, obj) for i, obj in enumerate(dlist) if obj is not None]
|
||||
idx, notnone = zip(*idx_notnone)
|
||||
converted = _dataclass_list_from_dict_list(notnone, typeannot)
|
||||
res = [None] * len(dlist)
|
||||
for i, obj in zip(idx, converted):
|
||||
res[i] = obj
|
||||
return res
|
||||
|
||||
is_optional, contained_type = _resolve_optional(typeannot)
|
||||
if is_optional:
|
||||
return _dataclass_list_from_dict_list(dlist, contained_type)
|
||||
|
||||
# otherwise, we dispatch by the type of the provided annotation to convert to
|
||||
if issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
|
||||
# For namedtuple, call the function recursively on the lists of corresponding keys
|
||||
types = cls.__annotations__.values()
|
||||
dlist_T = zip(*dlist)
|
||||
res_T = [
|
||||
_dataclass_list_from_dict_list(key_list, tp) for key_list, tp in zip(dlist_T, types)
|
||||
]
|
||||
return [cls(*converted_as_tuple) for converted_as_tuple in zip(*res_T)]
|
||||
elif issubclass(cls, (list, tuple)):
|
||||
# For list/tuple, call the function recursively on the lists of corresponding positions
|
||||
types = get_args(typeannot)
|
||||
if len(types) == 1: # probably List; replicate for all items
|
||||
types = types * len(dlist[0])
|
||||
dlist_T = zip(*dlist)
|
||||
res_T = (
|
||||
_dataclass_list_from_dict_list(pos_list, tp) for pos_list, tp in zip(dlist_T, types)
|
||||
)
|
||||
if issubclass(cls, tuple):
|
||||
return list(zip(*res_T))
|
||||
else:
|
||||
return [cls(converted_as_tuple) for converted_as_tuple in zip(*res_T)]
|
||||
elif issubclass(cls, dict):
|
||||
# For the dictionary, call the function recursively on concatenated keys and vertices
|
||||
key_t, val_t = get_args(typeannot)
|
||||
all_keys_res = _dataclass_list_from_dict_list(
|
||||
[k for obj in dlist for k in obj.keys()], key_t
|
||||
)
|
||||
all_vals_res = _dataclass_list_from_dict_list(
|
||||
[k for obj in dlist for k in obj.values()], val_t
|
||||
)
|
||||
indices = np.cumsum([len(obj) for obj in dlist])
|
||||
assert indices[-1] == len(all_keys_res)
|
||||
|
||||
keys = np.split(list(all_keys_res), indices[:-1])
|
||||
all_vals_res_iter = iter(all_vals_res)
|
||||
return [cls(zip(k, all_vals_res_iter)) for k in keys]
|
||||
elif not dataclasses.is_dataclass(typeannot):
|
||||
return dlist
|
||||
|
||||
# dataclass node: 2nd recursion base; call the function recursively on the lists
|
||||
# of the corresponding fields
|
||||
assert dataclasses.is_dataclass(cls)
|
||||
fieldtypes = {
|
||||
f.name: (_unwrap_type(f.type), _get_dataclass_field_default(f))
|
||||
for f in dataclasses.fields(typeannot)
|
||||
}
|
||||
|
||||
# NOTE the default object is shared here
|
||||
key_lists = (
|
||||
_dataclass_list_from_dict_list([obj.get(k, default) for obj in dlist], type_)
|
||||
for k, (type_, default) in fieldtypes.items()
|
||||
)
|
||||
transposed = zip(*key_lists)
|
||||
return [cls(*vals_as_tuple) for vals_as_tuple in transposed]
|
||||
|
@@ -1,161 +1,161 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import os
|
||||
import gzip
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.utils.data as data
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Any, Dict, Tuple
|
||||
|
||||
from cotracker.datasets.utils import CoTrackerData
|
||||
from cotracker.datasets.dataclass_utils import load_dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageAnnotation:
|
||||
# path to jpg file, relative w.r.t. dataset_root
|
||||
path: str
|
||||
# H x W
|
||||
size: Tuple[int, int]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DynamicReplicaFrameAnnotation:
|
||||
"""A dataclass used to load annotations from json."""
|
||||
|
||||
# can be used to join with `SequenceAnnotation`
|
||||
sequence_name: str
|
||||
# 0-based, continuous frame number within sequence
|
||||
frame_number: int
|
||||
# timestamp in seconds from the video start
|
||||
frame_timestamp: float
|
||||
|
||||
image: ImageAnnotation
|
||||
meta: Optional[Dict[str, Any]] = None
|
||||
|
||||
camera_name: Optional[str] = None
|
||||
trajectories: Optional[str] = None
|
||||
|
||||
|
||||
class DynamicReplicaDataset(data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
root,
|
||||
split="valid",
|
||||
traj_per_sample=256,
|
||||
crop_size=None,
|
||||
sample_len=-1,
|
||||
only_first_n_samples=-1,
|
||||
rgbd_input=False,
|
||||
):
|
||||
super(DynamicReplicaDataset, self).__init__()
|
||||
self.root = root
|
||||
self.sample_len = sample_len
|
||||
self.split = split
|
||||
self.traj_per_sample = traj_per_sample
|
||||
self.rgbd_input = rgbd_input
|
||||
self.crop_size = crop_size
|
||||
frame_annotations_file = f"frame_annotations_{split}.jgz"
|
||||
self.sample_list = []
|
||||
with gzip.open(
|
||||
os.path.join(root, split, frame_annotations_file), "rt", encoding="utf8"
|
||||
) as zipfile:
|
||||
frame_annots_list = load_dataclass(zipfile, List[DynamicReplicaFrameAnnotation])
|
||||
seq_annot = defaultdict(list)
|
||||
for frame_annot in frame_annots_list:
|
||||
if frame_annot.camera_name == "left":
|
||||
seq_annot[frame_annot.sequence_name].append(frame_annot)
|
||||
|
||||
for seq_name in seq_annot.keys():
|
||||
seq_len = len(seq_annot[seq_name])
|
||||
|
||||
step = self.sample_len if self.sample_len > 0 else seq_len
|
||||
counter = 0
|
||||
|
||||
for ref_idx in range(0, seq_len, step):
|
||||
sample = seq_annot[seq_name][ref_idx : ref_idx + step]
|
||||
self.sample_list.append(sample)
|
||||
counter += 1
|
||||
if only_first_n_samples > 0 and counter >= only_first_n_samples:
|
||||
break
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sample_list)
|
||||
|
||||
def crop(self, rgbs, trajs):
|
||||
T, N, _ = trajs.shape
|
||||
|
||||
S = len(rgbs)
|
||||
H, W = rgbs[0].shape[:2]
|
||||
assert S == T
|
||||
|
||||
H_new = H
|
||||
W_new = W
|
||||
|
||||
# simple random crop
|
||||
y0 = 0 if self.crop_size[0] >= H_new else (H_new - self.crop_size[0]) // 2
|
||||
x0 = 0 if self.crop_size[1] >= W_new else (W_new - self.crop_size[1]) // 2
|
||||
rgbs = [rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] for rgb in rgbs]
|
||||
|
||||
trajs[:, :, 0] -= x0
|
||||
trajs[:, :, 1] -= y0
|
||||
|
||||
return rgbs, trajs
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.sample_list[index]
|
||||
T = len(sample)
|
||||
rgbs, visibilities, traj_2d = [], [], []
|
||||
|
||||
H, W = sample[0].image.size
|
||||
image_size = (H, W)
|
||||
|
||||
for i in range(T):
|
||||
traj_path = os.path.join(self.root, self.split, sample[i].trajectories["path"])
|
||||
traj = torch.load(traj_path)
|
||||
|
||||
visibilities.append(traj["verts_inds_vis"].numpy())
|
||||
|
||||
rgbs.append(traj["img"].numpy())
|
||||
traj_2d.append(traj["traj_2d"].numpy()[..., :2])
|
||||
|
||||
traj_2d = np.stack(traj_2d)
|
||||
visibility = np.stack(visibilities)
|
||||
T, N, D = traj_2d.shape
|
||||
# subsample trajectories for augmentations
|
||||
visible_inds_sampled = torch.randperm(N)[: self.traj_per_sample]
|
||||
|
||||
traj_2d = traj_2d[:, visible_inds_sampled]
|
||||
visibility = visibility[:, visible_inds_sampled]
|
||||
|
||||
if self.crop_size is not None:
|
||||
rgbs, traj_2d = self.crop(rgbs, traj_2d)
|
||||
H, W, _ = rgbs[0].shape
|
||||
image_size = self.crop_size
|
||||
|
||||
visibility[traj_2d[:, :, 0] > image_size[1] - 1] = False
|
||||
visibility[traj_2d[:, :, 0] < 0] = False
|
||||
visibility[traj_2d[:, :, 1] > image_size[0] - 1] = False
|
||||
visibility[traj_2d[:, :, 1] < 0] = False
|
||||
|
||||
# filter out points that're visible for less than 10 frames
|
||||
visible_inds_resampled = visibility.sum(0) > 10
|
||||
traj_2d = torch.from_numpy(traj_2d[:, visible_inds_resampled])
|
||||
visibility = torch.from_numpy(visibility[:, visible_inds_resampled])
|
||||
|
||||
rgbs = np.stack(rgbs, 0)
|
||||
video = torch.from_numpy(rgbs).reshape(T, H, W, 3).permute(0, 3, 1, 2).float()
|
||||
return CoTrackerData(
|
||||
video=video,
|
||||
trajectory=traj_2d,
|
||||
visibility=visibility,
|
||||
valid=torch.ones(T, N),
|
||||
seq_name=sample[0].sequence_name,
|
||||
)
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import os
|
||||
import gzip
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.utils.data as data
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Any, Dict, Tuple
|
||||
|
||||
from cotracker.datasets.utils import CoTrackerData
|
||||
from cotracker.datasets.dataclass_utils import load_dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageAnnotation:
|
||||
# path to jpg file, relative w.r.t. dataset_root
|
||||
path: str
|
||||
# H x W
|
||||
size: Tuple[int, int]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DynamicReplicaFrameAnnotation:
|
||||
"""A dataclass used to load annotations from json."""
|
||||
|
||||
# can be used to join with `SequenceAnnotation`
|
||||
sequence_name: str
|
||||
# 0-based, continuous frame number within sequence
|
||||
frame_number: int
|
||||
# timestamp in seconds from the video start
|
||||
frame_timestamp: float
|
||||
|
||||
image: ImageAnnotation
|
||||
meta: Optional[Dict[str, Any]] = None
|
||||
|
||||
camera_name: Optional[str] = None
|
||||
trajectories: Optional[str] = None
|
||||
|
||||
|
||||
class DynamicReplicaDataset(data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
root,
|
||||
split="valid",
|
||||
traj_per_sample=256,
|
||||
crop_size=None,
|
||||
sample_len=-1,
|
||||
only_first_n_samples=-1,
|
||||
rgbd_input=False,
|
||||
):
|
||||
super(DynamicReplicaDataset, self).__init__()
|
||||
self.root = root
|
||||
self.sample_len = sample_len
|
||||
self.split = split
|
||||
self.traj_per_sample = traj_per_sample
|
||||
self.rgbd_input = rgbd_input
|
||||
self.crop_size = crop_size
|
||||
frame_annotations_file = f"frame_annotations_{split}.jgz"
|
||||
self.sample_list = []
|
||||
with gzip.open(
|
||||
os.path.join(root, split, frame_annotations_file), "rt", encoding="utf8"
|
||||
) as zipfile:
|
||||
frame_annots_list = load_dataclass(zipfile, List[DynamicReplicaFrameAnnotation])
|
||||
seq_annot = defaultdict(list)
|
||||
for frame_annot in frame_annots_list:
|
||||
if frame_annot.camera_name == "left":
|
||||
seq_annot[frame_annot.sequence_name].append(frame_annot)
|
||||
|
||||
for seq_name in seq_annot.keys():
|
||||
seq_len = len(seq_annot[seq_name])
|
||||
|
||||
step = self.sample_len if self.sample_len > 0 else seq_len
|
||||
counter = 0
|
||||
|
||||
for ref_idx in range(0, seq_len, step):
|
||||
sample = seq_annot[seq_name][ref_idx : ref_idx + step]
|
||||
self.sample_list.append(sample)
|
||||
counter += 1
|
||||
if only_first_n_samples > 0 and counter >= only_first_n_samples:
|
||||
break
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sample_list)
|
||||
|
||||
def crop(self, rgbs, trajs):
|
||||
T, N, _ = trajs.shape
|
||||
|
||||
S = len(rgbs)
|
||||
H, W = rgbs[0].shape[:2]
|
||||
assert S == T
|
||||
|
||||
H_new = H
|
||||
W_new = W
|
||||
|
||||
# simple random crop
|
||||
y0 = 0 if self.crop_size[0] >= H_new else (H_new - self.crop_size[0]) // 2
|
||||
x0 = 0 if self.crop_size[1] >= W_new else (W_new - self.crop_size[1]) // 2
|
||||
rgbs = [rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] for rgb in rgbs]
|
||||
|
||||
trajs[:, :, 0] -= x0
|
||||
trajs[:, :, 1] -= y0
|
||||
|
||||
return rgbs, trajs
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.sample_list[index]
|
||||
T = len(sample)
|
||||
rgbs, visibilities, traj_2d = [], [], []
|
||||
|
||||
H, W = sample[0].image.size
|
||||
image_size = (H, W)
|
||||
|
||||
for i in range(T):
|
||||
traj_path = os.path.join(self.root, self.split, sample[i].trajectories["path"])
|
||||
traj = torch.load(traj_path)
|
||||
|
||||
visibilities.append(traj["verts_inds_vis"].numpy())
|
||||
|
||||
rgbs.append(traj["img"].numpy())
|
||||
traj_2d.append(traj["traj_2d"].numpy()[..., :2])
|
||||
|
||||
traj_2d = np.stack(traj_2d)
|
||||
visibility = np.stack(visibilities)
|
||||
T, N, D = traj_2d.shape
|
||||
# subsample trajectories for augmentations
|
||||
visible_inds_sampled = torch.randperm(N)[: self.traj_per_sample]
|
||||
|
||||
traj_2d = traj_2d[:, visible_inds_sampled]
|
||||
visibility = visibility[:, visible_inds_sampled]
|
||||
|
||||
if self.crop_size is not None:
|
||||
rgbs, traj_2d = self.crop(rgbs, traj_2d)
|
||||
H, W, _ = rgbs[0].shape
|
||||
image_size = self.crop_size
|
||||
|
||||
visibility[traj_2d[:, :, 0] > image_size[1] - 1] = False
|
||||
visibility[traj_2d[:, :, 0] < 0] = False
|
||||
visibility[traj_2d[:, :, 1] > image_size[0] - 1] = False
|
||||
visibility[traj_2d[:, :, 1] < 0] = False
|
||||
|
||||
# filter out points that're visible for less than 10 frames
|
||||
visible_inds_resampled = visibility.sum(0) > 10
|
||||
traj_2d = torch.from_numpy(traj_2d[:, visible_inds_resampled])
|
||||
visibility = torch.from_numpy(visibility[:, visible_inds_resampled])
|
||||
|
||||
rgbs = np.stack(rgbs, 0)
|
||||
video = torch.from_numpy(rgbs).reshape(T, H, W, 3).permute(0, 3, 1, 2).float()
|
||||
return CoTrackerData(
|
||||
video=video,
|
||||
trajectory=traj_2d,
|
||||
visibility=visibility,
|
||||
valid=torch.ones(T, N),
|
||||
seq_name=sample[0].sequence_name,
|
||||
)
|
||||
|
@@ -1,441 +1,441 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import torch
|
||||
import cv2
|
||||
|
||||
import imageio
|
||||
import numpy as np
|
||||
|
||||
from cotracker.datasets.utils import CoTrackerData
|
||||
from torchvision.transforms import ColorJitter, GaussianBlur
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class CoTrackerDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_root,
|
||||
crop_size=(384, 512),
|
||||
seq_len=24,
|
||||
traj_per_sample=768,
|
||||
sample_vis_1st_frame=False,
|
||||
use_augs=False,
|
||||
):
|
||||
super(CoTrackerDataset, self).__init__()
|
||||
np.random.seed(0)
|
||||
torch.manual_seed(0)
|
||||
self.data_root = data_root
|
||||
self.seq_len = seq_len
|
||||
self.traj_per_sample = traj_per_sample
|
||||
self.sample_vis_1st_frame = sample_vis_1st_frame
|
||||
self.use_augs = use_augs
|
||||
self.crop_size = crop_size
|
||||
|
||||
# photometric augmentation
|
||||
self.photo_aug = ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.25 / 3.14)
|
||||
self.blur_aug = GaussianBlur(11, sigma=(0.1, 2.0))
|
||||
|
||||
self.blur_aug_prob = 0.25
|
||||
self.color_aug_prob = 0.25
|
||||
|
||||
# occlusion augmentation
|
||||
self.eraser_aug_prob = 0.5
|
||||
self.eraser_bounds = [2, 100]
|
||||
self.eraser_max = 10
|
||||
|
||||
# occlusion augmentation
|
||||
self.replace_aug_prob = 0.5
|
||||
self.replace_bounds = [2, 100]
|
||||
self.replace_max = 10
|
||||
|
||||
# spatial augmentations
|
||||
self.pad_bounds = [0, 100]
|
||||
self.crop_size = crop_size
|
||||
self.resize_lim = [0.25, 2.0] # sample resizes from here
|
||||
self.resize_delta = 0.2
|
||||
self.max_crop_offset = 50
|
||||
|
||||
self.do_flip = True
|
||||
self.h_flip_prob = 0.5
|
||||
self.v_flip_prob = 0.5
|
||||
|
||||
def getitem_helper(self, index):
|
||||
return NotImplementedError
|
||||
|
||||
def __getitem__(self, index):
|
||||
gotit = False
|
||||
|
||||
sample, gotit = self.getitem_helper(index)
|
||||
if not gotit:
|
||||
print("warning: sampling failed")
|
||||
# fake sample, so we can still collate
|
||||
sample = CoTrackerData(
|
||||
video=torch.zeros((self.seq_len, 3, self.crop_size[0], self.crop_size[1])),
|
||||
trajectory=torch.zeros((self.seq_len, self.traj_per_sample, 2)),
|
||||
visibility=torch.zeros((self.seq_len, self.traj_per_sample)),
|
||||
valid=torch.zeros((self.seq_len, self.traj_per_sample)),
|
||||
)
|
||||
|
||||
return sample, gotit
|
||||
|
||||
def add_photometric_augs(self, rgbs, trajs, visibles, eraser=True, replace=True):
|
||||
T, N, _ = trajs.shape
|
||||
|
||||
S = len(rgbs)
|
||||
H, W = rgbs[0].shape[:2]
|
||||
assert S == T
|
||||
|
||||
if eraser:
|
||||
############ eraser transform (per image after the first) ############
|
||||
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
|
||||
for i in range(1, S):
|
||||
if np.random.rand() < self.eraser_aug_prob:
|
||||
for _ in range(
|
||||
np.random.randint(1, self.eraser_max + 1)
|
||||
): # number of times to occlude
|
||||
xc = np.random.randint(0, W)
|
||||
yc = np.random.randint(0, H)
|
||||
dx = np.random.randint(self.eraser_bounds[0], self.eraser_bounds[1])
|
||||
dy = np.random.randint(self.eraser_bounds[0], self.eraser_bounds[1])
|
||||
x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)
|
||||
x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)
|
||||
y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)
|
||||
y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32)
|
||||
|
||||
mean_color = np.mean(rgbs[i][y0:y1, x0:x1, :].reshape(-1, 3), axis=0)
|
||||
rgbs[i][y0:y1, x0:x1, :] = mean_color
|
||||
|
||||
occ_inds = np.logical_and(
|
||||
np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1),
|
||||
np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1),
|
||||
)
|
||||
visibles[i, occ_inds] = 0
|
||||
rgbs = [rgb.astype(np.uint8) for rgb in rgbs]
|
||||
|
||||
if replace:
|
||||
rgbs_alt = [
|
||||
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs
|
||||
]
|
||||
rgbs_alt = [
|
||||
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs_alt
|
||||
]
|
||||
|
||||
############ replace transform (per image after the first) ############
|
||||
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
|
||||
rgbs_alt = [rgb.astype(np.float32) for rgb in rgbs_alt]
|
||||
for i in range(1, S):
|
||||
if np.random.rand() < self.replace_aug_prob:
|
||||
for _ in range(
|
||||
np.random.randint(1, self.replace_max + 1)
|
||||
): # number of times to occlude
|
||||
xc = np.random.randint(0, W)
|
||||
yc = np.random.randint(0, H)
|
||||
dx = np.random.randint(self.replace_bounds[0], self.replace_bounds[1])
|
||||
dy = np.random.randint(self.replace_bounds[0], self.replace_bounds[1])
|
||||
x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)
|
||||
x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)
|
||||
y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)
|
||||
y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32)
|
||||
|
||||
wid = x1 - x0
|
||||
hei = y1 - y0
|
||||
y00 = np.random.randint(0, H - hei)
|
||||
x00 = np.random.randint(0, W - wid)
|
||||
fr = np.random.randint(0, S)
|
||||
rep = rgbs_alt[fr][y00 : y00 + hei, x00 : x00 + wid, :]
|
||||
rgbs[i][y0:y1, x0:x1, :] = rep
|
||||
|
||||
occ_inds = np.logical_and(
|
||||
np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1),
|
||||
np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1),
|
||||
)
|
||||
visibles[i, occ_inds] = 0
|
||||
rgbs = [rgb.astype(np.uint8) for rgb in rgbs]
|
||||
|
||||
############ photometric augmentation ############
|
||||
if np.random.rand() < self.color_aug_prob:
|
||||
# random per-frame amount of aug
|
||||
rgbs = [np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs]
|
||||
|
||||
if np.random.rand() < self.blur_aug_prob:
|
||||
# random per-frame amount of blur
|
||||
rgbs = [np.array(self.blur_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs]
|
||||
|
||||
return rgbs, trajs, visibles
|
||||
|
||||
def add_spatial_augs(self, rgbs, trajs, visibles):
|
||||
T, N, __ = trajs.shape
|
||||
|
||||
S = len(rgbs)
|
||||
H, W = rgbs[0].shape[:2]
|
||||
assert S == T
|
||||
|
||||
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
|
||||
|
||||
############ spatial transform ############
|
||||
|
||||
# padding
|
||||
pad_x0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
|
||||
pad_x1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
|
||||
pad_y0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
|
||||
pad_y1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
|
||||
|
||||
rgbs = [np.pad(rgb, ((pad_y0, pad_y1), (pad_x0, pad_x1), (0, 0))) for rgb in rgbs]
|
||||
trajs[:, :, 0] += pad_x0
|
||||
trajs[:, :, 1] += pad_y0
|
||||
H, W = rgbs[0].shape[:2]
|
||||
|
||||
# scaling + stretching
|
||||
scale = np.random.uniform(self.resize_lim[0], self.resize_lim[1])
|
||||
scale_x = scale
|
||||
scale_y = scale
|
||||
H_new = H
|
||||
W_new = W
|
||||
|
||||
scale_delta_x = 0.0
|
||||
scale_delta_y = 0.0
|
||||
|
||||
rgbs_scaled = []
|
||||
for s in range(S):
|
||||
if s == 1:
|
||||
scale_delta_x = np.random.uniform(-self.resize_delta, self.resize_delta)
|
||||
scale_delta_y = np.random.uniform(-self.resize_delta, self.resize_delta)
|
||||
elif s > 1:
|
||||
scale_delta_x = (
|
||||
scale_delta_x * 0.8
|
||||
+ np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2
|
||||
)
|
||||
scale_delta_y = (
|
||||
scale_delta_y * 0.8
|
||||
+ np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2
|
||||
)
|
||||
scale_x = scale_x + scale_delta_x
|
||||
scale_y = scale_y + scale_delta_y
|
||||
|
||||
# bring h/w closer
|
||||
scale_xy = (scale_x + scale_y) * 0.5
|
||||
scale_x = scale_x * 0.5 + scale_xy * 0.5
|
||||
scale_y = scale_y * 0.5 + scale_xy * 0.5
|
||||
|
||||
# don't get too crazy
|
||||
scale_x = np.clip(scale_x, 0.2, 2.0)
|
||||
scale_y = np.clip(scale_y, 0.2, 2.0)
|
||||
|
||||
H_new = int(H * scale_y)
|
||||
W_new = int(W * scale_x)
|
||||
|
||||
# make it at least slightly bigger than the crop area,
|
||||
# so that the random cropping can add diversity
|
||||
H_new = np.clip(H_new, self.crop_size[0] + 10, None)
|
||||
W_new = np.clip(W_new, self.crop_size[1] + 10, None)
|
||||
# recompute scale in case we clipped
|
||||
scale_x = (W_new - 1) / float(W - 1)
|
||||
scale_y = (H_new - 1) / float(H - 1)
|
||||
rgbs_scaled.append(cv2.resize(rgbs[s], (W_new, H_new), interpolation=cv2.INTER_LINEAR))
|
||||
trajs[s, :, 0] *= scale_x
|
||||
trajs[s, :, 1] *= scale_y
|
||||
rgbs = rgbs_scaled
|
||||
|
||||
ok_inds = visibles[0, :] > 0
|
||||
vis_trajs = trajs[:, ok_inds] # S,?,2
|
||||
|
||||
if vis_trajs.shape[1] > 0:
|
||||
mid_x = np.mean(vis_trajs[0, :, 0])
|
||||
mid_y = np.mean(vis_trajs[0, :, 1])
|
||||
else:
|
||||
mid_y = self.crop_size[0]
|
||||
mid_x = self.crop_size[1]
|
||||
|
||||
x0 = int(mid_x - self.crop_size[1] // 2)
|
||||
y0 = int(mid_y - self.crop_size[0] // 2)
|
||||
|
||||
offset_x = 0
|
||||
offset_y = 0
|
||||
|
||||
for s in range(S):
|
||||
# on each frame, shift a bit more
|
||||
if s == 1:
|
||||
offset_x = np.random.randint(-self.max_crop_offset, self.max_crop_offset)
|
||||
offset_y = np.random.randint(-self.max_crop_offset, self.max_crop_offset)
|
||||
elif s > 1:
|
||||
offset_x = int(
|
||||
offset_x * 0.8
|
||||
+ np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1) * 0.2
|
||||
)
|
||||
offset_y = int(
|
||||
offset_y * 0.8
|
||||
+ np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1) * 0.2
|
||||
)
|
||||
x0 = x0 + offset_x
|
||||
y0 = y0 + offset_y
|
||||
|
||||
H_new, W_new = rgbs[s].shape[:2]
|
||||
if H_new == self.crop_size[0]:
|
||||
y0 = 0
|
||||
else:
|
||||
y0 = min(max(0, y0), H_new - self.crop_size[0] - 1)
|
||||
|
||||
if W_new == self.crop_size[1]:
|
||||
x0 = 0
|
||||
else:
|
||||
x0 = min(max(0, x0), W_new - self.crop_size[1] - 1)
|
||||
|
||||
rgbs[s] = rgbs[s][y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
|
||||
trajs[s, :, 0] -= x0
|
||||
trajs[s, :, 1] -= y0
|
||||
|
||||
H_new = self.crop_size[0]
|
||||
W_new = self.crop_size[1]
|
||||
|
||||
# flip
|
||||
h_flipped = False
|
||||
v_flipped = False
|
||||
if self.do_flip:
|
||||
# h flip
|
||||
if np.random.rand() < self.h_flip_prob:
|
||||
h_flipped = True
|
||||
rgbs = [rgb[:, ::-1] for rgb in rgbs]
|
||||
# v flip
|
||||
if np.random.rand() < self.v_flip_prob:
|
||||
v_flipped = True
|
||||
rgbs = [rgb[::-1] for rgb in rgbs]
|
||||
if h_flipped:
|
||||
trajs[:, :, 0] = W_new - trajs[:, :, 0]
|
||||
if v_flipped:
|
||||
trajs[:, :, 1] = H_new - trajs[:, :, 1]
|
||||
|
||||
return rgbs, trajs
|
||||
|
||||
def crop(self, rgbs, trajs):
|
||||
T, N, _ = trajs.shape
|
||||
|
||||
S = len(rgbs)
|
||||
H, W = rgbs[0].shape[:2]
|
||||
assert S == T
|
||||
|
||||
############ spatial transform ############
|
||||
|
||||
H_new = H
|
||||
W_new = W
|
||||
|
||||
# simple random crop
|
||||
y0 = 0 if self.crop_size[0] >= H_new else np.random.randint(0, H_new - self.crop_size[0])
|
||||
x0 = 0 if self.crop_size[1] >= W_new else np.random.randint(0, W_new - self.crop_size[1])
|
||||
rgbs = [rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] for rgb in rgbs]
|
||||
|
||||
trajs[:, :, 0] -= x0
|
||||
trajs[:, :, 1] -= y0
|
||||
|
||||
return rgbs, trajs
|
||||
|
||||
|
||||
class KubricMovifDataset(CoTrackerDataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_root,
|
||||
crop_size=(384, 512),
|
||||
seq_len=24,
|
||||
traj_per_sample=768,
|
||||
sample_vis_1st_frame=False,
|
||||
use_augs=False,
|
||||
):
|
||||
super(KubricMovifDataset, self).__init__(
|
||||
data_root=data_root,
|
||||
crop_size=crop_size,
|
||||
seq_len=seq_len,
|
||||
traj_per_sample=traj_per_sample,
|
||||
sample_vis_1st_frame=sample_vis_1st_frame,
|
||||
use_augs=use_augs,
|
||||
)
|
||||
|
||||
self.pad_bounds = [0, 25]
|
||||
self.resize_lim = [0.75, 1.25] # sample resizes from here
|
||||
self.resize_delta = 0.05
|
||||
self.max_crop_offset = 15
|
||||
self.seq_names = [
|
||||
fname
|
||||
for fname in os.listdir(data_root)
|
||||
if os.path.isdir(os.path.join(data_root, fname))
|
||||
]
|
||||
print("found %d unique videos in %s" % (len(self.seq_names), self.data_root))
|
||||
|
||||
def getitem_helper(self, index):
|
||||
gotit = True
|
||||
seq_name = self.seq_names[index]
|
||||
|
||||
npy_path = os.path.join(self.data_root, seq_name, seq_name + ".npy")
|
||||
rgb_path = os.path.join(self.data_root, seq_name, "frames")
|
||||
|
||||
img_paths = sorted(os.listdir(rgb_path))
|
||||
rgbs = []
|
||||
for i, img_path in enumerate(img_paths):
|
||||
rgbs.append(imageio.v2.imread(os.path.join(rgb_path, img_path)))
|
||||
|
||||
rgbs = np.stack(rgbs)
|
||||
annot_dict = np.load(npy_path, allow_pickle=True).item()
|
||||
traj_2d = annot_dict["coords"]
|
||||
visibility = annot_dict["visibility"]
|
||||
|
||||
# random crop
|
||||
assert self.seq_len <= len(rgbs)
|
||||
if self.seq_len < len(rgbs):
|
||||
start_ind = np.random.choice(len(rgbs) - self.seq_len, 1)[0]
|
||||
|
||||
rgbs = rgbs[start_ind : start_ind + self.seq_len]
|
||||
traj_2d = traj_2d[:, start_ind : start_ind + self.seq_len]
|
||||
visibility = visibility[:, start_ind : start_ind + self.seq_len]
|
||||
|
||||
traj_2d = np.transpose(traj_2d, (1, 0, 2))
|
||||
visibility = np.transpose(np.logical_not(visibility), (1, 0))
|
||||
if self.use_augs:
|
||||
rgbs, traj_2d, visibility = self.add_photometric_augs(rgbs, traj_2d, visibility)
|
||||
rgbs, traj_2d = self.add_spatial_augs(rgbs, traj_2d, visibility)
|
||||
else:
|
||||
rgbs, traj_2d = self.crop(rgbs, traj_2d)
|
||||
|
||||
visibility[traj_2d[:, :, 0] > self.crop_size[1] - 1] = False
|
||||
visibility[traj_2d[:, :, 0] < 0] = False
|
||||
visibility[traj_2d[:, :, 1] > self.crop_size[0] - 1] = False
|
||||
visibility[traj_2d[:, :, 1] < 0] = False
|
||||
|
||||
visibility = torch.from_numpy(visibility)
|
||||
traj_2d = torch.from_numpy(traj_2d)
|
||||
|
||||
visibile_pts_first_frame_inds = (visibility[0]).nonzero(as_tuple=False)[:, 0]
|
||||
|
||||
if self.sample_vis_1st_frame:
|
||||
visibile_pts_inds = visibile_pts_first_frame_inds
|
||||
else:
|
||||
visibile_pts_mid_frame_inds = (visibility[self.seq_len // 2]).nonzero(as_tuple=False)[
|
||||
:, 0
|
||||
]
|
||||
visibile_pts_inds = torch.cat(
|
||||
(visibile_pts_first_frame_inds, visibile_pts_mid_frame_inds), dim=0
|
||||
)
|
||||
point_inds = torch.randperm(len(visibile_pts_inds))[: self.traj_per_sample]
|
||||
if len(point_inds) < self.traj_per_sample:
|
||||
gotit = False
|
||||
|
||||
visible_inds_sampled = visibile_pts_inds[point_inds]
|
||||
|
||||
trajs = traj_2d[:, visible_inds_sampled].float()
|
||||
visibles = visibility[:, visible_inds_sampled]
|
||||
valids = torch.ones((self.seq_len, self.traj_per_sample))
|
||||
|
||||
rgbs = torch.from_numpy(np.stack(rgbs)).permute(0, 3, 1, 2).float()
|
||||
sample = CoTrackerData(
|
||||
video=rgbs,
|
||||
trajectory=trajs,
|
||||
visibility=visibles,
|
||||
valid=valids,
|
||||
seq_name=seq_name,
|
||||
)
|
||||
return sample, gotit
|
||||
|
||||
def __len__(self):
|
||||
return len(self.seq_names)
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import torch
|
||||
import cv2
|
||||
|
||||
import imageio
|
||||
import numpy as np
|
||||
|
||||
from cotracker.datasets.utils import CoTrackerData
|
||||
from torchvision.transforms import ColorJitter, GaussianBlur
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class CoTrackerDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_root,
|
||||
crop_size=(384, 512),
|
||||
seq_len=24,
|
||||
traj_per_sample=768,
|
||||
sample_vis_1st_frame=False,
|
||||
use_augs=False,
|
||||
):
|
||||
super(CoTrackerDataset, self).__init__()
|
||||
np.random.seed(0)
|
||||
torch.manual_seed(0)
|
||||
self.data_root = data_root
|
||||
self.seq_len = seq_len
|
||||
self.traj_per_sample = traj_per_sample
|
||||
self.sample_vis_1st_frame = sample_vis_1st_frame
|
||||
self.use_augs = use_augs
|
||||
self.crop_size = crop_size
|
||||
|
||||
# photometric augmentation
|
||||
self.photo_aug = ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.25 / 3.14)
|
||||
self.blur_aug = GaussianBlur(11, sigma=(0.1, 2.0))
|
||||
|
||||
self.blur_aug_prob = 0.25
|
||||
self.color_aug_prob = 0.25
|
||||
|
||||
# occlusion augmentation
|
||||
self.eraser_aug_prob = 0.5
|
||||
self.eraser_bounds = [2, 100]
|
||||
self.eraser_max = 10
|
||||
|
||||
# occlusion augmentation
|
||||
self.replace_aug_prob = 0.5
|
||||
self.replace_bounds = [2, 100]
|
||||
self.replace_max = 10
|
||||
|
||||
# spatial augmentations
|
||||
self.pad_bounds = [0, 100]
|
||||
self.crop_size = crop_size
|
||||
self.resize_lim = [0.25, 2.0] # sample resizes from here
|
||||
self.resize_delta = 0.2
|
||||
self.max_crop_offset = 50
|
||||
|
||||
self.do_flip = True
|
||||
self.h_flip_prob = 0.5
|
||||
self.v_flip_prob = 0.5
|
||||
|
||||
def getitem_helper(self, index):
|
||||
return NotImplementedError
|
||||
|
||||
def __getitem__(self, index):
|
||||
gotit = False
|
||||
|
||||
sample, gotit = self.getitem_helper(index)
|
||||
if not gotit:
|
||||
print("warning: sampling failed")
|
||||
# fake sample, so we can still collate
|
||||
sample = CoTrackerData(
|
||||
video=torch.zeros((self.seq_len, 3, self.crop_size[0], self.crop_size[1])),
|
||||
trajectory=torch.zeros((self.seq_len, self.traj_per_sample, 2)),
|
||||
visibility=torch.zeros((self.seq_len, self.traj_per_sample)),
|
||||
valid=torch.zeros((self.seq_len, self.traj_per_sample)),
|
||||
)
|
||||
|
||||
return sample, gotit
|
||||
|
||||
def add_photometric_augs(self, rgbs, trajs, visibles, eraser=True, replace=True):
|
||||
T, N, _ = trajs.shape
|
||||
|
||||
S = len(rgbs)
|
||||
H, W = rgbs[0].shape[:2]
|
||||
assert S == T
|
||||
|
||||
if eraser:
|
||||
############ eraser transform (per image after the first) ############
|
||||
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
|
||||
for i in range(1, S):
|
||||
if np.random.rand() < self.eraser_aug_prob:
|
||||
for _ in range(
|
||||
np.random.randint(1, self.eraser_max + 1)
|
||||
): # number of times to occlude
|
||||
xc = np.random.randint(0, W)
|
||||
yc = np.random.randint(0, H)
|
||||
dx = np.random.randint(self.eraser_bounds[0], self.eraser_bounds[1])
|
||||
dy = np.random.randint(self.eraser_bounds[0], self.eraser_bounds[1])
|
||||
x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)
|
||||
x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)
|
||||
y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)
|
||||
y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32)
|
||||
|
||||
mean_color = np.mean(rgbs[i][y0:y1, x0:x1, :].reshape(-1, 3), axis=0)
|
||||
rgbs[i][y0:y1, x0:x1, :] = mean_color
|
||||
|
||||
occ_inds = np.logical_and(
|
||||
np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1),
|
||||
np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1),
|
||||
)
|
||||
visibles[i, occ_inds] = 0
|
||||
rgbs = [rgb.astype(np.uint8) for rgb in rgbs]
|
||||
|
||||
if replace:
|
||||
rgbs_alt = [
|
||||
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs
|
||||
]
|
||||
rgbs_alt = [
|
||||
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs_alt
|
||||
]
|
||||
|
||||
############ replace transform (per image after the first) ############
|
||||
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
|
||||
rgbs_alt = [rgb.astype(np.float32) for rgb in rgbs_alt]
|
||||
for i in range(1, S):
|
||||
if np.random.rand() < self.replace_aug_prob:
|
||||
for _ in range(
|
||||
np.random.randint(1, self.replace_max + 1)
|
||||
): # number of times to occlude
|
||||
xc = np.random.randint(0, W)
|
||||
yc = np.random.randint(0, H)
|
||||
dx = np.random.randint(self.replace_bounds[0], self.replace_bounds[1])
|
||||
dy = np.random.randint(self.replace_bounds[0], self.replace_bounds[1])
|
||||
x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)
|
||||
x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)
|
||||
y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)
|
||||
y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32)
|
||||
|
||||
wid = x1 - x0
|
||||
hei = y1 - y0
|
||||
y00 = np.random.randint(0, H - hei)
|
||||
x00 = np.random.randint(0, W - wid)
|
||||
fr = np.random.randint(0, S)
|
||||
rep = rgbs_alt[fr][y00 : y00 + hei, x00 : x00 + wid, :]
|
||||
rgbs[i][y0:y1, x0:x1, :] = rep
|
||||
|
||||
occ_inds = np.logical_and(
|
||||
np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1),
|
||||
np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1),
|
||||
)
|
||||
visibles[i, occ_inds] = 0
|
||||
rgbs = [rgb.astype(np.uint8) for rgb in rgbs]
|
||||
|
||||
############ photometric augmentation ############
|
||||
if np.random.rand() < self.color_aug_prob:
|
||||
# random per-frame amount of aug
|
||||
rgbs = [np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs]
|
||||
|
||||
if np.random.rand() < self.blur_aug_prob:
|
||||
# random per-frame amount of blur
|
||||
rgbs = [np.array(self.blur_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs]
|
||||
|
||||
return rgbs, trajs, visibles
|
||||
|
||||
def add_spatial_augs(self, rgbs, trajs, visibles):
|
||||
T, N, __ = trajs.shape
|
||||
|
||||
S = len(rgbs)
|
||||
H, W = rgbs[0].shape[:2]
|
||||
assert S == T
|
||||
|
||||
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
|
||||
|
||||
############ spatial transform ############
|
||||
|
||||
# padding
|
||||
pad_x0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
|
||||
pad_x1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
|
||||
pad_y0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
|
||||
pad_y1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
|
||||
|
||||
rgbs = [np.pad(rgb, ((pad_y0, pad_y1), (pad_x0, pad_x1), (0, 0))) for rgb in rgbs]
|
||||
trajs[:, :, 0] += pad_x0
|
||||
trajs[:, :, 1] += pad_y0
|
||||
H, W = rgbs[0].shape[:2]
|
||||
|
||||
# scaling + stretching
|
||||
scale = np.random.uniform(self.resize_lim[0], self.resize_lim[1])
|
||||
scale_x = scale
|
||||
scale_y = scale
|
||||
H_new = H
|
||||
W_new = W
|
||||
|
||||
scale_delta_x = 0.0
|
||||
scale_delta_y = 0.0
|
||||
|
||||
rgbs_scaled = []
|
||||
for s in range(S):
|
||||
if s == 1:
|
||||
scale_delta_x = np.random.uniform(-self.resize_delta, self.resize_delta)
|
||||
scale_delta_y = np.random.uniform(-self.resize_delta, self.resize_delta)
|
||||
elif s > 1:
|
||||
scale_delta_x = (
|
||||
scale_delta_x * 0.8
|
||||
+ np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2
|
||||
)
|
||||
scale_delta_y = (
|
||||
scale_delta_y * 0.8
|
||||
+ np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2
|
||||
)
|
||||
scale_x = scale_x + scale_delta_x
|
||||
scale_y = scale_y + scale_delta_y
|
||||
|
||||
# bring h/w closer
|
||||
scale_xy = (scale_x + scale_y) * 0.5
|
||||
scale_x = scale_x * 0.5 + scale_xy * 0.5
|
||||
scale_y = scale_y * 0.5 + scale_xy * 0.5
|
||||
|
||||
# don't get too crazy
|
||||
scale_x = np.clip(scale_x, 0.2, 2.0)
|
||||
scale_y = np.clip(scale_y, 0.2, 2.0)
|
||||
|
||||
H_new = int(H * scale_y)
|
||||
W_new = int(W * scale_x)
|
||||
|
||||
# make it at least slightly bigger than the crop area,
|
||||
# so that the random cropping can add diversity
|
||||
H_new = np.clip(H_new, self.crop_size[0] + 10, None)
|
||||
W_new = np.clip(W_new, self.crop_size[1] + 10, None)
|
||||
# recompute scale in case we clipped
|
||||
scale_x = (W_new - 1) / float(W - 1)
|
||||
scale_y = (H_new - 1) / float(H - 1)
|
||||
rgbs_scaled.append(cv2.resize(rgbs[s], (W_new, H_new), interpolation=cv2.INTER_LINEAR))
|
||||
trajs[s, :, 0] *= scale_x
|
||||
trajs[s, :, 1] *= scale_y
|
||||
rgbs = rgbs_scaled
|
||||
|
||||
ok_inds = visibles[0, :] > 0
|
||||
vis_trajs = trajs[:, ok_inds] # S,?,2
|
||||
|
||||
if vis_trajs.shape[1] > 0:
|
||||
mid_x = np.mean(vis_trajs[0, :, 0])
|
||||
mid_y = np.mean(vis_trajs[0, :, 1])
|
||||
else:
|
||||
mid_y = self.crop_size[0]
|
||||
mid_x = self.crop_size[1]
|
||||
|
||||
x0 = int(mid_x - self.crop_size[1] // 2)
|
||||
y0 = int(mid_y - self.crop_size[0] // 2)
|
||||
|
||||
offset_x = 0
|
||||
offset_y = 0
|
||||
|
||||
for s in range(S):
|
||||
# on each frame, shift a bit more
|
||||
if s == 1:
|
||||
offset_x = np.random.randint(-self.max_crop_offset, self.max_crop_offset)
|
||||
offset_y = np.random.randint(-self.max_crop_offset, self.max_crop_offset)
|
||||
elif s > 1:
|
||||
offset_x = int(
|
||||
offset_x * 0.8
|
||||
+ np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1) * 0.2
|
||||
)
|
||||
offset_y = int(
|
||||
offset_y * 0.8
|
||||
+ np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1) * 0.2
|
||||
)
|
||||
x0 = x0 + offset_x
|
||||
y0 = y0 + offset_y
|
||||
|
||||
H_new, W_new = rgbs[s].shape[:2]
|
||||
if H_new == self.crop_size[0]:
|
||||
y0 = 0
|
||||
else:
|
||||
y0 = min(max(0, y0), H_new - self.crop_size[0] - 1)
|
||||
|
||||
if W_new == self.crop_size[1]:
|
||||
x0 = 0
|
||||
else:
|
||||
x0 = min(max(0, x0), W_new - self.crop_size[1] - 1)
|
||||
|
||||
rgbs[s] = rgbs[s][y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
|
||||
trajs[s, :, 0] -= x0
|
||||
trajs[s, :, 1] -= y0
|
||||
|
||||
H_new = self.crop_size[0]
|
||||
W_new = self.crop_size[1]
|
||||
|
||||
# flip
|
||||
h_flipped = False
|
||||
v_flipped = False
|
||||
if self.do_flip:
|
||||
# h flip
|
||||
if np.random.rand() < self.h_flip_prob:
|
||||
h_flipped = True
|
||||
rgbs = [rgb[:, ::-1] for rgb in rgbs]
|
||||
# v flip
|
||||
if np.random.rand() < self.v_flip_prob:
|
||||
v_flipped = True
|
||||
rgbs = [rgb[::-1] for rgb in rgbs]
|
||||
if h_flipped:
|
||||
trajs[:, :, 0] = W_new - trajs[:, :, 0]
|
||||
if v_flipped:
|
||||
trajs[:, :, 1] = H_new - trajs[:, :, 1]
|
||||
|
||||
return rgbs, trajs
|
||||
|
||||
def crop(self, rgbs, trajs):
|
||||
T, N, _ = trajs.shape
|
||||
|
||||
S = len(rgbs)
|
||||
H, W = rgbs[0].shape[:2]
|
||||
assert S == T
|
||||
|
||||
############ spatial transform ############
|
||||
|
||||
H_new = H
|
||||
W_new = W
|
||||
|
||||
# simple random crop
|
||||
y0 = 0 if self.crop_size[0] >= H_new else np.random.randint(0, H_new - self.crop_size[0])
|
||||
x0 = 0 if self.crop_size[1] >= W_new else np.random.randint(0, W_new - self.crop_size[1])
|
||||
rgbs = [rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] for rgb in rgbs]
|
||||
|
||||
trajs[:, :, 0] -= x0
|
||||
trajs[:, :, 1] -= y0
|
||||
|
||||
return rgbs, trajs
|
||||
|
||||
|
||||
class KubricMovifDataset(CoTrackerDataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_root,
|
||||
crop_size=(384, 512),
|
||||
seq_len=24,
|
||||
traj_per_sample=768,
|
||||
sample_vis_1st_frame=False,
|
||||
use_augs=False,
|
||||
):
|
||||
super(KubricMovifDataset, self).__init__(
|
||||
data_root=data_root,
|
||||
crop_size=crop_size,
|
||||
seq_len=seq_len,
|
||||
traj_per_sample=traj_per_sample,
|
||||
sample_vis_1st_frame=sample_vis_1st_frame,
|
||||
use_augs=use_augs,
|
||||
)
|
||||
|
||||
self.pad_bounds = [0, 25]
|
||||
self.resize_lim = [0.75, 1.25] # sample resizes from here
|
||||
self.resize_delta = 0.05
|
||||
self.max_crop_offset = 15
|
||||
self.seq_names = [
|
||||
fname
|
||||
for fname in os.listdir(data_root)
|
||||
if os.path.isdir(os.path.join(data_root, fname))
|
||||
]
|
||||
print("found %d unique videos in %s" % (len(self.seq_names), self.data_root))
|
||||
|
||||
def getitem_helper(self, index):
|
||||
gotit = True
|
||||
seq_name = self.seq_names[index]
|
||||
|
||||
npy_path = os.path.join(self.data_root, seq_name, seq_name + ".npy")
|
||||
rgb_path = os.path.join(self.data_root, seq_name, "frames")
|
||||
|
||||
img_paths = sorted(os.listdir(rgb_path))
|
||||
rgbs = []
|
||||
for i, img_path in enumerate(img_paths):
|
||||
rgbs.append(imageio.v2.imread(os.path.join(rgb_path, img_path)))
|
||||
|
||||
rgbs = np.stack(rgbs)
|
||||
annot_dict = np.load(npy_path, allow_pickle=True).item()
|
||||
traj_2d = annot_dict["coords"]
|
||||
visibility = annot_dict["visibility"]
|
||||
|
||||
# random crop
|
||||
assert self.seq_len <= len(rgbs)
|
||||
if self.seq_len < len(rgbs):
|
||||
start_ind = np.random.choice(len(rgbs) - self.seq_len, 1)[0]
|
||||
|
||||
rgbs = rgbs[start_ind : start_ind + self.seq_len]
|
||||
traj_2d = traj_2d[:, start_ind : start_ind + self.seq_len]
|
||||
visibility = visibility[:, start_ind : start_ind + self.seq_len]
|
||||
|
||||
traj_2d = np.transpose(traj_2d, (1, 0, 2))
|
||||
visibility = np.transpose(np.logical_not(visibility), (1, 0))
|
||||
if self.use_augs:
|
||||
rgbs, traj_2d, visibility = self.add_photometric_augs(rgbs, traj_2d, visibility)
|
||||
rgbs, traj_2d = self.add_spatial_augs(rgbs, traj_2d, visibility)
|
||||
else:
|
||||
rgbs, traj_2d = self.crop(rgbs, traj_2d)
|
||||
|
||||
visibility[traj_2d[:, :, 0] > self.crop_size[1] - 1] = False
|
||||
visibility[traj_2d[:, :, 0] < 0] = False
|
||||
visibility[traj_2d[:, :, 1] > self.crop_size[0] - 1] = False
|
||||
visibility[traj_2d[:, :, 1] < 0] = False
|
||||
|
||||
visibility = torch.from_numpy(visibility)
|
||||
traj_2d = torch.from_numpy(traj_2d)
|
||||
|
||||
visibile_pts_first_frame_inds = (visibility[0]).nonzero(as_tuple=False)[:, 0]
|
||||
|
||||
if self.sample_vis_1st_frame:
|
||||
visibile_pts_inds = visibile_pts_first_frame_inds
|
||||
else:
|
||||
visibile_pts_mid_frame_inds = (visibility[self.seq_len // 2]).nonzero(as_tuple=False)[
|
||||
:, 0
|
||||
]
|
||||
visibile_pts_inds = torch.cat(
|
||||
(visibile_pts_first_frame_inds, visibile_pts_mid_frame_inds), dim=0
|
||||
)
|
||||
point_inds = torch.randperm(len(visibile_pts_inds))[: self.traj_per_sample]
|
||||
if len(point_inds) < self.traj_per_sample:
|
||||
gotit = False
|
||||
|
||||
visible_inds_sampled = visibile_pts_inds[point_inds]
|
||||
|
||||
trajs = traj_2d[:, visible_inds_sampled].float()
|
||||
visibles = visibility[:, visible_inds_sampled]
|
||||
valids = torch.ones((self.seq_len, self.traj_per_sample))
|
||||
|
||||
rgbs = torch.from_numpy(np.stack(rgbs)).permute(0, 3, 1, 2).float()
|
||||
sample = CoTrackerData(
|
||||
video=rgbs,
|
||||
trajectory=trajs,
|
||||
visibility=visibles,
|
||||
valid=valids,
|
||||
seq_name=seq_name,
|
||||
)
|
||||
return sample, gotit
|
||||
|
||||
def __len__(self):
|
||||
return len(self.seq_names)
|
||||
|
@@ -1,209 +1,209 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import io
|
||||
import glob
|
||||
import torch
|
||||
import pickle
|
||||
import numpy as np
|
||||
import mediapy as media
|
||||
|
||||
from PIL import Image
|
||||
from typing import Mapping, Tuple, Union
|
||||
|
||||
from cotracker.datasets.utils import CoTrackerData
|
||||
|
||||
DatasetElement = Mapping[str, Mapping[str, Union[np.ndarray, str]]]
|
||||
|
||||
|
||||
def resize_video(video: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
|
||||
"""Resize a video to output_size."""
|
||||
# If you have a GPU, consider replacing this with a GPU-enabled resize op,
|
||||
# such as a jitted jax.image.resize. It will make things faster.
|
||||
return media.resize_video(video, output_size)
|
||||
|
||||
|
||||
def sample_queries_first(
|
||||
target_occluded: np.ndarray,
|
||||
target_points: np.ndarray,
|
||||
frames: np.ndarray,
|
||||
) -> Mapping[str, np.ndarray]:
|
||||
"""Package a set of frames and tracks for use in TAPNet evaluations.
|
||||
Given a set of frames and tracks with no query points, use the first
|
||||
visible point in each track as the query.
|
||||
Args:
|
||||
target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
|
||||
where True indicates occluded.
|
||||
target_points: Position, of shape [n_tracks, n_frames, 2], where each point
|
||||
is [x,y] scaled between 0 and 1.
|
||||
frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
|
||||
-1 and 1.
|
||||
Returns:
|
||||
A dict with the keys:
|
||||
video: Video tensor of shape [1, n_frames, height, width, 3]
|
||||
query_points: Query points of shape [1, n_queries, 3] where
|
||||
each point is [t, y, x] scaled to the range [-1, 1]
|
||||
target_points: Target points of shape [1, n_queries, n_frames, 2] where
|
||||
each point is [x, y] scaled to the range [-1, 1]
|
||||
"""
|
||||
valid = np.sum(~target_occluded, axis=1) > 0
|
||||
target_points = target_points[valid, :]
|
||||
target_occluded = target_occluded[valid, :]
|
||||
|
||||
query_points = []
|
||||
for i in range(target_points.shape[0]):
|
||||
index = np.where(target_occluded[i] == 0)[0][0]
|
||||
x, y = target_points[i, index, 0], target_points[i, index, 1]
|
||||
query_points.append(np.array([index, y, x])) # [t, y, x]
|
||||
query_points = np.stack(query_points, axis=0)
|
||||
|
||||
return {
|
||||
"video": frames[np.newaxis, ...],
|
||||
"query_points": query_points[np.newaxis, ...],
|
||||
"target_points": target_points[np.newaxis, ...],
|
||||
"occluded": target_occluded[np.newaxis, ...],
|
||||
}
|
||||
|
||||
|
||||
def sample_queries_strided(
|
||||
target_occluded: np.ndarray,
|
||||
target_points: np.ndarray,
|
||||
frames: np.ndarray,
|
||||
query_stride: int = 5,
|
||||
) -> Mapping[str, np.ndarray]:
|
||||
"""Package a set of frames and tracks for use in TAPNet evaluations.
|
||||
|
||||
Given a set of frames and tracks with no query points, sample queries
|
||||
strided every query_stride frames, ignoring points that are not visible
|
||||
at the selected frames.
|
||||
|
||||
Args:
|
||||
target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
|
||||
where True indicates occluded.
|
||||
target_points: Position, of shape [n_tracks, n_frames, 2], where each point
|
||||
is [x,y] scaled between 0 and 1.
|
||||
frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
|
||||
-1 and 1.
|
||||
query_stride: When sampling query points, search for un-occluded points
|
||||
every query_stride frames and convert each one into a query.
|
||||
|
||||
Returns:
|
||||
A dict with the keys:
|
||||
video: Video tensor of shape [1, n_frames, height, width, 3]. The video
|
||||
has floats scaled to the range [-1, 1].
|
||||
query_points: Query points of shape [1, n_queries, 3] where
|
||||
each point is [t, y, x] scaled to the range [-1, 1].
|
||||
target_points: Target points of shape [1, n_queries, n_frames, 2] where
|
||||
each point is [x, y] scaled to the range [-1, 1].
|
||||
trackgroup: Index of the original track that each query point was
|
||||
sampled from. This is useful for visualization.
|
||||
"""
|
||||
tracks = []
|
||||
occs = []
|
||||
queries = []
|
||||
trackgroups = []
|
||||
total = 0
|
||||
trackgroup = np.arange(target_occluded.shape[0])
|
||||
for i in range(0, target_occluded.shape[1], query_stride):
|
||||
mask = target_occluded[:, i] == 0
|
||||
query = np.stack(
|
||||
[
|
||||
i * np.ones(target_occluded.shape[0:1]),
|
||||
target_points[:, i, 1],
|
||||
target_points[:, i, 0],
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
queries.append(query[mask])
|
||||
tracks.append(target_points[mask])
|
||||
occs.append(target_occluded[mask])
|
||||
trackgroups.append(trackgroup[mask])
|
||||
total += np.array(np.sum(target_occluded[:, i] == 0))
|
||||
|
||||
return {
|
||||
"video": frames[np.newaxis, ...],
|
||||
"query_points": np.concatenate(queries, axis=0)[np.newaxis, ...],
|
||||
"target_points": np.concatenate(tracks, axis=0)[np.newaxis, ...],
|
||||
"occluded": np.concatenate(occs, axis=0)[np.newaxis, ...],
|
||||
"trackgroup": np.concatenate(trackgroups, axis=0)[np.newaxis, ...],
|
||||
}
|
||||
|
||||
|
||||
class TapVidDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_root,
|
||||
dataset_type="davis",
|
||||
resize_to_256=True,
|
||||
queried_first=True,
|
||||
):
|
||||
self.dataset_type = dataset_type
|
||||
self.resize_to_256 = resize_to_256
|
||||
self.queried_first = queried_first
|
||||
if self.dataset_type == "kinetics":
|
||||
all_paths = glob.glob(os.path.join(data_root, "*_of_0010.pkl"))
|
||||
points_dataset = []
|
||||
for pickle_path in all_paths:
|
||||
with open(pickle_path, "rb") as f:
|
||||
data = pickle.load(f)
|
||||
points_dataset = points_dataset + data
|
||||
self.points_dataset = points_dataset
|
||||
else:
|
||||
with open(data_root, "rb") as f:
|
||||
self.points_dataset = pickle.load(f)
|
||||
if self.dataset_type == "davis":
|
||||
self.video_names = list(self.points_dataset.keys())
|
||||
print("found %d unique videos in %s" % (len(self.points_dataset), data_root))
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.dataset_type == "davis":
|
||||
video_name = self.video_names[index]
|
||||
else:
|
||||
video_name = index
|
||||
video = self.points_dataset[video_name]
|
||||
frames = video["video"]
|
||||
|
||||
if isinstance(frames[0], bytes):
|
||||
# TAP-Vid is stored and JPEG bytes rather than `np.ndarray`s.
|
||||
def decode(frame):
|
||||
byteio = io.BytesIO(frame)
|
||||
img = Image.open(byteio)
|
||||
return np.array(img)
|
||||
|
||||
frames = np.array([decode(frame) for frame in frames])
|
||||
|
||||
target_points = self.points_dataset[video_name]["points"]
|
||||
if self.resize_to_256:
|
||||
frames = resize_video(frames, [256, 256])
|
||||
target_points *= np.array([255, 255]) # 1 should be mapped to 256-1
|
||||
else:
|
||||
target_points *= np.array([frames.shape[2] - 1, frames.shape[1] - 1])
|
||||
|
||||
target_occ = self.points_dataset[video_name]["occluded"]
|
||||
if self.queried_first:
|
||||
converted = sample_queries_first(target_occ, target_points, frames)
|
||||
else:
|
||||
converted = sample_queries_strided(target_occ, target_points, frames)
|
||||
assert converted["target_points"].shape[1] == converted["query_points"].shape[1]
|
||||
|
||||
trajs = torch.from_numpy(converted["target_points"])[0].permute(1, 0, 2).float() # T, N, D
|
||||
|
||||
rgbs = torch.from_numpy(frames).permute(0, 3, 1, 2).float()
|
||||
visibles = torch.logical_not(torch.from_numpy(converted["occluded"]))[0].permute(
|
||||
1, 0
|
||||
) # T, N
|
||||
query_points = torch.from_numpy(converted["query_points"])[0] # T, N
|
||||
return CoTrackerData(
|
||||
rgbs,
|
||||
trajs,
|
||||
visibles,
|
||||
seq_name=str(video_name),
|
||||
query_points=query_points,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.points_dataset)
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import io
|
||||
import glob
|
||||
import torch
|
||||
import pickle
|
||||
import numpy as np
|
||||
import mediapy as media
|
||||
|
||||
from PIL import Image
|
||||
from typing import Mapping, Tuple, Union
|
||||
|
||||
from cotracker.datasets.utils import CoTrackerData
|
||||
|
||||
DatasetElement = Mapping[str, Mapping[str, Union[np.ndarray, str]]]
|
||||
|
||||
|
||||
def resize_video(video: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
|
||||
"""Resize a video to output_size."""
|
||||
# If you have a GPU, consider replacing this with a GPU-enabled resize op,
|
||||
# such as a jitted jax.image.resize. It will make things faster.
|
||||
return media.resize_video(video, output_size)
|
||||
|
||||
|
||||
def sample_queries_first(
|
||||
target_occluded: np.ndarray,
|
||||
target_points: np.ndarray,
|
||||
frames: np.ndarray,
|
||||
) -> Mapping[str, np.ndarray]:
|
||||
"""Package a set of frames and tracks for use in TAPNet evaluations.
|
||||
Given a set of frames and tracks with no query points, use the first
|
||||
visible point in each track as the query.
|
||||
Args:
|
||||
target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
|
||||
where True indicates occluded.
|
||||
target_points: Position, of shape [n_tracks, n_frames, 2], where each point
|
||||
is [x,y] scaled between 0 and 1.
|
||||
frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
|
||||
-1 and 1.
|
||||
Returns:
|
||||
A dict with the keys:
|
||||
video: Video tensor of shape [1, n_frames, height, width, 3]
|
||||
query_points: Query points of shape [1, n_queries, 3] where
|
||||
each point is [t, y, x] scaled to the range [-1, 1]
|
||||
target_points: Target points of shape [1, n_queries, n_frames, 2] where
|
||||
each point is [x, y] scaled to the range [-1, 1]
|
||||
"""
|
||||
valid = np.sum(~target_occluded, axis=1) > 0
|
||||
target_points = target_points[valid, :]
|
||||
target_occluded = target_occluded[valid, :]
|
||||
|
||||
query_points = []
|
||||
for i in range(target_points.shape[0]):
|
||||
index = np.where(target_occluded[i] == 0)[0][0]
|
||||
x, y = target_points[i, index, 0], target_points[i, index, 1]
|
||||
query_points.append(np.array([index, y, x])) # [t, y, x]
|
||||
query_points = np.stack(query_points, axis=0)
|
||||
|
||||
return {
|
||||
"video": frames[np.newaxis, ...],
|
||||
"query_points": query_points[np.newaxis, ...],
|
||||
"target_points": target_points[np.newaxis, ...],
|
||||
"occluded": target_occluded[np.newaxis, ...],
|
||||
}
|
||||
|
||||
|
||||
def sample_queries_strided(
|
||||
target_occluded: np.ndarray,
|
||||
target_points: np.ndarray,
|
||||
frames: np.ndarray,
|
||||
query_stride: int = 5,
|
||||
) -> Mapping[str, np.ndarray]:
|
||||
"""Package a set of frames and tracks for use in TAPNet evaluations.
|
||||
|
||||
Given a set of frames and tracks with no query points, sample queries
|
||||
strided every query_stride frames, ignoring points that are not visible
|
||||
at the selected frames.
|
||||
|
||||
Args:
|
||||
target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
|
||||
where True indicates occluded.
|
||||
target_points: Position, of shape [n_tracks, n_frames, 2], where each point
|
||||
is [x,y] scaled between 0 and 1.
|
||||
frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
|
||||
-1 and 1.
|
||||
query_stride: When sampling query points, search for un-occluded points
|
||||
every query_stride frames and convert each one into a query.
|
||||
|
||||
Returns:
|
||||
A dict with the keys:
|
||||
video: Video tensor of shape [1, n_frames, height, width, 3]. The video
|
||||
has floats scaled to the range [-1, 1].
|
||||
query_points: Query points of shape [1, n_queries, 3] where
|
||||
each point is [t, y, x] scaled to the range [-1, 1].
|
||||
target_points: Target points of shape [1, n_queries, n_frames, 2] where
|
||||
each point is [x, y] scaled to the range [-1, 1].
|
||||
trackgroup: Index of the original track that each query point was
|
||||
sampled from. This is useful for visualization.
|
||||
"""
|
||||
tracks = []
|
||||
occs = []
|
||||
queries = []
|
||||
trackgroups = []
|
||||
total = 0
|
||||
trackgroup = np.arange(target_occluded.shape[0])
|
||||
for i in range(0, target_occluded.shape[1], query_stride):
|
||||
mask = target_occluded[:, i] == 0
|
||||
query = np.stack(
|
||||
[
|
||||
i * np.ones(target_occluded.shape[0:1]),
|
||||
target_points[:, i, 1],
|
||||
target_points[:, i, 0],
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
queries.append(query[mask])
|
||||
tracks.append(target_points[mask])
|
||||
occs.append(target_occluded[mask])
|
||||
trackgroups.append(trackgroup[mask])
|
||||
total += np.array(np.sum(target_occluded[:, i] == 0))
|
||||
|
||||
return {
|
||||
"video": frames[np.newaxis, ...],
|
||||
"query_points": np.concatenate(queries, axis=0)[np.newaxis, ...],
|
||||
"target_points": np.concatenate(tracks, axis=0)[np.newaxis, ...],
|
||||
"occluded": np.concatenate(occs, axis=0)[np.newaxis, ...],
|
||||
"trackgroup": np.concatenate(trackgroups, axis=0)[np.newaxis, ...],
|
||||
}
|
||||
|
||||
|
||||
class TapVidDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_root,
|
||||
dataset_type="davis",
|
||||
resize_to_256=True,
|
||||
queried_first=True,
|
||||
):
|
||||
self.dataset_type = dataset_type
|
||||
self.resize_to_256 = resize_to_256
|
||||
self.queried_first = queried_first
|
||||
if self.dataset_type == "kinetics":
|
||||
all_paths = glob.glob(os.path.join(data_root, "*_of_0010.pkl"))
|
||||
points_dataset = []
|
||||
for pickle_path in all_paths:
|
||||
with open(pickle_path, "rb") as f:
|
||||
data = pickle.load(f)
|
||||
points_dataset = points_dataset + data
|
||||
self.points_dataset = points_dataset
|
||||
else:
|
||||
with open(data_root, "rb") as f:
|
||||
self.points_dataset = pickle.load(f)
|
||||
if self.dataset_type == "davis":
|
||||
self.video_names = list(self.points_dataset.keys())
|
||||
print("found %d unique videos in %s" % (len(self.points_dataset), data_root))
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.dataset_type == "davis":
|
||||
video_name = self.video_names[index]
|
||||
else:
|
||||
video_name = index
|
||||
video = self.points_dataset[video_name]
|
||||
frames = video["video"]
|
||||
|
||||
if isinstance(frames[0], bytes):
|
||||
# TAP-Vid is stored and JPEG bytes rather than `np.ndarray`s.
|
||||
def decode(frame):
|
||||
byteio = io.BytesIO(frame)
|
||||
img = Image.open(byteio)
|
||||
return np.array(img)
|
||||
|
||||
frames = np.array([decode(frame) for frame in frames])
|
||||
|
||||
target_points = self.points_dataset[video_name]["points"]
|
||||
if self.resize_to_256:
|
||||
frames = resize_video(frames, [256, 256])
|
||||
target_points *= np.array([255, 255]) # 1 should be mapped to 256-1
|
||||
else:
|
||||
target_points *= np.array([frames.shape[2] - 1, frames.shape[1] - 1])
|
||||
|
||||
target_occ = self.points_dataset[video_name]["occluded"]
|
||||
if self.queried_first:
|
||||
converted = sample_queries_first(target_occ, target_points, frames)
|
||||
else:
|
||||
converted = sample_queries_strided(target_occ, target_points, frames)
|
||||
assert converted["target_points"].shape[1] == converted["query_points"].shape[1]
|
||||
|
||||
trajs = torch.from_numpy(converted["target_points"])[0].permute(1, 0, 2).float() # T, N, D
|
||||
|
||||
rgbs = torch.from_numpy(frames).permute(0, 3, 1, 2).float()
|
||||
visibles = torch.logical_not(torch.from_numpy(converted["occluded"]))[0].permute(
|
||||
1, 0
|
||||
) # T, N
|
||||
query_points = torch.from_numpy(converted["query_points"])[0] # T, N
|
||||
return CoTrackerData(
|
||||
rgbs,
|
||||
trajs,
|
||||
visibles,
|
||||
seq_name=str(video_name),
|
||||
query_points=query_points,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.points_dataset)
|
||||
|
@@ -1,106 +1,106 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import torch
|
||||
import dataclasses
|
||||
import torch.nn.functional as F
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class CoTrackerData:
|
||||
"""
|
||||
Dataclass for storing video tracks data.
|
||||
"""
|
||||
|
||||
video: torch.Tensor # B, S, C, H, W
|
||||
trajectory: torch.Tensor # B, S, N, 2
|
||||
visibility: torch.Tensor # B, S, N
|
||||
# optional data
|
||||
valid: Optional[torch.Tensor] = None # B, S, N
|
||||
segmentation: Optional[torch.Tensor] = None # B, S, 1, H, W
|
||||
seq_name: Optional[str] = None
|
||||
query_points: Optional[torch.Tensor] = None # TapVID evaluation format
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
"""
|
||||
Collate function for video tracks data.
|
||||
"""
|
||||
video = torch.stack([b.video for b in batch], dim=0)
|
||||
trajectory = torch.stack([b.trajectory for b in batch], dim=0)
|
||||
visibility = torch.stack([b.visibility for b in batch], dim=0)
|
||||
query_points = segmentation = None
|
||||
if batch[0].query_points is not None:
|
||||
query_points = torch.stack([b.query_points for b in batch], dim=0)
|
||||
if batch[0].segmentation is not None:
|
||||
segmentation = torch.stack([b.segmentation for b in batch], dim=0)
|
||||
seq_name = [b.seq_name for b in batch]
|
||||
|
||||
return CoTrackerData(
|
||||
video=video,
|
||||
trajectory=trajectory,
|
||||
visibility=visibility,
|
||||
segmentation=segmentation,
|
||||
seq_name=seq_name,
|
||||
query_points=query_points,
|
||||
)
|
||||
|
||||
|
||||
def collate_fn_train(batch):
|
||||
"""
|
||||
Collate function for video tracks data during training.
|
||||
"""
|
||||
gotit = [gotit for _, gotit in batch]
|
||||
video = torch.stack([b.video for b, _ in batch], dim=0)
|
||||
trajectory = torch.stack([b.trajectory for b, _ in batch], dim=0)
|
||||
visibility = torch.stack([b.visibility for b, _ in batch], dim=0)
|
||||
valid = torch.stack([b.valid for b, _ in batch], dim=0)
|
||||
seq_name = [b.seq_name for b, _ in batch]
|
||||
return (
|
||||
CoTrackerData(
|
||||
video=video,
|
||||
trajectory=trajectory,
|
||||
visibility=visibility,
|
||||
valid=valid,
|
||||
seq_name=seq_name,
|
||||
),
|
||||
gotit,
|
||||
)
|
||||
|
||||
|
||||
def try_to_cuda(t: Any) -> Any:
|
||||
"""
|
||||
Try to move the input variable `t` to a cuda device.
|
||||
|
||||
Args:
|
||||
t: Input.
|
||||
|
||||
Returns:
|
||||
t_cuda: `t` moved to a cuda device, if supported.
|
||||
"""
|
||||
try:
|
||||
t = t.float().cuda()
|
||||
except AttributeError:
|
||||
pass
|
||||
return t
|
||||
|
||||
|
||||
def dataclass_to_cuda_(obj):
|
||||
"""
|
||||
Move all contents of a dataclass to cuda inplace if supported.
|
||||
|
||||
Args:
|
||||
batch: Input dataclass.
|
||||
|
||||
Returns:
|
||||
batch_cuda: `batch` moved to a cuda device, if supported.
|
||||
"""
|
||||
for f in dataclasses.fields(obj):
|
||||
setattr(obj, f.name, try_to_cuda(getattr(obj, f.name)))
|
||||
return obj
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import torch
|
||||
import dataclasses
|
||||
import torch.nn.functional as F
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class CoTrackerData:
|
||||
"""
|
||||
Dataclass for storing video tracks data.
|
||||
"""
|
||||
|
||||
video: torch.Tensor # B, S, C, H, W
|
||||
trajectory: torch.Tensor # B, S, N, 2
|
||||
visibility: torch.Tensor # B, S, N
|
||||
# optional data
|
||||
valid: Optional[torch.Tensor] = None # B, S, N
|
||||
segmentation: Optional[torch.Tensor] = None # B, S, 1, H, W
|
||||
seq_name: Optional[str] = None
|
||||
query_points: Optional[torch.Tensor] = None # TapVID evaluation format
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
"""
|
||||
Collate function for video tracks data.
|
||||
"""
|
||||
video = torch.stack([b.video for b in batch], dim=0)
|
||||
trajectory = torch.stack([b.trajectory for b in batch], dim=0)
|
||||
visibility = torch.stack([b.visibility for b in batch], dim=0)
|
||||
query_points = segmentation = None
|
||||
if batch[0].query_points is not None:
|
||||
query_points = torch.stack([b.query_points for b in batch], dim=0)
|
||||
if batch[0].segmentation is not None:
|
||||
segmentation = torch.stack([b.segmentation for b in batch], dim=0)
|
||||
seq_name = [b.seq_name for b in batch]
|
||||
|
||||
return CoTrackerData(
|
||||
video=video,
|
||||
trajectory=trajectory,
|
||||
visibility=visibility,
|
||||
segmentation=segmentation,
|
||||
seq_name=seq_name,
|
||||
query_points=query_points,
|
||||
)
|
||||
|
||||
|
||||
def collate_fn_train(batch):
|
||||
"""
|
||||
Collate function for video tracks data during training.
|
||||
"""
|
||||
gotit = [gotit for _, gotit in batch]
|
||||
video = torch.stack([b.video for b, _ in batch], dim=0)
|
||||
trajectory = torch.stack([b.trajectory for b, _ in batch], dim=0)
|
||||
visibility = torch.stack([b.visibility for b, _ in batch], dim=0)
|
||||
valid = torch.stack([b.valid for b, _ in batch], dim=0)
|
||||
seq_name = [b.seq_name for b, _ in batch]
|
||||
return (
|
||||
CoTrackerData(
|
||||
video=video,
|
||||
trajectory=trajectory,
|
||||
visibility=visibility,
|
||||
valid=valid,
|
||||
seq_name=seq_name,
|
||||
),
|
||||
gotit,
|
||||
)
|
||||
|
||||
|
||||
def try_to_cuda(t: Any) -> Any:
|
||||
"""
|
||||
Try to move the input variable `t` to a cuda device.
|
||||
|
||||
Args:
|
||||
t: Input.
|
||||
|
||||
Returns:
|
||||
t_cuda: `t` moved to a cuda device, if supported.
|
||||
"""
|
||||
try:
|
||||
t = t.float().cuda()
|
||||
except AttributeError:
|
||||
pass
|
||||
return t
|
||||
|
||||
|
||||
def dataclass_to_cuda_(obj):
|
||||
"""
|
||||
Move all contents of a dataclass to cuda inplace if supported.
|
||||
|
||||
Args:
|
||||
batch: Input dataclass.
|
||||
|
||||
Returns:
|
||||
batch_cuda: `batch` moved to a cuda device, if supported.
|
||||
"""
|
||||
for f in dataclasses.fields(obj):
|
||||
setattr(obj, f.name, try_to_cuda(getattr(obj, f.name)))
|
||||
return obj
|
||||
|
@@ -1,5 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
@@ -1,6 +1,6 @@
|
||||
defaults:
|
||||
- default_config_eval
|
||||
exp_dir: ./outputs/cotracker
|
||||
dataset_name: dynamic_replica
|
||||
|
||||
defaults:
|
||||
- default_config_eval
|
||||
exp_dir: ./outputs/cotracker
|
||||
dataset_name: dynamic_replica
|
||||
|
||||
|
@@ -1,6 +1,6 @@
|
||||
defaults:
|
||||
- default_config_eval
|
||||
exp_dir: ./outputs/cotracker
|
||||
dataset_name: tapvid_davis_first
|
||||
|
||||
defaults:
|
||||
- default_config_eval
|
||||
exp_dir: ./outputs/cotracker
|
||||
dataset_name: tapvid_davis_first
|
||||
|
||||
|
@@ -1,6 +1,6 @@
|
||||
defaults:
|
||||
- default_config_eval
|
||||
exp_dir: ./outputs/cotracker
|
||||
dataset_name: tapvid_davis_strided
|
||||
|
||||
defaults:
|
||||
- default_config_eval
|
||||
exp_dir: ./outputs/cotracker
|
||||
dataset_name: tapvid_davis_strided
|
||||
|
||||
|
@@ -1,6 +1,6 @@
|
||||
defaults:
|
||||
- default_config_eval
|
||||
exp_dir: ./outputs/cotracker
|
||||
dataset_name: tapvid_kinetics_first
|
||||
|
||||
defaults:
|
||||
- default_config_eval
|
||||
exp_dir: ./outputs/cotracker
|
||||
dataset_name: tapvid_kinetics_first
|
||||
|
||||
|
@@ -1,5 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
@@ -1,138 +1,138 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from typing import Iterable, Mapping, Tuple, Union
|
||||
|
||||
|
||||
def compute_tapvid_metrics(
|
||||
query_points: np.ndarray,
|
||||
gt_occluded: np.ndarray,
|
||||
gt_tracks: np.ndarray,
|
||||
pred_occluded: np.ndarray,
|
||||
pred_tracks: np.ndarray,
|
||||
query_mode: str,
|
||||
) -> Mapping[str, np.ndarray]:
|
||||
"""Computes TAP-Vid metrics (Jaccard, Pts. Within Thresh, Occ. Acc.)
|
||||
See the TAP-Vid paper for details on the metric computation. All inputs are
|
||||
given in raster coordinates. The first three arguments should be the direct
|
||||
outputs of the reader: the 'query_points', 'occluded', and 'target_points'.
|
||||
The paper metrics assume these are scaled relative to 256x256 images.
|
||||
pred_occluded and pred_tracks are your algorithm's predictions.
|
||||
This function takes a batch of inputs, and computes metrics separately for
|
||||
each video. The metrics for the full benchmark are a simple mean of the
|
||||
metrics across the full set of videos. These numbers are between 0 and 1,
|
||||
but the paper multiplies them by 100 to ease reading.
|
||||
Args:
|
||||
query_points: The query points, an in the format [t, y, x]. Its size is
|
||||
[b, n, 3], where b is the batch size and n is the number of queries
|
||||
gt_occluded: A boolean array of shape [b, n, t], where t is the number
|
||||
of frames. True indicates that the point is occluded.
|
||||
gt_tracks: The target points, of shape [b, n, t, 2]. Each point is
|
||||
in the format [x, y]
|
||||
pred_occluded: A boolean array of predicted occlusions, in the same
|
||||
format as gt_occluded.
|
||||
pred_tracks: An array of track predictions from your algorithm, in the
|
||||
same format as gt_tracks.
|
||||
query_mode: Either 'first' or 'strided', depending on how queries are
|
||||
sampled. If 'first', we assume the prior knowledge that all points
|
||||
before the query point are occluded, and these are removed from the
|
||||
evaluation.
|
||||
Returns:
|
||||
A dict with the following keys:
|
||||
occlusion_accuracy: Accuracy at predicting occlusion.
|
||||
pts_within_{x} for x in [1, 2, 4, 8, 16]: Fraction of points
|
||||
predicted to be within the given pixel threshold, ignoring occlusion
|
||||
prediction.
|
||||
jaccard_{x} for x in [1, 2, 4, 8, 16]: Jaccard metric for the given
|
||||
threshold
|
||||
average_pts_within_thresh: average across pts_within_{x}
|
||||
average_jaccard: average across jaccard_{x}
|
||||
"""
|
||||
|
||||
metrics = {}
|
||||
# Fixed bug is described in:
|
||||
# https://github.com/facebookresearch/co-tracker/issues/20
|
||||
eye = np.eye(gt_tracks.shape[2], dtype=np.int32)
|
||||
|
||||
if query_mode == "first":
|
||||
# evaluate frames after the query frame
|
||||
query_frame_to_eval_frames = np.cumsum(eye, axis=1) - eye
|
||||
elif query_mode == "strided":
|
||||
# evaluate all frames except the query frame
|
||||
query_frame_to_eval_frames = 1 - eye
|
||||
else:
|
||||
raise ValueError("Unknown query mode " + query_mode)
|
||||
|
||||
query_frame = query_points[..., 0]
|
||||
query_frame = np.round(query_frame).astype(np.int32)
|
||||
evaluation_points = query_frame_to_eval_frames[query_frame] > 0
|
||||
|
||||
# Occlusion accuracy is simply how often the predicted occlusion equals the
|
||||
# ground truth.
|
||||
occ_acc = np.sum(
|
||||
np.equal(pred_occluded, gt_occluded) & evaluation_points,
|
||||
axis=(1, 2),
|
||||
) / np.sum(evaluation_points)
|
||||
metrics["occlusion_accuracy"] = occ_acc
|
||||
|
||||
# Next, convert the predictions and ground truth positions into pixel
|
||||
# coordinates.
|
||||
visible = np.logical_not(gt_occluded)
|
||||
pred_visible = np.logical_not(pred_occluded)
|
||||
all_frac_within = []
|
||||
all_jaccard = []
|
||||
for thresh in [1, 2, 4, 8, 16]:
|
||||
# True positives are points that are within the threshold and where both
|
||||
# the prediction and the ground truth are listed as visible.
|
||||
within_dist = np.sum(
|
||||
np.square(pred_tracks - gt_tracks),
|
||||
axis=-1,
|
||||
) < np.square(thresh)
|
||||
is_correct = np.logical_and(within_dist, visible)
|
||||
|
||||
# Compute the frac_within_threshold, which is the fraction of points
|
||||
# within the threshold among points that are visible in the ground truth,
|
||||
# ignoring whether they're predicted to be visible.
|
||||
count_correct = np.sum(
|
||||
is_correct & evaluation_points,
|
||||
axis=(1, 2),
|
||||
)
|
||||
count_visible_points = np.sum(visible & evaluation_points, axis=(1, 2))
|
||||
frac_correct = count_correct / count_visible_points
|
||||
metrics["pts_within_" + str(thresh)] = frac_correct
|
||||
all_frac_within.append(frac_correct)
|
||||
|
||||
true_positives = np.sum(
|
||||
is_correct & pred_visible & evaluation_points, axis=(1, 2)
|
||||
)
|
||||
|
||||
# The denominator of the jaccard metric is the true positives plus
|
||||
# false positives plus false negatives. However, note that true positives
|
||||
# plus false negatives is simply the number of points in the ground truth
|
||||
# which is easier to compute than trying to compute all three quantities.
|
||||
# Thus we just add the number of points in the ground truth to the number
|
||||
# of false positives.
|
||||
#
|
||||
# False positives are simply points that are predicted to be visible,
|
||||
# but the ground truth is not visible or too far from the prediction.
|
||||
gt_positives = np.sum(visible & evaluation_points, axis=(1, 2))
|
||||
false_positives = (~visible) & pred_visible
|
||||
false_positives = false_positives | ((~within_dist) & pred_visible)
|
||||
false_positives = np.sum(false_positives & evaluation_points, axis=(1, 2))
|
||||
jaccard = true_positives / (gt_positives + false_positives)
|
||||
metrics["jaccard_" + str(thresh)] = jaccard
|
||||
all_jaccard.append(jaccard)
|
||||
metrics["average_jaccard"] = np.mean(
|
||||
np.stack(all_jaccard, axis=1),
|
||||
axis=1,
|
||||
)
|
||||
metrics["average_pts_within_thresh"] = np.mean(
|
||||
np.stack(all_frac_within, axis=1),
|
||||
axis=1,
|
||||
)
|
||||
return metrics
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from typing import Iterable, Mapping, Tuple, Union
|
||||
|
||||
|
||||
def compute_tapvid_metrics(
|
||||
query_points: np.ndarray,
|
||||
gt_occluded: np.ndarray,
|
||||
gt_tracks: np.ndarray,
|
||||
pred_occluded: np.ndarray,
|
||||
pred_tracks: np.ndarray,
|
||||
query_mode: str,
|
||||
) -> Mapping[str, np.ndarray]:
|
||||
"""Computes TAP-Vid metrics (Jaccard, Pts. Within Thresh, Occ. Acc.)
|
||||
See the TAP-Vid paper for details on the metric computation. All inputs are
|
||||
given in raster coordinates. The first three arguments should be the direct
|
||||
outputs of the reader: the 'query_points', 'occluded', and 'target_points'.
|
||||
The paper metrics assume these are scaled relative to 256x256 images.
|
||||
pred_occluded and pred_tracks are your algorithm's predictions.
|
||||
This function takes a batch of inputs, and computes metrics separately for
|
||||
each video. The metrics for the full benchmark are a simple mean of the
|
||||
metrics across the full set of videos. These numbers are between 0 and 1,
|
||||
but the paper multiplies them by 100 to ease reading.
|
||||
Args:
|
||||
query_points: The query points, an in the format [t, y, x]. Its size is
|
||||
[b, n, 3], where b is the batch size and n is the number of queries
|
||||
gt_occluded: A boolean array of shape [b, n, t], where t is the number
|
||||
of frames. True indicates that the point is occluded.
|
||||
gt_tracks: The target points, of shape [b, n, t, 2]. Each point is
|
||||
in the format [x, y]
|
||||
pred_occluded: A boolean array of predicted occlusions, in the same
|
||||
format as gt_occluded.
|
||||
pred_tracks: An array of track predictions from your algorithm, in the
|
||||
same format as gt_tracks.
|
||||
query_mode: Either 'first' or 'strided', depending on how queries are
|
||||
sampled. If 'first', we assume the prior knowledge that all points
|
||||
before the query point are occluded, and these are removed from the
|
||||
evaluation.
|
||||
Returns:
|
||||
A dict with the following keys:
|
||||
occlusion_accuracy: Accuracy at predicting occlusion.
|
||||
pts_within_{x} for x in [1, 2, 4, 8, 16]: Fraction of points
|
||||
predicted to be within the given pixel threshold, ignoring occlusion
|
||||
prediction.
|
||||
jaccard_{x} for x in [1, 2, 4, 8, 16]: Jaccard metric for the given
|
||||
threshold
|
||||
average_pts_within_thresh: average across pts_within_{x}
|
||||
average_jaccard: average across jaccard_{x}
|
||||
"""
|
||||
|
||||
metrics = {}
|
||||
# Fixed bug is described in:
|
||||
# https://github.com/facebookresearch/co-tracker/issues/20
|
||||
eye = np.eye(gt_tracks.shape[2], dtype=np.int32)
|
||||
|
||||
if query_mode == "first":
|
||||
# evaluate frames after the query frame
|
||||
query_frame_to_eval_frames = np.cumsum(eye, axis=1) - eye
|
||||
elif query_mode == "strided":
|
||||
# evaluate all frames except the query frame
|
||||
query_frame_to_eval_frames = 1 - eye
|
||||
else:
|
||||
raise ValueError("Unknown query mode " + query_mode)
|
||||
|
||||
query_frame = query_points[..., 0]
|
||||
query_frame = np.round(query_frame).astype(np.int32)
|
||||
evaluation_points = query_frame_to_eval_frames[query_frame] > 0
|
||||
|
||||
# Occlusion accuracy is simply how often the predicted occlusion equals the
|
||||
# ground truth.
|
||||
occ_acc = np.sum(
|
||||
np.equal(pred_occluded, gt_occluded) & evaluation_points,
|
||||
axis=(1, 2),
|
||||
) / np.sum(evaluation_points)
|
||||
metrics["occlusion_accuracy"] = occ_acc
|
||||
|
||||
# Next, convert the predictions and ground truth positions into pixel
|
||||
# coordinates.
|
||||
visible = np.logical_not(gt_occluded)
|
||||
pred_visible = np.logical_not(pred_occluded)
|
||||
all_frac_within = []
|
||||
all_jaccard = []
|
||||
for thresh in [1, 2, 4, 8, 16]:
|
||||
# True positives are points that are within the threshold and where both
|
||||
# the prediction and the ground truth are listed as visible.
|
||||
within_dist = np.sum(
|
||||
np.square(pred_tracks - gt_tracks),
|
||||
axis=-1,
|
||||
) < np.square(thresh)
|
||||
is_correct = np.logical_and(within_dist, visible)
|
||||
|
||||
# Compute the frac_within_threshold, which is the fraction of points
|
||||
# within the threshold among points that are visible in the ground truth,
|
||||
# ignoring whether they're predicted to be visible.
|
||||
count_correct = np.sum(
|
||||
is_correct & evaluation_points,
|
||||
axis=(1, 2),
|
||||
)
|
||||
count_visible_points = np.sum(visible & evaluation_points, axis=(1, 2))
|
||||
frac_correct = count_correct / count_visible_points
|
||||
metrics["pts_within_" + str(thresh)] = frac_correct
|
||||
all_frac_within.append(frac_correct)
|
||||
|
||||
true_positives = np.sum(
|
||||
is_correct & pred_visible & evaluation_points, axis=(1, 2)
|
||||
)
|
||||
|
||||
# The denominator of the jaccard metric is the true positives plus
|
||||
# false positives plus false negatives. However, note that true positives
|
||||
# plus false negatives is simply the number of points in the ground truth
|
||||
# which is easier to compute than trying to compute all three quantities.
|
||||
# Thus we just add the number of points in the ground truth to the number
|
||||
# of false positives.
|
||||
#
|
||||
# False positives are simply points that are predicted to be visible,
|
||||
# but the ground truth is not visible or too far from the prediction.
|
||||
gt_positives = np.sum(visible & evaluation_points, axis=(1, 2))
|
||||
false_positives = (~visible) & pred_visible
|
||||
false_positives = false_positives | ((~within_dist) & pred_visible)
|
||||
false_positives = np.sum(false_positives & evaluation_points, axis=(1, 2))
|
||||
jaccard = true_positives / (gt_positives + false_positives)
|
||||
metrics["jaccard_" + str(thresh)] = jaccard
|
||||
all_jaccard.append(jaccard)
|
||||
metrics["average_jaccard"] = np.mean(
|
||||
np.stack(all_jaccard, axis=1),
|
||||
axis=1,
|
||||
)
|
||||
metrics["average_pts_within_thresh"] = np.mean(
|
||||
np.stack(all_frac_within, axis=1),
|
||||
axis=1,
|
||||
)
|
||||
return metrics
|
||||
|
@@ -1,253 +1,253 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from collections import defaultdict
|
||||
import os
|
||||
from typing import Optional
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from cotracker.datasets.utils import dataclass_to_cuda_
|
||||
from cotracker.utils.visualizer import Visualizer
|
||||
from cotracker.models.core.model_utils import reduce_masked_mean
|
||||
from cotracker.evaluation.core.eval_utils import compute_tapvid_metrics
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
class Evaluator:
|
||||
"""
|
||||
A class defining the CoTracker evaluator.
|
||||
"""
|
||||
|
||||
def __init__(self, exp_dir) -> None:
|
||||
# Visualization
|
||||
self.exp_dir = exp_dir
|
||||
os.makedirs(exp_dir, exist_ok=True)
|
||||
self.visualization_filepaths = defaultdict(lambda: defaultdict(list))
|
||||
self.visualize_dir = os.path.join(exp_dir, "visualisations")
|
||||
|
||||
def compute_metrics(self, metrics, sample, pred_trajectory, dataset_name):
|
||||
if isinstance(pred_trajectory, tuple):
|
||||
pred_trajectory, pred_visibility = pred_trajectory
|
||||
else:
|
||||
pred_visibility = None
|
||||
if "tapvid" in dataset_name:
|
||||
B, T, N, D = sample.trajectory.shape
|
||||
traj = sample.trajectory.clone()
|
||||
thr = 0.9
|
||||
|
||||
if pred_visibility is None:
|
||||
logging.warning("visibility is NONE")
|
||||
pred_visibility = torch.zeros_like(sample.visibility)
|
||||
|
||||
if not pred_visibility.dtype == torch.bool:
|
||||
pred_visibility = pred_visibility > thr
|
||||
|
||||
query_points = sample.query_points.clone().cpu().numpy()
|
||||
|
||||
pred_visibility = pred_visibility[:, :, :N]
|
||||
pred_trajectory = pred_trajectory[:, :, :N]
|
||||
|
||||
gt_tracks = traj.permute(0, 2, 1, 3).cpu().numpy()
|
||||
gt_occluded = (
|
||||
torch.logical_not(sample.visibility.clone().permute(0, 2, 1)).cpu().numpy()
|
||||
)
|
||||
|
||||
pred_occluded = (
|
||||
torch.logical_not(pred_visibility.clone().permute(0, 2, 1)).cpu().numpy()
|
||||
)
|
||||
pred_tracks = pred_trajectory.permute(0, 2, 1, 3).cpu().numpy()
|
||||
|
||||
out_metrics = compute_tapvid_metrics(
|
||||
query_points,
|
||||
gt_occluded,
|
||||
gt_tracks,
|
||||
pred_occluded,
|
||||
pred_tracks,
|
||||
query_mode="strided" if "strided" in dataset_name else "first",
|
||||
)
|
||||
|
||||
metrics[sample.seq_name[0]] = out_metrics
|
||||
for metric_name in out_metrics.keys():
|
||||
if "avg" not in metrics:
|
||||
metrics["avg"] = {}
|
||||
metrics["avg"][metric_name] = np.mean(
|
||||
[v[metric_name] for k, v in metrics.items() if k != "avg"]
|
||||
)
|
||||
|
||||
logging.info(f"Metrics: {out_metrics}")
|
||||
logging.info(f"avg: {metrics['avg']}")
|
||||
print("metrics", out_metrics)
|
||||
print("avg", metrics["avg"])
|
||||
elif dataset_name == "dynamic_replica" or dataset_name == "pointodyssey":
|
||||
*_, N, _ = sample.trajectory.shape
|
||||
B, T, N = sample.visibility.shape
|
||||
H, W = sample.video.shape[-2:]
|
||||
device = sample.video.device
|
||||
|
||||
out_metrics = {}
|
||||
|
||||
d_vis_sum = d_occ_sum = d_sum_all = 0.0
|
||||
thrs = [1, 2, 4, 8, 16]
|
||||
sx_ = (W - 1) / 255.0
|
||||
sy_ = (H - 1) / 255.0
|
||||
sc_py = np.array([sx_, sy_]).reshape([1, 1, 2])
|
||||
sc_pt = torch.from_numpy(sc_py).float().to(device)
|
||||
__, first_visible_inds = torch.max(sample.visibility, dim=1)
|
||||
|
||||
frame_ids_tensor = torch.arange(T, device=device)[None, :, None].repeat(B, 1, N)
|
||||
start_tracking_mask = frame_ids_tensor > (first_visible_inds.unsqueeze(1))
|
||||
|
||||
for thr in thrs:
|
||||
d_ = (
|
||||
torch.norm(
|
||||
pred_trajectory[..., :2] / sc_pt - sample.trajectory[..., :2] / sc_pt,
|
||||
dim=-1,
|
||||
)
|
||||
< thr
|
||||
).float() # B,S-1,N
|
||||
d_occ = (
|
||||
reduce_masked_mean(d_, (1 - sample.visibility) * start_tracking_mask).item()
|
||||
* 100.0
|
||||
)
|
||||
d_occ_sum += d_occ
|
||||
out_metrics[f"accuracy_occ_{thr}"] = d_occ
|
||||
|
||||
d_vis = (
|
||||
reduce_masked_mean(d_, sample.visibility * start_tracking_mask).item() * 100.0
|
||||
)
|
||||
d_vis_sum += d_vis
|
||||
out_metrics[f"accuracy_vis_{thr}"] = d_vis
|
||||
|
||||
d_all = reduce_masked_mean(d_, start_tracking_mask).item() * 100.0
|
||||
d_sum_all += d_all
|
||||
out_metrics[f"accuracy_{thr}"] = d_all
|
||||
|
||||
d_occ_avg = d_occ_sum / len(thrs)
|
||||
d_vis_avg = d_vis_sum / len(thrs)
|
||||
d_all_avg = d_sum_all / len(thrs)
|
||||
|
||||
sur_thr = 50
|
||||
dists = torch.norm(
|
||||
pred_trajectory[..., :2] / sc_pt - sample.trajectory[..., :2] / sc_pt,
|
||||
dim=-1,
|
||||
) # B,S,N
|
||||
dist_ok = 1 - (dists > sur_thr).float() * sample.visibility # B,S,N
|
||||
survival = torch.cumprod(dist_ok, dim=1) # B,S,N
|
||||
out_metrics["survival"] = torch.mean(survival).item() * 100.0
|
||||
|
||||
out_metrics["accuracy_occ"] = d_occ_avg
|
||||
out_metrics["accuracy_vis"] = d_vis_avg
|
||||
out_metrics["accuracy"] = d_all_avg
|
||||
|
||||
metrics[sample.seq_name[0]] = out_metrics
|
||||
for metric_name in out_metrics.keys():
|
||||
if "avg" not in metrics:
|
||||
metrics["avg"] = {}
|
||||
metrics["avg"][metric_name] = float(
|
||||
np.mean([v[metric_name] for k, v in metrics.items() if k != "avg"])
|
||||
)
|
||||
|
||||
logging.info(f"Metrics: {out_metrics}")
|
||||
logging.info(f"avg: {metrics['avg']}")
|
||||
print("metrics", out_metrics)
|
||||
print("avg", metrics["avg"])
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate_sequence(
|
||||
self,
|
||||
model,
|
||||
test_dataloader: torch.utils.data.DataLoader,
|
||||
dataset_name: str,
|
||||
train_mode=False,
|
||||
visualize_every: int = 1,
|
||||
writer: Optional[SummaryWriter] = None,
|
||||
step: Optional[int] = 0,
|
||||
):
|
||||
metrics = {}
|
||||
|
||||
vis = Visualizer(
|
||||
save_dir=self.exp_dir,
|
||||
fps=7,
|
||||
)
|
||||
|
||||
for ind, sample in enumerate(tqdm(test_dataloader)):
|
||||
if isinstance(sample, tuple):
|
||||
sample, gotit = sample
|
||||
if not all(gotit):
|
||||
print("batch is None")
|
||||
continue
|
||||
if torch.cuda.is_available():
|
||||
dataclass_to_cuda_(sample)
|
||||
device = torch.device("cuda")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
if (
|
||||
not train_mode
|
||||
and hasattr(model, "sequence_len")
|
||||
and (sample.visibility[:, : model.sequence_len].sum() == 0)
|
||||
):
|
||||
print(f"skipping batch {ind}")
|
||||
continue
|
||||
|
||||
if "tapvid" in dataset_name:
|
||||
queries = sample.query_points.clone().float()
|
||||
|
||||
queries = torch.stack(
|
||||
[
|
||||
queries[:, :, 0],
|
||||
queries[:, :, 2],
|
||||
queries[:, :, 1],
|
||||
],
|
||||
dim=2,
|
||||
).to(device)
|
||||
else:
|
||||
queries = torch.cat(
|
||||
[
|
||||
torch.zeros_like(sample.trajectory[:, 0, :, :1]),
|
||||
sample.trajectory[:, 0],
|
||||
],
|
||||
dim=2,
|
||||
).to(device)
|
||||
|
||||
pred_tracks = model(sample.video, queries)
|
||||
if "strided" in dataset_name:
|
||||
inv_video = sample.video.flip(1).clone()
|
||||
inv_queries = queries.clone()
|
||||
inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1
|
||||
|
||||
pred_trj, pred_vsb = pred_tracks
|
||||
inv_pred_trj, inv_pred_vsb = model(inv_video, inv_queries)
|
||||
|
||||
inv_pred_trj = inv_pred_trj.flip(1)
|
||||
inv_pred_vsb = inv_pred_vsb.flip(1)
|
||||
|
||||
mask = pred_trj == 0
|
||||
|
||||
pred_trj[mask] = inv_pred_trj[mask]
|
||||
pred_vsb[mask[:, :, :, 0]] = inv_pred_vsb[mask[:, :, :, 0]]
|
||||
|
||||
pred_tracks = pred_trj, pred_vsb
|
||||
|
||||
if dataset_name == "badja" or dataset_name == "fastcapture":
|
||||
seq_name = sample.seq_name[0]
|
||||
else:
|
||||
seq_name = str(ind)
|
||||
if ind % visualize_every == 0:
|
||||
vis.visualize(
|
||||
sample.video,
|
||||
pred_tracks[0] if isinstance(pred_tracks, tuple) else pred_tracks,
|
||||
filename=dataset_name + "_" + seq_name,
|
||||
writer=writer,
|
||||
step=step,
|
||||
)
|
||||
|
||||
self.compute_metrics(metrics, sample, pred_tracks, dataset_name)
|
||||
return metrics
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from collections import defaultdict
|
||||
import os
|
||||
from typing import Optional
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from cotracker.datasets.utils import dataclass_to_cuda_
|
||||
from cotracker.utils.visualizer import Visualizer
|
||||
from cotracker.models.core.model_utils import reduce_masked_mean
|
||||
from cotracker.evaluation.core.eval_utils import compute_tapvid_metrics
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
class Evaluator:
|
||||
"""
|
||||
A class defining the CoTracker evaluator.
|
||||
"""
|
||||
|
||||
def __init__(self, exp_dir) -> None:
|
||||
# Visualization
|
||||
self.exp_dir = exp_dir
|
||||
os.makedirs(exp_dir, exist_ok=True)
|
||||
self.visualization_filepaths = defaultdict(lambda: defaultdict(list))
|
||||
self.visualize_dir = os.path.join(exp_dir, "visualisations")
|
||||
|
||||
def compute_metrics(self, metrics, sample, pred_trajectory, dataset_name):
|
||||
if isinstance(pred_trajectory, tuple):
|
||||
pred_trajectory, pred_visibility = pred_trajectory
|
||||
else:
|
||||
pred_visibility = None
|
||||
if "tapvid" in dataset_name:
|
||||
B, T, N, D = sample.trajectory.shape
|
||||
traj = sample.trajectory.clone()
|
||||
thr = 0.9
|
||||
|
||||
if pred_visibility is None:
|
||||
logging.warning("visibility is NONE")
|
||||
pred_visibility = torch.zeros_like(sample.visibility)
|
||||
|
||||
if not pred_visibility.dtype == torch.bool:
|
||||
pred_visibility = pred_visibility > thr
|
||||
|
||||
query_points = sample.query_points.clone().cpu().numpy()
|
||||
|
||||
pred_visibility = pred_visibility[:, :, :N]
|
||||
pred_trajectory = pred_trajectory[:, :, :N]
|
||||
|
||||
gt_tracks = traj.permute(0, 2, 1, 3).cpu().numpy()
|
||||
gt_occluded = (
|
||||
torch.logical_not(sample.visibility.clone().permute(0, 2, 1)).cpu().numpy()
|
||||
)
|
||||
|
||||
pred_occluded = (
|
||||
torch.logical_not(pred_visibility.clone().permute(0, 2, 1)).cpu().numpy()
|
||||
)
|
||||
pred_tracks = pred_trajectory.permute(0, 2, 1, 3).cpu().numpy()
|
||||
|
||||
out_metrics = compute_tapvid_metrics(
|
||||
query_points,
|
||||
gt_occluded,
|
||||
gt_tracks,
|
||||
pred_occluded,
|
||||
pred_tracks,
|
||||
query_mode="strided" if "strided" in dataset_name else "first",
|
||||
)
|
||||
|
||||
metrics[sample.seq_name[0]] = out_metrics
|
||||
for metric_name in out_metrics.keys():
|
||||
if "avg" not in metrics:
|
||||
metrics["avg"] = {}
|
||||
metrics["avg"][metric_name] = np.mean(
|
||||
[v[metric_name] for k, v in metrics.items() if k != "avg"]
|
||||
)
|
||||
|
||||
logging.info(f"Metrics: {out_metrics}")
|
||||
logging.info(f"avg: {metrics['avg']}")
|
||||
print("metrics", out_metrics)
|
||||
print("avg", metrics["avg"])
|
||||
elif dataset_name == "dynamic_replica" or dataset_name == "pointodyssey":
|
||||
*_, N, _ = sample.trajectory.shape
|
||||
B, T, N = sample.visibility.shape
|
||||
H, W = sample.video.shape[-2:]
|
||||
device = sample.video.device
|
||||
|
||||
out_metrics = {}
|
||||
|
||||
d_vis_sum = d_occ_sum = d_sum_all = 0.0
|
||||
thrs = [1, 2, 4, 8, 16]
|
||||
sx_ = (W - 1) / 255.0
|
||||
sy_ = (H - 1) / 255.0
|
||||
sc_py = np.array([sx_, sy_]).reshape([1, 1, 2])
|
||||
sc_pt = torch.from_numpy(sc_py).float().to(device)
|
||||
__, first_visible_inds = torch.max(sample.visibility, dim=1)
|
||||
|
||||
frame_ids_tensor = torch.arange(T, device=device)[None, :, None].repeat(B, 1, N)
|
||||
start_tracking_mask = frame_ids_tensor > (first_visible_inds.unsqueeze(1))
|
||||
|
||||
for thr in thrs:
|
||||
d_ = (
|
||||
torch.norm(
|
||||
pred_trajectory[..., :2] / sc_pt - sample.trajectory[..., :2] / sc_pt,
|
||||
dim=-1,
|
||||
)
|
||||
< thr
|
||||
).float() # B,S-1,N
|
||||
d_occ = (
|
||||
reduce_masked_mean(d_, (1 - sample.visibility) * start_tracking_mask).item()
|
||||
* 100.0
|
||||
)
|
||||
d_occ_sum += d_occ
|
||||
out_metrics[f"accuracy_occ_{thr}"] = d_occ
|
||||
|
||||
d_vis = (
|
||||
reduce_masked_mean(d_, sample.visibility * start_tracking_mask).item() * 100.0
|
||||
)
|
||||
d_vis_sum += d_vis
|
||||
out_metrics[f"accuracy_vis_{thr}"] = d_vis
|
||||
|
||||
d_all = reduce_masked_mean(d_, start_tracking_mask).item() * 100.0
|
||||
d_sum_all += d_all
|
||||
out_metrics[f"accuracy_{thr}"] = d_all
|
||||
|
||||
d_occ_avg = d_occ_sum / len(thrs)
|
||||
d_vis_avg = d_vis_sum / len(thrs)
|
||||
d_all_avg = d_sum_all / len(thrs)
|
||||
|
||||
sur_thr = 50
|
||||
dists = torch.norm(
|
||||
pred_trajectory[..., :2] / sc_pt - sample.trajectory[..., :2] / sc_pt,
|
||||
dim=-1,
|
||||
) # B,S,N
|
||||
dist_ok = 1 - (dists > sur_thr).float() * sample.visibility # B,S,N
|
||||
survival = torch.cumprod(dist_ok, dim=1) # B,S,N
|
||||
out_metrics["survival"] = torch.mean(survival).item() * 100.0
|
||||
|
||||
out_metrics["accuracy_occ"] = d_occ_avg
|
||||
out_metrics["accuracy_vis"] = d_vis_avg
|
||||
out_metrics["accuracy"] = d_all_avg
|
||||
|
||||
metrics[sample.seq_name[0]] = out_metrics
|
||||
for metric_name in out_metrics.keys():
|
||||
if "avg" not in metrics:
|
||||
metrics["avg"] = {}
|
||||
metrics["avg"][metric_name] = float(
|
||||
np.mean([v[metric_name] for k, v in metrics.items() if k != "avg"])
|
||||
)
|
||||
|
||||
logging.info(f"Metrics: {out_metrics}")
|
||||
logging.info(f"avg: {metrics['avg']}")
|
||||
print("metrics", out_metrics)
|
||||
print("avg", metrics["avg"])
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate_sequence(
|
||||
self,
|
||||
model,
|
||||
test_dataloader: torch.utils.data.DataLoader,
|
||||
dataset_name: str,
|
||||
train_mode=False,
|
||||
visualize_every: int = 1,
|
||||
writer: Optional[SummaryWriter] = None,
|
||||
step: Optional[int] = 0,
|
||||
):
|
||||
metrics = {}
|
||||
|
||||
vis = Visualizer(
|
||||
save_dir=self.exp_dir,
|
||||
fps=7,
|
||||
)
|
||||
|
||||
for ind, sample in enumerate(tqdm(test_dataloader)):
|
||||
if isinstance(sample, tuple):
|
||||
sample, gotit = sample
|
||||
if not all(gotit):
|
||||
print("batch is None")
|
||||
continue
|
||||
if torch.cuda.is_available():
|
||||
dataclass_to_cuda_(sample)
|
||||
device = torch.device("cuda")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
if (
|
||||
not train_mode
|
||||
and hasattr(model, "sequence_len")
|
||||
and (sample.visibility[:, : model.sequence_len].sum() == 0)
|
||||
):
|
||||
print(f"skipping batch {ind}")
|
||||
continue
|
||||
|
||||
if "tapvid" in dataset_name:
|
||||
queries = sample.query_points.clone().float()
|
||||
|
||||
queries = torch.stack(
|
||||
[
|
||||
queries[:, :, 0],
|
||||
queries[:, :, 2],
|
||||
queries[:, :, 1],
|
||||
],
|
||||
dim=2,
|
||||
).to(device)
|
||||
else:
|
||||
queries = torch.cat(
|
||||
[
|
||||
torch.zeros_like(sample.trajectory[:, 0, :, :1]),
|
||||
sample.trajectory[:, 0],
|
||||
],
|
||||
dim=2,
|
||||
).to(device)
|
||||
|
||||
pred_tracks = model(sample.video, queries)
|
||||
if "strided" in dataset_name:
|
||||
inv_video = sample.video.flip(1).clone()
|
||||
inv_queries = queries.clone()
|
||||
inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1
|
||||
|
||||
pred_trj, pred_vsb = pred_tracks
|
||||
inv_pred_trj, inv_pred_vsb = model(inv_video, inv_queries)
|
||||
|
||||
inv_pred_trj = inv_pred_trj.flip(1)
|
||||
inv_pred_vsb = inv_pred_vsb.flip(1)
|
||||
|
||||
mask = pred_trj == 0
|
||||
|
||||
pred_trj[mask] = inv_pred_trj[mask]
|
||||
pred_vsb[mask[:, :, :, 0]] = inv_pred_vsb[mask[:, :, :, 0]]
|
||||
|
||||
pred_tracks = pred_trj, pred_vsb
|
||||
|
||||
if dataset_name == "badja" or dataset_name == "fastcapture":
|
||||
seq_name = sample.seq_name[0]
|
||||
else:
|
||||
seq_name = str(ind)
|
||||
if ind % visualize_every == 0:
|
||||
vis.visualize(
|
||||
sample.video,
|
||||
pred_tracks[0] if isinstance(pred_tracks, tuple) else pred_tracks,
|
||||
filename=dataset_name + "_" + seq_name,
|
||||
writer=writer,
|
||||
step=step,
|
||||
)
|
||||
|
||||
self.compute_metrics(metrics, sample, pred_tracks, dataset_name)
|
||||
return metrics
|
||||
|
@@ -1,169 +1,169 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import hydra
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from cotracker.datasets.tap_vid_datasets import TapVidDataset
|
||||
from cotracker.datasets.dr_dataset import DynamicReplicaDataset
|
||||
from cotracker.datasets.utils import collate_fn
|
||||
|
||||
from cotracker.models.evaluation_predictor import EvaluationPredictor
|
||||
|
||||
from cotracker.evaluation.core.evaluator import Evaluator
|
||||
from cotracker.models.build_cotracker import (
|
||||
build_cotracker,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class DefaultConfig:
|
||||
# Directory where all outputs of the experiment will be saved.
|
||||
exp_dir: str = "./outputs"
|
||||
|
||||
# Name of the dataset to be used for the evaluation.
|
||||
dataset_name: str = "tapvid_davis_first"
|
||||
# The root directory of the dataset.
|
||||
dataset_root: str = "./"
|
||||
|
||||
# Path to the pre-trained model checkpoint to be used for the evaluation.
|
||||
# The default value is the path to a specific CoTracker model checkpoint.
|
||||
checkpoint: str = "./checkpoints/cotracker2.pth"
|
||||
|
||||
# EvaluationPredictor parameters
|
||||
# The size (N) of the support grid used in the predictor.
|
||||
# The total number of points is (N*N).
|
||||
grid_size: int = 5
|
||||
# The size (N) of the local support grid.
|
||||
local_grid_size: int = 8
|
||||
# A flag indicating whether to evaluate one ground truth point at a time.
|
||||
single_point: bool = True
|
||||
# The number of iterative updates for each sliding window.
|
||||
n_iters: int = 6
|
||||
|
||||
seed: int = 0
|
||||
gpu_idx: int = 0
|
||||
|
||||
# Override hydra's working directory to current working dir,
|
||||
# also disable storing the .hydra logs:
|
||||
hydra: dict = field(
|
||||
default_factory=lambda: {
|
||||
"run": {"dir": "."},
|
||||
"output_subdir": None,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def run_eval(cfg: DefaultConfig):
|
||||
"""
|
||||
The function evaluates CoTracker on a specified benchmark dataset based on a provided configuration.
|
||||
|
||||
Args:
|
||||
cfg (DefaultConfig): An instance of DefaultConfig class which includes:
|
||||
- exp_dir (str): The directory path for the experiment.
|
||||
- dataset_name (str): The name of the dataset to be used.
|
||||
- dataset_root (str): The root directory of the dataset.
|
||||
- checkpoint (str): The path to the CoTracker model's checkpoint.
|
||||
- single_point (bool): A flag indicating whether to evaluate one ground truth point at a time.
|
||||
- n_iters (int): The number of iterative updates for each sliding window.
|
||||
- seed (int): The seed for setting the random state for reproducibility.
|
||||
- gpu_idx (int): The index of the GPU to be used.
|
||||
"""
|
||||
# Creating the experiment directory if it doesn't exist
|
||||
os.makedirs(cfg.exp_dir, exist_ok=True)
|
||||
|
||||
# Saving the experiment configuration to a .yaml file in the experiment directory
|
||||
cfg_file = os.path.join(cfg.exp_dir, "expconfig.yaml")
|
||||
with open(cfg_file, "w") as f:
|
||||
OmegaConf.save(config=cfg, f=f)
|
||||
|
||||
evaluator = Evaluator(cfg.exp_dir)
|
||||
cotracker_model = build_cotracker(cfg.checkpoint)
|
||||
|
||||
# Creating the EvaluationPredictor object
|
||||
predictor = EvaluationPredictor(
|
||||
cotracker_model,
|
||||
grid_size=cfg.grid_size,
|
||||
local_grid_size=cfg.local_grid_size,
|
||||
single_point=cfg.single_point,
|
||||
n_iters=cfg.n_iters,
|
||||
)
|
||||
if torch.cuda.is_available():
|
||||
predictor.model = predictor.model.cuda()
|
||||
|
||||
# Setting the random seeds
|
||||
torch.manual_seed(cfg.seed)
|
||||
np.random.seed(cfg.seed)
|
||||
|
||||
# Constructing the specified dataset
|
||||
curr_collate_fn = collate_fn
|
||||
if "tapvid" in cfg.dataset_name:
|
||||
dataset_type = cfg.dataset_name.split("_")[1]
|
||||
if dataset_type == "davis":
|
||||
data_root = os.path.join(cfg.dataset_root, "tapvid_davis", "tapvid_davis.pkl")
|
||||
elif dataset_type == "kinetics":
|
||||
data_root = os.path.join(
|
||||
cfg.dataset_root, "/kinetics/kinetics-dataset/k700-2020/tapvid_kinetics"
|
||||
)
|
||||
test_dataset = TapVidDataset(
|
||||
dataset_type=dataset_type,
|
||||
data_root=data_root,
|
||||
queried_first=not "strided" in cfg.dataset_name,
|
||||
)
|
||||
elif cfg.dataset_name == "dynamic_replica":
|
||||
test_dataset = DynamicReplicaDataset(sample_len=300, only_first_n_samples=1)
|
||||
|
||||
# Creating the DataLoader object
|
||||
test_dataloader = torch.utils.data.DataLoader(
|
||||
test_dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=14,
|
||||
collate_fn=curr_collate_fn,
|
||||
)
|
||||
|
||||
# Timing and conducting the evaluation
|
||||
import time
|
||||
|
||||
start = time.time()
|
||||
evaluate_result = evaluator.evaluate_sequence(
|
||||
predictor,
|
||||
test_dataloader,
|
||||
dataset_name=cfg.dataset_name,
|
||||
)
|
||||
end = time.time()
|
||||
print(end - start)
|
||||
|
||||
# Saving the evaluation results to a .json file
|
||||
evaluate_result = evaluate_result["avg"]
|
||||
print("evaluate_result", evaluate_result)
|
||||
result_file = os.path.join(cfg.exp_dir, f"result_eval_.json")
|
||||
evaluate_result["time"] = end - start
|
||||
print(f"Dumping eval results to {result_file}.")
|
||||
with open(result_file, "w") as f:
|
||||
json.dump(evaluate_result, f)
|
||||
|
||||
|
||||
cs = hydra.core.config_store.ConfigStore.instance()
|
||||
cs.store(name="default_config_eval", node=DefaultConfig)
|
||||
|
||||
|
||||
@hydra.main(config_path="./configs/", config_name="default_config_eval")
|
||||
def evaluate(cfg: DefaultConfig) -> None:
|
||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_idx)
|
||||
run_eval(cfg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
evaluate()
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import hydra
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from cotracker.datasets.tap_vid_datasets import TapVidDataset
|
||||
from cotracker.datasets.dr_dataset import DynamicReplicaDataset
|
||||
from cotracker.datasets.utils import collate_fn
|
||||
|
||||
from cotracker.models.evaluation_predictor import EvaluationPredictor
|
||||
|
||||
from cotracker.evaluation.core.evaluator import Evaluator
|
||||
from cotracker.models.build_cotracker import (
|
||||
build_cotracker,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class DefaultConfig:
|
||||
# Directory where all outputs of the experiment will be saved.
|
||||
exp_dir: str = "./outputs"
|
||||
|
||||
# Name of the dataset to be used for the evaluation.
|
||||
dataset_name: str = "tapvid_davis_first"
|
||||
# The root directory of the dataset.
|
||||
dataset_root: str = "./"
|
||||
|
||||
# Path to the pre-trained model checkpoint to be used for the evaluation.
|
||||
# The default value is the path to a specific CoTracker model checkpoint.
|
||||
checkpoint: str = "./checkpoints/cotracker2.pth"
|
||||
|
||||
# EvaluationPredictor parameters
|
||||
# The size (N) of the support grid used in the predictor.
|
||||
# The total number of points is (N*N).
|
||||
grid_size: int = 5
|
||||
# The size (N) of the local support grid.
|
||||
local_grid_size: int = 8
|
||||
# A flag indicating whether to evaluate one ground truth point at a time.
|
||||
single_point: bool = True
|
||||
# The number of iterative updates for each sliding window.
|
||||
n_iters: int = 6
|
||||
|
||||
seed: int = 0
|
||||
gpu_idx: int = 0
|
||||
|
||||
# Override hydra's working directory to current working dir,
|
||||
# also disable storing the .hydra logs:
|
||||
hydra: dict = field(
|
||||
default_factory=lambda: {
|
||||
"run": {"dir": "."},
|
||||
"output_subdir": None,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def run_eval(cfg: DefaultConfig):
|
||||
"""
|
||||
The function evaluates CoTracker on a specified benchmark dataset based on a provided configuration.
|
||||
|
||||
Args:
|
||||
cfg (DefaultConfig): An instance of DefaultConfig class which includes:
|
||||
- exp_dir (str): The directory path for the experiment.
|
||||
- dataset_name (str): The name of the dataset to be used.
|
||||
- dataset_root (str): The root directory of the dataset.
|
||||
- checkpoint (str): The path to the CoTracker model's checkpoint.
|
||||
- single_point (bool): A flag indicating whether to evaluate one ground truth point at a time.
|
||||
- n_iters (int): The number of iterative updates for each sliding window.
|
||||
- seed (int): The seed for setting the random state for reproducibility.
|
||||
- gpu_idx (int): The index of the GPU to be used.
|
||||
"""
|
||||
# Creating the experiment directory if it doesn't exist
|
||||
os.makedirs(cfg.exp_dir, exist_ok=True)
|
||||
|
||||
# Saving the experiment configuration to a .yaml file in the experiment directory
|
||||
cfg_file = os.path.join(cfg.exp_dir, "expconfig.yaml")
|
||||
with open(cfg_file, "w") as f:
|
||||
OmegaConf.save(config=cfg, f=f)
|
||||
|
||||
evaluator = Evaluator(cfg.exp_dir)
|
||||
cotracker_model = build_cotracker(cfg.checkpoint)
|
||||
|
||||
# Creating the EvaluationPredictor object
|
||||
predictor = EvaluationPredictor(
|
||||
cotracker_model,
|
||||
grid_size=cfg.grid_size,
|
||||
local_grid_size=cfg.local_grid_size,
|
||||
single_point=cfg.single_point,
|
||||
n_iters=cfg.n_iters,
|
||||
)
|
||||
if torch.cuda.is_available():
|
||||
predictor.model = predictor.model.cuda()
|
||||
|
||||
# Setting the random seeds
|
||||
torch.manual_seed(cfg.seed)
|
||||
np.random.seed(cfg.seed)
|
||||
|
||||
# Constructing the specified dataset
|
||||
curr_collate_fn = collate_fn
|
||||
if "tapvid" in cfg.dataset_name:
|
||||
dataset_type = cfg.dataset_name.split("_")[1]
|
||||
if dataset_type == "davis":
|
||||
data_root = os.path.join(cfg.dataset_root, "tapvid_davis", "tapvid_davis.pkl")
|
||||
elif dataset_type == "kinetics":
|
||||
data_root = os.path.join(
|
||||
cfg.dataset_root, "/kinetics/kinetics-dataset/k700-2020/tapvid_kinetics"
|
||||
)
|
||||
test_dataset = TapVidDataset(
|
||||
dataset_type=dataset_type,
|
||||
data_root=data_root,
|
||||
queried_first=not "strided" in cfg.dataset_name,
|
||||
)
|
||||
elif cfg.dataset_name == "dynamic_replica":
|
||||
test_dataset = DynamicReplicaDataset(sample_len=300, only_first_n_samples=1)
|
||||
|
||||
# Creating the DataLoader object
|
||||
test_dataloader = torch.utils.data.DataLoader(
|
||||
test_dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=14,
|
||||
collate_fn=curr_collate_fn,
|
||||
)
|
||||
|
||||
# Timing and conducting the evaluation
|
||||
import time
|
||||
|
||||
start = time.time()
|
||||
evaluate_result = evaluator.evaluate_sequence(
|
||||
predictor,
|
||||
test_dataloader,
|
||||
dataset_name=cfg.dataset_name,
|
||||
)
|
||||
end = time.time()
|
||||
print(end - start)
|
||||
|
||||
# Saving the evaluation results to a .json file
|
||||
evaluate_result = evaluate_result["avg"]
|
||||
print("evaluate_result", evaluate_result)
|
||||
result_file = os.path.join(cfg.exp_dir, f"result_eval_.json")
|
||||
evaluate_result["time"] = end - start
|
||||
print(f"Dumping eval results to {result_file}.")
|
||||
with open(result_file, "w") as f:
|
||||
json.dump(evaluate_result, f)
|
||||
|
||||
|
||||
cs = hydra.core.config_store.ConfigStore.instance()
|
||||
cs.store(name="default_config_eval", node=DefaultConfig)
|
||||
|
||||
|
||||
@hydra.main(config_path="./configs/", config_name="default_config_eval")
|
||||
def evaluate(cfg: DefaultConfig) -> None:
|
||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_idx)
|
||||
run_eval(cfg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
evaluate()
|
||||
|
@@ -1,5 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
BIN
cotracker/models/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
cotracker/models/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/models/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
cotracker/models/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
cotracker/models/__pycache__/build_cotracker.cpython-38.pyc
Normal file
BIN
cotracker/models/__pycache__/build_cotracker.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/models/__pycache__/build_cotracker.cpython-39.pyc
Normal file
BIN
cotracker/models/__pycache__/build_cotracker.cpython-39.pyc
Normal file
Binary file not shown.
@@ -1,33 +1,33 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
|
||||
from cotracker.models.core.cotracker.cotracker import CoTracker2
|
||||
|
||||
|
||||
def build_cotracker(
|
||||
checkpoint: str,
|
||||
):
|
||||
if checkpoint is None:
|
||||
return build_cotracker()
|
||||
model_name = checkpoint.split("/")[-1].split(".")[0]
|
||||
if model_name == "cotracker":
|
||||
return build_cotracker(checkpoint=checkpoint)
|
||||
else:
|
||||
raise ValueError(f"Unknown model name {model_name}")
|
||||
|
||||
|
||||
def build_cotracker(checkpoint=None):
|
||||
cotracker = CoTracker2(stride=4, window_len=8, add_space_attn=True)
|
||||
|
||||
if checkpoint is not None:
|
||||
with open(checkpoint, "rb") as f:
|
||||
state_dict = torch.load(f, map_location="cpu")
|
||||
if "model" in state_dict:
|
||||
state_dict = state_dict["model"]
|
||||
cotracker.load_state_dict(state_dict)
|
||||
return cotracker
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
|
||||
from cotracker.models.core.cotracker.cotracker import CoTracker2
|
||||
|
||||
|
||||
def build_cotracker(
|
||||
checkpoint: str,
|
||||
):
|
||||
if checkpoint is None:
|
||||
return build_cotracker()
|
||||
model_name = checkpoint.split("/")[-1].split(".")[0]
|
||||
if model_name == "cotracker":
|
||||
return build_cotracker(checkpoint=checkpoint)
|
||||
else:
|
||||
raise ValueError(f"Unknown model name {model_name}")
|
||||
|
||||
|
||||
def build_cotracker(checkpoint=None):
|
||||
cotracker = CoTracker2(stride=4, window_len=8, add_space_attn=True)
|
||||
|
||||
if checkpoint is not None:
|
||||
with open(checkpoint, "rb") as f:
|
||||
state_dict = torch.load(f, map_location="cpu")
|
||||
if "model" in state_dict:
|
||||
state_dict = state_dict["model"]
|
||||
cotracker.load_state_dict(state_dict)
|
||||
return cotracker
|
||||
|
@@ -1,5 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
BIN
cotracker/models/core/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
cotracker/models/core/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/models/core/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
cotracker/models/core/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
cotracker/models/core/__pycache__/embeddings.cpython-38.pyc
Normal file
BIN
cotracker/models/core/__pycache__/embeddings.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/models/core/__pycache__/embeddings.cpython-39.pyc
Normal file
BIN
cotracker/models/core/__pycache__/embeddings.cpython-39.pyc
Normal file
Binary file not shown.
BIN
cotracker/models/core/__pycache__/model_utils.cpython-38.pyc
Normal file
BIN
cotracker/models/core/__pycache__/model_utils.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/models/core/__pycache__/model_utils.cpython-39.pyc
Normal file
BIN
cotracker/models/core/__pycache__/model_utils.cpython-39.pyc
Normal file
Binary file not shown.
@@ -1,5 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,367 +1,368 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from functools import partial
|
||||
from typing import Callable
|
||||
import collections
|
||||
from torch import Tensor
|
||||
from itertools import repeat
|
||||
|
||||
from cotracker.models.core.model_utils import bilinear_sampler
|
||||
|
||||
|
||||
# From PyTorch internals
|
||||
def _ntuple(n):
|
||||
def parse(x):
|
||||
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
||||
return tuple(x)
|
||||
return tuple(repeat(x, n))
|
||||
|
||||
return parse
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
|
||||
to_2tuple = _ntuple(2)
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=None,
|
||||
bias=True,
|
||||
drop=0.0,
|
||||
use_conv=False,
|
||||
):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
bias = to_2tuple(bias)
|
||||
drop_probs = to_2tuple(drop)
|
||||
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
||||
|
||||
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
|
||||
self.act = act_layer()
|
||||
self.drop1 = nn.Dropout(drop_probs[0])
|
||||
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
|
||||
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
|
||||
self.drop2 = nn.Dropout(drop_probs[1])
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop1(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop2(x)
|
||||
return x
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_planes,
|
||||
planes,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
stride=stride,
|
||||
padding_mode="zeros",
|
||||
)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, padding_mode="zeros")
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
num_groups = planes // 8
|
||||
|
||||
if norm_fn == "group":
|
||||
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
|
||||
elif norm_fn == "batch":
|
||||
self.norm1 = nn.BatchNorm2d(planes)
|
||||
self.norm2 = nn.BatchNorm2d(planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.BatchNorm2d(planes)
|
||||
|
||||
elif norm_fn == "instance":
|
||||
self.norm1 = nn.InstanceNorm2d(planes)
|
||||
self.norm2 = nn.InstanceNorm2d(planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.InstanceNorm2d(planes)
|
||||
|
||||
elif norm_fn == "none":
|
||||
self.norm1 = nn.Sequential()
|
||||
self.norm2 = nn.Sequential()
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.Sequential()
|
||||
|
||||
if stride == 1:
|
||||
self.downsample = None
|
||||
|
||||
else:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = x
|
||||
y = self.relu(self.norm1(self.conv1(y)))
|
||||
y = self.relu(self.norm2(self.conv2(y)))
|
||||
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
|
||||
return self.relu(x + y)
|
||||
|
||||
|
||||
class BasicEncoder(nn.Module):
|
||||
def __init__(self, input_dim=3, output_dim=128, stride=4):
|
||||
super(BasicEncoder, self).__init__()
|
||||
self.stride = stride
|
||||
self.norm_fn = "instance"
|
||||
self.in_planes = output_dim // 2
|
||||
|
||||
self.norm1 = nn.InstanceNorm2d(self.in_planes)
|
||||
self.norm2 = nn.InstanceNorm2d(output_dim * 2)
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
input_dim,
|
||||
self.in_planes,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
padding=3,
|
||||
padding_mode="zeros",
|
||||
)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
self.layer1 = self._make_layer(output_dim // 2, stride=1)
|
||||
self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2)
|
||||
self.layer3 = self._make_layer(output_dim, stride=2)
|
||||
self.layer4 = self._make_layer(output_dim, stride=2)
|
||||
|
||||
self.conv2 = nn.Conv2d(
|
||||
output_dim * 3 + output_dim // 4,
|
||||
output_dim * 2,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
padding_mode="zeros",
|
||||
)
|
||||
self.relu2 = nn.ReLU(inplace=True)
|
||||
self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||
elif isinstance(m, (nn.InstanceNorm2d)):
|
||||
if m.weight is not None:
|
||||
nn.init.constant_(m.weight, 1)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def _make_layer(self, dim, stride=1):
|
||||
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
||||
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
||||
layers = (layer1, layer2)
|
||||
|
||||
self.in_planes = dim
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
_, _, H, W = x.shape
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.relu1(x)
|
||||
|
||||
a = self.layer1(x)
|
||||
b = self.layer2(a)
|
||||
c = self.layer3(b)
|
||||
d = self.layer4(c)
|
||||
|
||||
def _bilinear_intepolate(x):
|
||||
return F.interpolate(
|
||||
x,
|
||||
(H // self.stride, W // self.stride),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
|
||||
a = _bilinear_intepolate(a)
|
||||
b = _bilinear_intepolate(b)
|
||||
c = _bilinear_intepolate(c)
|
||||
d = _bilinear_intepolate(d)
|
||||
|
||||
x = self.conv2(torch.cat([a, b, c, d], dim=1))
|
||||
x = self.norm2(x)
|
||||
x = self.relu2(x)
|
||||
x = self.conv3(x)
|
||||
return x
|
||||
|
||||
|
||||
class CorrBlock:
|
||||
def __init__(
|
||||
self,
|
||||
fmaps,
|
||||
num_levels=4,
|
||||
radius=4,
|
||||
multiple_track_feats=False,
|
||||
padding_mode="zeros",
|
||||
):
|
||||
B, S, C, H, W = fmaps.shape
|
||||
self.S, self.C, self.H, self.W = S, C, H, W
|
||||
self.padding_mode = padding_mode
|
||||
self.num_levels = num_levels
|
||||
self.radius = radius
|
||||
self.fmaps_pyramid = []
|
||||
self.multiple_track_feats = multiple_track_feats
|
||||
|
||||
self.fmaps_pyramid.append(fmaps)
|
||||
for i in range(self.num_levels - 1):
|
||||
fmaps_ = fmaps.reshape(B * S, C, H, W)
|
||||
fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
|
||||
_, _, H, W = fmaps_.shape
|
||||
fmaps = fmaps_.reshape(B, S, C, H, W)
|
||||
self.fmaps_pyramid.append(fmaps)
|
||||
|
||||
def sample(self, coords):
|
||||
r = self.radius
|
||||
B, S, N, D = coords.shape
|
||||
assert D == 2
|
||||
|
||||
H, W = self.H, self.W
|
||||
out_pyramid = []
|
||||
for i in range(self.num_levels):
|
||||
corrs = self.corrs_pyramid[i] # B, S, N, H, W
|
||||
*_, H, W = corrs.shape
|
||||
|
||||
dx = torch.linspace(-r, r, 2 * r + 1)
|
||||
dy = torch.linspace(-r, r, 2 * r + 1)
|
||||
delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(coords.device)
|
||||
|
||||
centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2**i
|
||||
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
|
||||
coords_lvl = centroid_lvl + delta_lvl
|
||||
|
||||
corrs = bilinear_sampler(
|
||||
corrs.reshape(B * S * N, 1, H, W),
|
||||
coords_lvl,
|
||||
padding_mode=self.padding_mode,
|
||||
)
|
||||
corrs = corrs.view(B, S, N, -1)
|
||||
out_pyramid.append(corrs)
|
||||
|
||||
out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
|
||||
out = out.permute(0, 2, 1, 3).contiguous().view(B * N, S, -1).float()
|
||||
return out
|
||||
|
||||
def corr(self, targets):
|
||||
B, S, N, C = targets.shape
|
||||
if self.multiple_track_feats:
|
||||
targets_split = targets.split(C // self.num_levels, dim=-1)
|
||||
B, S, N, C = targets_split[0].shape
|
||||
|
||||
assert C == self.C
|
||||
assert S == self.S
|
||||
|
||||
fmap1 = targets
|
||||
|
||||
self.corrs_pyramid = []
|
||||
for i, fmaps in enumerate(self.fmaps_pyramid):
|
||||
*_, H, W = fmaps.shape
|
||||
fmap2s = fmaps.view(B, S, C, H * W) # B S C H W -> B S C (H W)
|
||||
if self.multiple_track_feats:
|
||||
fmap1 = targets_split[i]
|
||||
corrs = torch.matmul(fmap1, fmap2s)
|
||||
corrs = corrs.view(B, S, N, H, W) # B S N (H W) -> B S N H W
|
||||
corrs = corrs / torch.sqrt(torch.tensor(C).float())
|
||||
self.corrs_pyramid.append(corrs)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, num_heads=8, dim_head=48, qkv_bias=False):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * num_heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = num_heads
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias)
|
||||
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=qkv_bias)
|
||||
self.to_out = nn.Linear(inner_dim, query_dim)
|
||||
|
||||
def forward(self, x, context=None, attn_bias=None):
|
||||
B, N1, C = x.shape
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x).reshape(B, N1, h, C // h).permute(0, 2, 1, 3)
|
||||
context = default(context, x)
|
||||
k, v = self.to_kv(context).chunk(2, dim=-1)
|
||||
|
||||
N2 = context.shape[1]
|
||||
k = k.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
|
||||
v = v.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
|
||||
|
||||
sim = (q @ k.transpose(-2, -1)) * self.scale
|
||||
|
||||
if attn_bias is not None:
|
||||
sim = sim + attn_bias
|
||||
attn = sim.softmax(dim=-1)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N1, C)
|
||||
return self.to_out(x)
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
num_heads,
|
||||
attn_class: Callable[..., nn.Module] = Attention,
|
||||
mlp_ratio=4.0,
|
||||
**block_kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.attn = attn_class(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
|
||||
|
||||
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||
self.mlp = Mlp(
|
||||
in_features=hidden_size,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=approx_gelu,
|
||||
drop=0,
|
||||
)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
attn_bias = mask
|
||||
if mask is not None:
|
||||
mask = (
|
||||
(mask[:, None] * mask[:, :, None])
|
||||
.unsqueeze(1)
|
||||
.expand(-1, self.attn.num_heads, -1, -1)
|
||||
)
|
||||
max_neg_value = -torch.finfo(x.dtype).max
|
||||
attn_bias = (~mask) * max_neg_value
|
||||
x = x + self.attn(self.norm1(x), attn_bias=attn_bias)
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from functools import partial
|
||||
from typing import Callable
|
||||
import collections
|
||||
from torch import Tensor
|
||||
from itertools import repeat
|
||||
|
||||
from cotracker.models.core.model_utils import bilinear_sampler
|
||||
|
||||
|
||||
# From PyTorch internals
|
||||
def _ntuple(n):
|
||||
def parse(x):
|
||||
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
||||
return tuple(x)
|
||||
return tuple(repeat(x, n))
|
||||
|
||||
return parse
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
|
||||
to_2tuple = _ntuple(2)
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=None,
|
||||
bias=True,
|
||||
drop=0.0,
|
||||
use_conv=False,
|
||||
):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
bias = to_2tuple(bias)
|
||||
drop_probs = to_2tuple(drop)
|
||||
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
||||
|
||||
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
|
||||
self.act = act_layer()
|
||||
self.drop1 = nn.Dropout(drop_probs[0])
|
||||
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
|
||||
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
|
||||
self.drop2 = nn.Dropout(drop_probs[1])
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop1(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop2(x)
|
||||
return x
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_planes,
|
||||
planes,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
stride=stride,
|
||||
padding_mode="zeros",
|
||||
)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, padding_mode="zeros")
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
num_groups = planes // 8
|
||||
|
||||
if norm_fn == "group":
|
||||
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
|
||||
elif norm_fn == "batch":
|
||||
self.norm1 = nn.BatchNorm2d(planes)
|
||||
self.norm2 = nn.BatchNorm2d(planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.BatchNorm2d(planes)
|
||||
|
||||
elif norm_fn == "instance":
|
||||
self.norm1 = nn.InstanceNorm2d(planes)
|
||||
self.norm2 = nn.InstanceNorm2d(planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.InstanceNorm2d(planes)
|
||||
|
||||
elif norm_fn == "none":
|
||||
self.norm1 = nn.Sequential()
|
||||
self.norm2 = nn.Sequential()
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.Sequential()
|
||||
|
||||
if stride == 1:
|
||||
self.downsample = None
|
||||
|
||||
else:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = x
|
||||
y = self.relu(self.norm1(self.conv1(y)))
|
||||
y = self.relu(self.norm2(self.conv2(y)))
|
||||
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
|
||||
return self.relu(x + y)
|
||||
|
||||
|
||||
class BasicEncoder(nn.Module):
|
||||
def __init__(self, input_dim=3, output_dim=128, stride=4):
|
||||
super(BasicEncoder, self).__init__()
|
||||
self.stride = stride
|
||||
self.norm_fn = "instance"
|
||||
self.in_planes = output_dim // 2
|
||||
|
||||
self.norm1 = nn.InstanceNorm2d(self.in_planes)
|
||||
self.norm2 = nn.InstanceNorm2d(output_dim * 2)
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
input_dim,
|
||||
self.in_planes,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
padding=3,
|
||||
padding_mode="zeros",
|
||||
)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
self.layer1 = self._make_layer(output_dim // 2, stride=1)
|
||||
self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2)
|
||||
self.layer3 = self._make_layer(output_dim, stride=2)
|
||||
self.layer4 = self._make_layer(output_dim, stride=2)
|
||||
|
||||
self.conv2 = nn.Conv2d(
|
||||
output_dim * 3 + output_dim // 4,
|
||||
output_dim * 2,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
padding_mode="zeros",
|
||||
)
|
||||
self.relu2 = nn.ReLU(inplace=True)
|
||||
self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||
elif isinstance(m, (nn.InstanceNorm2d)):
|
||||
if m.weight is not None:
|
||||
nn.init.constant_(m.weight, 1)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def _make_layer(self, dim, stride=1):
|
||||
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
||||
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
||||
layers = (layer1, layer2)
|
||||
|
||||
self.in_planes = dim
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
_, _, H, W = x.shape
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.relu1(x)
|
||||
|
||||
# 四层残差块
|
||||
a = self.layer1(x)
|
||||
b = self.layer2(a)
|
||||
c = self.layer3(b)
|
||||
d = self.layer4(c)
|
||||
|
||||
def _bilinear_intepolate(x):
|
||||
return F.interpolate(
|
||||
x,
|
||||
(H // self.stride, W // self.stride),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
|
||||
a = _bilinear_intepolate(a)
|
||||
b = _bilinear_intepolate(b)
|
||||
c = _bilinear_intepolate(c)
|
||||
d = _bilinear_intepolate(d)
|
||||
|
||||
x = self.conv2(torch.cat([a, b, c, d], dim=1))
|
||||
x = self.norm2(x)
|
||||
x = self.relu2(x)
|
||||
x = self.conv3(x)
|
||||
return x
|
||||
|
||||
|
||||
class CorrBlock:
|
||||
def __init__(
|
||||
self,
|
||||
fmaps,
|
||||
num_levels=4,
|
||||
radius=4,
|
||||
multiple_track_feats=False,
|
||||
padding_mode="zeros",
|
||||
):
|
||||
B, S, C, H, W = fmaps.shape
|
||||
self.S, self.C, self.H, self.W = S, C, H, W
|
||||
self.padding_mode = padding_mode
|
||||
self.num_levels = num_levels
|
||||
self.radius = radius
|
||||
self.fmaps_pyramid = []
|
||||
self.multiple_track_feats = multiple_track_feats
|
||||
|
||||
self.fmaps_pyramid.append(fmaps)
|
||||
for i in range(self.num_levels - 1):
|
||||
fmaps_ = fmaps.reshape(B * S, C, H, W)
|
||||
fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
|
||||
_, _, H, W = fmaps_.shape
|
||||
fmaps = fmaps_.reshape(B, S, C, H, W)
|
||||
self.fmaps_pyramid.append(fmaps)
|
||||
|
||||
def sample(self, coords):
|
||||
r = self.radius
|
||||
B, S, N, D = coords.shape
|
||||
assert D == 2
|
||||
|
||||
H, W = self.H, self.W
|
||||
out_pyramid = []
|
||||
for i in range(self.num_levels):
|
||||
corrs = self.corrs_pyramid[i] # B, S, N, H, W
|
||||
*_, H, W = corrs.shape
|
||||
|
||||
dx = torch.linspace(-r, r, 2 * r + 1)
|
||||
dy = torch.linspace(-r, r, 2 * r + 1)
|
||||
delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(coords.device)
|
||||
|
||||
centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2**i
|
||||
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
|
||||
coords_lvl = centroid_lvl + delta_lvl
|
||||
|
||||
corrs = bilinear_sampler(
|
||||
corrs.reshape(B * S * N, 1, H, W),
|
||||
coords_lvl,
|
||||
padding_mode=self.padding_mode,
|
||||
)
|
||||
corrs = corrs.view(B, S, N, -1)
|
||||
out_pyramid.append(corrs)
|
||||
|
||||
out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
|
||||
out = out.permute(0, 2, 1, 3).contiguous().view(B * N, S, -1).float()
|
||||
return out
|
||||
|
||||
def corr(self, targets):
|
||||
B, S, N, C = targets.shape
|
||||
if self.multiple_track_feats:
|
||||
targets_split = targets.split(C // self.num_levels, dim=-1)
|
||||
B, S, N, C = targets_split[0].shape
|
||||
|
||||
assert C == self.C
|
||||
assert S == self.S
|
||||
|
||||
fmap1 = targets
|
||||
|
||||
self.corrs_pyramid = []
|
||||
for i, fmaps in enumerate(self.fmaps_pyramid):
|
||||
*_, H, W = fmaps.shape
|
||||
fmap2s = fmaps.view(B, S, C, H * W) # B S C H W -> B S C (H W)
|
||||
if self.multiple_track_feats:
|
||||
fmap1 = targets_split[i]
|
||||
corrs = torch.matmul(fmap1, fmap2s)
|
||||
corrs = corrs.view(B, S, N, H, W) # B S N (H W) -> B S N H W
|
||||
corrs = corrs / torch.sqrt(torch.tensor(C).float())
|
||||
self.corrs_pyramid.append(corrs)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, num_heads=8, dim_head=48, qkv_bias=False):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * num_heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = num_heads
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias)
|
||||
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=qkv_bias)
|
||||
self.to_out = nn.Linear(inner_dim, query_dim)
|
||||
|
||||
def forward(self, x, context=None, attn_bias=None):
|
||||
B, N1, C = x.shape
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x).reshape(B, N1, h, C // h).permute(0, 2, 1, 3)
|
||||
context = default(context, x)
|
||||
k, v = self.to_kv(context).chunk(2, dim=-1)
|
||||
|
||||
N2 = context.shape[1]
|
||||
k = k.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
|
||||
v = v.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
|
||||
|
||||
sim = (q @ k.transpose(-2, -1)) * self.scale
|
||||
|
||||
if attn_bias is not None:
|
||||
sim = sim + attn_bias
|
||||
attn = sim.softmax(dim=-1)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N1, C)
|
||||
return self.to_out(x)
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
num_heads,
|
||||
attn_class: Callable[..., nn.Module] = Attention,
|
||||
mlp_ratio=4.0,
|
||||
**block_kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.attn = attn_class(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
|
||||
|
||||
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||
self.mlp = Mlp(
|
||||
in_features=hidden_size,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=approx_gelu,
|
||||
drop=0,
|
||||
)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
attn_bias = mask
|
||||
if mask is not None:
|
||||
mask = (
|
||||
(mask[:, None] * mask[:, :, None])
|
||||
.unsqueeze(1)
|
||||
.expand(-1, self.attn.num_heads, -1, -1)
|
||||
)
|
||||
max_neg_value = -torch.finfo(x.dtype).max
|
||||
attn_bias = (~mask) * max_neg_value
|
||||
x = x + self.attn(self.norm1(x), attn_bias=attn_bias)
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
|
File diff suppressed because it is too large
Load Diff
@@ -1,61 +1,61 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from cotracker.models.core.model_utils import reduce_masked_mean
|
||||
|
||||
EPS = 1e-6
|
||||
|
||||
|
||||
def balanced_ce_loss(pred, gt, valid=None):
|
||||
total_balanced_loss = 0.0
|
||||
for j in range(len(gt)):
|
||||
B, S, N = gt[j].shape
|
||||
# pred and gt are the same shape
|
||||
for (a, b) in zip(pred[j].size(), gt[j].size()):
|
||||
assert a == b # some shape mismatch!
|
||||
# if valid is not None:
|
||||
for (a, b) in zip(pred[j].size(), valid[j].size()):
|
||||
assert a == b # some shape mismatch!
|
||||
|
||||
pos = (gt[j] > 0.95).float()
|
||||
neg = (gt[j] < 0.05).float()
|
||||
|
||||
label = pos * 2.0 - 1.0
|
||||
a = -label * pred[j]
|
||||
b = F.relu(a)
|
||||
loss = b + torch.log(torch.exp(-b) + torch.exp(a - b))
|
||||
|
||||
pos_loss = reduce_masked_mean(loss, pos * valid[j])
|
||||
neg_loss = reduce_masked_mean(loss, neg * valid[j])
|
||||
|
||||
balanced_loss = pos_loss + neg_loss
|
||||
total_balanced_loss += balanced_loss / float(N)
|
||||
return total_balanced_loss
|
||||
|
||||
|
||||
def sequence_loss(flow_preds, flow_gt, vis, valids, gamma=0.8):
|
||||
"""Loss function defined over sequence of flow predictions"""
|
||||
total_flow_loss = 0.0
|
||||
for j in range(len(flow_gt)):
|
||||
B, S, N, D = flow_gt[j].shape
|
||||
assert D == 2
|
||||
B, S1, N = vis[j].shape
|
||||
B, S2, N = valids[j].shape
|
||||
assert S == S1
|
||||
assert S == S2
|
||||
n_predictions = len(flow_preds[j])
|
||||
flow_loss = 0.0
|
||||
for i in range(n_predictions):
|
||||
i_weight = gamma ** (n_predictions - i - 1)
|
||||
flow_pred = flow_preds[j][i]
|
||||
i_loss = (flow_pred - flow_gt[j]).abs() # B, S, N, 2
|
||||
i_loss = torch.mean(i_loss, dim=3) # B, S, N
|
||||
flow_loss += i_weight * reduce_masked_mean(i_loss, valids[j])
|
||||
flow_loss = flow_loss / n_predictions
|
||||
total_flow_loss += flow_loss / float(N)
|
||||
return total_flow_loss
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from cotracker.models.core.model_utils import reduce_masked_mean
|
||||
|
||||
EPS = 1e-6
|
||||
|
||||
|
||||
def balanced_ce_loss(pred, gt, valid=None):
|
||||
total_balanced_loss = 0.0
|
||||
for j in range(len(gt)):
|
||||
B, S, N = gt[j].shape
|
||||
# pred and gt are the same shape
|
||||
for (a, b) in zip(pred[j].size(), gt[j].size()):
|
||||
assert a == b # some shape mismatch!
|
||||
# if valid is not None:
|
||||
for (a, b) in zip(pred[j].size(), valid[j].size()):
|
||||
assert a == b # some shape mismatch!
|
||||
|
||||
pos = (gt[j] > 0.95).float()
|
||||
neg = (gt[j] < 0.05).float()
|
||||
|
||||
label = pos * 2.0 - 1.0
|
||||
a = -label * pred[j]
|
||||
b = F.relu(a)
|
||||
loss = b + torch.log(torch.exp(-b) + torch.exp(a - b))
|
||||
|
||||
pos_loss = reduce_masked_mean(loss, pos * valid[j])
|
||||
neg_loss = reduce_masked_mean(loss, neg * valid[j])
|
||||
|
||||
balanced_loss = pos_loss + neg_loss
|
||||
total_balanced_loss += balanced_loss / float(N)
|
||||
return total_balanced_loss
|
||||
|
||||
|
||||
def sequence_loss(flow_preds, flow_gt, vis, valids, gamma=0.8):
|
||||
"""Loss function defined over sequence of flow predictions"""
|
||||
total_flow_loss = 0.0
|
||||
for j in range(len(flow_gt)):
|
||||
B, S, N, D = flow_gt[j].shape
|
||||
assert D == 2
|
||||
B, S1, N = vis[j].shape
|
||||
B, S2, N = valids[j].shape
|
||||
assert S == S1
|
||||
assert S == S2
|
||||
n_predictions = len(flow_preds[j])
|
||||
flow_loss = 0.0
|
||||
for i in range(n_predictions):
|
||||
i_weight = gamma ** (n_predictions - i - 1)
|
||||
flow_pred = flow_preds[j][i]
|
||||
i_loss = (flow_pred - flow_gt[j]).abs() # B, S, N, 2
|
||||
i_loss = torch.mean(i_loss, dim=3) # B, S, N
|
||||
flow_loss += i_weight * reduce_masked_mean(i_loss, valids[j])
|
||||
flow_loss = flow_loss / n_predictions
|
||||
total_flow_loss += flow_loss / float(N)
|
||||
return total_flow_loss
|
||||
|
@@ -1,120 +1,120 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Tuple, Union
|
||||
import torch
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed(
|
||||
embed_dim: int, grid_size: Union[int, Tuple[int, int]]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
|
||||
It is a wrapper of get_2d_sincos_pos_embed_from_grid.
|
||||
Args:
|
||||
- embed_dim: The embedding dimension.
|
||||
- grid_size: The grid size.
|
||||
Returns:
|
||||
- pos_embed: The generated 2D positional embedding.
|
||||
"""
|
||||
if isinstance(grid_size, tuple):
|
||||
grid_size_h, grid_size_w = grid_size
|
||||
else:
|
||||
grid_size_h = grid_size_w = grid_size
|
||||
grid_h = torch.arange(grid_size_h, dtype=torch.float)
|
||||
grid_w = torch.arange(grid_size_w, dtype=torch.float)
|
||||
grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
|
||||
grid = torch.stack(grid, dim=0)
|
||||
grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
|
||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||
return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_from_grid(
|
||||
embed_dim: int, grid: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function generates a 2D positional embedding from a given grid using sine and cosine functions.
|
||||
|
||||
Args:
|
||||
- embed_dim: The embedding dimension.
|
||||
- grid: The grid to generate the embedding from.
|
||||
|
||||
Returns:
|
||||
- emb: The generated 2D positional embedding.
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||
|
||||
emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(
|
||||
embed_dim: int, pos: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function generates a 1D positional embedding from a given grid using sine and cosine functions.
|
||||
|
||||
Args:
|
||||
- embed_dim: The embedding dimension.
|
||||
- pos: The position to generate the embedding from.
|
||||
|
||||
Returns:
|
||||
- emb: The generated 1D positional embedding.
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
omega = torch.arange(embed_dim // 2, dtype=torch.double)
|
||||
omega /= embed_dim / 2.0
|
||||
omega = 1.0 / 10000**omega # (D/2,)
|
||||
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
||||
|
||||
emb_sin = torch.sin(out) # (M, D/2)
|
||||
emb_cos = torch.cos(out) # (M, D/2)
|
||||
|
||||
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
||||
return emb[None].float()
|
||||
|
||||
|
||||
def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
|
||||
|
||||
Args:
|
||||
- xy: The coordinates to generate the embedding from.
|
||||
- C: The size of the embedding.
|
||||
- cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
|
||||
|
||||
Returns:
|
||||
- pe: The generated 2D positional embedding.
|
||||
"""
|
||||
B, N, D = xy.shape
|
||||
assert D == 2
|
||||
|
||||
x = xy[:, :, 0:1]
|
||||
y = xy[:, :, 1:2]
|
||||
div_term = (
|
||||
torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)
|
||||
).reshape(1, 1, int(C / 2))
|
||||
|
||||
pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
|
||||
pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
|
||||
|
||||
pe_x[:, :, 0::2] = torch.sin(x * div_term)
|
||||
pe_x[:, :, 1::2] = torch.cos(x * div_term)
|
||||
|
||||
pe_y[:, :, 0::2] = torch.sin(y * div_term)
|
||||
pe_y[:, :, 1::2] = torch.cos(y * div_term)
|
||||
|
||||
pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
|
||||
if cat_coords:
|
||||
pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
|
||||
return pe
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Tuple, Union
|
||||
import torch
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed(
|
||||
embed_dim: int, grid_size: Union[int, Tuple[int, int]]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
|
||||
It is a wrapper of get_2d_sincos_pos_embed_from_grid.
|
||||
Args:
|
||||
- embed_dim: The embedding dimension.
|
||||
- grid_size: The grid size.
|
||||
Returns:
|
||||
- pos_embed: The generated 2D positional embedding.
|
||||
"""
|
||||
if isinstance(grid_size, tuple):
|
||||
grid_size_h, grid_size_w = grid_size
|
||||
else:
|
||||
grid_size_h = grid_size_w = grid_size
|
||||
grid_h = torch.arange(grid_size_h, dtype=torch.float)
|
||||
grid_w = torch.arange(grid_size_w, dtype=torch.float)
|
||||
grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
|
||||
grid = torch.stack(grid, dim=0)
|
||||
grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
|
||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||
return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_from_grid(
|
||||
embed_dim: int, grid: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function generates a 2D positional embedding from a given grid using sine and cosine functions.
|
||||
|
||||
Args:
|
||||
- embed_dim: The embedding dimension.
|
||||
- grid: The grid to generate the embedding from.
|
||||
|
||||
Returns:
|
||||
- emb: The generated 2D positional embedding.
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||
|
||||
emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(
|
||||
embed_dim: int, pos: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function generates a 1D positional embedding from a given grid using sine and cosine functions.
|
||||
|
||||
Args:
|
||||
- embed_dim: The embedding dimension.
|
||||
- pos: The position to generate the embedding from.
|
||||
|
||||
Returns:
|
||||
- emb: The generated 1D positional embedding.
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
omega = torch.arange(embed_dim // 2, dtype=torch.double)
|
||||
omega /= embed_dim / 2.0
|
||||
omega = 1.0 / 10000**omega # (D/2,)
|
||||
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
||||
|
||||
emb_sin = torch.sin(out) # (M, D/2)
|
||||
emb_cos = torch.cos(out) # (M, D/2)
|
||||
|
||||
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
||||
return emb[None].float()
|
||||
|
||||
|
||||
def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
|
||||
|
||||
Args:
|
||||
- xy: The coordinates to generate the embedding from.
|
||||
- C: The size of the embedding.
|
||||
- cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
|
||||
|
||||
Returns:
|
||||
- pe: The generated 2D positional embedding.
|
||||
"""
|
||||
B, N, D = xy.shape
|
||||
assert D == 2
|
||||
|
||||
x = xy[:, :, 0:1]
|
||||
y = xy[:, :, 1:2]
|
||||
div_term = (
|
||||
torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)
|
||||
).reshape(1, 1, int(C / 2))
|
||||
|
||||
pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
|
||||
pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
|
||||
|
||||
pe_x[:, :, 0::2] = torch.sin(x * div_term)
|
||||
pe_x[:, :, 1::2] = torch.cos(x * div_term)
|
||||
|
||||
pe_y[:, :, 0::2] = torch.sin(y * div_term)
|
||||
pe_y[:, :, 1::2] = torch.cos(y * div_term)
|
||||
|
||||
pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
|
||||
if cat_coords:
|
||||
pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
|
||||
return pe
|
||||
|
@@ -1,256 +1,256 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, Tuple
|
||||
|
||||
EPS = 1e-6
|
||||
|
||||
|
||||
def smart_cat(tensor1, tensor2, dim):
|
||||
if tensor1 is None:
|
||||
return tensor2
|
||||
return torch.cat([tensor1, tensor2], dim=dim)
|
||||
|
||||
|
||||
def get_points_on_a_grid(
|
||||
size: int,
|
||||
extent: Tuple[float, ...],
|
||||
center: Optional[Tuple[float, ...]] = None,
|
||||
device: Optional[torch.device] = torch.device("cpu"),
|
||||
):
|
||||
r"""Get a grid of points covering a rectangular region
|
||||
|
||||
`get_points_on_a_grid(size, extent)` generates a :attr:`size` by
|
||||
:attr:`size` grid fo points distributed to cover a rectangular area
|
||||
specified by `extent`.
|
||||
|
||||
The `extent` is a pair of integer :math:`(H,W)` specifying the height
|
||||
and width of the rectangle.
|
||||
|
||||
Optionally, the :attr:`center` can be specified as a pair :math:`(c_y,c_x)`
|
||||
specifying the vertical and horizontal center coordinates. The center
|
||||
defaults to the middle of the extent.
|
||||
|
||||
Points are distributed uniformly within the rectangle leaving a margin
|
||||
:math:`m=W/64` from the border.
|
||||
|
||||
It returns a :math:`(1, \text{size} \times \text{size}, 2)` tensor of
|
||||
points :math:`P_{ij}=(x_i, y_i)` where
|
||||
|
||||
.. math::
|
||||
P_{ij} = \left(
|
||||
c_x + m -\frac{W}{2} + \frac{W - 2m}{\text{size} - 1}\, j,~
|
||||
c_y + m -\frac{H}{2} + \frac{H - 2m}{\text{size} - 1}\, i
|
||||
\right)
|
||||
|
||||
Points are returned in row-major order.
|
||||
|
||||
Args:
|
||||
size (int): grid size.
|
||||
extent (tuple): height and with of the grid extent.
|
||||
center (tuple, optional): grid center.
|
||||
device (str, optional): Defaults to `"cpu"`.
|
||||
|
||||
Returns:
|
||||
Tensor: grid.
|
||||
"""
|
||||
if size == 1:
|
||||
return torch.tensor([extent[1] / 2, extent[0] / 2], device=device)[None, None]
|
||||
|
||||
if center is None:
|
||||
center = [extent[0] / 2, extent[1] / 2]
|
||||
|
||||
margin = extent[1] / 64
|
||||
range_y = (margin - extent[0] / 2 + center[0], extent[0] / 2 + center[0] - margin)
|
||||
range_x = (margin - extent[1] / 2 + center[1], extent[1] / 2 + center[1] - margin)
|
||||
grid_y, grid_x = torch.meshgrid(
|
||||
torch.linspace(*range_y, size, device=device),
|
||||
torch.linspace(*range_x, size, device=device),
|
||||
indexing="ij",
|
||||
)
|
||||
return torch.stack([grid_x, grid_y], dim=-1).reshape(1, -1, 2)
|
||||
|
||||
|
||||
def reduce_masked_mean(input, mask, dim=None, keepdim=False):
|
||||
r"""Masked mean
|
||||
|
||||
`reduce_masked_mean(x, mask)` computes the mean of a tensor :attr:`input`
|
||||
over a mask :attr:`mask`, returning
|
||||
|
||||
.. math::
|
||||
\text{output} =
|
||||
\frac
|
||||
{\sum_{i=1}^N \text{input}_i \cdot \text{mask}_i}
|
||||
{\epsilon + \sum_{i=1}^N \text{mask}_i}
|
||||
|
||||
where :math:`N` is the number of elements in :attr:`input` and
|
||||
:attr:`mask`, and :math:`\epsilon` is a small constant to avoid
|
||||
division by zero.
|
||||
|
||||
`reduced_masked_mean(x, mask, dim)` computes the mean of a tensor
|
||||
:attr:`input` over a mask :attr:`mask` along a dimension :attr:`dim`.
|
||||
Optionally, the dimension can be kept in the output by setting
|
||||
:attr:`keepdim` to `True`. Tensor :attr:`mask` must be broadcastable to
|
||||
the same dimension as :attr:`input`.
|
||||
|
||||
The interface is similar to `torch.mean()`.
|
||||
|
||||
Args:
|
||||
inout (Tensor): input tensor.
|
||||
mask (Tensor): mask.
|
||||
dim (int, optional): Dimension to sum over. Defaults to None.
|
||||
keepdim (bool, optional): Keep the summed dimension. Defaults to False.
|
||||
|
||||
Returns:
|
||||
Tensor: mean tensor.
|
||||
"""
|
||||
|
||||
mask = mask.expand_as(input)
|
||||
|
||||
prod = input * mask
|
||||
|
||||
if dim is None:
|
||||
numer = torch.sum(prod)
|
||||
denom = torch.sum(mask)
|
||||
else:
|
||||
numer = torch.sum(prod, dim=dim, keepdim=keepdim)
|
||||
denom = torch.sum(mask, dim=dim, keepdim=keepdim)
|
||||
|
||||
mean = numer / (EPS + denom)
|
||||
return mean
|
||||
|
||||
|
||||
def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
|
||||
r"""Sample a tensor using bilinear interpolation
|
||||
|
||||
`bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
|
||||
coordinates :attr:`coords` using bilinear interpolation. It is the same
|
||||
as `torch.nn.functional.grid_sample()` but with a different coordinate
|
||||
convention.
|
||||
|
||||
The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
|
||||
:math:`B` is the batch size, :math:`C` is the number of channels,
|
||||
:math:`H` is the height of the image, and :math:`W` is the width of the
|
||||
image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
|
||||
interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
|
||||
|
||||
Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
|
||||
in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
|
||||
that in this case the order of the components is slightly different
|
||||
from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
|
||||
|
||||
If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
|
||||
in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
|
||||
left-most image pixel :math:`W-1` to the center of the right-most
|
||||
pixel.
|
||||
|
||||
If `align_corners` is `False`, the coordinate :math:`x` is assumed to
|
||||
be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
|
||||
the left-most pixel :math:`W` to the right edge of the right-most
|
||||
pixel.
|
||||
|
||||
Similar conventions apply to the :math:`y` for the range
|
||||
:math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
|
||||
:math:`[0,T-1]` and :math:`[0,T]`.
|
||||
|
||||
Args:
|
||||
input (Tensor): batch of input images.
|
||||
coords (Tensor): batch of coordinates.
|
||||
align_corners (bool, optional): Coordinate convention. Defaults to `True`.
|
||||
padding_mode (str, optional): Padding mode. Defaults to `"border"`.
|
||||
|
||||
Returns:
|
||||
Tensor: sampled points.
|
||||
"""
|
||||
|
||||
sizes = input.shape[2:]
|
||||
|
||||
assert len(sizes) in [2, 3]
|
||||
|
||||
if len(sizes) == 3:
|
||||
# t x y -> x y t to match dimensions T H W in grid_sample
|
||||
coords = coords[..., [1, 2, 0]]
|
||||
|
||||
if align_corners:
|
||||
coords = coords * torch.tensor(
|
||||
[2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device
|
||||
)
|
||||
else:
|
||||
coords = coords * torch.tensor([2 / size for size in reversed(sizes)], device=coords.device)
|
||||
|
||||
coords -= 1
|
||||
|
||||
return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
|
||||
|
||||
|
||||
def sample_features4d(input, coords):
|
||||
r"""Sample spatial features
|
||||
|
||||
`sample_features4d(input, coords)` samples the spatial features
|
||||
:attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
|
||||
|
||||
The field is sampled at coordinates :attr:`coords` using bilinear
|
||||
interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
|
||||
3)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
|
||||
same convention as :func:`bilinear_sampler` with `align_corners=True`.
|
||||
|
||||
The output tensor has one feature per point, and has shape :math:`(B,
|
||||
R, C)`.
|
||||
|
||||
Args:
|
||||
input (Tensor): spatial features.
|
||||
coords (Tensor): points.
|
||||
|
||||
Returns:
|
||||
Tensor: sampled features.
|
||||
"""
|
||||
|
||||
B, _, _, _ = input.shape
|
||||
|
||||
# B R 2 -> B R 1 2
|
||||
coords = coords.unsqueeze(2)
|
||||
|
||||
# B C R 1
|
||||
feats = bilinear_sampler(input, coords)
|
||||
|
||||
return feats.permute(0, 2, 1, 3).view(
|
||||
B, -1, feats.shape[1] * feats.shape[3]
|
||||
) # B C R 1 -> B R C
|
||||
|
||||
|
||||
def sample_features5d(input, coords):
|
||||
r"""Sample spatio-temporal features
|
||||
|
||||
`sample_features5d(input, coords)` works in the same way as
|
||||
:func:`sample_features4d` but for spatio-temporal features and points:
|
||||
:attr:`input` is a 5D tensor :math:`(B, T, C, H, W)`, :attr:`coords` is
|
||||
a :math:`(B, R1, R2, 3)` tensor of spatio-temporal point :math:`(t_i,
|
||||
x_i, y_i)`. The output tensor has shape :math:`(B, R1, R2, C)`.
|
||||
|
||||
Args:
|
||||
input (Tensor): spatio-temporal features.
|
||||
coords (Tensor): spatio-temporal points.
|
||||
|
||||
Returns:
|
||||
Tensor: sampled features.
|
||||
"""
|
||||
|
||||
B, T, _, _, _ = input.shape
|
||||
|
||||
# B T C H W -> B C T H W
|
||||
input = input.permute(0, 2, 1, 3, 4)
|
||||
|
||||
# B R1 R2 3 -> B R1 R2 1 3
|
||||
coords = coords.unsqueeze(3)
|
||||
|
||||
# B C R1 R2 1
|
||||
feats = bilinear_sampler(input, coords)
|
||||
|
||||
return feats.permute(0, 2, 3, 1, 4).view(
|
||||
B, feats.shape[2], feats.shape[3], feats.shape[1]
|
||||
) # B C R1 R2 1 -> B R1 R2 C
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, Tuple
|
||||
|
||||
EPS = 1e-6
|
||||
|
||||
|
||||
def smart_cat(tensor1, tensor2, dim):
|
||||
if tensor1 is None:
|
||||
return tensor2
|
||||
return torch.cat([tensor1, tensor2], dim=dim)
|
||||
|
||||
|
||||
def get_points_on_a_grid(
|
||||
size: int,
|
||||
extent: Tuple[float, ...],
|
||||
center: Optional[Tuple[float, ...]] = None,
|
||||
device: Optional[torch.device] = torch.device("cpu"),
|
||||
):
|
||||
r"""Get a grid of points covering a rectangular region
|
||||
|
||||
`get_points_on_a_grid(size, extent)` generates a :attr:`size` by
|
||||
:attr:`size` grid fo points distributed to cover a rectangular area
|
||||
specified by `extent`.
|
||||
|
||||
The `extent` is a pair of integer :math:`(H,W)` specifying the height
|
||||
and width of the rectangle.
|
||||
|
||||
Optionally, the :attr:`center` can be specified as a pair :math:`(c_y,c_x)`
|
||||
specifying the vertical and horizontal center coordinates. The center
|
||||
defaults to the middle of the extent.
|
||||
|
||||
Points are distributed uniformly within the rectangle leaving a margin
|
||||
:math:`m=W/64` from the border.
|
||||
|
||||
It returns a :math:`(1, \text{size} \times \text{size}, 2)` tensor of
|
||||
points :math:`P_{ij}=(x_i, y_i)` where
|
||||
|
||||
.. math::
|
||||
P_{ij} = \left(
|
||||
c_x + m -\frac{W}{2} + \frac{W - 2m}{\text{size} - 1}\, j,~
|
||||
c_y + m -\frac{H}{2} + \frac{H - 2m}{\text{size} - 1}\, i
|
||||
\right)
|
||||
|
||||
Points are returned in row-major order.
|
||||
|
||||
Args:
|
||||
size (int): grid size.
|
||||
extent (tuple): height and with of the grid extent.
|
||||
center (tuple, optional): grid center.
|
||||
device (str, optional): Defaults to `"cpu"`.
|
||||
|
||||
Returns:
|
||||
Tensor: grid.
|
||||
"""
|
||||
if size == 1:
|
||||
return torch.tensor([extent[1] / 2, extent[0] / 2], device=device)[None, None]
|
||||
|
||||
if center is None:
|
||||
center = [extent[0] / 2, extent[1] / 2]
|
||||
|
||||
margin = extent[1] / 64
|
||||
range_y = (margin - extent[0] / 2 + center[0], extent[0] / 2 + center[0] - margin)
|
||||
range_x = (margin - extent[1] / 2 + center[1], extent[1] / 2 + center[1] - margin)
|
||||
grid_y, grid_x = torch.meshgrid(
|
||||
torch.linspace(*range_y, size, device=device),
|
||||
torch.linspace(*range_x, size, device=device),
|
||||
indexing="ij",
|
||||
)
|
||||
return torch.stack([grid_x, grid_y], dim=-1).reshape(1, -1, 2)
|
||||
|
||||
|
||||
def reduce_masked_mean(input, mask, dim=None, keepdim=False):
|
||||
r"""Masked mean
|
||||
|
||||
`reduce_masked_mean(x, mask)` computes the mean of a tensor :attr:`input`
|
||||
over a mask :attr:`mask`, returning
|
||||
|
||||
.. math::
|
||||
\text{output} =
|
||||
\frac
|
||||
{\sum_{i=1}^N \text{input}_i \cdot \text{mask}_i}
|
||||
{\epsilon + \sum_{i=1}^N \text{mask}_i}
|
||||
|
||||
where :math:`N` is the number of elements in :attr:`input` and
|
||||
:attr:`mask`, and :math:`\epsilon` is a small constant to avoid
|
||||
division by zero.
|
||||
|
||||
`reduced_masked_mean(x, mask, dim)` computes the mean of a tensor
|
||||
:attr:`input` over a mask :attr:`mask` along a dimension :attr:`dim`.
|
||||
Optionally, the dimension can be kept in the output by setting
|
||||
:attr:`keepdim` to `True`. Tensor :attr:`mask` must be broadcastable to
|
||||
the same dimension as :attr:`input`.
|
||||
|
||||
The interface is similar to `torch.mean()`.
|
||||
|
||||
Args:
|
||||
inout (Tensor): input tensor.
|
||||
mask (Tensor): mask.
|
||||
dim (int, optional): Dimension to sum over. Defaults to None.
|
||||
keepdim (bool, optional): Keep the summed dimension. Defaults to False.
|
||||
|
||||
Returns:
|
||||
Tensor: mean tensor.
|
||||
"""
|
||||
|
||||
mask = mask.expand_as(input)
|
||||
|
||||
prod = input * mask
|
||||
|
||||
if dim is None:
|
||||
numer = torch.sum(prod)
|
||||
denom = torch.sum(mask)
|
||||
else:
|
||||
numer = torch.sum(prod, dim=dim, keepdim=keepdim)
|
||||
denom = torch.sum(mask, dim=dim, keepdim=keepdim)
|
||||
|
||||
mean = numer / (EPS + denom)
|
||||
return mean
|
||||
|
||||
|
||||
def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
|
||||
r"""Sample a tensor using bilinear interpolation
|
||||
|
||||
`bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
|
||||
coordinates :attr:`coords` using bilinear interpolation. It is the same
|
||||
as `torch.nn.functional.grid_sample()` but with a different coordinate
|
||||
convention.
|
||||
|
||||
The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
|
||||
:math:`B` is the batch size, :math:`C` is the number of channels,
|
||||
:math:`H` is the height of the image, and :math:`W` is the width of the
|
||||
image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
|
||||
interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
|
||||
|
||||
Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
|
||||
in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
|
||||
that in this case the order of the components is slightly different
|
||||
from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
|
||||
|
||||
If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
|
||||
in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
|
||||
left-most image pixel :math:`W-1` to the center of the right-most
|
||||
pixel.
|
||||
|
||||
If `align_corners` is `False`, the coordinate :math:`x` is assumed to
|
||||
be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
|
||||
the left-most pixel :math:`W` to the right edge of the right-most
|
||||
pixel.
|
||||
|
||||
Similar conventions apply to the :math:`y` for the range
|
||||
:math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
|
||||
:math:`[0,T-1]` and :math:`[0,T]`.
|
||||
|
||||
Args:
|
||||
input (Tensor): batch of input images.
|
||||
coords (Tensor): batch of coordinates.
|
||||
align_corners (bool, optional): Coordinate convention. Defaults to `True`.
|
||||
padding_mode (str, optional): Padding mode. Defaults to `"border"`.
|
||||
|
||||
Returns:
|
||||
Tensor: sampled points.
|
||||
"""
|
||||
|
||||
sizes = input.shape[2:]
|
||||
|
||||
assert len(sizes) in [2, 3]
|
||||
|
||||
if len(sizes) == 3:
|
||||
# t x y -> x y t to match dimensions T H W in grid_sample
|
||||
coords = coords[..., [1, 2, 0]]
|
||||
|
||||
if align_corners:
|
||||
coords = coords * torch.tensor(
|
||||
[2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device
|
||||
)
|
||||
else:
|
||||
coords = coords * torch.tensor([2 / size for size in reversed(sizes)], device=coords.device)
|
||||
|
||||
coords -= 1
|
||||
|
||||
return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
|
||||
|
||||
|
||||
def sample_features4d(input, coords):
|
||||
r"""Sample spatial features
|
||||
|
||||
`sample_features4d(input, coords)` samples the spatial features
|
||||
:attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
|
||||
|
||||
The field is sampled at coordinates :attr:`coords` using bilinear
|
||||
interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
|
||||
3)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
|
||||
same convention as :func:`bilinear_sampler` with `align_corners=True`.
|
||||
|
||||
The output tensor has one feature per point, and has shape :math:`(B,
|
||||
R, C)`.
|
||||
|
||||
Args:
|
||||
input (Tensor): spatial features.
|
||||
coords (Tensor): points.
|
||||
|
||||
Returns:
|
||||
Tensor: sampled features.
|
||||
"""
|
||||
|
||||
B, _, _, _ = input.shape
|
||||
|
||||
# B R 2 -> B R 1 2
|
||||
coords = coords.unsqueeze(2)
|
||||
|
||||
# B C R 1
|
||||
feats = bilinear_sampler(input, coords)
|
||||
|
||||
return feats.permute(0, 2, 1, 3).view(
|
||||
B, -1, feats.shape[1] * feats.shape[3]
|
||||
) # B C R 1 -> B R C
|
||||
|
||||
|
||||
def sample_features5d(input, coords):
|
||||
r"""Sample spatio-temporal features
|
||||
|
||||
`sample_features5d(input, coords)` works in the same way as
|
||||
:func:`sample_features4d` but for spatio-temporal features and points:
|
||||
:attr:`input` is a 5D tensor :math:`(B, T, C, H, W)`, :attr:`coords` is
|
||||
a :math:`(B, R1, R2, 3)` tensor of spatio-temporal point :math:`(t_i,
|
||||
x_i, y_i)`. The output tensor has shape :math:`(B, R1, R2, C)`.
|
||||
|
||||
Args:
|
||||
input (Tensor): spatio-temporal features.
|
||||
coords (Tensor): spatio-temporal points.
|
||||
|
||||
Returns:
|
||||
Tensor: sampled features.
|
||||
"""
|
||||
|
||||
B, T, _, _, _ = input.shape
|
||||
|
||||
# B T C H W -> B C T H W
|
||||
input = input.permute(0, 2, 1, 3, 4)
|
||||
|
||||
# B R1 R2 3 -> B R1 R2 1 3
|
||||
coords = coords.unsqueeze(3)
|
||||
|
||||
# B C R1 R2 1
|
||||
feats = bilinear_sampler(input, coords)
|
||||
|
||||
return feats.permute(0, 2, 3, 1, 4).view(
|
||||
B, feats.shape[2], feats.shape[3], feats.shape[1]
|
||||
) # B C R1 R2 1 -> B R1 R2 C
|
||||
|
@@ -1,105 +1,104 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Tuple
|
||||
|
||||
from cotracker.models.core.cotracker.cotracker import CoTracker2
|
||||
from cotracker.models.core.model_utils import get_points_on_a_grid
|
||||
|
||||
|
||||
class EvaluationPredictor(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
cotracker_model: CoTracker2,
|
||||
interp_shape: Tuple[int, int] = (384, 512),
|
||||
grid_size: int = 5,
|
||||
local_grid_size: int = 8,
|
||||
single_point: bool = True,
|
||||
n_iters: int = 6,
|
||||
) -> None:
|
||||
super(EvaluationPredictor, self).__init__()
|
||||
self.grid_size = grid_size
|
||||
self.local_grid_size = local_grid_size
|
||||
self.single_point = single_point
|
||||
self.interp_shape = interp_shape
|
||||
self.n_iters = n_iters
|
||||
|
||||
self.model = cotracker_model
|
||||
self.model.eval()
|
||||
|
||||
def forward(self, video, queries):
|
||||
queries = queries.clone()
|
||||
B, T, C, H, W = video.shape
|
||||
B, N, D = queries.shape
|
||||
|
||||
assert D == 3
|
||||
assert B == 1
|
||||
|
||||
video = video.reshape(B * T, C, H, W)
|
||||
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True)
|
||||
video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
|
||||
|
||||
device = video.device
|
||||
|
||||
queries[:, :, 1] *= (self.interp_shape[1] - 1) / (W - 1)
|
||||
queries[:, :, 2] *= (self.interp_shape[0] - 1) / (H - 1)
|
||||
|
||||
if self.single_point:
|
||||
traj_e = torch.zeros((B, T, N, 2), device=device)
|
||||
vis_e = torch.zeros((B, T, N), device=device)
|
||||
for pind in range((N)):
|
||||
query = queries[:, pind : pind + 1]
|
||||
|
||||
t = query[0, 0, 0].long()
|
||||
|
||||
traj_e_pind, vis_e_pind = self._process_one_point(video, query)
|
||||
traj_e[:, t:, pind : pind + 1] = traj_e_pind[:, :, :1]
|
||||
vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1]
|
||||
else:
|
||||
if self.grid_size > 0:
|
||||
xy = get_points_on_a_grid(self.grid_size, video.shape[3:])
|
||||
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
|
||||
queries = torch.cat([queries, xy], dim=1) #
|
||||
|
||||
traj_e, vis_e, __ = self.model(
|
||||
video=video,
|
||||
queries=queries,
|
||||
iters=self.n_iters,
|
||||
)
|
||||
|
||||
traj_e[:, :, :, 0] *= (W - 1) / float(self.interp_shape[1] - 1)
|
||||
traj_e[:, :, :, 1] *= (H - 1) / float(self.interp_shape[0] - 1)
|
||||
return traj_e, vis_e
|
||||
|
||||
def _process_one_point(self, video, query):
|
||||
t = query[0, 0, 0].long()
|
||||
|
||||
device = query.device
|
||||
if self.local_grid_size > 0:
|
||||
xy_target = get_points_on_a_grid(
|
||||
self.local_grid_size,
|
||||
(50, 50),
|
||||
[query[0, 0, 2].item(), query[0, 0, 1].item()],
|
||||
)
|
||||
|
||||
xy_target = torch.cat([torch.zeros_like(xy_target[:, :, :1]), xy_target], dim=2).to(
|
||||
device
|
||||
) #
|
||||
query = torch.cat([query, xy_target], dim=1) #
|
||||
|
||||
if self.grid_size > 0:
|
||||
xy = get_points_on_a_grid(self.grid_size, video.shape[3:])
|
||||
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
|
||||
query = torch.cat([query, xy], dim=1) #
|
||||
# crop the video to start from the queried frame
|
||||
query[0, 0, 0] = 0
|
||||
traj_e_pind, vis_e_pind, __ = self.model(
|
||||
video=video[:, t:], queries=query, iters=self.n_iters
|
||||
)
|
||||
|
||||
return traj_e_pind, vis_e_pind
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Tuple
|
||||
|
||||
from cotracker.models.core.cotracker.cotracker import CoTracker2
|
||||
from cotracker.models.core.model_utils import get_points_on_a_grid
|
||||
|
||||
|
||||
class EvaluationPredictor(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
cotracker_model: CoTracker2,
|
||||
interp_shape: Tuple[int, int] = (384, 512),
|
||||
grid_size: int = 5,
|
||||
local_grid_size: int = 8,
|
||||
single_point: bool = True,
|
||||
n_iters: int = 6,
|
||||
) -> None:
|
||||
super(EvaluationPredictor, self).__init__()
|
||||
self.grid_size = grid_size
|
||||
self.local_grid_size = local_grid_size
|
||||
self.single_point = single_point
|
||||
self.interp_shape = interp_shape
|
||||
self.n_iters = n_iters
|
||||
|
||||
self.model = cotracker_model
|
||||
self.model.eval()
|
||||
|
||||
def forward(self, video, queries):
|
||||
queries = queries.clone()
|
||||
B, T, C, H, W = video.shape
|
||||
B, N, D = queries.shape
|
||||
|
||||
assert D == 3
|
||||
|
||||
video = video.reshape(B * T, C, H, W)
|
||||
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True)
|
||||
video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
|
||||
|
||||
device = video.device
|
||||
|
||||
queries[:, :, 1] *= (self.interp_shape[1] - 1) / (W - 1)
|
||||
queries[:, :, 2] *= (self.interp_shape[0] - 1) / (H - 1)
|
||||
|
||||
if self.single_point:
|
||||
traj_e = torch.zeros((B, T, N, 2), device=device)
|
||||
vis_e = torch.zeros((B, T, N), device=device)
|
||||
for pind in range((N)):
|
||||
query = queries[:, pind : pind + 1]
|
||||
|
||||
t = query[0, 0, 0].long()
|
||||
|
||||
traj_e_pind, vis_e_pind = self._process_one_point(video, query)
|
||||
traj_e[:, t:, pind : pind + 1] = traj_e_pind[:, :, :1]
|
||||
vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1]
|
||||
else:
|
||||
if self.grid_size > 0:
|
||||
xy = get_points_on_a_grid(self.grid_size, video.shape[3:])
|
||||
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
|
||||
queries = torch.cat([queries, xy], dim=1) #
|
||||
|
||||
traj_e, vis_e, __ = self.model(
|
||||
video=video,
|
||||
queries=queries,
|
||||
iters=self.n_iters,
|
||||
)
|
||||
|
||||
traj_e[:, :, :, 0] *= (W - 1) / float(self.interp_shape[1] - 1)
|
||||
traj_e[:, :, :, 1] *= (H - 1) / float(self.interp_shape[0] - 1)
|
||||
return traj_e, vis_e
|
||||
|
||||
def _process_one_point(self, video, query):
|
||||
t = query[0, 0, 0].long()
|
||||
|
||||
device = query.device
|
||||
if self.local_grid_size > 0:
|
||||
xy_target = get_points_on_a_grid(
|
||||
self.local_grid_size,
|
||||
(50, 50),
|
||||
[query[0, 0, 2].item(), query[0, 0, 1].item()],
|
||||
)
|
||||
|
||||
xy_target = torch.cat([torch.zeros_like(xy_target[:, :, :1]), xy_target], dim=2).to(
|
||||
device
|
||||
) #
|
||||
query = torch.cat([query, xy_target], dim=1) #
|
||||
|
||||
if self.grid_size > 0:
|
||||
xy = get_points_on_a_grid(self.grid_size, video.shape[3:])
|
||||
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
|
||||
query = torch.cat([query, xy], dim=1) #
|
||||
# crop the video to start from the queried frame
|
||||
query[0, 0, 0] = 0
|
||||
traj_e_pind, vis_e_pind, __ = self.model(
|
||||
video=video[:, t:], queries=query, iters=self.n_iters
|
||||
)
|
||||
|
||||
return traj_e_pind, vis_e_pind
|
||||
|
@@ -1,257 +1,282 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from cotracker.models.core.model_utils import smart_cat, get_points_on_a_grid
|
||||
from cotracker.models.build_cotracker import build_cotracker
|
||||
|
||||
|
||||
class CoTrackerPredictor(torch.nn.Module):
|
||||
def __init__(self, checkpoint="./checkpoints/cotracker2.pth"):
|
||||
super().__init__()
|
||||
self.support_grid_size = 6
|
||||
model = build_cotracker(checkpoint)
|
||||
self.interp_shape = model.model_resolution
|
||||
self.model = model
|
||||
self.model.eval()
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
video, # (1, T, 3, H, W)
|
||||
# input prompt types:
|
||||
# - None. Dense tracks are computed in this case. You can adjust *query_frame* to compute tracks starting from a specific frame.
|
||||
# *backward_tracking=True* will compute tracks in both directions.
|
||||
# - queries. Queried points of shape (1, N, 3) in format (t, x, y) for frame index and pixel coordinates.
|
||||
# - grid_size. Grid of N*N points from the first frame. if segm_mask is provided, then computed only for the mask.
|
||||
# You can adjust *query_frame* and *backward_tracking* for the regular grid in the same way as for dense tracks.
|
||||
queries: torch.Tensor = None,
|
||||
segm_mask: torch.Tensor = None, # Segmentation mask of shape (B, 1, H, W)
|
||||
grid_size: int = 0,
|
||||
grid_query_frame: int = 0, # only for dense and regular grid tracks
|
||||
backward_tracking: bool = False,
|
||||
):
|
||||
if queries is None and grid_size == 0:
|
||||
tracks, visibilities = self._compute_dense_tracks(
|
||||
video,
|
||||
grid_query_frame=grid_query_frame,
|
||||
backward_tracking=backward_tracking,
|
||||
)
|
||||
else:
|
||||
tracks, visibilities = self._compute_sparse_tracks(
|
||||
video,
|
||||
queries,
|
||||
segm_mask,
|
||||
grid_size,
|
||||
add_support_grid=(grid_size == 0 or segm_mask is not None),
|
||||
grid_query_frame=grid_query_frame,
|
||||
backward_tracking=backward_tracking,
|
||||
)
|
||||
|
||||
return tracks, visibilities
|
||||
|
||||
def _compute_dense_tracks(self, video, grid_query_frame, grid_size=80, backward_tracking=False):
|
||||
*_, H, W = video.shape
|
||||
grid_step = W // grid_size
|
||||
grid_width = W // grid_step
|
||||
grid_height = H // grid_step
|
||||
tracks = visibilities = None
|
||||
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(video.device)
|
||||
grid_pts[0, :, 0] = grid_query_frame
|
||||
for offset in range(grid_step * grid_step):
|
||||
print(f"step {offset} / {grid_step * grid_step}")
|
||||
ox = offset % grid_step
|
||||
oy = offset // grid_step
|
||||
grid_pts[0, :, 1] = torch.arange(grid_width).repeat(grid_height) * grid_step + ox
|
||||
grid_pts[0, :, 2] = (
|
||||
torch.arange(grid_height).repeat_interleave(grid_width) * grid_step + oy
|
||||
)
|
||||
tracks_step, visibilities_step = self._compute_sparse_tracks(
|
||||
video=video,
|
||||
queries=grid_pts,
|
||||
backward_tracking=backward_tracking,
|
||||
)
|
||||
tracks = smart_cat(tracks, tracks_step, dim=2)
|
||||
visibilities = smart_cat(visibilities, visibilities_step, dim=2)
|
||||
|
||||
return tracks, visibilities
|
||||
|
||||
def _compute_sparse_tracks(
|
||||
self,
|
||||
video,
|
||||
queries,
|
||||
segm_mask=None,
|
||||
grid_size=0,
|
||||
add_support_grid=False,
|
||||
grid_query_frame=0,
|
||||
backward_tracking=False,
|
||||
):
|
||||
B, T, C, H, W = video.shape
|
||||
assert B == 1
|
||||
|
||||
video = video.reshape(B * T, C, H, W)
|
||||
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True)
|
||||
video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
|
||||
|
||||
if queries is not None:
|
||||
B, N, D = queries.shape
|
||||
assert D == 3
|
||||
queries = queries.clone()
|
||||
queries[:, :, 1:] *= queries.new_tensor(
|
||||
[
|
||||
(self.interp_shape[1] - 1) / (W - 1),
|
||||
(self.interp_shape[0] - 1) / (H - 1),
|
||||
]
|
||||
)
|
||||
elif grid_size > 0:
|
||||
grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=video.device)
|
||||
if segm_mask is not None:
|
||||
segm_mask = F.interpolate(segm_mask, tuple(self.interp_shape), mode="nearest")
|
||||
point_mask = segm_mask[0, 0][
|
||||
(grid_pts[0, :, 1]).round().long().cpu(),
|
||||
(grid_pts[0, :, 0]).round().long().cpu(),
|
||||
].bool()
|
||||
grid_pts = grid_pts[:, point_mask]
|
||||
|
||||
queries = torch.cat(
|
||||
[torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts],
|
||||
dim=2,
|
||||
)
|
||||
|
||||
if add_support_grid:
|
||||
grid_pts = get_points_on_a_grid(
|
||||
self.support_grid_size, self.interp_shape, device=video.device
|
||||
)
|
||||
grid_pts = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)
|
||||
queries = torch.cat([queries, grid_pts], dim=1)
|
||||
|
||||
tracks, visibilities, __ = self.model.forward(video=video, queries=queries, iters=6)
|
||||
|
||||
if backward_tracking:
|
||||
tracks, visibilities = self._compute_backward_tracks(
|
||||
video, queries, tracks, visibilities
|
||||
)
|
||||
if add_support_grid:
|
||||
queries[:, -self.support_grid_size**2 :, 0] = T - 1
|
||||
if add_support_grid:
|
||||
tracks = tracks[:, :, : -self.support_grid_size**2]
|
||||
visibilities = visibilities[:, :, : -self.support_grid_size**2]
|
||||
thr = 0.9
|
||||
visibilities = visibilities > thr
|
||||
|
||||
# correct query-point predictions
|
||||
# see https://github.com/facebookresearch/co-tracker/issues/28
|
||||
|
||||
# TODO: batchify
|
||||
for i in range(len(queries)):
|
||||
queries_t = queries[i, : tracks.size(2), 0].to(torch.int64)
|
||||
arange = torch.arange(0, len(queries_t))
|
||||
|
||||
# overwrite the predictions with the query points
|
||||
tracks[i, queries_t, arange] = queries[i, : tracks.size(2), 1:]
|
||||
|
||||
# correct visibilities, the query points should be visible
|
||||
visibilities[i, queries_t, arange] = True
|
||||
|
||||
tracks *= tracks.new_tensor(
|
||||
[(W - 1) / (self.interp_shape[1] - 1), (H - 1) / (self.interp_shape[0] - 1)]
|
||||
)
|
||||
return tracks, visibilities
|
||||
|
||||
def _compute_backward_tracks(self, video, queries, tracks, visibilities):
|
||||
inv_video = video.flip(1).clone()
|
||||
inv_queries = queries.clone()
|
||||
inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1
|
||||
|
||||
inv_tracks, inv_visibilities, __ = self.model(video=inv_video, queries=inv_queries, iters=6)
|
||||
|
||||
inv_tracks = inv_tracks.flip(1)
|
||||
inv_visibilities = inv_visibilities.flip(1)
|
||||
arange = torch.arange(video.shape[1], device=queries.device)[None, :, None]
|
||||
|
||||
mask = (arange < queries[None, :, :, 0]).unsqueeze(-1).repeat(1, 1, 1, 2)
|
||||
|
||||
tracks[mask] = inv_tracks[mask]
|
||||
visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]]
|
||||
return tracks, visibilities
|
||||
|
||||
|
||||
class CoTrackerOnlinePredictor(torch.nn.Module):
|
||||
def __init__(self, checkpoint="./checkpoints/cotracker2.pth"):
|
||||
super().__init__()
|
||||
self.support_grid_size = 6
|
||||
model = build_cotracker(checkpoint)
|
||||
self.interp_shape = model.model_resolution
|
||||
self.step = model.window_len // 2
|
||||
self.model = model
|
||||
self.model.eval()
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
video_chunk,
|
||||
is_first_step: bool = False,
|
||||
queries: torch.Tensor = None,
|
||||
grid_size: int = 10,
|
||||
grid_query_frame: int = 0,
|
||||
add_support_grid=False,
|
||||
):
|
||||
# Initialize online video processing and save queried points
|
||||
# This needs to be done before processing *each new video*
|
||||
if is_first_step:
|
||||
self.model.init_video_online_processing()
|
||||
if queries is not None:
|
||||
B, N, D = queries.shape
|
||||
assert D == 3
|
||||
queries = queries.clone()
|
||||
queries[:, :, 1:] *= queries.new_tensor(
|
||||
[
|
||||
(self.interp_shape[1] - 1) / (W - 1),
|
||||
(self.interp_shape[0] - 1) / (H - 1),
|
||||
]
|
||||
)
|
||||
elif grid_size > 0:
|
||||
grid_pts = get_points_on_a_grid(
|
||||
grid_size, self.interp_shape, device=video_chunk.device
|
||||
)
|
||||
queries = torch.cat(
|
||||
[torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts],
|
||||
dim=2,
|
||||
)
|
||||
if add_support_grid:
|
||||
grid_pts = get_points_on_a_grid(
|
||||
self.support_grid_size, self.interp_shape, device=video_chunk.device
|
||||
)
|
||||
grid_pts = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)
|
||||
queries = torch.cat([queries, grid_pts], dim=1)
|
||||
self.queries = queries
|
||||
return (None, None)
|
||||
B, T, C, H, W = video_chunk.shape
|
||||
video_chunk = video_chunk.reshape(B * T, C, H, W)
|
||||
video_chunk = F.interpolate(
|
||||
video_chunk, tuple(self.interp_shape), mode="bilinear", align_corners=True
|
||||
)
|
||||
video_chunk = video_chunk.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
|
||||
|
||||
tracks, visibilities, __ = self.model(
|
||||
video=video_chunk,
|
||||
queries=self.queries,
|
||||
iters=6,
|
||||
is_online=True,
|
||||
)
|
||||
thr = 0.9
|
||||
return (
|
||||
tracks
|
||||
* tracks.new_tensor(
|
||||
[
|
||||
(W - 1) / (self.interp_shape[1] - 1),
|
||||
(H - 1) / (self.interp_shape[0] - 1),
|
||||
]
|
||||
),
|
||||
visibilities > thr,
|
||||
)
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from cotracker.models.core.model_utils import smart_cat, get_points_on_a_grid
|
||||
from cotracker.models.build_cotracker import build_cotracker
|
||||
|
||||
|
||||
class CoTrackerPredictor(torch.nn.Module):
|
||||
def __init__(self, checkpoint="./checkpoints/cotracker2.pth"):
|
||||
super().__init__()
|
||||
self.support_grid_size = 6
|
||||
model = build_cotracker(checkpoint)
|
||||
self.interp_shape = model.model_resolution
|
||||
print(self.interp_shape)
|
||||
self.model = model
|
||||
self.model.eval()
|
||||
|
||||
self.device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
|
||||
self.model.to(self.device)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
video, # (B, T, 3, H, W) Batch_size, time, rgb, height, width
|
||||
# input prompt types:
|
||||
# - None. Dense tracks are computed in this case. You can adjust *query_frame* to compute tracks starting from a specific frame.
|
||||
# *backward_tracking=True* will compute tracks in both directions.
|
||||
# - queries. Queried points of shape (B, N, 3) in format (t, x, y) for frame index and pixel coordinates.
|
||||
# - grid_size. Grid of N*N points from the first frame. if segm_mask is provided, then computed only for the mask.
|
||||
# You can adjust *query_frame* and *backward_tracking* for the regular grid in the same way as for dense tracks.
|
||||
queries: torch.Tensor = None,
|
||||
segm_mask: torch.Tensor = None, # Segmentation mask of shape (B, 1, H, W)
|
||||
grid_size: int = 0,
|
||||
grid_query_frame: int = 0, # only for dense and regular grid tracks
|
||||
backward_tracking: bool = False,
|
||||
):
|
||||
if queries is None and grid_size == 0:
|
||||
tracks, visibilities = self._compute_dense_tracks(
|
||||
video,
|
||||
grid_query_frame=grid_query_frame,
|
||||
backward_tracking=backward_tracking,
|
||||
)
|
||||
else:
|
||||
tracks, visibilities = self._compute_sparse_tracks(
|
||||
video,
|
||||
queries,
|
||||
segm_mask,
|
||||
grid_size,
|
||||
add_support_grid=(grid_size == 0 or segm_mask is not None),
|
||||
grid_query_frame=grid_query_frame,
|
||||
backward_tracking=backward_tracking,
|
||||
)
|
||||
|
||||
return tracks, visibilities
|
||||
|
||||
# gpu dense inference time
|
||||
# raft gpu comparison
|
||||
# vision effects
|
||||
# raft integrated
|
||||
def _compute_dense_tracks(self, video, grid_query_frame, grid_size=80, backward_tracking=False):
|
||||
*_, H, W = video.shape
|
||||
grid_step = W // grid_size
|
||||
grid_width = W // grid_step
|
||||
grid_height = H // grid_step # set the whole video to grid_size number of grids
|
||||
tracks = visibilities = None
|
||||
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(video.device)
|
||||
# (batch_size, grid_number, t,x,y)
|
||||
grid_pts[0, :, 0] = grid_query_frame
|
||||
# iterate every grid
|
||||
for offset in range(grid_step * grid_step):
|
||||
print(f"step {offset} / {grid_step * grid_step}")
|
||||
ox = offset % grid_step
|
||||
oy = offset // grid_step
|
||||
# initialize
|
||||
# for example
|
||||
# grid width = 4, grid height = 4, grid step = 10, ox = 1
|
||||
# torch.arange(grid_width) = [0,1,2,3]
|
||||
# torch.arange(grid_width).repeat(grid_height) = [0,1,2,3,0,1,2,3,0,1,2,3]
|
||||
# torch.arange(grid_width).repeat(grid_height) * grid_step = [0,10,20,30,0,10,20,30,0,10,20,30]
|
||||
# get the location in the image
|
||||
grid_pts[0, :, 1] = torch.arange(grid_width).repeat(grid_height) * grid_step + ox
|
||||
grid_pts[0, :, 2] = (
|
||||
torch.arange(grid_height).repeat_interleave(grid_width) * grid_step + oy
|
||||
)
|
||||
tracks_step, visibilities_step = self._compute_sparse_tracks(
|
||||
video=video,
|
||||
queries=grid_pts,
|
||||
backward_tracking=backward_tracking,
|
||||
)
|
||||
tracks = smart_cat(tracks, tracks_step, dim=2)
|
||||
visibilities = smart_cat(visibilities, visibilities_step, dim=2)
|
||||
|
||||
return tracks, visibilities
|
||||
|
||||
def _compute_sparse_tracks(
|
||||
self,
|
||||
video,
|
||||
queries,
|
||||
segm_mask=None,
|
||||
grid_size=0,
|
||||
add_support_grid=False,
|
||||
grid_query_frame=0,
|
||||
backward_tracking=False,
|
||||
):
|
||||
B, T, C, H, W = video.shape
|
||||
|
||||
video = video.reshape(B * T, C, H, W)
|
||||
# ? what is interpolate?
|
||||
# 将video插值成interp_shape?
|
||||
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True)
|
||||
video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
|
||||
|
||||
if queries is not None:
|
||||
B, N, D = queries.shape # batch_size, number of points, (t,x,y)
|
||||
assert D == 3
|
||||
# query 缩放到( interp_shape - 1 ) / (W - 1)
|
||||
# 插完值之后缩放
|
||||
queries = queries.clone()
|
||||
queries[:, :, 1:] *= queries.new_tensor(
|
||||
[
|
||||
(self.interp_shape[1] - 1) / (W - 1),
|
||||
(self.interp_shape[0] - 1) / (H - 1),
|
||||
]
|
||||
)
|
||||
# 生成grid
|
||||
elif grid_size > 0:
|
||||
grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=video.device)
|
||||
if segm_mask is not None:
|
||||
segm_mask = F.interpolate(segm_mask, tuple(self.interp_shape), mode="nearest")
|
||||
point_mask = segm_mask[0, 0][
|
||||
(grid_pts[0, :, 1]).round().long().cpu(),
|
||||
(grid_pts[0, :, 0]).round().long().cpu(),
|
||||
].bool()
|
||||
grid_pts = grid_pts[:, point_mask]
|
||||
|
||||
queries = torch.cat(
|
||||
[torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts],
|
||||
dim=2,
|
||||
).repeat(B, 1, 1)
|
||||
|
||||
# 添加支持点
|
||||
|
||||
if add_support_grid:
|
||||
grid_pts = get_points_on_a_grid(
|
||||
self.support_grid_size, self.interp_shape, device=video.device
|
||||
)
|
||||
grid_pts = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)
|
||||
grid_pts = grid_pts.repeat(B, 1, 1)
|
||||
queries = torch.cat([queries, grid_pts], dim=1)
|
||||
|
||||
tracks, visibilities, __ = self.model.forward(video=video, queries=queries, iters=6)
|
||||
|
||||
if backward_tracking:
|
||||
tracks, visibilities = self._compute_backward_tracks(
|
||||
video, queries, tracks, visibilities
|
||||
)
|
||||
if add_support_grid:
|
||||
queries[:, -self.support_grid_size**2 :, 0] = T - 1
|
||||
if add_support_grid:
|
||||
tracks = tracks[:, :, : -self.support_grid_size**2]
|
||||
visibilities = visibilities[:, :, : -self.support_grid_size**2]
|
||||
thr = 0.9
|
||||
visibilities = visibilities > thr
|
||||
|
||||
# correct query-point predictions
|
||||
# see https://github.com/facebookresearch/co-tracker/issues/28
|
||||
|
||||
# TODO: batchify
|
||||
for i in range(len(queries)):
|
||||
queries_t = queries[i, : tracks.size(2), 0].to(torch.int64)
|
||||
arange = torch.arange(0, len(queries_t))
|
||||
|
||||
# overwrite the predictions with the query points
|
||||
tracks[i, queries_t, arange] = queries[i, : tracks.size(2), 1:]
|
||||
|
||||
# correct visibilities, the query points should be visible
|
||||
visibilities[i, queries_t, arange] = True
|
||||
|
||||
tracks *= tracks.new_tensor(
|
||||
[(W - 1) / (self.interp_shape[1] - 1), (H - 1) / (self.interp_shape[0] - 1)]
|
||||
)
|
||||
return tracks, visibilities
|
||||
|
||||
def _compute_backward_tracks(self, video, queries, tracks, visibilities):
|
||||
inv_video = video.flip(1).clone()
|
||||
inv_queries = queries.clone()
|
||||
inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1
|
||||
|
||||
inv_tracks, inv_visibilities, __ = self.model(video=inv_video, queries=inv_queries, iters=6)
|
||||
|
||||
inv_tracks = inv_tracks.flip(1)
|
||||
inv_visibilities = inv_visibilities.flip(1)
|
||||
arange = torch.arange(video.shape[1], device=queries.device)[None, :, None]
|
||||
|
||||
mask = (arange < queries[:, None, :, 0]).unsqueeze(-1).repeat(1, 1, 1, 2)
|
||||
|
||||
tracks[mask] = inv_tracks[mask]
|
||||
visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]]
|
||||
return tracks, visibilities
|
||||
|
||||
|
||||
class CoTrackerOnlinePredictor(torch.nn.Module):
|
||||
def __init__(self, checkpoint="./checkpoints/cotracker2.pth"):
|
||||
super().__init__()
|
||||
self.support_grid_size = 6
|
||||
model = build_cotracker(checkpoint)
|
||||
self.interp_shape = model.model_resolution
|
||||
self.step = model.window_len // 2
|
||||
self.model = model
|
||||
self.model.eval()
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
video_chunk,
|
||||
is_first_step: bool = False,
|
||||
queries: torch.Tensor = None,
|
||||
grid_size: int = 10,
|
||||
grid_query_frame: int = 0,
|
||||
add_support_grid=False,
|
||||
):
|
||||
B, T, C, H, W = video_chunk.shape
|
||||
# Initialize online video processing and save queried points
|
||||
# This needs to be done before processing *each new video*
|
||||
if is_first_step:
|
||||
self.model.init_video_online_processing()
|
||||
if queries is not None:
|
||||
B, N, D = queries.shape
|
||||
assert D == 3
|
||||
queries = queries.clone()
|
||||
queries[:, :, 1:] *= queries.new_tensor(
|
||||
[
|
||||
(self.interp_shape[1] - 1) / (W - 1),
|
||||
(self.interp_shape[0] - 1) / (H - 1),
|
||||
]
|
||||
)
|
||||
elif grid_size > 0:
|
||||
grid_pts = get_points_on_a_grid(
|
||||
grid_size, self.interp_shape, device=video_chunk.device
|
||||
)
|
||||
queries = torch.cat(
|
||||
[torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts],
|
||||
dim=2,
|
||||
)
|
||||
if add_support_grid:
|
||||
grid_pts = get_points_on_a_grid(
|
||||
self.support_grid_size, self.interp_shape, device=video_chunk.device
|
||||
)
|
||||
grid_pts = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)
|
||||
queries = torch.cat([queries, grid_pts], dim=1)
|
||||
self.queries = queries
|
||||
return (None, None)
|
||||
|
||||
video_chunk = video_chunk.reshape(B * T, C, H, W)
|
||||
video_chunk = F.interpolate(
|
||||
video_chunk, tuple(self.interp_shape), mode="bilinear", align_corners=True
|
||||
)
|
||||
video_chunk = video_chunk.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
|
||||
|
||||
tracks, visibilities, __ = self.model(
|
||||
video=video_chunk,
|
||||
queries=self.queries,
|
||||
iters=6,
|
||||
is_online=True,
|
||||
)
|
||||
thr = 0.9
|
||||
return (
|
||||
tracks
|
||||
* tracks.new_tensor(
|
||||
[
|
||||
(W - 1) / (self.interp_shape[1] - 1),
|
||||
(H - 1) / (self.interp_shape[0] - 1),
|
||||
]
|
||||
),
|
||||
visibilities > thr,
|
||||
)
|
||||
|
@@ -1,5 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
BIN
cotracker/utils/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
cotracker/utils/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/utils/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
cotracker/utils/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
cotracker/utils/__pycache__/visualizer.cpython-38.pyc
Normal file
BIN
cotracker/utils/__pycache__/visualizer.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/utils/__pycache__/visualizer.cpython-39.pyc
Normal file
BIN
cotracker/utils/__pycache__/visualizer.cpython-39.pyc
Normal file
Binary file not shown.
@@ -1,343 +1,343 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import os
|
||||
import numpy as np
|
||||
import imageio
|
||||
import torch
|
||||
|
||||
from matplotlib import cm
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms as transforms
|
||||
import matplotlib.pyplot as plt
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
|
||||
def read_video_from_path(path):
|
||||
try:
|
||||
reader = imageio.get_reader(path)
|
||||
except Exception as e:
|
||||
print("Error opening video file: ", e)
|
||||
return None
|
||||
frames = []
|
||||
for i, im in enumerate(reader):
|
||||
frames.append(np.array(im))
|
||||
return np.stack(frames)
|
||||
|
||||
|
||||
def draw_circle(rgb, coord, radius, color=(255, 0, 0), visible=True):
|
||||
# Create a draw object
|
||||
draw = ImageDraw.Draw(rgb)
|
||||
# Calculate the bounding box of the circle
|
||||
left_up_point = (coord[0] - radius, coord[1] - radius)
|
||||
right_down_point = (coord[0] + radius, coord[1] + radius)
|
||||
# Draw the circle
|
||||
draw.ellipse(
|
||||
[left_up_point, right_down_point],
|
||||
fill=tuple(color) if visible else None,
|
||||
outline=tuple(color),
|
||||
)
|
||||
return rgb
|
||||
|
||||
|
||||
def draw_line(rgb, coord_y, coord_x, color, linewidth):
|
||||
draw = ImageDraw.Draw(rgb)
|
||||
draw.line(
|
||||
(coord_y[0], coord_y[1], coord_x[0], coord_x[1]),
|
||||
fill=tuple(color),
|
||||
width=linewidth,
|
||||
)
|
||||
return rgb
|
||||
|
||||
|
||||
def add_weighted(rgb, alpha, original, beta, gamma):
|
||||
return (rgb * alpha + original * beta + gamma).astype("uint8")
|
||||
|
||||
|
||||
class Visualizer:
|
||||
def __init__(
|
||||
self,
|
||||
save_dir: str = "./results",
|
||||
grayscale: bool = False,
|
||||
pad_value: int = 0,
|
||||
fps: int = 10,
|
||||
mode: str = "rainbow", # 'cool', 'optical_flow'
|
||||
linewidth: int = 2,
|
||||
show_first_frame: int = 10,
|
||||
tracks_leave_trace: int = 0, # -1 for infinite
|
||||
):
|
||||
self.mode = mode
|
||||
self.save_dir = save_dir
|
||||
if mode == "rainbow":
|
||||
self.color_map = cm.get_cmap("gist_rainbow")
|
||||
elif mode == "cool":
|
||||
self.color_map = cm.get_cmap(mode)
|
||||
self.show_first_frame = show_first_frame
|
||||
self.grayscale = grayscale
|
||||
self.tracks_leave_trace = tracks_leave_trace
|
||||
self.pad_value = pad_value
|
||||
self.linewidth = linewidth
|
||||
self.fps = fps
|
||||
|
||||
def visualize(
|
||||
self,
|
||||
video: torch.Tensor, # (B,T,C,H,W)
|
||||
tracks: torch.Tensor, # (B,T,N,2)
|
||||
visibility: torch.Tensor = None, # (B, T, N, 1) bool
|
||||
gt_tracks: torch.Tensor = None, # (B,T,N,2)
|
||||
segm_mask: torch.Tensor = None, # (B,1,H,W)
|
||||
filename: str = "video",
|
||||
writer=None, # tensorboard Summary Writer, used for visualization during training
|
||||
step: int = 0,
|
||||
query_frame: int = 0,
|
||||
save_video: bool = True,
|
||||
compensate_for_camera_motion: bool = False,
|
||||
):
|
||||
if compensate_for_camera_motion:
|
||||
assert segm_mask is not None
|
||||
if segm_mask is not None:
|
||||
coords = tracks[0, query_frame].round().long()
|
||||
segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long()
|
||||
|
||||
video = F.pad(
|
||||
video,
|
||||
(self.pad_value, self.pad_value, self.pad_value, self.pad_value),
|
||||
"constant",
|
||||
255,
|
||||
)
|
||||
tracks = tracks + self.pad_value
|
||||
|
||||
if self.grayscale:
|
||||
transform = transforms.Grayscale()
|
||||
video = transform(video)
|
||||
video = video.repeat(1, 1, 3, 1, 1)
|
||||
|
||||
res_video = self.draw_tracks_on_video(
|
||||
video=video,
|
||||
tracks=tracks,
|
||||
visibility=visibility,
|
||||
segm_mask=segm_mask,
|
||||
gt_tracks=gt_tracks,
|
||||
query_frame=query_frame,
|
||||
compensate_for_camera_motion=compensate_for_camera_motion,
|
||||
)
|
||||
if save_video:
|
||||
self.save_video(res_video, filename=filename, writer=writer, step=step)
|
||||
return res_video
|
||||
|
||||
def save_video(self, video, filename, writer=None, step=0):
|
||||
if writer is not None:
|
||||
writer.add_video(
|
||||
filename,
|
||||
video.to(torch.uint8),
|
||||
global_step=step,
|
||||
fps=self.fps,
|
||||
)
|
||||
else:
|
||||
os.makedirs(self.save_dir, exist_ok=True)
|
||||
wide_list = list(video.unbind(1))
|
||||
wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list]
|
||||
|
||||
# Prepare the video file path
|
||||
save_path = os.path.join(self.save_dir, f"{filename}.mp4")
|
||||
|
||||
# Create a writer object
|
||||
video_writer = imageio.get_writer(save_path, fps=self.fps)
|
||||
|
||||
# Write frames to the video file
|
||||
for frame in wide_list[2:-1]:
|
||||
video_writer.append_data(frame)
|
||||
|
||||
video_writer.close()
|
||||
|
||||
print(f"Video saved to {save_path}")
|
||||
|
||||
def draw_tracks_on_video(
|
||||
self,
|
||||
video: torch.Tensor,
|
||||
tracks: torch.Tensor,
|
||||
visibility: torch.Tensor = None,
|
||||
segm_mask: torch.Tensor = None,
|
||||
gt_tracks=None,
|
||||
query_frame: int = 0,
|
||||
compensate_for_camera_motion=False,
|
||||
):
|
||||
B, T, C, H, W = video.shape
|
||||
_, _, N, D = tracks.shape
|
||||
|
||||
assert D == 2
|
||||
assert C == 3
|
||||
video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy() # S, H, W, C
|
||||
tracks = tracks[0].long().detach().cpu().numpy() # S, N, 2
|
||||
if gt_tracks is not None:
|
||||
gt_tracks = gt_tracks[0].detach().cpu().numpy()
|
||||
|
||||
res_video = []
|
||||
|
||||
# process input video
|
||||
for rgb in video:
|
||||
res_video.append(rgb.copy())
|
||||
vector_colors = np.zeros((T, N, 3))
|
||||
|
||||
if self.mode == "optical_flow":
|
||||
import flow_vis
|
||||
|
||||
vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None])
|
||||
elif segm_mask is None:
|
||||
if self.mode == "rainbow":
|
||||
y_min, y_max = (
|
||||
tracks[query_frame, :, 1].min(),
|
||||
tracks[query_frame, :, 1].max(),
|
||||
)
|
||||
norm = plt.Normalize(y_min, y_max)
|
||||
for n in range(N):
|
||||
color = self.color_map(norm(tracks[query_frame, n, 1]))
|
||||
color = np.array(color[:3])[None] * 255
|
||||
vector_colors[:, n] = np.repeat(color, T, axis=0)
|
||||
else:
|
||||
# color changes with time
|
||||
for t in range(T):
|
||||
color = np.array(self.color_map(t / T)[:3])[None] * 255
|
||||
vector_colors[t] = np.repeat(color, N, axis=0)
|
||||
else:
|
||||
if self.mode == "rainbow":
|
||||
vector_colors[:, segm_mask <= 0, :] = 255
|
||||
|
||||
y_min, y_max = (
|
||||
tracks[0, segm_mask > 0, 1].min(),
|
||||
tracks[0, segm_mask > 0, 1].max(),
|
||||
)
|
||||
norm = plt.Normalize(y_min, y_max)
|
||||
for n in range(N):
|
||||
if segm_mask[n] > 0:
|
||||
color = self.color_map(norm(tracks[0, n, 1]))
|
||||
color = np.array(color[:3])[None] * 255
|
||||
vector_colors[:, n] = np.repeat(color, T, axis=0)
|
||||
|
||||
else:
|
||||
# color changes with segm class
|
||||
segm_mask = segm_mask.cpu()
|
||||
color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32)
|
||||
color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0
|
||||
color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0
|
||||
vector_colors = np.repeat(color[None], T, axis=0)
|
||||
|
||||
# draw tracks
|
||||
if self.tracks_leave_trace != 0:
|
||||
for t in range(query_frame + 1, T):
|
||||
first_ind = (
|
||||
max(0, t - self.tracks_leave_trace) if self.tracks_leave_trace >= 0 else 0
|
||||
)
|
||||
curr_tracks = tracks[first_ind : t + 1]
|
||||
curr_colors = vector_colors[first_ind : t + 1]
|
||||
if compensate_for_camera_motion:
|
||||
diff = (
|
||||
tracks[first_ind : t + 1, segm_mask <= 0]
|
||||
- tracks[t : t + 1, segm_mask <= 0]
|
||||
).mean(1)[:, None]
|
||||
|
||||
curr_tracks = curr_tracks - diff
|
||||
curr_tracks = curr_tracks[:, segm_mask > 0]
|
||||
curr_colors = curr_colors[:, segm_mask > 0]
|
||||
|
||||
res_video[t] = self._draw_pred_tracks(
|
||||
res_video[t],
|
||||
curr_tracks,
|
||||
curr_colors,
|
||||
)
|
||||
if gt_tracks is not None:
|
||||
res_video[t] = self._draw_gt_tracks(res_video[t], gt_tracks[first_ind : t + 1])
|
||||
|
||||
# draw points
|
||||
for t in range(query_frame, T):
|
||||
img = Image.fromarray(np.uint8(res_video[t]))
|
||||
for i in range(N):
|
||||
coord = (tracks[t, i, 0], tracks[t, i, 1])
|
||||
visibile = True
|
||||
if visibility is not None:
|
||||
visibile = visibility[0, t, i]
|
||||
if coord[0] != 0 and coord[1] != 0:
|
||||
if not compensate_for_camera_motion or (
|
||||
compensate_for_camera_motion and segm_mask[i] > 0
|
||||
):
|
||||
img = draw_circle(
|
||||
img,
|
||||
coord=coord,
|
||||
radius=int(self.linewidth * 2),
|
||||
color=vector_colors[t, i].astype(int),
|
||||
visible=visibile,
|
||||
)
|
||||
res_video[t] = np.array(img)
|
||||
|
||||
# construct the final rgb sequence
|
||||
if self.show_first_frame > 0:
|
||||
res_video = [res_video[0]] * self.show_first_frame + res_video[1:]
|
||||
return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte()
|
||||
|
||||
def _draw_pred_tracks(
|
||||
self,
|
||||
rgb: np.ndarray, # H x W x 3
|
||||
tracks: np.ndarray, # T x 2
|
||||
vector_colors: np.ndarray,
|
||||
alpha: float = 0.5,
|
||||
):
|
||||
T, N, _ = tracks.shape
|
||||
rgb = Image.fromarray(np.uint8(rgb))
|
||||
for s in range(T - 1):
|
||||
vector_color = vector_colors[s]
|
||||
original = rgb.copy()
|
||||
alpha = (s / T) ** 2
|
||||
for i in range(N):
|
||||
coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1]))
|
||||
coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1]))
|
||||
if coord_y[0] != 0 and coord_y[1] != 0:
|
||||
rgb = draw_line(
|
||||
rgb,
|
||||
coord_y,
|
||||
coord_x,
|
||||
vector_color[i].astype(int),
|
||||
self.linewidth,
|
||||
)
|
||||
if self.tracks_leave_trace > 0:
|
||||
rgb = Image.fromarray(
|
||||
np.uint8(add_weighted(np.array(rgb), alpha, np.array(original), 1 - alpha, 0))
|
||||
)
|
||||
rgb = np.array(rgb)
|
||||
return rgb
|
||||
|
||||
def _draw_gt_tracks(
|
||||
self,
|
||||
rgb: np.ndarray, # H x W x 3,
|
||||
gt_tracks: np.ndarray, # T x 2
|
||||
):
|
||||
T, N, _ = gt_tracks.shape
|
||||
color = np.array((211, 0, 0))
|
||||
rgb = Image.fromarray(np.uint8(rgb))
|
||||
for t in range(T):
|
||||
for i in range(N):
|
||||
gt_tracks = gt_tracks[t][i]
|
||||
# draw a red cross
|
||||
if gt_tracks[0] > 0 and gt_tracks[1] > 0:
|
||||
length = self.linewidth * 3
|
||||
coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length)
|
||||
coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length)
|
||||
rgb = draw_line(
|
||||
rgb,
|
||||
coord_y,
|
||||
coord_x,
|
||||
color,
|
||||
self.linewidth,
|
||||
)
|
||||
coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length)
|
||||
coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length)
|
||||
rgb = draw_line(
|
||||
rgb,
|
||||
coord_y,
|
||||
coord_x,
|
||||
color,
|
||||
self.linewidth,
|
||||
)
|
||||
rgb = np.array(rgb)
|
||||
return rgb
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import os
|
||||
import numpy as np
|
||||
import imageio
|
||||
import torch
|
||||
|
||||
from matplotlib import cm
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms as transforms
|
||||
import matplotlib.pyplot as plt
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
|
||||
def read_video_from_path(path):
|
||||
try:
|
||||
reader = imageio.get_reader(path)
|
||||
except Exception as e:
|
||||
print("Error opening video file: ", e)
|
||||
return None
|
||||
frames = []
|
||||
for i, im in enumerate(reader):
|
||||
frames.append(np.array(im))
|
||||
return np.stack(frames)
|
||||
|
||||
|
||||
def draw_circle(rgb, coord, radius, color=(255, 0, 0), visible=True):
|
||||
# Create a draw object
|
||||
draw = ImageDraw.Draw(rgb)
|
||||
# Calculate the bounding box of the circle
|
||||
left_up_point = (coord[0] - radius, coord[1] - radius)
|
||||
right_down_point = (coord[0] + radius, coord[1] + radius)
|
||||
# Draw the circle
|
||||
draw.ellipse(
|
||||
[left_up_point, right_down_point],
|
||||
fill=tuple(color) if visible else None,
|
||||
outline=tuple(color),
|
||||
)
|
||||
return rgb
|
||||
|
||||
|
||||
def draw_line(rgb, coord_y, coord_x, color, linewidth):
|
||||
draw = ImageDraw.Draw(rgb)
|
||||
draw.line(
|
||||
(coord_y[0], coord_y[1], coord_x[0], coord_x[1]),
|
||||
fill=tuple(color),
|
||||
width=linewidth,
|
||||
)
|
||||
return rgb
|
||||
|
||||
|
||||
def add_weighted(rgb, alpha, original, beta, gamma):
|
||||
return (rgb * alpha + original * beta + gamma).astype("uint8")
|
||||
|
||||
|
||||
class Visualizer:
|
||||
def __init__(
|
||||
self,
|
||||
save_dir: str = "./results",
|
||||
grayscale: bool = False,
|
||||
pad_value: int = 0,
|
||||
fps: int = 10,
|
||||
mode: str = "rainbow", # 'cool', 'optical_flow'
|
||||
linewidth: int = 2,
|
||||
show_first_frame: int = 10,
|
||||
tracks_leave_trace: int = 0, # -1 for infinite
|
||||
):
|
||||
self.mode = mode
|
||||
self.save_dir = save_dir
|
||||
if mode == "rainbow":
|
||||
self.color_map = cm.get_cmap("gist_rainbow")
|
||||
elif mode == "cool":
|
||||
self.color_map = cm.get_cmap(mode)
|
||||
self.show_first_frame = show_first_frame
|
||||
self.grayscale = grayscale
|
||||
self.tracks_leave_trace = tracks_leave_trace
|
||||
self.pad_value = pad_value
|
||||
self.linewidth = linewidth
|
||||
self.fps = fps
|
||||
|
||||
def visualize(
|
||||
self,
|
||||
video: torch.Tensor, # (B,T,C,H,W)
|
||||
tracks: torch.Tensor, # (B,T,N,2)
|
||||
visibility: torch.Tensor = None, # (B, T, N, 1) bool
|
||||
gt_tracks: torch.Tensor = None, # (B,T,N,2)
|
||||
segm_mask: torch.Tensor = None, # (B,1,H,W)
|
||||
filename: str = "video",
|
||||
writer=None, # tensorboard Summary Writer, used for visualization during training
|
||||
step: int = 0,
|
||||
query_frame: int = 0,
|
||||
save_video: bool = True,
|
||||
compensate_for_camera_motion: bool = False,
|
||||
):
|
||||
if compensate_for_camera_motion:
|
||||
assert segm_mask is not None
|
||||
if segm_mask is not None:
|
||||
coords = tracks[0, query_frame].round().long()
|
||||
segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long()
|
||||
|
||||
video = F.pad(
|
||||
video,
|
||||
(self.pad_value, self.pad_value, self.pad_value, self.pad_value),
|
||||
"constant",
|
||||
255,
|
||||
)
|
||||
tracks = tracks + self.pad_value
|
||||
|
||||
if self.grayscale:
|
||||
transform = transforms.Grayscale()
|
||||
video = transform(video)
|
||||
video = video.repeat(1, 1, 3, 1, 1)
|
||||
|
||||
res_video = self.draw_tracks_on_video(
|
||||
video=video,
|
||||
tracks=tracks,
|
||||
visibility=visibility,
|
||||
segm_mask=segm_mask,
|
||||
gt_tracks=gt_tracks,
|
||||
query_frame=query_frame,
|
||||
compensate_for_camera_motion=compensate_for_camera_motion,
|
||||
)
|
||||
if save_video:
|
||||
self.save_video(res_video, filename=filename, writer=writer, step=step)
|
||||
return res_video
|
||||
|
||||
def save_video(self, video, filename, writer=None, step=0):
|
||||
if writer is not None:
|
||||
writer.add_video(
|
||||
filename,
|
||||
video.to(torch.uint8),
|
||||
global_step=step,
|
||||
fps=self.fps,
|
||||
)
|
||||
else:
|
||||
os.makedirs(self.save_dir, exist_ok=True)
|
||||
wide_list = list(video.unbind(1))
|
||||
wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list]
|
||||
|
||||
# Prepare the video file path
|
||||
save_path = os.path.join(self.save_dir, f"{filename}.mp4")
|
||||
|
||||
# Create a writer object
|
||||
video_writer = imageio.get_writer(save_path, fps=self.fps)
|
||||
|
||||
# Write frames to the video file
|
||||
for frame in wide_list[2:-1]:
|
||||
video_writer.append_data(frame)
|
||||
|
||||
video_writer.close()
|
||||
|
||||
print(f"Video saved to {save_path}")
|
||||
|
||||
def draw_tracks_on_video(
|
||||
self,
|
||||
video: torch.Tensor,
|
||||
tracks: torch.Tensor,
|
||||
visibility: torch.Tensor = None,
|
||||
segm_mask: torch.Tensor = None,
|
||||
gt_tracks=None,
|
||||
query_frame: int = 0,
|
||||
compensate_for_camera_motion=False,
|
||||
):
|
||||
B, T, C, H, W = video.shape
|
||||
_, _, N, D = tracks.shape
|
||||
|
||||
assert D == 2
|
||||
assert C == 3
|
||||
video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy() # S, H, W, C
|
||||
tracks = tracks[0].long().detach().cpu().numpy() # S, N, 2
|
||||
if gt_tracks is not None:
|
||||
gt_tracks = gt_tracks[0].detach().cpu().numpy()
|
||||
|
||||
res_video = []
|
||||
|
||||
# process input video
|
||||
for rgb in video:
|
||||
res_video.append(rgb.copy())
|
||||
vector_colors = np.zeros((T, N, 3))
|
||||
|
||||
if self.mode == "optical_flow":
|
||||
import flow_vis
|
||||
|
||||
vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None])
|
||||
elif segm_mask is None:
|
||||
if self.mode == "rainbow":
|
||||
y_min, y_max = (
|
||||
tracks[query_frame, :, 1].min(),
|
||||
tracks[query_frame, :, 1].max(),
|
||||
)
|
||||
norm = plt.Normalize(y_min, y_max)
|
||||
for n in range(N):
|
||||
color = self.color_map(norm(tracks[query_frame, n, 1]))
|
||||
color = np.array(color[:3])[None] * 255
|
||||
vector_colors[:, n] = np.repeat(color, T, axis=0)
|
||||
else:
|
||||
# color changes with time
|
||||
for t in range(T):
|
||||
color = np.array(self.color_map(t / T)[:3])[None] * 255
|
||||
vector_colors[t] = np.repeat(color, N, axis=0)
|
||||
else:
|
||||
if self.mode == "rainbow":
|
||||
vector_colors[:, segm_mask <= 0, :] = 255
|
||||
|
||||
y_min, y_max = (
|
||||
tracks[0, segm_mask > 0, 1].min(),
|
||||
tracks[0, segm_mask > 0, 1].max(),
|
||||
)
|
||||
norm = plt.Normalize(y_min, y_max)
|
||||
for n in range(N):
|
||||
if segm_mask[n] > 0:
|
||||
color = self.color_map(norm(tracks[0, n, 1]))
|
||||
color = np.array(color[:3])[None] * 255
|
||||
vector_colors[:, n] = np.repeat(color, T, axis=0)
|
||||
|
||||
else:
|
||||
# color changes with segm class
|
||||
segm_mask = segm_mask.cpu()
|
||||
color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32)
|
||||
color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0
|
||||
color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0
|
||||
vector_colors = np.repeat(color[None], T, axis=0)
|
||||
|
||||
# draw tracks
|
||||
if self.tracks_leave_trace != 0:
|
||||
for t in range(query_frame + 1, T):
|
||||
first_ind = (
|
||||
max(0, t - self.tracks_leave_trace) if self.tracks_leave_trace >= 0 else 0
|
||||
)
|
||||
curr_tracks = tracks[first_ind : t + 1]
|
||||
curr_colors = vector_colors[first_ind : t + 1]
|
||||
if compensate_for_camera_motion:
|
||||
diff = (
|
||||
tracks[first_ind : t + 1, segm_mask <= 0]
|
||||
- tracks[t : t + 1, segm_mask <= 0]
|
||||
).mean(1)[:, None]
|
||||
|
||||
curr_tracks = curr_tracks - diff
|
||||
curr_tracks = curr_tracks[:, segm_mask > 0]
|
||||
curr_colors = curr_colors[:, segm_mask > 0]
|
||||
|
||||
res_video[t] = self._draw_pred_tracks(
|
||||
res_video[t],
|
||||
curr_tracks,
|
||||
curr_colors,
|
||||
)
|
||||
if gt_tracks is not None:
|
||||
res_video[t] = self._draw_gt_tracks(res_video[t], gt_tracks[first_ind : t + 1])
|
||||
|
||||
# draw points
|
||||
for t in range(query_frame, T):
|
||||
img = Image.fromarray(np.uint8(res_video[t]))
|
||||
for i in range(N):
|
||||
coord = (tracks[t, i, 0], tracks[t, i, 1])
|
||||
visibile = True
|
||||
if visibility is not None:
|
||||
visibile = visibility[0, t, i]
|
||||
if coord[0] != 0 and coord[1] != 0:
|
||||
if not compensate_for_camera_motion or (
|
||||
compensate_for_camera_motion and segm_mask[i] > 0
|
||||
):
|
||||
img = draw_circle(
|
||||
img,
|
||||
coord=coord,
|
||||
radius=int(self.linewidth * 2),
|
||||
color=vector_colors[t, i].astype(int),
|
||||
visible=visibile,
|
||||
)
|
||||
res_video[t] = np.array(img)
|
||||
|
||||
# construct the final rgb sequence
|
||||
if self.show_first_frame > 0:
|
||||
res_video = [res_video[0]] * self.show_first_frame + res_video[1:]
|
||||
return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte()
|
||||
|
||||
def _draw_pred_tracks(
|
||||
self,
|
||||
rgb: np.ndarray, # H x W x 3
|
||||
tracks: np.ndarray, # T x 2
|
||||
vector_colors: np.ndarray,
|
||||
alpha: float = 0.5,
|
||||
):
|
||||
T, N, _ = tracks.shape
|
||||
rgb = Image.fromarray(np.uint8(rgb))
|
||||
for s in range(T - 1):
|
||||
vector_color = vector_colors[s]
|
||||
original = rgb.copy()
|
||||
alpha = (s / T) ** 2
|
||||
for i in range(N):
|
||||
coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1]))
|
||||
coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1]))
|
||||
if coord_y[0] != 0 and coord_y[1] != 0:
|
||||
rgb = draw_line(
|
||||
rgb,
|
||||
coord_y,
|
||||
coord_x,
|
||||
vector_color[i].astype(int),
|
||||
self.linewidth,
|
||||
)
|
||||
if self.tracks_leave_trace > 0:
|
||||
rgb = Image.fromarray(
|
||||
np.uint8(add_weighted(np.array(rgb), alpha, np.array(original), 1 - alpha, 0))
|
||||
)
|
||||
rgb = np.array(rgb)
|
||||
return rgb
|
||||
|
||||
def _draw_gt_tracks(
|
||||
self,
|
||||
rgb: np.ndarray, # H x W x 3,
|
||||
gt_tracks: np.ndarray, # T x 2
|
||||
):
|
||||
T, N, _ = gt_tracks.shape
|
||||
color = np.array((211, 0, 0))
|
||||
rgb = Image.fromarray(np.uint8(rgb))
|
||||
for t in range(T):
|
||||
for i in range(N):
|
||||
gt_tracks = gt_tracks[t][i]
|
||||
# draw a red cross
|
||||
if gt_tracks[0] > 0 and gt_tracks[1] > 0:
|
||||
length = self.linewidth * 3
|
||||
coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length)
|
||||
coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length)
|
||||
rgb = draw_line(
|
||||
rgb,
|
||||
coord_y,
|
||||
coord_x,
|
||||
color,
|
||||
self.linewidth,
|
||||
)
|
||||
coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length)
|
||||
coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length)
|
||||
rgb = draw_line(
|
||||
rgb,
|
||||
coord_y,
|
||||
coord_x,
|
||||
color,
|
||||
self.linewidth,
|
||||
)
|
||||
rgb = np.array(rgb)
|
||||
return rgb
|
||||
|
@@ -1,8 +1,8 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
__version__ = "2.0.0"
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
__version__ = "2.0.0"
|
||||
|
3
demo.py
3
demo.py
@@ -83,11 +83,12 @@ if __name__ == "__main__":
|
||||
print("computed")
|
||||
|
||||
# save a video with predicted tracks
|
||||
seq_name = args.video_path.split("/")[-1]
|
||||
seq_name = os.path.splitext(args.video_path.split("/")[-1])[0]
|
||||
vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3)
|
||||
vis.visualize(
|
||||
video,
|
||||
pred_tracks,
|
||||
pred_visibility,
|
||||
query_frame=0 if args.backward_tracking else args.grid_query_frame,
|
||||
filename=seq_name,
|
||||
)
|
||||
|
82
demo1.py
Normal file
82
demo1.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import os
|
||||
import torch
|
||||
|
||||
from base64 import b64encode
|
||||
from cotracker.utils.visualizer import Visualizer, read_video_from_path
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import time
|
||||
|
||||
device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
start_time = time.time()
|
||||
print(f'Using device: {device}')
|
||||
print(f'start loading video')
|
||||
video = read_video_from_path('./assets/F1_shorts.mp4')
|
||||
print(f'video shape: {video.shape}')
|
||||
# video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float().to(device)
|
||||
video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float()
|
||||
end_time = time.time()
|
||||
print(f'video shape after permute: {video.shape}')
|
||||
print("Load video Time taken: {:.2f} seconds".format(end_time - start_time))
|
||||
|
||||
from cotracker.predictor import CoTrackerPredictor
|
||||
|
||||
|
||||
model = CoTrackerPredictor(
|
||||
checkpoint=os.path.join(
|
||||
'./checkpoints/cotracker2.pth'
|
||||
)
|
||||
)
|
||||
|
||||
# pred_tracks, pred_visibility = model(video, grid_size=30)
|
||||
|
||||
# vis = Visualizer(save_dir='./videos', pad_value=100)
|
||||
# vis.visualize(video=video, tracks=pred_tracks, visibility=pred_visibility, filename='teaser');
|
||||
|
||||
grid_query_frame=20
|
||||
|
||||
import torch.nn.functional as F
|
||||
# video_interp = F.interpolate(video[0], [200, 360], mode="bilinear")[None].to(device)
|
||||
interp_size = (720, 1280)
|
||||
video_interp = F.interpolate(video[0], [interp_size[0], interp_size[1]], mode="bilinear")[None].to(device)
|
||||
print(f'video_interp shape: {video_interp.shape}')
|
||||
|
||||
start_time = time.time()
|
||||
# pred_tracks, pred_visibility = model(video_interp,
|
||||
input_mask='./assets/F1_mask.png'
|
||||
segm_mask = Image.open(input_mask)
|
||||
interp_size = (interp_size[1], interp_size[0])
|
||||
segm_mask = segm_mask.resize(interp_size, Image.BILINEAR)
|
||||
segm_mask = np.array(Image.open(input_mask))
|
||||
segm_mask = torch.tensor(segm_mask).to(device)
|
||||
# pred_tracks, pred_visibility = model(video,
|
||||
pred_tracks, pred_visibility = model(video_interp,
|
||||
grid_query_frame=grid_query_frame, backward_tracking=True,
|
||||
segm_mask=segm_mask )
|
||||
end_time = time.time()
|
||||
|
||||
print("Time taken: {:.2f} seconds".format(end_time - start_time))
|
||||
|
||||
start_time = time.time()
|
||||
print(f'start visualizing')
|
||||
vis = Visualizer(
|
||||
save_dir='./videos',
|
||||
pad_value=20,
|
||||
linewidth=1,
|
||||
mode='optical_flow'
|
||||
)
|
||||
print(f'vis initialized')
|
||||
end_time = time.time()
|
||||
print("Time taken: {:.2f} seconds".format(end_time - start_time))
|
||||
start_time = time.time()
|
||||
print(f'start visualize')
|
||||
vis.visualize(
|
||||
video=video_interp,
|
||||
# video=video,
|
||||
tracks=pred_tracks,
|
||||
visibility=pred_visibility,
|
||||
filename='dense2');
|
||||
print(f'done')
|
||||
end_time = time.time()
|
||||
print("Time taken: {:.2f} seconds".format(end_time - start_time))
|
@@ -1,3 +1,10 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import os
|
||||
import torch
|
||||
import gradio as gr
|
||||
@@ -22,7 +29,12 @@ def cotracker_demo(
|
||||
model = model.cuda()
|
||||
load_video = load_video.cuda()
|
||||
|
||||
model(video_chunk=load_video, is_first_step=True, grid_size=grid_size)
|
||||
model(
|
||||
video_chunk=load_video,
|
||||
is_first_step=True,
|
||||
grid_size=grid_size,
|
||||
grid_query_frame=grid_query_frame,
|
||||
)
|
||||
for ind in range(0, load_video.shape[1] - model.step, model.step):
|
||||
pred_tracks, pred_visibility = model(
|
||||
video_chunk=load_video[:, ind : ind + model.step * 2]
|
||||
|
File diff suppressed because one or more lines are too long
@@ -4,6 +4,7 @@
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import torch
|
||||
import argparse
|
||||
import imageio.v3 as iio
|
||||
@@ -44,6 +45,9 @@ if __name__ == "__main__":
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.isfile(args.video_path):
|
||||
raise ValueError("Video file does not exist")
|
||||
|
||||
if args.checkpoint is not None:
|
||||
model = CoTrackerOnlinePredictor(checkpoint=args.checkpoint)
|
||||
else:
|
||||
@@ -52,25 +56,33 @@ if __name__ == "__main__":
|
||||
|
||||
window_frames = []
|
||||
|
||||
def _process_step(window_frames, is_first_step, grid_size):
|
||||
def _process_step(window_frames, is_first_step, grid_size, grid_query_frame):
|
||||
video_chunk = (
|
||||
torch.tensor(np.stack(window_frames[-model.step * 2 :]), device=DEFAULT_DEVICE)
|
||||
.float()
|
||||
.permute(0, 3, 1, 2)[None]
|
||||
) # (1, T, 3, H, W)
|
||||
return model(video_chunk, is_first_step=is_first_step, grid_size=grid_size)
|
||||
return model(
|
||||
video_chunk,
|
||||
is_first_step=is_first_step,
|
||||
grid_size=grid_size,
|
||||
grid_query_frame=grid_query_frame,
|
||||
)
|
||||
|
||||
# Iterating over video frames, processing one window at a time:
|
||||
is_first_step = True
|
||||
for i, frame in enumerate(
|
||||
iio.imiter(
|
||||
"https://github.com/facebookresearch/co-tracker/blob/main/assets/apple.mp4",
|
||||
args.video_path,
|
||||
plugin="FFMPEG",
|
||||
)
|
||||
):
|
||||
if i % model.step == 0 and i != 0:
|
||||
pred_tracks, pred_visibility = _process_step(
|
||||
window_frames, is_first_step, grid_size=args.grid_size
|
||||
window_frames,
|
||||
is_first_step,
|
||||
grid_size=args.grid_size,
|
||||
grid_query_frame=args.grid_query_frame,
|
||||
)
|
||||
is_first_step = False
|
||||
window_frames.append(frame)
|
||||
@@ -79,12 +91,13 @@ if __name__ == "__main__":
|
||||
window_frames[-(i % model.step) - model.step - 1 :],
|
||||
is_first_step,
|
||||
grid_size=args.grid_size,
|
||||
grid_query_frame=args.grid_query_frame,
|
||||
)
|
||||
|
||||
print("Tracks are computed")
|
||||
|
||||
# save a video with predicted tracks
|
||||
seq_name = args.video_path.split("/")[-1]
|
||||
seq_name = os.path.splitext(args.video_path.split("/")[-1])[0]
|
||||
video = torch.tensor(np.stack(window_frames), device=DEFAULT_DEVICE).permute(0, 3, 1, 2)[None]
|
||||
vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3)
|
||||
vis.visualize(video, pred_tracks, pred_visibility, query_frame=args.grid_query_frame)
|
||||
vis.visualize(video, pred_tracks, pred_visibility, query_frame=args.grid_query_frame, filename=seq_name)
|
||||
|
Reference in New Issue
Block a user