Initial commit
This commit is contained in:
		
							
								
								
									
										5
									
								
								cotracker/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								cotracker/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
							
								
								
									
										5
									
								
								cotracker/datasets/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								cotracker/datasets/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
							
								
								
									
										390
									
								
								cotracker/datasets/badja_dataset.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										390
									
								
								cotracker/datasets/badja_dataset.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,390 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| import torch | ||||
| import numpy as np | ||||
| import os | ||||
|  | ||||
| import json | ||||
| import imageio | ||||
| import cv2 | ||||
|  | ||||
| from enum import Enum | ||||
|  | ||||
| from cotracker.datasets.utils import CoTrackerData, resize_sample | ||||
|  | ||||
| IGNORE_ANIMALS = [ | ||||
|     # "bear.json", | ||||
|     # "camel.json", | ||||
|     "cat_jump.json" | ||||
|     # "cows.json", | ||||
|     # "dog.json", | ||||
|     # "dog-agility.json", | ||||
|     # "horsejump-high.json", | ||||
|     # "horsejump-low.json", | ||||
|     # "impala0.json", | ||||
|     # "rs_dog.json" | ||||
|     "tiger.json" | ||||
| ] | ||||
|  | ||||
|  | ||||
| class SMALJointCatalog(Enum): | ||||
|     # body_0 = 0 | ||||
|     # body_1 = 1 | ||||
|     # body_2 = 2 | ||||
|     # body_3 = 3 | ||||
|     # body_4 = 4 | ||||
|     # body_5 = 5 | ||||
|     # body_6 = 6 | ||||
|     # upper_right_0 = 7 | ||||
|     upper_right_1 = 8 | ||||
|     upper_right_2 = 9 | ||||
|     upper_right_3 = 10 | ||||
|     # upper_left_0 = 11 | ||||
|     upper_left_1 = 12 | ||||
|     upper_left_2 = 13 | ||||
|     upper_left_3 = 14 | ||||
|     neck_lower = 15 | ||||
|     # neck_upper = 16 | ||||
|     # lower_right_0 = 17 | ||||
|     lower_right_1 = 18 | ||||
|     lower_right_2 = 19 | ||||
|     lower_right_3 = 20 | ||||
|     # lower_left_0 = 21 | ||||
|     lower_left_1 = 22 | ||||
|     lower_left_2 = 23 | ||||
|     lower_left_3 = 24 | ||||
|     tail_0 = 25 | ||||
|     # tail_1 = 26 | ||||
|     # tail_2 = 27 | ||||
|     tail_3 = 28 | ||||
|     # tail_4 = 29 | ||||
|     # tail_5 = 30 | ||||
|     tail_6 = 31 | ||||
|     jaw = 32 | ||||
|     nose = 33  # ADDED JOINT FOR VERTEX 1863 | ||||
|     # chin = 34 # ADDED JOINT FOR VERTEX 26 | ||||
|     right_ear = 35  # ADDED JOINT FOR VERTEX 149 | ||||
|     left_ear = 36  # ADDED JOINT FOR VERTEX 2124 | ||||
|  | ||||
|  | ||||
| class SMALJointInfo: | ||||
|     def __init__(self): | ||||
|         # These are the | ||||
|         self.annotated_classes = np.array( | ||||
|             [ | ||||
|                 8, | ||||
|                 9, | ||||
|                 10,  # upper_right | ||||
|                 12, | ||||
|                 13, | ||||
|                 14,  # upper_left | ||||
|                 15,  # neck | ||||
|                 18, | ||||
|                 19, | ||||
|                 20,  # lower_right | ||||
|                 22, | ||||
|                 23, | ||||
|                 24,  # lower_left | ||||
|                 25, | ||||
|                 28, | ||||
|                 31,  # tail | ||||
|                 32, | ||||
|                 33,  # head | ||||
|                 35,  # right_ear | ||||
|                 36, | ||||
|             ] | ||||
|         )  # left_ear | ||||
|  | ||||
|         self.annotated_markers = np.array( | ||||
|             [ | ||||
|                 cv2.MARKER_CROSS, | ||||
|                 cv2.MARKER_STAR, | ||||
|                 cv2.MARKER_TRIANGLE_DOWN, | ||||
|                 cv2.MARKER_CROSS, | ||||
|                 cv2.MARKER_STAR, | ||||
|                 cv2.MARKER_TRIANGLE_DOWN, | ||||
|                 cv2.MARKER_CROSS, | ||||
|                 cv2.MARKER_CROSS, | ||||
|                 cv2.MARKER_STAR, | ||||
|                 cv2.MARKER_TRIANGLE_DOWN, | ||||
|                 cv2.MARKER_CROSS, | ||||
|                 cv2.MARKER_STAR, | ||||
|                 cv2.MARKER_TRIANGLE_DOWN, | ||||
|                 cv2.MARKER_CROSS, | ||||
|                 cv2.MARKER_STAR, | ||||
|                 cv2.MARKER_TRIANGLE_DOWN, | ||||
|                 cv2.MARKER_CROSS, | ||||
|                 cv2.MARKER_STAR, | ||||
|                 cv2.MARKER_CROSS, | ||||
|                 cv2.MARKER_CROSS, | ||||
|             ] | ||||
|         ) | ||||
|  | ||||
|         self.joint_regions = np.array( | ||||
|             [ | ||||
|                 0, | ||||
|                 0, | ||||
|                 0, | ||||
|                 0, | ||||
|                 0, | ||||
|                 0, | ||||
|                 0, | ||||
|                 1, | ||||
|                 1, | ||||
|                 1, | ||||
|                 1, | ||||
|                 2, | ||||
|                 2, | ||||
|                 2, | ||||
|                 2, | ||||
|                 3, | ||||
|                 3, | ||||
|                 4, | ||||
|                 4, | ||||
|                 4, | ||||
|                 4, | ||||
|                 5, | ||||
|                 5, | ||||
|                 5, | ||||
|                 5, | ||||
|                 6, | ||||
|                 6, | ||||
|                 6, | ||||
|                 6, | ||||
|                 6, | ||||
|                 6, | ||||
|                 6, | ||||
|                 7, | ||||
|                 7, | ||||
|                 7, | ||||
|                 8, | ||||
|                 9, | ||||
|             ] | ||||
|         ) | ||||
|  | ||||
|         self.annotated_joint_region = self.joint_regions[self.annotated_classes] | ||||
|         self.region_colors = np.array( | ||||
|             [ | ||||
|                 [250, 190, 190],  # body, light pink | ||||
|                 [60, 180, 75],  # upper_right, green | ||||
|                 [230, 25, 75],  # upper_left, red | ||||
|                 [128, 0, 0],  # neck, maroon | ||||
|                 [0, 130, 200],  # lower_right, blue | ||||
|                 [255, 255, 25],  # lower_left, yellow | ||||
|                 [240, 50, 230],  # tail, majenta | ||||
|                 [245, 130, 48],  # jaw / nose / chin, orange | ||||
|                 [29, 98, 115],  # right_ear, turquoise | ||||
|                 [255, 153, 204], | ||||
|             ] | ||||
|         )  # left_ear, pink | ||||
|  | ||||
|         self.joint_colors = np.array(self.region_colors)[self.annotated_joint_region] | ||||
|  | ||||
|  | ||||
| class BADJAData: | ||||
|     def __init__(self, data_root, complete=False): | ||||
|         annotations_path = os.path.join(data_root, "joint_annotations") | ||||
|  | ||||
|         self.animal_dict = {} | ||||
|         self.animal_count = 0 | ||||
|         self.smal_joint_info = SMALJointInfo() | ||||
|         for __, animal_json in enumerate(sorted(os.listdir(annotations_path))): | ||||
|             if animal_json not in IGNORE_ANIMALS: | ||||
|                 json_path = os.path.join(annotations_path, animal_json) | ||||
|                 with open(json_path) as json_data: | ||||
|                     animal_joint_data = json.load(json_data) | ||||
|  | ||||
|                 filenames = [] | ||||
|                 segnames = [] | ||||
|                 joints = [] | ||||
|                 visible = [] | ||||
|  | ||||
|                 first_path = animal_joint_data[0]["segmentation_path"] | ||||
|                 last_path = animal_joint_data[-1]["segmentation_path"] | ||||
|                 first_frame = first_path.split("/")[-1] | ||||
|                 last_frame = last_path.split("/")[-1] | ||||
|  | ||||
|                 if not "extra_videos" in first_path: | ||||
|                     animal = first_path.split("/")[-2] | ||||
|  | ||||
|                     first_frame_int = int(first_frame.split(".")[0]) | ||||
|                     last_frame_int = int(last_frame.split(".")[0]) | ||||
|  | ||||
|                     for fr in range(first_frame_int, last_frame_int + 1): | ||||
|                         ref_file_name = os.path.join( | ||||
|                             data_root, | ||||
|                             "DAVIS/JPEGImages/Full-Resolution/%s/%05d.jpg" | ||||
|                             % (animal, fr), | ||||
|                         ) | ||||
|                         ref_seg_name = os.path.join( | ||||
|                             data_root, | ||||
|                             "DAVIS/Annotations/Full-Resolution/%s/%05d.png" | ||||
|                             % (animal, fr), | ||||
|                         ) | ||||
|  | ||||
|                         foundit = False | ||||
|                         for ind, image_annotation in enumerate(animal_joint_data): | ||||
|                             file_name = os.path.join( | ||||
|                                 data_root, image_annotation["image_path"] | ||||
|                             ) | ||||
|                             seg_name = os.path.join( | ||||
|                                 data_root, image_annotation["segmentation_path"] | ||||
|                             ) | ||||
|  | ||||
|                             if file_name == ref_file_name: | ||||
|                                 foundit = True | ||||
|                                 label_ind = ind | ||||
|  | ||||
|                         if foundit: | ||||
|                             image_annotation = animal_joint_data[label_ind] | ||||
|                             file_name = os.path.join( | ||||
|                                 data_root, image_annotation["image_path"] | ||||
|                             ) | ||||
|                             seg_name = os.path.join( | ||||
|                                 data_root, image_annotation["segmentation_path"] | ||||
|                             ) | ||||
|                             joint = np.array(image_annotation["joints"]) | ||||
|                             vis = np.array(image_annotation["visibility"]) | ||||
|                         else: | ||||
|                             file_name = ref_file_name | ||||
|                             seg_name = ref_seg_name | ||||
|                             joint = None | ||||
|                             vis = None | ||||
|  | ||||
|                         filenames.append(file_name) | ||||
|                         segnames.append(seg_name) | ||||
|                         joints.append(joint) | ||||
|                         visible.append(vis) | ||||
|  | ||||
|                 if len(filenames): | ||||
|                     self.animal_dict[self.animal_count] = ( | ||||
|                         filenames, | ||||
|                         segnames, | ||||
|                         joints, | ||||
|                         visible, | ||||
|                     ) | ||||
|                     self.animal_count += 1 | ||||
|         print("Loaded BADJA dataset") | ||||
|  | ||||
|     def get_loader(self): | ||||
|         for __ in range(int(1e6)): | ||||
|             animal_id = np.random.choice(len(self.animal_dict.keys())) | ||||
|             filenames, segnames, joints, visible = self.animal_dict[animal_id] | ||||
|  | ||||
|             image_id = np.random.randint(0, len(filenames)) | ||||
|  | ||||
|             seg_file = segnames[image_id] | ||||
|             image_file = filenames[image_id] | ||||
|  | ||||
|             joints = joints[image_id].copy() | ||||
|             joints = joints[self.smal_joint_info.annotated_classes] | ||||
|             visible = visible[image_id][self.smal_joint_info.annotated_classes] | ||||
|  | ||||
|             rgb_img = imageio.imread(image_file)  # , mode='RGB') | ||||
|             sil_img = imageio.imread(seg_file)  # , mode='RGB') | ||||
|  | ||||
|             rgb_h, rgb_w, _ = rgb_img.shape | ||||
|             sil_img = cv2.resize(sil_img, (rgb_w, rgb_h), cv2.INTER_NEAREST) | ||||
|  | ||||
|             yield rgb_img, sil_img, joints, visible, image_file | ||||
|  | ||||
|     def get_video(self, animal_id): | ||||
|         filenames, segnames, joint, visible = self.animal_dict[animal_id] | ||||
|  | ||||
|         rgbs = [] | ||||
|         segs = [] | ||||
|         joints = [] | ||||
|         visibles = [] | ||||
|  | ||||
|         for s in range(len(filenames)): | ||||
|             image_file = filenames[s] | ||||
|             rgb_img = imageio.imread(image_file)  # , mode='RGB') | ||||
|             rgb_h, rgb_w, _ = rgb_img.shape | ||||
|  | ||||
|             seg_file = segnames[s] | ||||
|             sil_img = imageio.imread(seg_file)  # , mode='RGB') | ||||
|             sil_img = cv2.resize(sil_img, (rgb_w, rgb_h), cv2.INTER_NEAREST) | ||||
|  | ||||
|             jo = joint[s] | ||||
|  | ||||
|             if jo is not None: | ||||
|                 joi = joint[s].copy() | ||||
|                 joi = joi[self.smal_joint_info.annotated_classes] | ||||
|                 vis = visible[s][self.smal_joint_info.annotated_classes] | ||||
|             else: | ||||
|                 joi = None | ||||
|                 vis = None | ||||
|  | ||||
|             rgbs.append(rgb_img) | ||||
|             segs.append(sil_img) | ||||
|             joints.append(joi) | ||||
|             visibles.append(vis) | ||||
|  | ||||
|         return rgbs, segs, joints, visibles, filenames[0] | ||||
|  | ||||
|  | ||||
| class BadjaDataset(torch.utils.data.Dataset): | ||||
|     def __init__( | ||||
|         self, data_root, max_seq_len=1000, dataset_resolution=(384, 512) | ||||
|     ): | ||||
|  | ||||
|         self.data_root = data_root | ||||
|         self.badja_data = BADJAData(data_root) | ||||
|         self.max_seq_len = max_seq_len | ||||
|         self.dataset_resolution = dataset_resolution | ||||
|         print( | ||||
|             "found %d unique videos in %s" | ||||
|             % (self.badja_data.animal_count, self.data_root) | ||||
|         ) | ||||
|  | ||||
|     def __getitem__(self, index): | ||||
|  | ||||
|         rgbs, segs, joints, visibles, filename = self.badja_data.get_video(index) | ||||
|         S = len(rgbs) | ||||
|         H, W, __ = rgbs[0].shape | ||||
|         H, W, __ = segs[0].shape | ||||
|  | ||||
|         N, __ = joints[0].shape | ||||
|  | ||||
|         # let's eliminate the Nones | ||||
|         # note the first one is guaranteed present | ||||
|         for s in range(1, S): | ||||
|             if joints[s] is None: | ||||
|                 joints[s] = np.zeros_like(joints[0]) | ||||
|                 visibles[s] = np.zeros_like(visibles[0]) | ||||
|  | ||||
|         # eliminate the mystery dim | ||||
|         segs = [seg[:, :, 0] for seg in segs] | ||||
|  | ||||
|         rgbs = np.stack(rgbs, 0) | ||||
|         segs = np.stack(segs, 0) | ||||
|         trajs = np.stack(joints, 0) | ||||
|         visibles = np.stack(visibles, 0) | ||||
|  | ||||
|         rgbs = torch.from_numpy(rgbs).reshape(S, H, W, 3).permute(0, 3, 1, 2).float() | ||||
|         segs = torch.from_numpy(segs).reshape(S, 1, H, W).float() | ||||
|         trajs = torch.from_numpy(trajs).reshape(S, N, 2).float() | ||||
|         visibles = torch.from_numpy(visibles).reshape(S, N) | ||||
|  | ||||
|         rgbs = rgbs[: self.max_seq_len] | ||||
|         segs = segs[: self.max_seq_len] | ||||
|         trajs = trajs[: self.max_seq_len] | ||||
|         visibles = visibles[: self.max_seq_len] | ||||
|         # apparently the coords are in yx order | ||||
|         trajs = torch.flip(trajs, [2]) | ||||
|  | ||||
|         if "extra_videos" in filename: | ||||
|             seq_name = filename.split("/")[-3] | ||||
|         else: | ||||
|             seq_name = filename.split("/")[-2] | ||||
|  | ||||
|         rgbs, trajs, segs = resize_sample(rgbs, trajs, segs, self.dataset_resolution) | ||||
|  | ||||
|         return CoTrackerData(rgbs, segs, trajs, visibles, seq_name=seq_name) | ||||
|  | ||||
|     def __len__(self): | ||||
|         return self.badja_data.animal_count | ||||
							
								
								
									
										72
									
								
								cotracker/datasets/fast_capture_dataset.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										72
									
								
								cotracker/datasets/fast_capture_dataset.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,72 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| import os | ||||
| import torch | ||||
|  | ||||
| # from PIL import Image | ||||
| import imageio | ||||
| import numpy as np | ||||
| from cotracker.datasets.utils import CoTrackerData, resize_sample | ||||
|  | ||||
|  | ||||
| class FastCaptureDataset(torch.utils.data.Dataset): | ||||
|     def __init__( | ||||
|         self, | ||||
|         data_root, | ||||
|         max_seq_len=50, | ||||
|         max_num_points=20, | ||||
|         dataset_resolution=(384, 512), | ||||
|     ): | ||||
|  | ||||
|         self.data_root = data_root | ||||
|         self.seq_names = os.listdir(os.path.join(data_root, "renders_local_rm")) | ||||
|         self.pth_dir = os.path.join(data_root, "zju_tracking") | ||||
|         self.max_seq_len = max_seq_len | ||||
|         self.max_num_points = max_num_points | ||||
|         self.dataset_resolution = dataset_resolution | ||||
|         print("found %d unique videos in %s" % (len(self.seq_names), self.data_root)) | ||||
|  | ||||
|     def __getitem__(self, index): | ||||
|         seq_name = self.seq_names[index] | ||||
|         spath = os.path.join(self.data_root, "renders_local_rm", seq_name) | ||||
|         pthpath = os.path.join(self.pth_dir, seq_name + ".pth") | ||||
|  | ||||
|         rgbs = [] | ||||
|         img_paths = sorted(os.listdir(spath)) | ||||
|         for i, img_path in enumerate(img_paths): | ||||
|             if i < self.max_seq_len: | ||||
|                 rgbs.append(imageio.imread(os.path.join(spath, img_path))) | ||||
|  | ||||
|         annot_dict = torch.load(pthpath) | ||||
|         traj_2d = annot_dict["traj_2d"][:, :, : self.max_seq_len] | ||||
|         visibility = annot_dict["visibility"][:, : self.max_seq_len] | ||||
|  | ||||
|         S = len(rgbs) | ||||
|         H, W, __ = rgbs[0].shape | ||||
|         *_, S = traj_2d.shape | ||||
|         visibile_pts_first_frame_inds = (visibility[:, 0] > 0).nonzero(as_tuple=False)[ | ||||
|             :, 0 | ||||
|         ] | ||||
|         torch.manual_seed(0) | ||||
|         point_inds = torch.randperm(len(visibile_pts_first_frame_inds))[ | ||||
|             : self.max_num_points | ||||
|         ] | ||||
|         visible_inds_sampled = visibile_pts_first_frame_inds[point_inds] | ||||
|  | ||||
|         rgbs = np.stack(rgbs, 0) | ||||
|         rgbs = torch.from_numpy(rgbs).reshape(S, H, W, 3).permute(0, 3, 1, 2).float() | ||||
|  | ||||
|         segs = torch.ones(S, 1, H, W).float() | ||||
|         trajs = traj_2d[visible_inds_sampled].permute(2, 0, 1).float() | ||||
|         visibles = visibility[visible_inds_sampled].permute(1, 0) | ||||
|  | ||||
|         rgbs, trajs, segs = resize_sample(rgbs, trajs, segs, self.dataset_resolution) | ||||
|  | ||||
|         return CoTrackerData(rgbs, segs, trajs, visibles, seq_name=seq_name) | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self.seq_names) | ||||
							
								
								
									
										494
									
								
								cotracker/datasets/kubric_movif_dataset.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										494
									
								
								cotracker/datasets/kubric_movif_dataset.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,494 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| import os | ||||
| import torch | ||||
|  | ||||
| import imageio | ||||
| import numpy as np | ||||
|  | ||||
| from cotracker.datasets.utils import CoTrackerData | ||||
| from torchvision.transforms import ColorJitter, GaussianBlur | ||||
| from PIL import Image | ||||
| import cv2 | ||||
|  | ||||
|  | ||||
| class CoTrackerDataset(torch.utils.data.Dataset): | ||||
|     def __init__( | ||||
|         self, | ||||
|         data_root, | ||||
|         crop_size=(384, 512), | ||||
|         seq_len=24, | ||||
|         traj_per_sample=768, | ||||
|         sample_vis_1st_frame=False, | ||||
|         use_augs=False, | ||||
|     ): | ||||
|         super(CoTrackerDataset, self).__init__() | ||||
|         np.random.seed(0) | ||||
|         torch.manual_seed(0) | ||||
|         self.data_root = data_root | ||||
|         self.seq_len = seq_len | ||||
|         self.traj_per_sample = traj_per_sample | ||||
|         self.sample_vis_1st_frame = sample_vis_1st_frame | ||||
|         self.use_augs = use_augs | ||||
|         self.crop_size = crop_size | ||||
|  | ||||
|         # photometric augmentation | ||||
|         self.photo_aug = ColorJitter( | ||||
|             brightness=0.2, contrast=0.2, saturation=0.2, hue=0.25 / 3.14 | ||||
|         ) | ||||
|         self.blur_aug = GaussianBlur(11, sigma=(0.1, 2.0)) | ||||
|  | ||||
|         self.blur_aug_prob = 0.25 | ||||
|         self.color_aug_prob = 0.25 | ||||
|  | ||||
|         # occlusion augmentation | ||||
|         self.eraser_aug_prob = 0.5 | ||||
|         self.eraser_bounds = [2, 100] | ||||
|         self.eraser_max = 10 | ||||
|  | ||||
|         # occlusion augmentation | ||||
|         self.replace_aug_prob = 0.5 | ||||
|         self.replace_bounds = [2, 100] | ||||
|         self.replace_max = 10 | ||||
|  | ||||
|         # spatial augmentations | ||||
|         self.pad_bounds = [0, 100] | ||||
|         self.crop_size = crop_size | ||||
|         self.resize_lim = [0.25, 2.0]  # sample resizes from here | ||||
|         self.resize_delta = 0.2 | ||||
|         self.max_crop_offset = 50 | ||||
|  | ||||
|         self.do_flip = True | ||||
|         self.h_flip_prob = 0.5 | ||||
|         self.v_flip_prob = 0.5 | ||||
|  | ||||
|     def getitem_helper(self, index): | ||||
|         return NotImplementedError | ||||
|  | ||||
|     def __getitem__(self, index): | ||||
|         gotit = False | ||||
|  | ||||
|         sample, gotit = self.getitem_helper(index) | ||||
|         if not gotit: | ||||
|             print("warning: sampling failed") | ||||
|             # fake sample, so we can still collate | ||||
|             sample = CoTrackerData( | ||||
|                 video=torch.zeros( | ||||
|                     (self.seq_len, 3, self.crop_size[0], self.crop_size[1]) | ||||
|                 ), | ||||
|                 segmentation=torch.zeros( | ||||
|                     (self.seq_len, 1, self.crop_size[0], self.crop_size[1]) | ||||
|                 ), | ||||
|                 trajectory=torch.zeros((self.seq_len, self.traj_per_sample, 2)), | ||||
|                 visibility=torch.zeros((self.seq_len, self.traj_per_sample)), | ||||
|                 valid=torch.zeros((self.seq_len, self.traj_per_sample)), | ||||
|             ) | ||||
|  | ||||
|         return sample, gotit | ||||
|  | ||||
|     def add_photometric_augs(self, rgbs, trajs, visibles, eraser=True, replace=True): | ||||
|         T, N, _ = trajs.shape | ||||
|  | ||||
|         S = len(rgbs) | ||||
|         H, W = rgbs[0].shape[:2] | ||||
|         assert S == T | ||||
|  | ||||
|         if eraser: | ||||
|             ############ eraser transform (per image after the first) ############ | ||||
|             rgbs = [rgb.astype(np.float32) for rgb in rgbs] | ||||
|             for i in range(1, S): | ||||
|                 if np.random.rand() < self.eraser_aug_prob: | ||||
|                     for _ in range( | ||||
|                         np.random.randint(1, self.eraser_max + 1) | ||||
|                     ):  # number of times to occlude | ||||
|  | ||||
|                         xc = np.random.randint(0, W) | ||||
|                         yc = np.random.randint(0, H) | ||||
|                         dx = np.random.randint( | ||||
|                             self.eraser_bounds[0], self.eraser_bounds[1] | ||||
|                         ) | ||||
|                         dy = np.random.randint( | ||||
|                             self.eraser_bounds[0], self.eraser_bounds[1] | ||||
|                         ) | ||||
|                         x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32) | ||||
|                         x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32) | ||||
|                         y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32) | ||||
|                         y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32) | ||||
|  | ||||
|                         mean_color = np.mean( | ||||
|                             rgbs[i][y0:y1, x0:x1, :].reshape(-1, 3), axis=0 | ||||
|                         ) | ||||
|                         rgbs[i][y0:y1, x0:x1, :] = mean_color | ||||
|  | ||||
|                         occ_inds = np.logical_and( | ||||
|                             np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1), | ||||
|                             np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1), | ||||
|                         ) | ||||
|                         visibles[i, occ_inds] = 0 | ||||
|             rgbs = [rgb.astype(np.uint8) for rgb in rgbs] | ||||
|  | ||||
|         if replace: | ||||
|  | ||||
|             rgbs_alt = [ | ||||
|                 np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) | ||||
|                 for rgb in rgbs | ||||
|             ] | ||||
|             rgbs_alt = [ | ||||
|                 np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) | ||||
|                 for rgb in rgbs_alt | ||||
|             ] | ||||
|  | ||||
|             ############ replace transform (per image after the first) ############ | ||||
|             rgbs = [rgb.astype(np.float32) for rgb in rgbs] | ||||
|             rgbs_alt = [rgb.astype(np.float32) for rgb in rgbs_alt] | ||||
|             for i in range(1, S): | ||||
|                 if np.random.rand() < self.replace_aug_prob: | ||||
|                     for _ in range( | ||||
|                         np.random.randint(1, self.replace_max + 1) | ||||
|                     ):  # number of times to occlude | ||||
|                         xc = np.random.randint(0, W) | ||||
|                         yc = np.random.randint(0, H) | ||||
|                         dx = np.random.randint( | ||||
|                             self.replace_bounds[0], self.replace_bounds[1] | ||||
|                         ) | ||||
|                         dy = np.random.randint( | ||||
|                             self.replace_bounds[0], self.replace_bounds[1] | ||||
|                         ) | ||||
|                         x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32) | ||||
|                         x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32) | ||||
|                         y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32) | ||||
|                         y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32) | ||||
|  | ||||
|                         wid = x1 - x0 | ||||
|                         hei = y1 - y0 | ||||
|                         y00 = np.random.randint(0, H - hei) | ||||
|                         x00 = np.random.randint(0, W - wid) | ||||
|                         fr = np.random.randint(0, S) | ||||
|                         rep = rgbs_alt[fr][y00 : y00 + hei, x00 : x00 + wid, :] | ||||
|                         rgbs[i][y0:y1, x0:x1, :] = rep | ||||
|  | ||||
|                         occ_inds = np.logical_and( | ||||
|                             np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1), | ||||
|                             np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1), | ||||
|                         ) | ||||
|                         visibles[i, occ_inds] = 0 | ||||
|             rgbs = [rgb.astype(np.uint8) for rgb in rgbs] | ||||
|  | ||||
|         ############ photometric augmentation ############ | ||||
|         if np.random.rand() < self.color_aug_prob: | ||||
|             # random per-frame amount of aug | ||||
|             rgbs = [ | ||||
|                 np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) | ||||
|                 for rgb in rgbs | ||||
|             ] | ||||
|  | ||||
|         if np.random.rand() < self.blur_aug_prob: | ||||
|             # random per-frame amount of blur | ||||
|             rgbs = [ | ||||
|                 np.array(self.blur_aug(Image.fromarray(rgb)), dtype=np.uint8) | ||||
|                 for rgb in rgbs | ||||
|             ] | ||||
|  | ||||
|         return rgbs, trajs, visibles | ||||
|  | ||||
|     def add_spatial_augs(self, rgbs, trajs, visibles): | ||||
|         T, N, __ = trajs.shape | ||||
|  | ||||
|         S = len(rgbs) | ||||
|         H, W = rgbs[0].shape[:2] | ||||
|         assert S == T | ||||
|  | ||||
|         rgbs = [rgb.astype(np.float32) for rgb in rgbs] | ||||
|  | ||||
|         ############ spatial transform ############ | ||||
|  | ||||
|         # padding | ||||
|         pad_x0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1]) | ||||
|         pad_x1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1]) | ||||
|         pad_y0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1]) | ||||
|         pad_y1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1]) | ||||
|  | ||||
|         rgbs = [ | ||||
|             np.pad(rgb, ((pad_y0, pad_y1), (pad_x0, pad_x1), (0, 0))) for rgb in rgbs | ||||
|         ] | ||||
|         trajs[:, :, 0] += pad_x0 | ||||
|         trajs[:, :, 1] += pad_y0 | ||||
|         H, W = rgbs[0].shape[:2] | ||||
|  | ||||
|         # scaling + stretching | ||||
|         scale = np.random.uniform(self.resize_lim[0], self.resize_lim[1]) | ||||
|         scale_x = scale | ||||
|         scale_y = scale | ||||
|         H_new = H | ||||
|         W_new = W | ||||
|  | ||||
|         scale_delta_x = 0.0 | ||||
|         scale_delta_y = 0.0 | ||||
|  | ||||
|         rgbs_scaled = [] | ||||
|         for s in range(S): | ||||
|             if s == 1: | ||||
|                 scale_delta_x = np.random.uniform(-self.resize_delta, self.resize_delta) | ||||
|                 scale_delta_y = np.random.uniform(-self.resize_delta, self.resize_delta) | ||||
|             elif s > 1: | ||||
|                 scale_delta_x = ( | ||||
|                     scale_delta_x * 0.8 | ||||
|                     + np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2 | ||||
|                 ) | ||||
|                 scale_delta_y = ( | ||||
|                     scale_delta_y * 0.8 | ||||
|                     + np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2 | ||||
|                 ) | ||||
|             scale_x = scale_x + scale_delta_x | ||||
|             scale_y = scale_y + scale_delta_y | ||||
|  | ||||
|             # bring h/w closer | ||||
|             scale_xy = (scale_x + scale_y) * 0.5 | ||||
|             scale_x = scale_x * 0.5 + scale_xy * 0.5 | ||||
|             scale_y = scale_y * 0.5 + scale_xy * 0.5 | ||||
|  | ||||
|             # don't get too crazy | ||||
|             scale_x = np.clip(scale_x, 0.2, 2.0) | ||||
|             scale_y = np.clip(scale_y, 0.2, 2.0) | ||||
|  | ||||
|             H_new = int(H * scale_y) | ||||
|             W_new = int(W * scale_x) | ||||
|  | ||||
|             # make it at least slightly bigger than the crop area, | ||||
|             # so that the random cropping can add diversity | ||||
|             H_new = np.clip(H_new, self.crop_size[0] + 10, None) | ||||
|             W_new = np.clip(W_new, self.crop_size[1] + 10, None) | ||||
|             # recompute scale in case we clipped | ||||
|             scale_x = W_new / float(W) | ||||
|             scale_y = H_new / float(H) | ||||
|  | ||||
|             rgbs_scaled.append( | ||||
|                 cv2.resize(rgbs[s], (W_new, H_new), interpolation=cv2.INTER_LINEAR) | ||||
|             ) | ||||
|             trajs[s, :, 0] *= scale_x | ||||
|             trajs[s, :, 1] *= scale_y | ||||
|         rgbs = rgbs_scaled | ||||
|  | ||||
|         ok_inds = visibles[0, :] > 0 | ||||
|         vis_trajs = trajs[:, ok_inds]  # S,?,2 | ||||
|  | ||||
|         if vis_trajs.shape[1] > 0: | ||||
|             mid_x = np.mean(vis_trajs[0, :, 0]) | ||||
|             mid_y = np.mean(vis_trajs[0, :, 1]) | ||||
|         else: | ||||
|             mid_y = self.crop_size[0] | ||||
|             mid_x = self.crop_size[1] | ||||
|  | ||||
|         x0 = int(mid_x - self.crop_size[1] // 2) | ||||
|         y0 = int(mid_y - self.crop_size[0] // 2) | ||||
|  | ||||
|         offset_x = 0 | ||||
|         offset_y = 0 | ||||
|  | ||||
|         for s in range(S): | ||||
|             # on each frame, shift a bit more | ||||
|             if s == 1: | ||||
|                 offset_x = np.random.randint( | ||||
|                     -self.max_crop_offset, self.max_crop_offset | ||||
|                 ) | ||||
|                 offset_y = np.random.randint( | ||||
|                     -self.max_crop_offset, self.max_crop_offset | ||||
|                 ) | ||||
|             elif s > 1: | ||||
|                 offset_x = int( | ||||
|                     offset_x * 0.8 | ||||
|                     + np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1) | ||||
|                     * 0.2 | ||||
|                 ) | ||||
|                 offset_y = int( | ||||
|                     offset_y * 0.8 | ||||
|                     + np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1) | ||||
|                     * 0.2 | ||||
|                 ) | ||||
|             x0 = x0 + offset_x | ||||
|             y0 = y0 + offset_y | ||||
|  | ||||
|             H_new, W_new = rgbs[s].shape[:2] | ||||
|             if H_new == self.crop_size[0]: | ||||
|                 y0 = 0 | ||||
|             else: | ||||
|                 y0 = min(max(0, y0), H_new - self.crop_size[0] - 1) | ||||
|  | ||||
|             if W_new == self.crop_size[1]: | ||||
|                 x0 = 0 | ||||
|             else: | ||||
|                 x0 = min(max(0, x0), W_new - self.crop_size[1] - 1) | ||||
|  | ||||
|             rgbs[s] = rgbs[s][y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] | ||||
|             trajs[s, :, 0] -= x0 | ||||
|             trajs[s, :, 1] -= y0 | ||||
|  | ||||
|         H_new = self.crop_size[0] | ||||
|         W_new = self.crop_size[1] | ||||
|  | ||||
|         # flip | ||||
|         h_flipped = False | ||||
|         v_flipped = False | ||||
|         if self.do_flip: | ||||
|             # h flip | ||||
|             if np.random.rand() < self.h_flip_prob: | ||||
|                 h_flipped = True | ||||
|                 rgbs = [rgb[:, ::-1] for rgb in rgbs] | ||||
|             # v flip | ||||
|             if np.random.rand() < self.v_flip_prob: | ||||
|                 v_flipped = True | ||||
|                 rgbs = [rgb[::-1] for rgb in rgbs] | ||||
|         if h_flipped: | ||||
|             trajs[:, :, 0] = W_new - trajs[:, :, 0] | ||||
|         if v_flipped: | ||||
|             trajs[:, :, 1] = H_new - trajs[:, :, 1] | ||||
|  | ||||
|         return rgbs, trajs | ||||
|  | ||||
|     def crop(self, rgbs, trajs): | ||||
|         T, N, _ = trajs.shape | ||||
|  | ||||
|         S = len(rgbs) | ||||
|         H, W = rgbs[0].shape[:2] | ||||
|         assert S == T | ||||
|  | ||||
|         ############ spatial transform ############ | ||||
|  | ||||
|         H_new = H | ||||
|         W_new = W | ||||
|  | ||||
|         # simple random crop | ||||
|         y0 = ( | ||||
|             0 | ||||
|             if self.crop_size[0] >= H_new | ||||
|             else np.random.randint(0, H_new - self.crop_size[0]) | ||||
|         ) | ||||
|         x0 = ( | ||||
|             0 | ||||
|             if self.crop_size[1] >= W_new | ||||
|             else np.random.randint(0, W_new - self.crop_size[1]) | ||||
|         ) | ||||
|         rgbs = [ | ||||
|             rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] | ||||
|             for rgb in rgbs | ||||
|         ] | ||||
|  | ||||
|         trajs[:, :, 0] -= x0 | ||||
|         trajs[:, :, 1] -= y0 | ||||
|  | ||||
|         return rgbs, trajs | ||||
|  | ||||
|  | ||||
| class KubricMovifDataset(CoTrackerDataset): | ||||
|     def __init__( | ||||
|         self, | ||||
|         data_root, | ||||
|         crop_size=(384, 512), | ||||
|         seq_len=24, | ||||
|         traj_per_sample=768, | ||||
|         sample_vis_1st_frame=False, | ||||
|         use_augs=False, | ||||
|     ): | ||||
|         super(KubricMovifDataset, self).__init__( | ||||
|             data_root=data_root, | ||||
|             crop_size=crop_size, | ||||
|             seq_len=seq_len, | ||||
|             traj_per_sample=traj_per_sample, | ||||
|             sample_vis_1st_frame=sample_vis_1st_frame, | ||||
|             use_augs=use_augs, | ||||
|         ) | ||||
|  | ||||
|         self.pad_bounds = [0, 25] | ||||
|         self.resize_lim = [0.75, 1.25]  # sample resizes from here | ||||
|         self.resize_delta = 0.05 | ||||
|         self.max_crop_offset = 15 | ||||
|         self.seq_names = [ | ||||
|             fname | ||||
|             for fname in os.listdir(data_root) | ||||
|             if os.path.isdir(os.path.join(data_root, fname)) | ||||
|         ] | ||||
|         print("found %d unique videos in %s" % (len(self.seq_names), self.data_root)) | ||||
|  | ||||
|     def getitem_helper(self, index): | ||||
|         gotit = True | ||||
|         seq_name = self.seq_names[index] | ||||
|  | ||||
|         npy_path = os.path.join(self.data_root, seq_name, seq_name + ".npy") | ||||
|         rgb_path = os.path.join(self.data_root, seq_name, "frames") | ||||
|  | ||||
|         img_paths = sorted(os.listdir(rgb_path)) | ||||
|         rgbs = [] | ||||
|         for i, img_path in enumerate(img_paths): | ||||
|             rgbs.append(imageio.v2.imread(os.path.join(rgb_path, img_path))) | ||||
|  | ||||
|         rgbs = np.stack(rgbs) | ||||
|         annot_dict = np.load(npy_path, allow_pickle=True).item() | ||||
|         traj_2d = annot_dict["coords"] | ||||
|         visibility = annot_dict["visibility"] | ||||
|  | ||||
|         # random crop | ||||
|         assert self.seq_len <= len(rgbs) | ||||
|         if self.seq_len < len(rgbs): | ||||
|             start_ind = np.random.choice(len(rgbs) - self.seq_len, 1)[0] | ||||
|  | ||||
|             rgbs = rgbs[start_ind : start_ind + self.seq_len] | ||||
|             traj_2d = traj_2d[:, start_ind : start_ind + self.seq_len] | ||||
|             visibility = visibility[:, start_ind : start_ind + self.seq_len] | ||||
|  | ||||
|         traj_2d = np.transpose(traj_2d, (1, 0, 2)) | ||||
|         visibility = np.transpose(np.logical_not(visibility), (1, 0)) | ||||
|         if self.use_augs: | ||||
|             rgbs, traj_2d, visibility = self.add_photometric_augs( | ||||
|                 rgbs, traj_2d, visibility | ||||
|             ) | ||||
|             rgbs, traj_2d = self.add_spatial_augs(rgbs, traj_2d, visibility) | ||||
|         else: | ||||
|             rgbs, traj_2d = self.crop(rgbs, traj_2d) | ||||
|  | ||||
|         visibility[traj_2d[:, :, 0] > self.crop_size[1] - 1] = False | ||||
|         visibility[traj_2d[:, :, 0] < 0] = False | ||||
|         visibility[traj_2d[:, :, 1] > self.crop_size[0] - 1] = False | ||||
|         visibility[traj_2d[:, :, 1] < 0] = False | ||||
|  | ||||
|         visibility = torch.from_numpy(visibility) | ||||
|         traj_2d = torch.from_numpy(traj_2d) | ||||
|  | ||||
|         visibile_pts_first_frame_inds = (visibility[0]).nonzero(as_tuple=False)[:, 0] | ||||
|  | ||||
|         if self.sample_vis_1st_frame: | ||||
|             visibile_pts_inds = visibile_pts_first_frame_inds | ||||
|         else: | ||||
|             visibile_pts_mid_frame_inds = (visibility[self.seq_len // 2]).nonzero( | ||||
|                 as_tuple=False | ||||
|             )[:, 0] | ||||
|             visibile_pts_inds = torch.cat( | ||||
|                 (visibile_pts_first_frame_inds, visibile_pts_mid_frame_inds), dim=0 | ||||
|             ) | ||||
|         point_inds = torch.randperm(len(visibile_pts_inds))[: self.traj_per_sample] | ||||
|         if len(point_inds) < self.traj_per_sample: | ||||
|             gotit = False | ||||
|  | ||||
|         visible_inds_sampled = visibile_pts_inds[point_inds] | ||||
|  | ||||
|         trajs = traj_2d[:, visible_inds_sampled].float() | ||||
|         visibles = visibility[:, visible_inds_sampled] | ||||
|         valids = torch.ones((self.seq_len, self.traj_per_sample)) | ||||
|  | ||||
|         rgbs = torch.from_numpy(np.stack(rgbs)).permute(0, 3, 1, 2).float() | ||||
|         segs = torch.ones((self.seq_len, 1, self.crop_size[0], self.crop_size[1])) | ||||
|         sample = CoTrackerData( | ||||
|             video=rgbs, | ||||
|             segmentation=segs, | ||||
|             trajectory=trajs, | ||||
|             visibility=visibles, | ||||
|             valid=valids, | ||||
|             seq_name=seq_name, | ||||
|         ) | ||||
|         return sample, gotit | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self.seq_names) | ||||
							
								
								
									
										218
									
								
								cotracker/datasets/tap_vid_datasets.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										218
									
								
								cotracker/datasets/tap_vid_datasets.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,218 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| import os | ||||
| import io | ||||
| import glob | ||||
| import torch | ||||
| import pickle | ||||
| import numpy as np | ||||
| import mediapy as media | ||||
|  | ||||
| from PIL import Image | ||||
| from typing import Mapping, Tuple, Union | ||||
|  | ||||
| from cotracker.datasets.utils import CoTrackerData | ||||
|  | ||||
| DatasetElement = Mapping[str, Mapping[str, Union[np.ndarray, str]]] | ||||
|  | ||||
|  | ||||
| def resize_video(video: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray: | ||||
|     """Resize a video to output_size.""" | ||||
|     # If you have a GPU, consider replacing this with a GPU-enabled resize op, | ||||
|     # such as a jitted jax.image.resize.  It will make things faster. | ||||
|     return media.resize_video(video, output_size) | ||||
|  | ||||
|  | ||||
| def sample_queries_first( | ||||
|     target_occluded: np.ndarray, | ||||
|     target_points: np.ndarray, | ||||
|     frames: np.ndarray, | ||||
| ) -> Mapping[str, np.ndarray]: | ||||
|     """Package a set of frames and tracks for use in TAPNet evaluations. | ||||
|     Given a set of frames and tracks with no query points, use the first | ||||
|     visible point in each track as the query. | ||||
|     Args: | ||||
|       target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames], | ||||
|         where True indicates occluded. | ||||
|       target_points: Position, of shape [n_tracks, n_frames, 2], where each point | ||||
|         is [x,y] scaled between 0 and 1. | ||||
|       frames: Video tensor, of shape [n_frames, height, width, 3].  Scaled between | ||||
|         -1 and 1. | ||||
|     Returns: | ||||
|       A dict with the keys: | ||||
|         video: Video tensor of shape [1, n_frames, height, width, 3] | ||||
|         query_points: Query points of shape [1, n_queries, 3] where | ||||
|           each point is [t, y, x] scaled to the range [-1, 1] | ||||
|         target_points: Target points of shape [1, n_queries, n_frames, 2] where | ||||
|           each point is [x, y] scaled to the range [-1, 1] | ||||
|     """ | ||||
|     valid = np.sum(~target_occluded, axis=1) > 0 | ||||
|     target_points = target_points[valid, :] | ||||
|     target_occluded = target_occluded[valid, :] | ||||
|  | ||||
|     query_points = [] | ||||
|     for i in range(target_points.shape[0]): | ||||
|         index = np.where(target_occluded[i] == 0)[0][0] | ||||
|         x, y = target_points[i, index, 0], target_points[i, index, 1] | ||||
|         query_points.append(np.array([index, y, x]))  # [t, y, x] | ||||
|     query_points = np.stack(query_points, axis=0) | ||||
|  | ||||
|     return { | ||||
|         "video": frames[np.newaxis, ...], | ||||
|         "query_points": query_points[np.newaxis, ...], | ||||
|         "target_points": target_points[np.newaxis, ...], | ||||
|         "occluded": target_occluded[np.newaxis, ...], | ||||
|     } | ||||
|  | ||||
|  | ||||
| def sample_queries_strided( | ||||
|     target_occluded: np.ndarray, | ||||
|     target_points: np.ndarray, | ||||
|     frames: np.ndarray, | ||||
|     query_stride: int = 5, | ||||
| ) -> Mapping[str, np.ndarray]: | ||||
|     """Package a set of frames and tracks for use in TAPNet evaluations. | ||||
|  | ||||
|     Given a set of frames and tracks with no query points, sample queries | ||||
|     strided every query_stride frames, ignoring points that are not visible | ||||
|     at the selected frames. | ||||
|  | ||||
|     Args: | ||||
|       target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames], | ||||
|         where True indicates occluded. | ||||
|       target_points: Position, of shape [n_tracks, n_frames, 2], where each point | ||||
|         is [x,y] scaled between 0 and 1. | ||||
|       frames: Video tensor, of shape [n_frames, height, width, 3].  Scaled between | ||||
|         -1 and 1. | ||||
|       query_stride: When sampling query points, search for un-occluded points | ||||
|         every query_stride frames and convert each one into a query. | ||||
|  | ||||
|     Returns: | ||||
|       A dict with the keys: | ||||
|         video: Video tensor of shape [1, n_frames, height, width, 3].  The video | ||||
|           has floats scaled to the range [-1, 1]. | ||||
|         query_points: Query points of shape [1, n_queries, 3] where | ||||
|           each point is [t, y, x] scaled to the range [-1, 1]. | ||||
|         target_points: Target points of shape [1, n_queries, n_frames, 2] where | ||||
|           each point is [x, y] scaled to the range [-1, 1]. | ||||
|         trackgroup: Index of the original track that each query point was | ||||
|           sampled from.  This is useful for visualization. | ||||
|     """ | ||||
|     tracks = [] | ||||
|     occs = [] | ||||
|     queries = [] | ||||
|     trackgroups = [] | ||||
|     total = 0 | ||||
|     trackgroup = np.arange(target_occluded.shape[0]) | ||||
|     for i in range(0, target_occluded.shape[1], query_stride): | ||||
|         mask = target_occluded[:, i] == 0 | ||||
|         query = np.stack( | ||||
|             [ | ||||
|                 i * np.ones(target_occluded.shape[0:1]), | ||||
|                 target_points[:, i, 1], | ||||
|                 target_points[:, i, 0], | ||||
|             ], | ||||
|             axis=-1, | ||||
|         ) | ||||
|         queries.append(query[mask]) | ||||
|         tracks.append(target_points[mask]) | ||||
|         occs.append(target_occluded[mask]) | ||||
|         trackgroups.append(trackgroup[mask]) | ||||
|         total += np.array(np.sum(target_occluded[:, i] == 0)) | ||||
|  | ||||
|     return { | ||||
|         "video": frames[np.newaxis, ...], | ||||
|         "query_points": np.concatenate(queries, axis=0)[np.newaxis, ...], | ||||
|         "target_points": np.concatenate(tracks, axis=0)[np.newaxis, ...], | ||||
|         "occluded": np.concatenate(occs, axis=0)[np.newaxis, ...], | ||||
|         "trackgroup": np.concatenate(trackgroups, axis=0)[np.newaxis, ...], | ||||
|     } | ||||
|  | ||||
|  | ||||
| class TapVidDataset(torch.utils.data.Dataset): | ||||
|     def __init__( | ||||
|         self, | ||||
|         data_root, | ||||
|         dataset_type="davis", | ||||
|         resize_to_256=True, | ||||
|         queried_first=True, | ||||
|     ): | ||||
|         self.dataset_type = dataset_type | ||||
|         self.resize_to_256 = resize_to_256 | ||||
|         self.queried_first = queried_first | ||||
|         if self.dataset_type == "kinetics": | ||||
|             all_paths = glob.glob(os.path.join(data_root, "*_of_0010.pkl")) | ||||
|             points_dataset = [] | ||||
|             for pickle_path in all_paths: | ||||
|                 with open(pickle_path, "rb") as f: | ||||
|                     data = pickle.load(f) | ||||
|                     points_dataset = points_dataset + data | ||||
|             self.points_dataset = points_dataset | ||||
|         else: | ||||
|             with open(data_root, "rb") as f: | ||||
|                 self.points_dataset = pickle.load(f) | ||||
|             if self.dataset_type == "davis": | ||||
|                 self.video_names = list(self.points_dataset.keys()) | ||||
|         print("found %d unique videos in %s" % (len(self.points_dataset), data_root)) | ||||
|  | ||||
|     def __getitem__(self, index): | ||||
|         if self.dataset_type == "davis": | ||||
|             video_name = self.video_names[index] | ||||
|         else: | ||||
|             video_name = index | ||||
|         video = self.points_dataset[video_name] | ||||
|         frames = video["video"] | ||||
|  | ||||
|         if isinstance(frames[0], bytes): | ||||
|             # TAP-Vid is stored and JPEG bytes rather than `np.ndarray`s. | ||||
|             def decode(frame): | ||||
|                 byteio = io.BytesIO(frame) | ||||
|                 img = Image.open(byteio) | ||||
|                 return np.array(img) | ||||
|  | ||||
|             frames = np.array([decode(frame) for frame in frames]) | ||||
|  | ||||
|         target_points = self.points_dataset[video_name]["points"] | ||||
|         if self.resize_to_256: | ||||
|             frames = resize_video(frames, [256, 256]) | ||||
|             target_points *= np.array([256, 256]) | ||||
|         else: | ||||
|             target_points *= np.array([frames.shape[2], frames.shape[1]]) | ||||
|  | ||||
|         T, H, W, C = frames.shape | ||||
|         N, T, D = target_points.shape | ||||
|  | ||||
|         target_occ = self.points_dataset[video_name]["occluded"] | ||||
|         if self.queried_first: | ||||
|             converted = sample_queries_first(target_occ, target_points, frames) | ||||
|         else: | ||||
|             converted = sample_queries_strided(target_occ, target_points, frames) | ||||
|         assert converted["target_points"].shape[1] == converted["query_points"].shape[1] | ||||
|  | ||||
|         trajs = ( | ||||
|             torch.from_numpy(converted["target_points"])[0].permute(1, 0, 2).float() | ||||
|         )  # T, N, D | ||||
|  | ||||
|         rgbs = torch.from_numpy(frames).permute(0, 3, 1, 2).float() | ||||
|         segs = torch.ones(T, 1, H, W).float() | ||||
|         visibles = torch.logical_not(torch.from_numpy(converted["occluded"]))[ | ||||
|             0 | ||||
|         ].permute( | ||||
|             1, 0 | ||||
|         )  # T, N | ||||
|         query_points = torch.from_numpy(converted["query_points"])[0]  # T, N | ||||
|         return CoTrackerData( | ||||
|             rgbs, | ||||
|             segs, | ||||
|             trajs, | ||||
|             visibles, | ||||
|             seq_name=str(video_name), | ||||
|             query_points=query_points, | ||||
|         ) | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self.points_dataset) | ||||
							
								
								
									
										114
									
								
								cotracker/datasets/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										114
									
								
								cotracker/datasets/utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,114 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
|  | ||||
| import torch | ||||
| import dataclasses | ||||
| import torch.nn.functional as F | ||||
| from dataclasses import dataclass | ||||
| from typing import Any, Optional | ||||
|  | ||||
|  | ||||
| @dataclass(eq=False) | ||||
| class CoTrackerData: | ||||
|     """ | ||||
|     Dataclass for storing video tracks data. | ||||
|     """ | ||||
|  | ||||
|     video: torch.Tensor  # B, S, C, H, W | ||||
|     segmentation: torch.Tensor  # B, S, 1, H, W | ||||
|     trajectory: torch.Tensor  # B, S, N, 2 | ||||
|     visibility: torch.Tensor  # B, S, N | ||||
|     # optional data | ||||
|     valid: Optional[torch.Tensor] = None  # B, S, N | ||||
|     seq_name: Optional[str] = None | ||||
|     query_points: Optional[torch.Tensor] = None  # TapVID evaluation format | ||||
|  | ||||
|  | ||||
| def collate_fn(batch): | ||||
|     """ | ||||
|     Collate function for video tracks data. | ||||
|     """ | ||||
|     video = torch.stack([b.video for b in batch], dim=0) | ||||
|     segmentation = torch.stack([b.segmentation for b in batch], dim=0) | ||||
|     trajectory = torch.stack([b.trajectory for b in batch], dim=0) | ||||
|     visibility = torch.stack([b.visibility for b in batch], dim=0) | ||||
|     query_points = None | ||||
|     if batch[0].query_points is not None: | ||||
|         query_points = torch.stack([b.query_points for b in batch], dim=0) | ||||
|     seq_name = [b.seq_name for b in batch] | ||||
|  | ||||
|     return CoTrackerData( | ||||
|         video, | ||||
|         segmentation, | ||||
|         trajectory, | ||||
|         visibility, | ||||
|         seq_name=seq_name, | ||||
|         query_points=query_points, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def collate_fn_train(batch): | ||||
|     """ | ||||
|     Collate function for video tracks data during training. | ||||
|     """ | ||||
|     gotit = [gotit for _, gotit in batch] | ||||
|     video = torch.stack([b.video for b, _ in batch], dim=0) | ||||
|     segmentation = torch.stack([b.segmentation for b, _ in batch], dim=0) | ||||
|     trajectory = torch.stack([b.trajectory for b, _ in batch], dim=0) | ||||
|     visibility = torch.stack([b.visibility for b, _ in batch], dim=0) | ||||
|     valid = torch.stack([b.valid for b, _ in batch], dim=0) | ||||
|     seq_name = [b.seq_name for b, _ in batch] | ||||
|     return ( | ||||
|         CoTrackerData(video, segmentation, trajectory, visibility, valid, seq_name), | ||||
|         gotit, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def try_to_cuda(t: Any) -> Any: | ||||
|     """ | ||||
|     Try to move the input variable `t` to a cuda device. | ||||
|  | ||||
|     Args: | ||||
|         t: Input. | ||||
|  | ||||
|     Returns: | ||||
|         t_cuda: `t` moved to a cuda device, if supported. | ||||
|     """ | ||||
|     try: | ||||
|         t = t.float().cuda() | ||||
|     except AttributeError: | ||||
|         pass | ||||
|     return t | ||||
|  | ||||
|  | ||||
| def dataclass_to_cuda_(obj): | ||||
|     """ | ||||
|     Move all contents of a dataclass to cuda inplace if supported. | ||||
|  | ||||
|     Args: | ||||
|         batch: Input dataclass. | ||||
|  | ||||
|     Returns: | ||||
|         batch_cuda: `batch` moved to a cuda device, if supported. | ||||
|     """ | ||||
|     for f in dataclasses.fields(obj): | ||||
|         setattr(obj, f.name, try_to_cuda(getattr(obj, f.name))) | ||||
|     return obj | ||||
|  | ||||
|  | ||||
| def resize_sample(rgbs, trajs_g, segs, interp_shape): | ||||
|     S, C, H, W = rgbs.shape | ||||
|     S, N, D = trajs_g.shape | ||||
|  | ||||
|     assert D == 2 | ||||
|  | ||||
|     rgbs = F.interpolate(rgbs, interp_shape, mode="bilinear") | ||||
|     segs = F.interpolate(segs, interp_shape, mode="nearest") | ||||
|  | ||||
|     trajs_g[:, :, 0] *= interp_shape[1] / W | ||||
|     trajs_g[:, :, 1] *= interp_shape[0] / H | ||||
|     return rgbs, trajs_g, segs | ||||
							
								
								
									
										5
									
								
								cotracker/evaluation/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								cotracker/evaluation/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
							
								
								
									
										6
									
								
								cotracker/evaluation/configs/eval_badja.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								cotracker/evaluation/configs/eval_badja.yaml
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,6 @@ | ||||
| defaults: | ||||
|   - default_config_eval | ||||
| exp_dir: ./outputs/cotracker | ||||
| dataset_name: badja | ||||
|  | ||||
|     | ||||
							
								
								
									
										6
									
								
								cotracker/evaluation/configs/eval_fastcapture.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								cotracker/evaluation/configs/eval_fastcapture.yaml
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,6 @@ | ||||
| defaults: | ||||
|   - default_config_eval | ||||
| exp_dir: ./outputs/cotracker | ||||
| dataset_name: fastcapture | ||||
|  | ||||
|     | ||||
| @@ -0,0 +1,6 @@ | ||||
| defaults: | ||||
|   - default_config_eval | ||||
| exp_dir: ./outputs/cotracker | ||||
| dataset_name: tapvid_davis_first | ||||
|  | ||||
|     | ||||
| @@ -0,0 +1,6 @@ | ||||
| defaults: | ||||
|   - default_config_eval | ||||
| exp_dir: ./outputs/cotracker | ||||
| dataset_name: tapvid_davis_strided | ||||
|  | ||||
|     | ||||
| @@ -0,0 +1,6 @@ | ||||
| defaults: | ||||
|   - default_config_eval | ||||
| exp_dir: ./outputs/cotracker | ||||
| dataset_name: tapvid_kinetics_first | ||||
|  | ||||
|     | ||||
							
								
								
									
										5
									
								
								cotracker/evaluation/core/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								cotracker/evaluation/core/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
							
								
								
									
										144
									
								
								cotracker/evaluation/core/eval_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										144
									
								
								cotracker/evaluation/core/eval_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,144 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| import numpy as np | ||||
|  | ||||
| from typing import Iterable, Mapping, Tuple, Union | ||||
|  | ||||
|  | ||||
| def compute_tapvid_metrics( | ||||
|     query_points: np.ndarray, | ||||
|     gt_occluded: np.ndarray, | ||||
|     gt_tracks: np.ndarray, | ||||
|     pred_occluded: np.ndarray, | ||||
|     pred_tracks: np.ndarray, | ||||
|     query_mode: str, | ||||
| ) -> Mapping[str, np.ndarray]: | ||||
|     """Computes TAP-Vid metrics (Jaccard, Pts. Within Thresh, Occ. Acc.) | ||||
|     See the TAP-Vid paper for details on the metric computation.  All inputs are | ||||
|     given in raster coordinates.  The first three arguments should be the direct | ||||
|     outputs of the reader: the 'query_points', 'occluded', and 'target_points'. | ||||
|     The paper metrics assume these are scaled relative to 256x256 images. | ||||
|     pred_occluded and pred_tracks are your algorithm's predictions. | ||||
|     This function takes a batch of inputs, and computes metrics separately for | ||||
|     each video.  The metrics for the full benchmark are a simple mean of the | ||||
|     metrics across the full set of videos.  These numbers are between 0 and 1, | ||||
|     but the paper multiplies them by 100 to ease reading. | ||||
|     Args: | ||||
|        query_points: The query points, an in the format [t, y, x].  Its size is | ||||
|          [b, n, 3], where b is the batch size and n is the number of queries | ||||
|        gt_occluded: A boolean array of shape [b, n, t], where t is the number | ||||
|          of frames.  True indicates that the point is occluded. | ||||
|        gt_tracks: The target points, of shape [b, n, t, 2].  Each point is | ||||
|          in the format [x, y] | ||||
|        pred_occluded: A boolean array of predicted occlusions, in the same | ||||
|          format as gt_occluded. | ||||
|        pred_tracks: An array of track predictions from your algorithm, in the | ||||
|          same format as gt_tracks. | ||||
|        query_mode: Either 'first' or 'strided', depending on how queries are | ||||
|          sampled.  If 'first', we assume the prior knowledge that all points | ||||
|          before the query point are occluded, and these are removed from the | ||||
|          evaluation. | ||||
|     Returns: | ||||
|         A dict with the following keys: | ||||
|         occlusion_accuracy: Accuracy at predicting occlusion. | ||||
|         pts_within_{x} for x in [1, 2, 4, 8, 16]: Fraction of points | ||||
|           predicted to be within the given pixel threshold, ignoring occlusion | ||||
|           prediction. | ||||
|         jaccard_{x} for x in [1, 2, 4, 8, 16]: Jaccard metric for the given | ||||
|           threshold | ||||
|         average_pts_within_thresh: average across pts_within_{x} | ||||
|         average_jaccard: average across jaccard_{x} | ||||
|     """ | ||||
|  | ||||
|     metrics = {} | ||||
|  | ||||
|     # Don't evaluate the query point.  Numpy doesn't have one_hot, so we | ||||
|     # replicate it by indexing into an identity matrix. | ||||
|     one_hot_eye = np.eye(gt_tracks.shape[2]) | ||||
|     query_frame = query_points[..., 0] | ||||
|     query_frame = np.round(query_frame).astype(np.int32) | ||||
|     evaluation_points = one_hot_eye[query_frame] == 0 | ||||
|  | ||||
|     # If we're using the first point on the track as a query, don't evaluate the | ||||
|     # other points. | ||||
|     if query_mode == "first": | ||||
|         for i in range(gt_occluded.shape[0]): | ||||
|             index = np.where(gt_occluded[i] == 0)[0][0] | ||||
|             evaluation_points[i, :index] = False | ||||
|     elif query_mode != "strided": | ||||
|         raise ValueError("Unknown query mode " + query_mode) | ||||
|  | ||||
|     # Occlusion accuracy is simply how often the predicted occlusion equals the | ||||
|     # ground truth. | ||||
|     occ_acc = ( | ||||
|         np.sum( | ||||
|             np.equal(pred_occluded, gt_occluded) & evaluation_points, | ||||
|             axis=(1, 2), | ||||
|         ) | ||||
|         / np.sum(evaluation_points) | ||||
|     ) | ||||
|     metrics["occlusion_accuracy"] = occ_acc | ||||
|  | ||||
|     # Next, convert the predictions and ground truth positions into pixel | ||||
|     # coordinates. | ||||
|     visible = np.logical_not(gt_occluded) | ||||
|     pred_visible = np.logical_not(pred_occluded) | ||||
|     all_frac_within = [] | ||||
|     all_jaccard = [] | ||||
|     for thresh in [1, 2, 4, 8, 16]: | ||||
|         # True positives are points that are within the threshold and where both | ||||
|         # the prediction and the ground truth are listed as visible. | ||||
|         within_dist = ( | ||||
|             np.sum( | ||||
|                 np.square(pred_tracks - gt_tracks), | ||||
|                 axis=-1, | ||||
|             ) | ||||
|             < np.square(thresh) | ||||
|         ) | ||||
|         is_correct = np.logical_and(within_dist, visible) | ||||
|  | ||||
|         # Compute the frac_within_threshold, which is the fraction of points | ||||
|         # within the threshold among points that are visible in the ground truth, | ||||
|         # ignoring whether they're predicted to be visible. | ||||
|         count_correct = np.sum( | ||||
|             is_correct & evaluation_points, | ||||
|             axis=(1, 2), | ||||
|         ) | ||||
|         count_visible_points = np.sum(visible & evaluation_points, axis=(1, 2)) | ||||
|         frac_correct = count_correct / count_visible_points | ||||
|         metrics["pts_within_" + str(thresh)] = frac_correct | ||||
|         all_frac_within.append(frac_correct) | ||||
|  | ||||
|         true_positives = np.sum( | ||||
|             is_correct & pred_visible & evaluation_points, axis=(1, 2) | ||||
|         ) | ||||
|  | ||||
|         # The denominator of the jaccard metric is the true positives plus | ||||
|         # false positives plus false negatives.  However, note that true positives | ||||
|         # plus false negatives is simply the number of points in the ground truth | ||||
|         # which is easier to compute than trying to compute all three quantities. | ||||
|         # Thus we just add the number of points in the ground truth to the number | ||||
|         # of false positives. | ||||
|         # | ||||
|         # False positives are simply points that are predicted to be visible, | ||||
|         # but the ground truth is not visible or too far from the prediction. | ||||
|         gt_positives = np.sum(visible & evaluation_points, axis=(1, 2)) | ||||
|         false_positives = (~visible) & pred_visible | ||||
|         false_positives = false_positives | ((~within_dist) & pred_visible) | ||||
|         false_positives = np.sum(false_positives & evaluation_points, axis=(1, 2)) | ||||
|         jaccard = true_positives / (gt_positives + false_positives) | ||||
|         metrics["jaccard_" + str(thresh)] = jaccard | ||||
|         all_jaccard.append(jaccard) | ||||
|     metrics["average_jaccard"] = np.mean( | ||||
|         np.stack(all_jaccard, axis=1), | ||||
|         axis=1, | ||||
|     ) | ||||
|     metrics["average_pts_within_thresh"] = np.mean( | ||||
|         np.stack(all_frac_within, axis=1), | ||||
|         axis=1, | ||||
|     ) | ||||
|     return metrics | ||||
							
								
								
									
										252
									
								
								cotracker/evaluation/core/evaluator.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										252
									
								
								cotracker/evaluation/core/evaluator.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,252 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| from collections import defaultdict | ||||
| import os | ||||
| from typing import Optional | ||||
| import torch | ||||
| from tqdm import tqdm | ||||
| import numpy as np | ||||
|  | ||||
| from torch.utils.tensorboard import SummaryWriter | ||||
| from cotracker.datasets.utils import dataclass_to_cuda_ | ||||
| from cotracker.utils.visualizer import Visualizer | ||||
| from cotracker.models.core.model_utils import reduce_masked_mean | ||||
| from cotracker.evaluation.core.eval_utils import compute_tapvid_metrics | ||||
|  | ||||
| import logging | ||||
|  | ||||
|  | ||||
| class Evaluator: | ||||
|     """ | ||||
|     A class defining the CoTracker evaluator. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, exp_dir) -> None: | ||||
|         # Visualization | ||||
|         self.exp_dir = exp_dir | ||||
|         os.makedirs(exp_dir, exist_ok=True) | ||||
|         self.visualization_filepaths = defaultdict(lambda: defaultdict(list)) | ||||
|         self.visualize_dir = os.path.join(exp_dir, "visualisations") | ||||
|  | ||||
|     def compute_metrics(self, metrics, sample, pred_trajectory, dataset_name): | ||||
|         if isinstance(pred_trajectory, tuple): | ||||
|             pred_trajectory, pred_visibility = pred_trajectory | ||||
|         else: | ||||
|             pred_visibility = None | ||||
|         if dataset_name == "badja": | ||||
|             sample.segmentation = (sample.segmentation > 0).float() | ||||
|             *_, N, _ = sample.trajectory.shape | ||||
|             accs = [] | ||||
|             accs_3px = [] | ||||
|             for s1 in range(1, sample.video.shape[1]):  # target frame | ||||
|                 for n in range(N): | ||||
|                     vis = sample.visibility[0, s1, n] | ||||
|                     if vis > 0: | ||||
|                         coord_e = pred_trajectory[0, s1, n]  # 2 | ||||
|                         coord_g = sample.trajectory[0, s1, n]  # 2 | ||||
|                         dist = torch.sqrt(torch.sum((coord_e - coord_g) ** 2, dim=0)) | ||||
|                         area = torch.sum(sample.segmentation[0, s1]) | ||||
|                         # print_('0.2*sqrt(area)', 0.2*torch.sqrt(area)) | ||||
|                         thr = 0.2 * torch.sqrt(area) | ||||
|                         # correct = | ||||
|                         accs.append((dist < thr).float()) | ||||
|                         # print('thr',thr) | ||||
|                         accs_3px.append((dist < 3.0).float()) | ||||
|  | ||||
|             res = torch.mean(torch.stack(accs)) * 100.0 | ||||
|             res_3px = torch.mean(torch.stack(accs_3px)) * 100.0 | ||||
|             metrics[sample.seq_name[0]] = res.item() | ||||
|             metrics[sample.seq_name[0] + "_accuracy"] = res_3px.item() | ||||
|             print(metrics) | ||||
|             print( | ||||
|                 "avg", np.mean([v for k, v in metrics.items() if "accuracy" not in k]) | ||||
|             ) | ||||
|             print( | ||||
|                 "avg acc 3px", | ||||
|                 np.mean([v for k, v in metrics.items() if "accuracy" in k]), | ||||
|             ) | ||||
|         elif dataset_name == "fastcapture" or ("kubric" in dataset_name): | ||||
|             *_, N, _ = sample.trajectory.shape | ||||
|             accs = [] | ||||
|             for s1 in range(1, sample.video.shape[1]):  # target frame | ||||
|                 for n in range(N): | ||||
|                     vis = sample.visibility[0, s1, n] | ||||
|                     if vis > 0: | ||||
|                         coord_e = pred_trajectory[0, s1, n]  # 2 | ||||
|                         coord_g = sample.trajectory[0, s1, n]  # 2 | ||||
|                         dist = torch.sqrt(torch.sum((coord_e - coord_g) ** 2, dim=0)) | ||||
|                         thr = 3 | ||||
|                         correct = (dist < thr).float() | ||||
|                         accs.append(correct) | ||||
|  | ||||
|             res = torch.mean(torch.stack(accs)) * 100.0 | ||||
|             metrics[sample.seq_name[0] + "_accuracy"] = res.item() | ||||
|             print(metrics) | ||||
|             print("avg", np.mean([v for v in metrics.values()])) | ||||
|         elif "tapvid" in dataset_name: | ||||
|             B, T, N, D = sample.trajectory.shape | ||||
|             traj = sample.trajectory.clone() | ||||
|             thr = 0.9 | ||||
|  | ||||
|             if pred_visibility is None: | ||||
|                 logging.warning("visibility is NONE") | ||||
|                 pred_visibility = torch.zeros_like(sample.visibility) | ||||
|  | ||||
|             if not pred_visibility.dtype == torch.bool: | ||||
|                 pred_visibility = pred_visibility > thr | ||||
|  | ||||
|             # pred_trajectory | ||||
|             query_points = sample.query_points.clone().cpu().numpy() | ||||
|  | ||||
|             pred_visibility = pred_visibility[:, :, :N] | ||||
|             pred_trajectory = pred_trajectory[:, :, :N] | ||||
|  | ||||
|             gt_tracks = traj.permute(0, 2, 1, 3).cpu().numpy() | ||||
|             gt_occluded = ( | ||||
|                 torch.logical_not(sample.visibility.clone().permute(0, 2, 1)) | ||||
|                 .cpu() | ||||
|                 .numpy() | ||||
|             ) | ||||
|  | ||||
|             pred_occluded = ( | ||||
|                 torch.logical_not(pred_visibility.clone().permute(0, 2, 1)) | ||||
|                 .cpu() | ||||
|                 .numpy() | ||||
|             ) | ||||
|             pred_tracks = pred_trajectory.permute(0, 2, 1, 3).cpu().numpy() | ||||
|  | ||||
|             out_metrics = compute_tapvid_metrics( | ||||
|                 query_points, | ||||
|                 gt_occluded, | ||||
|                 gt_tracks, | ||||
|                 pred_occluded, | ||||
|                 pred_tracks, | ||||
|                 query_mode="strided" if "strided" in dataset_name else "first", | ||||
|             ) | ||||
|  | ||||
|             metrics[sample.seq_name[0]] = out_metrics | ||||
|             for metric_name in out_metrics.keys(): | ||||
|                 if "avg" not in metrics: | ||||
|                     metrics["avg"] = {} | ||||
|                 metrics["avg"][metric_name] = np.mean( | ||||
|                     [v[metric_name] for k, v in metrics.items() if k != "avg"] | ||||
|                 ) | ||||
|  | ||||
|             logging.info(f"Metrics: {out_metrics}") | ||||
|             logging.info(f"avg: {metrics['avg']}") | ||||
|             print("metrics", out_metrics) | ||||
|             print("avg", metrics["avg"]) | ||||
|         else: | ||||
|             rgbs = sample.video | ||||
|             trajs_g = sample.trajectory | ||||
|             valids = sample.valid | ||||
|             vis_g = sample.visibility | ||||
|  | ||||
|             B, S, C, H, W = rgbs.shape | ||||
|             assert C == 3 | ||||
|             B, S, N, D = trajs_g.shape | ||||
|  | ||||
|             assert torch.sum(valids) == B * S * N | ||||
|  | ||||
|             vis_g = (torch.sum(vis_g, dim=1, keepdim=True) >= 4).float().repeat(1, S, 1) | ||||
|  | ||||
|             ate = torch.norm(pred_trajectory - trajs_g, dim=-1)  # B, S, N | ||||
|  | ||||
|             metrics["things_all"] = reduce_masked_mean(ate, valids).item() | ||||
|             metrics["things_vis"] = reduce_masked_mean(ate, valids * vis_g).item() | ||||
|             metrics["things_occ"] = reduce_masked_mean( | ||||
|                 ate, valids * (1.0 - vis_g) | ||||
|             ).item() | ||||
|  | ||||
|     @torch.no_grad() | ||||
|     def evaluate_sequence( | ||||
|         self, | ||||
|         model, | ||||
|         test_dataloader: torch.utils.data.DataLoader, | ||||
|         dataset_name: str, | ||||
|         train_mode=False, | ||||
|         writer: Optional[SummaryWriter] = None, | ||||
|         step: Optional[int] = 0, | ||||
|     ): | ||||
|         metrics = {} | ||||
|  | ||||
|         vis = Visualizer( | ||||
|             save_dir=self.exp_dir, | ||||
|             fps=7, | ||||
|         ) | ||||
|  | ||||
|         for ind, sample in enumerate(tqdm(test_dataloader)): | ||||
|             if isinstance(sample, tuple): | ||||
|                 sample, gotit = sample | ||||
|                 if not all(gotit): | ||||
|                     print("batch is None") | ||||
|                     continue | ||||
|             dataclass_to_cuda_(sample) | ||||
|  | ||||
|             if ( | ||||
|                 not train_mode | ||||
|                 and hasattr(model, "sequence_len") | ||||
|                 and (sample.visibility[:, : model.sequence_len].sum() == 0) | ||||
|             ): | ||||
|                 print(f"skipping batch {ind}") | ||||
|                 continue | ||||
|  | ||||
|             if "tapvid" in dataset_name: | ||||
|                 queries = sample.query_points.clone().float() | ||||
|  | ||||
|                 queries = torch.stack( | ||||
|                     [ | ||||
|                         queries[:, :, 0], | ||||
|                         queries[:, :, 2], | ||||
|                         queries[:, :, 1], | ||||
|                     ], | ||||
|                     dim=2, | ||||
|                 ) | ||||
|             else: | ||||
|                 queries = torch.cat( | ||||
|                     [ | ||||
|                         torch.zeros_like(sample.trajectory[:, 0, :, :1]), | ||||
|                         sample.trajectory[:, 0], | ||||
|                     ], | ||||
|                     dim=2, | ||||
|                 ) | ||||
|  | ||||
|             pred_tracks = model(sample.video, queries) | ||||
|             if "strided" in dataset_name: | ||||
|  | ||||
|                 inv_video = sample.video.flip(1).clone() | ||||
|                 inv_queries = queries.clone() | ||||
|                 inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1 | ||||
|  | ||||
|                 pred_trj, pred_vsb = pred_tracks | ||||
|                 inv_pred_trj, inv_pred_vsb = model(inv_video, inv_queries) | ||||
|  | ||||
|                 inv_pred_trj = inv_pred_trj.flip(1) | ||||
|                 inv_pred_vsb = inv_pred_vsb.flip(1) | ||||
|  | ||||
|                 mask = pred_trj == 0 | ||||
|  | ||||
|                 pred_trj[mask] = inv_pred_trj[mask] | ||||
|                 pred_vsb[mask[:, :, :, 0]] = inv_pred_vsb[mask[:, :, :, 0]] | ||||
|  | ||||
|                 pred_tracks = pred_trj, pred_vsb | ||||
|  | ||||
|             if dataset_name == "badja" or dataset_name == "fastcapture": | ||||
|                 seq_name = sample.seq_name[0] | ||||
|             else: | ||||
|                 seq_name = str(ind) | ||||
|  | ||||
|             vis.visualize( | ||||
|                 sample.video, | ||||
|                 pred_tracks[0] if isinstance(pred_tracks, tuple) else pred_tracks, | ||||
|                 filename=dataset_name + "_" + seq_name, | ||||
|                 writer=writer, | ||||
|                 step=step, | ||||
|             ) | ||||
|  | ||||
|             self.compute_metrics(metrics, sample, pred_tracks, dataset_name) | ||||
|         return metrics | ||||
							
								
								
									
										179
									
								
								cotracker/evaluation/evaluate.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										179
									
								
								cotracker/evaluation/evaluate.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,179 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| import json | ||||
| import os | ||||
| from dataclasses import dataclass, field | ||||
|  | ||||
| import hydra | ||||
| import numpy as np | ||||
|  | ||||
| import torch | ||||
| from omegaconf import OmegaConf | ||||
|  | ||||
| from cotracker.datasets.badja_dataset import BadjaDataset | ||||
| from cotracker.datasets.fast_capture_dataset import FastCaptureDataset | ||||
| from cotracker.datasets.tap_vid_datasets import TapVidDataset | ||||
| from cotracker.datasets.utils import collate_fn | ||||
|  | ||||
| from cotracker.models.evaluation_predictor import EvaluationPredictor | ||||
|  | ||||
| from cotracker.evaluation.core.evaluator import Evaluator | ||||
| from cotracker.models.build_cotracker import ( | ||||
|     build_cotracker, | ||||
| ) | ||||
|  | ||||
|  | ||||
| @dataclass(eq=False) | ||||
| class DefaultConfig: | ||||
|     # Directory where all outputs of the experiment will be saved. | ||||
|     exp_dir: str = "./outputs" | ||||
|  | ||||
|     # Name of the dataset to be used for the evaluation. | ||||
|     dataset_name: str = "badja" | ||||
|     # The root directory of the dataset. | ||||
|     dataset_root: str = "./" | ||||
|  | ||||
|     # Path to the pre-trained model checkpoint to be used for the evaluation. | ||||
|     # The default value is the path to a specific CoTracker model checkpoint. | ||||
|     # Other available options are commented. | ||||
|     checkpoint: str = "./checkpoints/cotracker_stride_4_wind_8.pth" | ||||
|     # cotracker_stride_4_wind_12 | ||||
|     # cotracker_stride_8_wind_16 | ||||
|  | ||||
|     # EvaluationPredictor parameters | ||||
|     # The size (N) of the support grid used in the predictor. | ||||
|     # The total number of points is (N*N). | ||||
|     grid_size: int = 6 | ||||
|     # The size (N) of the local support grid. | ||||
|     local_grid_size: int = 6 | ||||
|     # A flag indicating whether to evaluate one ground truth point at a time. | ||||
|     single_point: bool = True | ||||
|     # The number of iterative updates for each sliding window. | ||||
|     n_iters: int = 6 | ||||
|  | ||||
|     seed: int = 0 | ||||
|     gpu_idx: int = 0 | ||||
|  | ||||
|     # Override hydra's working directory to current working dir, | ||||
|     # also disable storing the .hydra logs: | ||||
|     hydra: dict = field( | ||||
|         default_factory=lambda: { | ||||
|             "run": {"dir": "."}, | ||||
|             "output_subdir": None, | ||||
|         } | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def run_eval(cfg: DefaultConfig): | ||||
|     """ | ||||
|     The function evaluates CoTracker on a specified benchmark dataset based on a provided configuration. | ||||
|  | ||||
|     Args: | ||||
|         cfg (DefaultConfig): An instance of DefaultConfig class which includes: | ||||
|             - exp_dir (str): The directory path for the experiment. | ||||
|             - dataset_name (str): The name of the dataset to be used. | ||||
|             - dataset_root (str): The root directory of the dataset. | ||||
|             - checkpoint (str): The path to the CoTracker model's checkpoint. | ||||
|             - single_point (bool): A flag indicating whether to evaluate one ground truth point at a time. | ||||
|             - n_iters (int): The number of iterative updates for each sliding window. | ||||
|             - seed (int): The seed for setting the random state for reproducibility. | ||||
|             - gpu_idx (int): The index of the GPU to be used. | ||||
|     """ | ||||
|     # Creating the experiment directory if it doesn't exist | ||||
|     os.makedirs(cfg.exp_dir, exist_ok=True) | ||||
|  | ||||
|     # Saving the experiment configuration to a .yaml file in the experiment directory | ||||
|     cfg_file = os.path.join(cfg.exp_dir, "expconfig.yaml") | ||||
|     with open(cfg_file, "w") as f: | ||||
|         OmegaConf.save(config=cfg, f=f) | ||||
|  | ||||
|     evaluator = Evaluator(cfg.exp_dir) | ||||
|     cotracker_model = build_cotracker(cfg.checkpoint) | ||||
|  | ||||
|     # Creating the EvaluationPredictor object | ||||
|     predictor = EvaluationPredictor( | ||||
|         cotracker_model, | ||||
|         grid_size=cfg.grid_size, | ||||
|         local_grid_size=cfg.local_grid_size, | ||||
|         single_point=cfg.single_point, | ||||
|         n_iters=cfg.n_iters, | ||||
|     ) | ||||
|  | ||||
|     # Setting the random seeds | ||||
|     torch.manual_seed(cfg.seed) | ||||
|     np.random.seed(cfg.seed) | ||||
|  | ||||
|     # Constructing the specified dataset | ||||
|     curr_collate_fn = collate_fn | ||||
|     if cfg.dataset_name == "badja": | ||||
|         test_dataset = BadjaDataset(data_root=os.path.join(cfg.dataset_root, "BADJA")) | ||||
|     elif cfg.dataset_name == "fastcapture": | ||||
|         test_dataset = FastCaptureDataset( | ||||
|             data_root=os.path.join(cfg.dataset_root, "fastcapture"), | ||||
|             max_seq_len=100, | ||||
|             max_num_points=20, | ||||
|         ) | ||||
|     elif "tapvid" in cfg.dataset_name: | ||||
|         dataset_type = cfg.dataset_name.split("_")[1] | ||||
|         if dataset_type == "davis": | ||||
|             data_root = os.path.join(cfg.dataset_root, "/tapvid_davis/tapvid_davis.pkl") | ||||
|         elif dataset_type == "kinetics": | ||||
|             data_root = os.path.join( | ||||
|                 cfg.dataset_root, "/kinetics/kinetics-dataset/k700-2020/tapvid_kinetics" | ||||
|             ) | ||||
|         test_dataset = TapVidDataset( | ||||
|             dataset_type=dataset_type, | ||||
|             data_root=data_root, | ||||
|             queried_first=not "strided" in cfg.dataset_name, | ||||
|         ) | ||||
|  | ||||
|     # Creating the DataLoader object | ||||
|     test_dataloader = torch.utils.data.DataLoader( | ||||
|         test_dataset, | ||||
|         batch_size=1, | ||||
|         shuffle=False, | ||||
|         num_workers=14, | ||||
|         collate_fn=curr_collate_fn, | ||||
|     ) | ||||
|  | ||||
|     # Timing and conducting the evaluation | ||||
|     import time | ||||
|  | ||||
|     start = time.time() | ||||
|     evaluate_result = evaluator.evaluate_sequence( | ||||
|         predictor, | ||||
|         test_dataloader, | ||||
|         dataset_name=cfg.dataset_name, | ||||
|     ) | ||||
|     end = time.time() | ||||
|     print(end - start) | ||||
|  | ||||
|     # Saving the evaluation results to a .json file | ||||
|     if not "tapvid" in cfg.dataset_name: | ||||
|         print("evaluate_result", evaluate_result) | ||||
|     else: | ||||
|         evaluate_result = evaluate_result["avg"] | ||||
|     result_file = os.path.join(cfg.exp_dir, f"result_eval_.json") | ||||
|     evaluate_result["time"] = end - start | ||||
|     print(f"Dumping eval results to {result_file}.") | ||||
|     with open(result_file, "w") as f: | ||||
|         json.dump(evaluate_result, f) | ||||
|  | ||||
|  | ||||
| cs = hydra.core.config_store.ConfigStore.instance() | ||||
| cs.store(name="default_config_eval", node=DefaultConfig) | ||||
|  | ||||
|  | ||||
| @hydra.main(config_path="./configs/", config_name="default_config_eval") | ||||
| def evaluate(cfg: DefaultConfig) -> None: | ||||
|     os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | ||||
|     os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_idx) | ||||
|     run_eval(cfg) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     evaluate() | ||||
							
								
								
									
										5
									
								
								cotracker/models/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								cotracker/models/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
							
								
								
									
										70
									
								
								cotracker/models/build_cotracker.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										70
									
								
								cotracker/models/build_cotracker.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,70 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| import torch | ||||
|  | ||||
| from cotracker.models.core.cotracker.cotracker import CoTracker | ||||
|  | ||||
|  | ||||
| def build_cotracker( | ||||
|     checkpoint: str, | ||||
| ): | ||||
|     model_name = checkpoint.split("/")[-1].split(".")[0] | ||||
|     if model_name == "cotracker_stride_4_wind_8": | ||||
|         return build_cotracker_stride_4_wind_8(checkpoint=checkpoint) | ||||
|     elif model_name == "cotracker_stride_4_wind_12": | ||||
|         return build_cotracker_stride_4_wind_12(checkpoint=checkpoint) | ||||
|     elif model_name == "cotracker_stride_8_wind_16": | ||||
|         return build_cotracker_stride_8_wind_16(checkpoint=checkpoint) | ||||
|     else: | ||||
|         raise ValueError(f"Unknown model name {model_name}") | ||||
|  | ||||
|  | ||||
| # model used to produce the results in the paper | ||||
| def build_cotracker_stride_4_wind_8(checkpoint=None): | ||||
|     return _build_cotracker( | ||||
|         stride=4, | ||||
|         sequence_len=8, | ||||
|         checkpoint=checkpoint, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def build_cotracker_stride_4_wind_12(checkpoint=None): | ||||
|     return _build_cotracker( | ||||
|         stride=4, | ||||
|         sequence_len=12, | ||||
|         checkpoint=checkpoint, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| # the fastest model | ||||
| def build_cotracker_stride_8_wind_16(checkpoint=None): | ||||
|     return _build_cotracker( | ||||
|         stride=8, | ||||
|         sequence_len=16, | ||||
|         checkpoint=checkpoint, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def _build_cotracker( | ||||
|     stride, | ||||
|     sequence_len, | ||||
|     checkpoint=None, | ||||
| ): | ||||
|     cotracker = CoTracker( | ||||
|         stride=stride, | ||||
|         S=sequence_len, | ||||
|         add_space_attn=True, | ||||
|         space_depth=6, | ||||
|         time_depth=6, | ||||
|     ) | ||||
|     if checkpoint is not None: | ||||
|         with open(checkpoint, "rb") as f: | ||||
|             state_dict = torch.load(f, map_location="cpu") | ||||
|             if "model" in state_dict: | ||||
|                 state_dict = state_dict["model"] | ||||
|         cotracker.load_state_dict(state_dict) | ||||
|     return cotracker | ||||
							
								
								
									
										5
									
								
								cotracker/models/core/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								cotracker/models/core/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
							
								
								
									
										5
									
								
								cotracker/models/core/cotracker/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								cotracker/models/core/cotracker/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
							
								
								
									
										400
									
								
								cotracker/models/core/cotracker/blocks.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										400
									
								
								cotracker/models/core/cotracker/blocks.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,400 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
|  | ||||
| from einops import rearrange | ||||
| from timm.models.vision_transformer import Attention, Mlp | ||||
|  | ||||
|  | ||||
| class ResidualBlock(nn.Module): | ||||
|     def __init__(self, in_planes, planes, norm_fn="group", stride=1): | ||||
|         super(ResidualBlock, self).__init__() | ||||
|  | ||||
|         self.conv1 = nn.Conv2d( | ||||
|             in_planes, | ||||
|             planes, | ||||
|             kernel_size=3, | ||||
|             padding=1, | ||||
|             stride=stride, | ||||
|             padding_mode="zeros", | ||||
|         ) | ||||
|         self.conv2 = nn.Conv2d( | ||||
|             planes, planes, kernel_size=3, padding=1, padding_mode="zeros" | ||||
|         ) | ||||
|         self.relu = nn.ReLU(inplace=True) | ||||
|  | ||||
|         num_groups = planes // 8 | ||||
|  | ||||
|         if norm_fn == "group": | ||||
|             self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) | ||||
|             self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) | ||||
|             if not stride == 1: | ||||
|                 self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) | ||||
|  | ||||
|         elif norm_fn == "batch": | ||||
|             self.norm1 = nn.BatchNorm2d(planes) | ||||
|             self.norm2 = nn.BatchNorm2d(planes) | ||||
|             if not stride == 1: | ||||
|                 self.norm3 = nn.BatchNorm2d(planes) | ||||
|  | ||||
|         elif norm_fn == "instance": | ||||
|             self.norm1 = nn.InstanceNorm2d(planes) | ||||
|             self.norm2 = nn.InstanceNorm2d(planes) | ||||
|             if not stride == 1: | ||||
|                 self.norm3 = nn.InstanceNorm2d(planes) | ||||
|  | ||||
|         elif norm_fn == "none": | ||||
|             self.norm1 = nn.Sequential() | ||||
|             self.norm2 = nn.Sequential() | ||||
|             if not stride == 1: | ||||
|                 self.norm3 = nn.Sequential() | ||||
|  | ||||
|         if stride == 1: | ||||
|             self.downsample = None | ||||
|  | ||||
|         else: | ||||
|             self.downsample = nn.Sequential( | ||||
|                 nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3 | ||||
|             ) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         y = x | ||||
|         y = self.relu(self.norm1(self.conv1(y))) | ||||
|         y = self.relu(self.norm2(self.conv2(y))) | ||||
|  | ||||
|         if self.downsample is not None: | ||||
|             x = self.downsample(x) | ||||
|  | ||||
|         return self.relu(x + y) | ||||
|  | ||||
|  | ||||
| class BasicEncoder(nn.Module): | ||||
|     def __init__( | ||||
|         self, input_dim=3, output_dim=128, stride=8, norm_fn="batch", dropout=0.0 | ||||
|     ): | ||||
|         super(BasicEncoder, self).__init__() | ||||
|         self.stride = stride | ||||
|         self.norm_fn = norm_fn | ||||
|         self.in_planes = 64 | ||||
|  | ||||
|         if self.norm_fn == "group": | ||||
|             self.norm1 = nn.GroupNorm(num_groups=8, num_channels=self.in_planes) | ||||
|             self.norm2 = nn.GroupNorm(num_groups=8, num_channels=output_dim * 2) | ||||
|  | ||||
|         elif self.norm_fn == "batch": | ||||
|             self.norm1 = nn.BatchNorm2d(self.in_planes) | ||||
|             self.norm2 = nn.BatchNorm2d(output_dim * 2) | ||||
|  | ||||
|         elif self.norm_fn == "instance": | ||||
|             self.norm1 = nn.InstanceNorm2d(self.in_planes) | ||||
|             self.norm2 = nn.InstanceNorm2d(output_dim * 2) | ||||
|  | ||||
|         elif self.norm_fn == "none": | ||||
|             self.norm1 = nn.Sequential() | ||||
|  | ||||
|         self.conv1 = nn.Conv2d( | ||||
|             input_dim, | ||||
|             self.in_planes, | ||||
|             kernel_size=7, | ||||
|             stride=2, | ||||
|             padding=3, | ||||
|             padding_mode="zeros", | ||||
|         ) | ||||
|         self.relu1 = nn.ReLU(inplace=True) | ||||
|  | ||||
|         self.shallow = False | ||||
|         if self.shallow: | ||||
|             self.layer1 = self._make_layer(64, stride=1) | ||||
|             self.layer2 = self._make_layer(96, stride=2) | ||||
|             self.layer3 = self._make_layer(128, stride=2) | ||||
|             self.conv2 = nn.Conv2d(128 + 96 + 64, output_dim, kernel_size=1) | ||||
|         else: | ||||
|             self.layer1 = self._make_layer(64, stride=1) | ||||
|             self.layer2 = self._make_layer(96, stride=2) | ||||
|             self.layer3 = self._make_layer(128, stride=2) | ||||
|             self.layer4 = self._make_layer(128, stride=2) | ||||
|  | ||||
|             self.conv2 = nn.Conv2d( | ||||
|                 128 + 128 + 96 + 64, | ||||
|                 output_dim * 2, | ||||
|                 kernel_size=3, | ||||
|                 padding=1, | ||||
|                 padding_mode="zeros", | ||||
|             ) | ||||
|             self.relu2 = nn.ReLU(inplace=True) | ||||
|             self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1) | ||||
|  | ||||
|         self.dropout = None | ||||
|         if dropout > 0: | ||||
|             self.dropout = nn.Dropout2d(p=dropout) | ||||
|  | ||||
|         for m in self.modules(): | ||||
|             if isinstance(m, nn.Conv2d): | ||||
|                 nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") | ||||
|             elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): | ||||
|                 if m.weight is not None: | ||||
|                     nn.init.constant_(m.weight, 1) | ||||
|                 if m.bias is not None: | ||||
|                     nn.init.constant_(m.bias, 0) | ||||
|  | ||||
|     def _make_layer(self, dim, stride=1): | ||||
|         layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) | ||||
|         layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) | ||||
|         layers = (layer1, layer2) | ||||
|  | ||||
|         self.in_planes = dim | ||||
|         return nn.Sequential(*layers) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         _, _, H, W = x.shape | ||||
|  | ||||
|         x = self.conv1(x) | ||||
|         x = self.norm1(x) | ||||
|         x = self.relu1(x) | ||||
|  | ||||
|         if self.shallow: | ||||
|             a = self.layer1(x) | ||||
|             b = self.layer2(a) | ||||
|             c = self.layer3(b) | ||||
|             a = F.interpolate( | ||||
|                 a, | ||||
|                 (H // self.stride, W // self.stride), | ||||
|                 mode="bilinear", | ||||
|                 align_corners=True, | ||||
|             ) | ||||
|             b = F.interpolate( | ||||
|                 b, | ||||
|                 (H // self.stride, W // self.stride), | ||||
|                 mode="bilinear", | ||||
|                 align_corners=True, | ||||
|             ) | ||||
|             c = F.interpolate( | ||||
|                 c, | ||||
|                 (H // self.stride, W // self.stride), | ||||
|                 mode="bilinear", | ||||
|                 align_corners=True, | ||||
|             ) | ||||
|             x = self.conv2(torch.cat([a, b, c], dim=1)) | ||||
|         else: | ||||
|             a = self.layer1(x) | ||||
|             b = self.layer2(a) | ||||
|             c = self.layer3(b) | ||||
|             d = self.layer4(c) | ||||
|             a = F.interpolate( | ||||
|                 a, | ||||
|                 (H // self.stride, W // self.stride), | ||||
|                 mode="bilinear", | ||||
|                 align_corners=True, | ||||
|             ) | ||||
|             b = F.interpolate( | ||||
|                 b, | ||||
|                 (H // self.stride, W // self.stride), | ||||
|                 mode="bilinear", | ||||
|                 align_corners=True, | ||||
|             ) | ||||
|             c = F.interpolate( | ||||
|                 c, | ||||
|                 (H // self.stride, W // self.stride), | ||||
|                 mode="bilinear", | ||||
|                 align_corners=True, | ||||
|             ) | ||||
|             d = F.interpolate( | ||||
|                 d, | ||||
|                 (H // self.stride, W // self.stride), | ||||
|                 mode="bilinear", | ||||
|                 align_corners=True, | ||||
|             ) | ||||
|             x = self.conv2(torch.cat([a, b, c, d], dim=1)) | ||||
|             x = self.norm2(x) | ||||
|             x = self.relu2(x) | ||||
|             x = self.conv3(x) | ||||
|  | ||||
|         if self.training and self.dropout is not None: | ||||
|             x = self.dropout(x) | ||||
|         return x | ||||
|  | ||||
|  | ||||
| class AttnBlock(nn.Module): | ||||
|     """ | ||||
|     A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): | ||||
|         super().__init__() | ||||
|         self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | ||||
|         self.attn = Attention( | ||||
|             hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs | ||||
|         ) | ||||
|  | ||||
|         self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | ||||
|         mlp_hidden_dim = int(hidden_size * mlp_ratio) | ||||
|         approx_gelu = lambda: nn.GELU(approximate="tanh") | ||||
|         self.mlp = Mlp( | ||||
|             in_features=hidden_size, | ||||
|             hidden_features=mlp_hidden_dim, | ||||
|             act_layer=approx_gelu, | ||||
|             drop=0, | ||||
|         ) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         x = x + self.attn(self.norm1(x)) | ||||
|         x = x + self.mlp(self.norm2(x)) | ||||
|         return x | ||||
|  | ||||
|  | ||||
| def bilinear_sampler(img, coords, mode="bilinear", mask=False): | ||||
|     """Wrapper for grid_sample, uses pixel coordinates""" | ||||
|     H, W = img.shape[-2:] | ||||
|     xgrid, ygrid = coords.split([1, 1], dim=-1) | ||||
|     # go to 0,1 then 0,2 then -1,1 | ||||
|     xgrid = 2 * xgrid / (W - 1) - 1 | ||||
|     ygrid = 2 * ygrid / (H - 1) - 1 | ||||
|  | ||||
|     grid = torch.cat([xgrid, ygrid], dim=-1) | ||||
|     img = F.grid_sample(img, grid, align_corners=True) | ||||
|  | ||||
|     if mask: | ||||
|         mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) | ||||
|         return img, mask.float() | ||||
|  | ||||
|     return img | ||||
|  | ||||
|  | ||||
| class CorrBlock: | ||||
|     def __init__(self, fmaps, num_levels=4, radius=4): | ||||
|         B, S, C, H, W = fmaps.shape | ||||
|         self.S, self.C, self.H, self.W = S, C, H, W | ||||
|  | ||||
|         self.num_levels = num_levels | ||||
|         self.radius = radius | ||||
|         self.fmaps_pyramid = [] | ||||
|  | ||||
|         self.fmaps_pyramid.append(fmaps) | ||||
|         for i in range(self.num_levels - 1): | ||||
|             fmaps_ = fmaps.reshape(B * S, C, H, W) | ||||
|             fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2) | ||||
|             _, _, H, W = fmaps_.shape | ||||
|             fmaps = fmaps_.reshape(B, S, C, H, W) | ||||
|             self.fmaps_pyramid.append(fmaps) | ||||
|  | ||||
|     def sample(self, coords): | ||||
|         r = self.radius | ||||
|         B, S, N, D = coords.shape | ||||
|         assert D == 2 | ||||
|  | ||||
|         H, W = self.H, self.W | ||||
|         out_pyramid = [] | ||||
|         for i in range(self.num_levels): | ||||
|             corrs = self.corrs_pyramid[i]  # B, S, N, H, W | ||||
|             _, _, _, H, W = corrs.shape | ||||
|  | ||||
|             dx = torch.linspace(-r, r, 2 * r + 1) | ||||
|             dy = torch.linspace(-r, r, 2 * r + 1) | ||||
|             delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to( | ||||
|                 coords.device | ||||
|             ) | ||||
|  | ||||
|             centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2 ** i | ||||
|             delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) | ||||
|             coords_lvl = centroid_lvl + delta_lvl | ||||
|  | ||||
|             corrs = bilinear_sampler(corrs.reshape(B * S * N, 1, H, W), coords_lvl) | ||||
|             corrs = corrs.view(B, S, N, -1) | ||||
|             out_pyramid.append(corrs) | ||||
|  | ||||
|         out = torch.cat(out_pyramid, dim=-1)  # B, S, N, LRR*2 | ||||
|         return out.contiguous().float() | ||||
|  | ||||
|     def corr(self, targets): | ||||
|         B, S, N, C = targets.shape | ||||
|         assert C == self.C | ||||
|         assert S == self.S | ||||
|  | ||||
|         fmap1 = targets | ||||
|  | ||||
|         self.corrs_pyramid = [] | ||||
|         for fmaps in self.fmaps_pyramid: | ||||
|             _, _, _, H, W = fmaps.shape | ||||
|             fmap2s = fmaps.view(B, S, C, H * W) | ||||
|             corrs = torch.matmul(fmap1, fmap2s) | ||||
|             corrs = corrs.view(B, S, N, H, W) | ||||
|             corrs = corrs / torch.sqrt(torch.tensor(C).float()) | ||||
|             self.corrs_pyramid.append(corrs) | ||||
|  | ||||
|  | ||||
| class UpdateFormer(nn.Module): | ||||
|     """ | ||||
|     Transformer model that updates track estimates. | ||||
|     """ | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         space_depth=12, | ||||
|         time_depth=12, | ||||
|         input_dim=320, | ||||
|         hidden_size=384, | ||||
|         num_heads=8, | ||||
|         output_dim=130, | ||||
|         mlp_ratio=4.0, | ||||
|         add_space_attn=True, | ||||
|     ): | ||||
|         super().__init__() | ||||
|         self.out_channels = 2 | ||||
|         self.num_heads = num_heads | ||||
|         self.hidden_size = hidden_size | ||||
|         self.add_space_attn = add_space_attn | ||||
|         self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True) | ||||
|         self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True) | ||||
|  | ||||
|         self.time_blocks = nn.ModuleList( | ||||
|             [ | ||||
|                 AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) | ||||
|                 for _ in range(time_depth) | ||||
|             ] | ||||
|         ) | ||||
|  | ||||
|         if add_space_attn: | ||||
|             self.space_blocks = nn.ModuleList( | ||||
|                 [ | ||||
|                     AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) | ||||
|                     for _ in range(space_depth) | ||||
|                 ] | ||||
|             ) | ||||
|             assert len(self.time_blocks) >= len(self.space_blocks) | ||||
|         self.initialize_weights() | ||||
|  | ||||
|     def initialize_weights(self): | ||||
|         def _basic_init(module): | ||||
|             if isinstance(module, nn.Linear): | ||||
|                 torch.nn.init.xavier_uniform_(module.weight) | ||||
|                 if module.bias is not None: | ||||
|                     nn.init.constant_(module.bias, 0) | ||||
|  | ||||
|         self.apply(_basic_init) | ||||
|  | ||||
|     def forward(self, input_tensor): | ||||
|         x = self.input_transform(input_tensor) | ||||
|  | ||||
|         j = 0 | ||||
|         for i in range(len(self.time_blocks)): | ||||
|             B, N, T, _ = x.shape | ||||
|             x_time = rearrange(x, "b n t c -> (b n) t c", b=B, t=T, n=N) | ||||
|             x_time = self.time_blocks[i](x_time) | ||||
|  | ||||
|             x = rearrange(x_time, "(b n) t c -> b n t c ", b=B, t=T, n=N) | ||||
|             if self.add_space_attn and ( | ||||
|                 i % (len(self.time_blocks) // len(self.space_blocks)) == 0 | ||||
|             ): | ||||
|                 x_space = rearrange(x, "b n t c -> (b t) n c ", b=B, t=T, n=N) | ||||
|                 x_space = self.space_blocks[j](x_space) | ||||
|                 x = rearrange(x_space, "(b t) n c -> b n t c  ", b=B, t=T, n=N) | ||||
|                 j += 1 | ||||
|  | ||||
|         flow = self.flow_head(x) | ||||
|         return flow | ||||
							
								
								
									
										351
									
								
								cotracker/models/core/cotracker/cotracker.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										351
									
								
								cotracker/models/core/cotracker/cotracker.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,351 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from einops import rearrange | ||||
|  | ||||
| from cotracker.models.core.cotracker.blocks import ( | ||||
|     BasicEncoder, | ||||
|     CorrBlock, | ||||
|     UpdateFormer, | ||||
| ) | ||||
|  | ||||
| from cotracker.models.core.model_utils import meshgrid2d, bilinear_sample2d, smart_cat | ||||
| from cotracker.models.core.embeddings import ( | ||||
|     get_2d_embedding, | ||||
|     get_1d_sincos_pos_embed_from_grid, | ||||
|     get_2d_sincos_pos_embed, | ||||
| ) | ||||
|  | ||||
|  | ||||
| torch.manual_seed(0) | ||||
|  | ||||
|  | ||||
| def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0)): | ||||
|     if grid_size == 1: | ||||
|         return torch.tensor([interp_shape[1] / 2, interp_shape[0] / 2])[ | ||||
|             None, None | ||||
|         ].cuda() | ||||
|  | ||||
|     grid_y, grid_x = meshgrid2d( | ||||
|         1, grid_size, grid_size, stack=False, norm=False, device="cuda" | ||||
|     ) | ||||
|     step = interp_shape[1] // 64 | ||||
|     if grid_center[0] != 0 or grid_center[1] != 0: | ||||
|         grid_y = grid_y - grid_size / 2.0 | ||||
|         grid_x = grid_x - grid_size / 2.0 | ||||
|     grid_y = step + grid_y.reshape(1, -1) / float(grid_size - 1) * ( | ||||
|         interp_shape[0] - step * 2 | ||||
|     ) | ||||
|     grid_x = step + grid_x.reshape(1, -1) / float(grid_size - 1) * ( | ||||
|         interp_shape[1] - step * 2 | ||||
|     ) | ||||
|  | ||||
|     grid_y = grid_y + grid_center[0] | ||||
|     grid_x = grid_x + grid_center[1] | ||||
|     xy = torch.stack([grid_x, grid_y], dim=-1).cuda() | ||||
|     return xy | ||||
|  | ||||
|  | ||||
| def sample_pos_embed(grid_size, embed_dim, coords): | ||||
|     pos_embed = get_2d_sincos_pos_embed(embed_dim=embed_dim, grid_size=grid_size) | ||||
|     pos_embed = ( | ||||
|         torch.from_numpy(pos_embed) | ||||
|         .reshape(grid_size[0], grid_size[1], embed_dim) | ||||
|         .float() | ||||
|         .unsqueeze(0) | ||||
|         .to(coords.device) | ||||
|     ) | ||||
|     sampled_pos_embed = bilinear_sample2d( | ||||
|         pos_embed.permute(0, 3, 1, 2), coords[:, 0, :, 0], coords[:, 0, :, 1] | ||||
|     ) | ||||
|     return sampled_pos_embed | ||||
|  | ||||
|  | ||||
| class CoTracker(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         S=8, | ||||
|         stride=8, | ||||
|         add_space_attn=True, | ||||
|         num_heads=8, | ||||
|         hidden_size=384, | ||||
|         space_depth=12, | ||||
|         time_depth=12, | ||||
|     ): | ||||
|         super(CoTracker, self).__init__() | ||||
|         self.S = S | ||||
|         self.stride = stride | ||||
|         self.hidden_dim = 256 | ||||
|         self.latent_dim = latent_dim = 128 | ||||
|         self.corr_levels = 4 | ||||
|         self.corr_radius = 3 | ||||
|         self.add_space_attn = add_space_attn | ||||
|         self.fnet = BasicEncoder( | ||||
|             output_dim=self.latent_dim, norm_fn="instance", dropout=0, stride=stride | ||||
|         ) | ||||
|  | ||||
|         self.updateformer = UpdateFormer( | ||||
|             space_depth=space_depth, | ||||
|             time_depth=time_depth, | ||||
|             input_dim=456, | ||||
|             hidden_size=hidden_size, | ||||
|             num_heads=num_heads, | ||||
|             output_dim=latent_dim + 2, | ||||
|             mlp_ratio=4.0, | ||||
|             add_space_attn=add_space_attn, | ||||
|         ) | ||||
|  | ||||
|         self.norm = nn.GroupNorm(1, self.latent_dim) | ||||
|         self.ffeat_updater = nn.Sequential( | ||||
|             nn.Linear(self.latent_dim, self.latent_dim), | ||||
|             nn.GELU(), | ||||
|         ) | ||||
|         self.vis_predictor = nn.Sequential( | ||||
|             nn.Linear(self.latent_dim, 1), | ||||
|         ) | ||||
|  | ||||
|     def forward_iteration( | ||||
|         self, | ||||
|         fmaps, | ||||
|         coords_init, | ||||
|         feat_init=None, | ||||
|         vis_init=None, | ||||
|         track_mask=None, | ||||
|         iters=4, | ||||
|     ): | ||||
|         B, S_init, N, D = coords_init.shape | ||||
|         assert D == 2 | ||||
|         assert B == 1 | ||||
|  | ||||
|         B, S, __, H8, W8 = fmaps.shape | ||||
|  | ||||
|         device = fmaps.device | ||||
|  | ||||
|         if S_init < S: | ||||
|             coords = torch.cat( | ||||
|                 [coords_init, coords_init[:, -1].repeat(1, S - S_init, 1, 1)], dim=1 | ||||
|             ) | ||||
|             vis_init = torch.cat( | ||||
|                 [vis_init, vis_init[:, -1].repeat(1, S - S_init, 1, 1)], dim=1 | ||||
|             ) | ||||
|         else: | ||||
|             coords = coords_init.clone() | ||||
|  | ||||
|         fcorr_fn = CorrBlock( | ||||
|             fmaps, num_levels=self.corr_levels, radius=self.corr_radius | ||||
|         ) | ||||
|  | ||||
|         ffeats = feat_init.clone() | ||||
|  | ||||
|         times_ = torch.linspace(0, S - 1, S).reshape(1, S, 1) | ||||
|  | ||||
|         pos_embed = sample_pos_embed( | ||||
|             grid_size=(H8, W8), | ||||
|             embed_dim=456, | ||||
|             coords=coords, | ||||
|         ) | ||||
|         pos_embed = rearrange(pos_embed, "b e n -> (b n) e").unsqueeze(1) | ||||
|         times_embed = ( | ||||
|             torch.from_numpy(get_1d_sincos_pos_embed_from_grid(456, times_[0]))[None] | ||||
|             .repeat(B, 1, 1) | ||||
|             .float() | ||||
|             .to(device) | ||||
|         ) | ||||
|         coord_predictions = [] | ||||
|  | ||||
|         for __ in range(iters): | ||||
|             coords = coords.detach() | ||||
|             fcorr_fn.corr(ffeats) | ||||
|  | ||||
|             fcorrs = fcorr_fn.sample(coords)  # B, S, N, LRR | ||||
|             LRR = fcorrs.shape[3] | ||||
|  | ||||
|             fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, LRR) | ||||
|             flows_ = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2) | ||||
|  | ||||
|             flows_cat = get_2d_embedding(flows_, 64, cat_coords=True) | ||||
|             ffeats_ = ffeats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim) | ||||
|  | ||||
|             if track_mask.shape[1] < vis_init.shape[1]: | ||||
|                 track_mask = torch.cat( | ||||
|                     [ | ||||
|                         track_mask, | ||||
|                         torch.zeros_like(track_mask[:, 0]).repeat( | ||||
|                             1, vis_init.shape[1] - track_mask.shape[1], 1, 1 | ||||
|                         ), | ||||
|                     ], | ||||
|                     dim=1, | ||||
|                 ) | ||||
|             concat = ( | ||||
|                 torch.cat([track_mask, vis_init], dim=2) | ||||
|                 .permute(0, 2, 1, 3) | ||||
|                 .reshape(B * N, S, 2) | ||||
|             ) | ||||
|  | ||||
|             transformer_input = torch.cat([flows_cat, fcorrs_, ffeats_, concat], dim=2) | ||||
|             x = transformer_input + pos_embed + times_embed | ||||
|  | ||||
|             x = rearrange(x, "(b n) t d -> b n t d", b=B) | ||||
|  | ||||
|             delta = self.updateformer(x) | ||||
|  | ||||
|             delta = rearrange(delta, " b n t d -> (b n) t d") | ||||
|  | ||||
|             delta_coords_ = delta[:, :, :2] | ||||
|             delta_feats_ = delta[:, :, 2:] | ||||
|  | ||||
|             delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim) | ||||
|             ffeats_ = ffeats.permute(0, 2, 1, 3).reshape(B * N * S, self.latent_dim) | ||||
|  | ||||
|             ffeats_ = self.ffeat_updater(self.norm(delta_feats_)) + ffeats_ | ||||
|  | ||||
|             ffeats = ffeats_.reshape(B, N, S, self.latent_dim).permute( | ||||
|                 0, 2, 1, 3 | ||||
|             )  # B,S,N,C | ||||
|  | ||||
|             coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3) | ||||
|             coord_predictions.append(coords * self.stride) | ||||
|  | ||||
|         vis_e = self.vis_predictor(ffeats.reshape(B * S * N, self.latent_dim)).reshape( | ||||
|             B, S, N | ||||
|         ) | ||||
|         return coord_predictions, vis_e, feat_init | ||||
|  | ||||
|     def forward(self, rgbs, queries, iters=4, feat_init=None, is_train=False): | ||||
|         B, T, C, H, W = rgbs.shape | ||||
|         B, N, __ = queries.shape | ||||
|  | ||||
|         device = rgbs.device | ||||
|         assert B == 1 | ||||
|         # INIT for the first sequence | ||||
|         # We want to sort points by the first frame they are visible to add them to the tensor of tracked points consequtively | ||||
|         first_positive_inds = queries[:, :, 0].long() | ||||
|  | ||||
|         __, sort_inds = torch.sort(first_positive_inds[0], dim=0, descending=False) | ||||
|         inv_sort_inds = torch.argsort(sort_inds, dim=0) | ||||
|         first_positive_sorted_inds = first_positive_inds[0][sort_inds] | ||||
|  | ||||
|         assert torch.allclose( | ||||
|             first_positive_inds[0], first_positive_inds[0][sort_inds][inv_sort_inds] | ||||
|         ) | ||||
|  | ||||
|         coords_init = queries[:, :, 1:].reshape(B, 1, N, 2).repeat( | ||||
|             1, self.S, 1, 1 | ||||
|         ) / float(self.stride) | ||||
|  | ||||
|         rgbs = 2 * (rgbs / 255.0) - 1.0 | ||||
|  | ||||
|         traj_e = torch.zeros((B, T, N, 2), device=device) | ||||
|         vis_e = torch.zeros((B, T, N), device=device) | ||||
|  | ||||
|         ind_array = torch.arange(T, device=device) | ||||
|         ind_array = ind_array[None, :, None].repeat(B, 1, N) | ||||
|  | ||||
|         track_mask = (ind_array >= first_positive_inds[:, None, :]).unsqueeze(-1) | ||||
|         # these are logits, so we initialize visibility with something that would give a value close to 1 after softmax | ||||
|         vis_init = torch.ones((B, self.S, N, 1), device=device).float() * 10 | ||||
|  | ||||
|         ind = 0 | ||||
|  | ||||
|         track_mask_ = track_mask[:, :, sort_inds].clone() | ||||
|         coords_init_ = coords_init[:, :, sort_inds].clone() | ||||
|         vis_init_ = vis_init[:, :, sort_inds].clone() | ||||
|  | ||||
|         prev_wind_idx = 0 | ||||
|         fmaps_ = None | ||||
|         vis_predictions = [] | ||||
|         coord_predictions = [] | ||||
|         wind_inds = [] | ||||
|         while ind < T - self.S // 2: | ||||
|             rgbs_seq = rgbs[:, ind : ind + self.S] | ||||
|  | ||||
|             S = S_local = rgbs_seq.shape[1] | ||||
|             if S < self.S: | ||||
|                 rgbs_seq = torch.cat( | ||||
|                     [rgbs_seq, rgbs_seq[:, -1, None].repeat(1, self.S - S, 1, 1, 1)], | ||||
|                     dim=1, | ||||
|                 ) | ||||
|                 S = rgbs_seq.shape[1] | ||||
|             rgbs_ = rgbs_seq.reshape(B * S, C, H, W) | ||||
|  | ||||
|             if fmaps_ is None: | ||||
|                 fmaps_ = self.fnet(rgbs_) | ||||
|             else: | ||||
|                 fmaps_ = torch.cat( | ||||
|                     [fmaps_[self.S // 2 :], self.fnet(rgbs_[self.S // 2 :])], dim=0 | ||||
|                 ) | ||||
|             fmaps = fmaps_.reshape( | ||||
|                 B, S, self.latent_dim, H // self.stride, W // self.stride | ||||
|             ) | ||||
|  | ||||
|             curr_wind_points = torch.nonzero(first_positive_sorted_inds < ind + self.S) | ||||
|             if curr_wind_points.shape[0] == 0: | ||||
|                 ind = ind + self.S // 2 | ||||
|                 continue | ||||
|             wind_idx = curr_wind_points[-1] + 1 | ||||
|  | ||||
|             if wind_idx - prev_wind_idx > 0: | ||||
|                 fmaps_sample = fmaps[ | ||||
|                     :, first_positive_sorted_inds[prev_wind_idx:wind_idx] - ind | ||||
|                 ] | ||||
|  | ||||
|                 feat_init_ = bilinear_sample2d( | ||||
|                     fmaps_sample, | ||||
|                     coords_init_[:, 0, prev_wind_idx:wind_idx, 0], | ||||
|                     coords_init_[:, 0, prev_wind_idx:wind_idx, 1], | ||||
|                 ).permute(0, 2, 1) | ||||
|  | ||||
|                 feat_init_ = feat_init_.unsqueeze(1).repeat(1, self.S, 1, 1) | ||||
|                 feat_init = smart_cat(feat_init, feat_init_, dim=2) | ||||
|  | ||||
|             if prev_wind_idx > 0: | ||||
|                 new_coords = coords[-1][:, self.S // 2 :] / float(self.stride) | ||||
|  | ||||
|                 coords_init_[:, : self.S // 2, :prev_wind_idx] = new_coords | ||||
|                 coords_init_[:, self.S // 2 :, :prev_wind_idx] = new_coords[ | ||||
|                     :, -1 | ||||
|                 ].repeat(1, self.S // 2, 1, 1) | ||||
|  | ||||
|                 new_vis = vis[:, self.S // 2 :].unsqueeze(-1) | ||||
|                 vis_init_[:, : self.S // 2, :prev_wind_idx] = new_vis | ||||
|                 vis_init_[:, self.S // 2 :, :prev_wind_idx] = new_vis[:, -1].repeat( | ||||
|                     1, self.S // 2, 1, 1 | ||||
|                 ) | ||||
|  | ||||
|             coords, vis, __ = self.forward_iteration( | ||||
|                 fmaps=fmaps, | ||||
|                 coords_init=coords_init_[:, :, :wind_idx], | ||||
|                 feat_init=feat_init[:, :, :wind_idx], | ||||
|                 vis_init=vis_init_[:, :, :wind_idx], | ||||
|                 track_mask=track_mask_[:, ind : ind + self.S, :wind_idx], | ||||
|                 iters=iters, | ||||
|             ) | ||||
|             if is_train: | ||||
|                 vis_predictions.append(torch.sigmoid(vis[:, :S_local])) | ||||
|                 coord_predictions.append([coord[:, :S_local] for coord in coords]) | ||||
|                 wind_inds.append(wind_idx) | ||||
|  | ||||
|             traj_e[:, ind : ind + self.S, :wind_idx] = coords[-1][:, :S_local] | ||||
|             vis_e[:, ind : ind + self.S, :wind_idx] = vis[:, :S_local] | ||||
|  | ||||
|             track_mask_[:, : ind + self.S, :wind_idx] = 0.0 | ||||
|             ind = ind + self.S // 2 | ||||
|  | ||||
|             prev_wind_idx = wind_idx | ||||
|  | ||||
|         traj_e = traj_e[:, :, inv_sort_inds] | ||||
|         vis_e = vis_e[:, :, inv_sort_inds] | ||||
|  | ||||
|         vis_e = torch.sigmoid(vis_e) | ||||
|  | ||||
|         train_data = ( | ||||
|             (vis_predictions, coord_predictions, wind_inds, sort_inds) | ||||
|             if is_train | ||||
|             else None | ||||
|         ) | ||||
|         return traj_e, feat_init, vis_e, train_data | ||||
							
								
								
									
										61
									
								
								cotracker/models/core/cotracker/losses.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								cotracker/models/core/cotracker/losses.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,61 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
| from cotracker.models.core.model_utils import reduce_masked_mean | ||||
|  | ||||
| EPS = 1e-6 | ||||
|  | ||||
|  | ||||
| def balanced_ce_loss(pred, gt, valid=None): | ||||
|     total_balanced_loss = 0.0 | ||||
|     for j in range(len(gt)): | ||||
|         B, S, N = gt[j].shape | ||||
|         # pred and gt are the same shape | ||||
|         for (a, b) in zip(pred[j].size(), gt[j].size()): | ||||
|             assert a == b  # some shape mismatch! | ||||
|         # if valid is not None: | ||||
|         for (a, b) in zip(pred[j].size(), valid[j].size()): | ||||
|             assert a == b  # some shape mismatch! | ||||
|  | ||||
|         pos = (gt[j] > 0.95).float() | ||||
|         neg = (gt[j] < 0.05).float() | ||||
|  | ||||
|         label = pos * 2.0 - 1.0 | ||||
|         a = -label * pred[j] | ||||
|         b = F.relu(a) | ||||
|         loss = b + torch.log(torch.exp(-b) + torch.exp(a - b)) | ||||
|  | ||||
|         pos_loss = reduce_masked_mean(loss, pos * valid[j]) | ||||
|         neg_loss = reduce_masked_mean(loss, neg * valid[j]) | ||||
|  | ||||
|         balanced_loss = pos_loss + neg_loss | ||||
|         total_balanced_loss += balanced_loss / float(N) | ||||
|     return total_balanced_loss | ||||
|  | ||||
|  | ||||
| def sequence_loss(flow_preds, flow_gt, vis, valids, gamma=0.8): | ||||
|     """Loss function defined over sequence of flow predictions""" | ||||
|     total_flow_loss = 0.0 | ||||
|     for j in range(len(flow_gt)): | ||||
|         B, S, N, D = flow_gt[j].shape | ||||
|         assert D == 2 | ||||
|         B, S1, N = vis[j].shape | ||||
|         B, S2, N = valids[j].shape | ||||
|         assert S == S1 | ||||
|         assert S == S2 | ||||
|         n_predictions = len(flow_preds[j]) | ||||
|         flow_loss = 0.0 | ||||
|         for i in range(n_predictions): | ||||
|             i_weight = gamma ** (n_predictions - i - 1) | ||||
|             flow_pred = flow_preds[j][i] | ||||
|             i_loss = (flow_pred - flow_gt[j]).abs()  # B, S, N, 2 | ||||
|             i_loss = torch.mean(i_loss, dim=3)  # B, S, N | ||||
|             flow_loss += i_weight * reduce_masked_mean(i_loss, valids[j]) | ||||
|         flow_loss = flow_loss / n_predictions | ||||
|         total_flow_loss += flow_loss / float(N) | ||||
|     return total_flow_loss | ||||
							
								
								
									
										154
									
								
								cotracker/models/core/embeddings.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										154
									
								
								cotracker/models/core/embeddings.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,154 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| import torch | ||||
| import numpy as np | ||||
|  | ||||
|  | ||||
| def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): | ||||
|     """ | ||||
|     grid_size: int of the grid height and width | ||||
|     return: | ||||
|     pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) | ||||
|     """ | ||||
|     if isinstance(grid_size, tuple): | ||||
|         grid_size_h, grid_size_w = grid_size | ||||
|     else: | ||||
|         grid_size_h = grid_size_w = grid_size | ||||
|     grid_h = np.arange(grid_size_h, dtype=np.float32) | ||||
|     grid_w = np.arange(grid_size_w, dtype=np.float32) | ||||
|     grid = np.meshgrid(grid_w, grid_h)  # here w goes first | ||||
|     grid = np.stack(grid, axis=0) | ||||
|  | ||||
|     grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) | ||||
|     pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) | ||||
|     if cls_token and extra_tokens > 0: | ||||
|         pos_embed = np.concatenate( | ||||
|             [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0 | ||||
|         ) | ||||
|     return pos_embed | ||||
|  | ||||
|  | ||||
| def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): | ||||
|     assert embed_dim % 2 == 0 | ||||
|  | ||||
|     # use half of dimensions to encode grid_h | ||||
|     emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2) | ||||
|     emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2) | ||||
|  | ||||
|     emb = np.concatenate([emb_h, emb_w], axis=1)  # (H*W, D) | ||||
|     return emb | ||||
|  | ||||
|  | ||||
| def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): | ||||
|     """ | ||||
|     embed_dim: output dimension for each position | ||||
|     pos: a list of positions to be encoded: size (M,) | ||||
|     out: (M, D) | ||||
|     """ | ||||
|     assert embed_dim % 2 == 0 | ||||
|     omega = np.arange(embed_dim // 2, dtype=np.float64) | ||||
|     omega /= embed_dim / 2.0 | ||||
|     omega = 1.0 / 10000 ** omega  # (D/2,) | ||||
|  | ||||
|     pos = pos.reshape(-1)  # (M,) | ||||
|     out = np.einsum("m,d->md", pos, omega)  # (M, D/2), outer product | ||||
|  | ||||
|     emb_sin = np.sin(out)  # (M, D/2) | ||||
|     emb_cos = np.cos(out)  # (M, D/2) | ||||
|  | ||||
|     emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D) | ||||
|     return emb | ||||
|  | ||||
|  | ||||
| def get_2d_embedding(xy, C, cat_coords=True): | ||||
|     B, N, D = xy.shape | ||||
|     assert D == 2 | ||||
|  | ||||
|     x = xy[:, :, 0:1] | ||||
|     y = xy[:, :, 1:2] | ||||
|     div_term = ( | ||||
|         torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C) | ||||
|     ).reshape(1, 1, int(C / 2)) | ||||
|  | ||||
|     pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) | ||||
|     pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) | ||||
|  | ||||
|     pe_x[:, :, 0::2] = torch.sin(x * div_term) | ||||
|     pe_x[:, :, 1::2] = torch.cos(x * div_term) | ||||
|  | ||||
|     pe_y[:, :, 0::2] = torch.sin(y * div_term) | ||||
|     pe_y[:, :, 1::2] = torch.cos(y * div_term) | ||||
|  | ||||
|     pe = torch.cat([pe_x, pe_y], dim=2)  # B, N, C*3 | ||||
|     if cat_coords: | ||||
|         pe = torch.cat([xy, pe], dim=2)  # B, N, C*3+3 | ||||
|     return pe | ||||
|  | ||||
|  | ||||
| def get_3d_embedding(xyz, C, cat_coords=True): | ||||
|     B, N, D = xyz.shape | ||||
|     assert D == 3 | ||||
|  | ||||
|     x = xyz[:, :, 0:1] | ||||
|     y = xyz[:, :, 1:2] | ||||
|     z = xyz[:, :, 2:3] | ||||
|     div_term = ( | ||||
|         torch.arange(0, C, 2, device=xyz.device, dtype=torch.float32) * (1000.0 / C) | ||||
|     ).reshape(1, 1, int(C / 2)) | ||||
|  | ||||
|     pe_x = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32) | ||||
|     pe_y = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32) | ||||
|     pe_z = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32) | ||||
|  | ||||
|     pe_x[:, :, 0::2] = torch.sin(x * div_term) | ||||
|     pe_x[:, :, 1::2] = torch.cos(x * div_term) | ||||
|  | ||||
|     pe_y[:, :, 0::2] = torch.sin(y * div_term) | ||||
|     pe_y[:, :, 1::2] = torch.cos(y * div_term) | ||||
|  | ||||
|     pe_z[:, :, 0::2] = torch.sin(z * div_term) | ||||
|     pe_z[:, :, 1::2] = torch.cos(z * div_term) | ||||
|  | ||||
|     pe = torch.cat([pe_x, pe_y, pe_z], dim=2)  # B, N, C*3 | ||||
|     if cat_coords: | ||||
|         pe = torch.cat([pe, xyz], dim=2)  # B, N, C*3+3 | ||||
|     return pe | ||||
|  | ||||
|  | ||||
| def get_4d_embedding(xyzw, C, cat_coords=True): | ||||
|     B, N, D = xyzw.shape | ||||
|     assert D == 4 | ||||
|  | ||||
|     x = xyzw[:, :, 0:1] | ||||
|     y = xyzw[:, :, 1:2] | ||||
|     z = xyzw[:, :, 2:3] | ||||
|     w = xyzw[:, :, 3:4] | ||||
|     div_term = ( | ||||
|         torch.arange(0, C, 2, device=xyzw.device, dtype=torch.float32) * (1000.0 / C) | ||||
|     ).reshape(1, 1, int(C / 2)) | ||||
|  | ||||
|     pe_x = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32) | ||||
|     pe_y = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32) | ||||
|     pe_z = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32) | ||||
|     pe_w = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32) | ||||
|  | ||||
|     pe_x[:, :, 0::2] = torch.sin(x * div_term) | ||||
|     pe_x[:, :, 1::2] = torch.cos(x * div_term) | ||||
|  | ||||
|     pe_y[:, :, 0::2] = torch.sin(y * div_term) | ||||
|     pe_y[:, :, 1::2] = torch.cos(y * div_term) | ||||
|  | ||||
|     pe_z[:, :, 0::2] = torch.sin(z * div_term) | ||||
|     pe_z[:, :, 1::2] = torch.cos(z * div_term) | ||||
|  | ||||
|     pe_w[:, :, 0::2] = torch.sin(w * div_term) | ||||
|     pe_w[:, :, 1::2] = torch.cos(w * div_term) | ||||
|  | ||||
|     pe = torch.cat([pe_x, pe_y, pe_z, pe_w], dim=2)  # B, N, C*3 | ||||
|     if cat_coords: | ||||
|         pe = torch.cat([pe, xyzw], dim=2)  # B, N, C*3+3 | ||||
|     return pe | ||||
							
								
								
									
										169
									
								
								cotracker/models/core/model_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										169
									
								
								cotracker/models/core/model_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,169 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| import torch | ||||
|  | ||||
| EPS = 1e-6 | ||||
|  | ||||
|  | ||||
| def smart_cat(tensor1, tensor2, dim): | ||||
|     if tensor1 is None: | ||||
|         return tensor2 | ||||
|     return torch.cat([tensor1, tensor2], dim=dim) | ||||
|  | ||||
|  | ||||
| def normalize_single(d): | ||||
|     # d is a whatever shape torch tensor | ||||
|     dmin = torch.min(d) | ||||
|     dmax = torch.max(d) | ||||
|     d = (d - dmin) / (EPS + (dmax - dmin)) | ||||
|     return d | ||||
|  | ||||
|  | ||||
| def normalize(d): | ||||
|     # d is B x whatever. normalize within each element of the batch | ||||
|     out = torch.zeros(d.size()) | ||||
|     if d.is_cuda: | ||||
|         out = out.cuda() | ||||
|     B = list(d.size())[0] | ||||
|     for b in list(range(B)): | ||||
|         out[b] = normalize_single(d[b]) | ||||
|     return out | ||||
|  | ||||
|  | ||||
| def meshgrid2d(B, Y, X, stack=False, norm=False, device="cuda"): | ||||
|     # returns a meshgrid sized B x Y x X | ||||
|  | ||||
|     grid_y = torch.linspace(0.0, Y - 1, Y, device=torch.device(device)) | ||||
|     grid_y = torch.reshape(grid_y, [1, Y, 1]) | ||||
|     grid_y = grid_y.repeat(B, 1, X) | ||||
|  | ||||
|     grid_x = torch.linspace(0.0, X - 1, X, device=torch.device(device)) | ||||
|     grid_x = torch.reshape(grid_x, [1, 1, X]) | ||||
|     grid_x = grid_x.repeat(B, Y, 1) | ||||
|  | ||||
|     if stack: | ||||
|         # note we stack in xy order | ||||
|         # (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample) | ||||
|         grid = torch.stack([grid_x, grid_y], dim=-1) | ||||
|         return grid | ||||
|     else: | ||||
|         return grid_y, grid_x | ||||
|  | ||||
|  | ||||
| def reduce_masked_mean(x, mask, dim=None, keepdim=False): | ||||
|     # x and mask are the same shape, or at least broadcastably so < actually it's safer if you disallow broadcasting | ||||
|     # returns shape-1 | ||||
|     # axis can be a list of axes | ||||
|     for (a, b) in zip(x.size(), mask.size()): | ||||
|         assert a == b  # some shape mismatch! | ||||
|     prod = x * mask | ||||
|     if dim is None: | ||||
|         numer = torch.sum(prod) | ||||
|         denom = EPS + torch.sum(mask) | ||||
|     else: | ||||
|         numer = torch.sum(prod, dim=dim, keepdim=keepdim) | ||||
|         denom = EPS + torch.sum(mask, dim=dim, keepdim=keepdim) | ||||
|  | ||||
|     mean = numer / denom | ||||
|     return mean | ||||
|  | ||||
|  | ||||
| def bilinear_sample2d(im, x, y, return_inbounds=False): | ||||
|     # x and y are each B, N | ||||
|     # output is B, C, N | ||||
|     if len(im.shape) == 5: | ||||
|         B, N, C, H, W = list(im.shape) | ||||
|     else: | ||||
|         B, C, H, W = list(im.shape) | ||||
|     N = list(x.shape)[1] | ||||
|  | ||||
|     x = x.float() | ||||
|     y = y.float() | ||||
|     H_f = torch.tensor(H, dtype=torch.float32) | ||||
|     W_f = torch.tensor(W, dtype=torch.float32) | ||||
|  | ||||
|     # inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x<W_f+0.5).float()*(y<H_f+0.5).float() | ||||
|  | ||||
|     max_y = (H_f - 1).int() | ||||
|     max_x = (W_f - 1).int() | ||||
|  | ||||
|     x0 = torch.floor(x).int() | ||||
|     x1 = x0 + 1 | ||||
|     y0 = torch.floor(y).int() | ||||
|     y1 = y0 + 1 | ||||
|  | ||||
|     x0_clip = torch.clamp(x0, 0, max_x) | ||||
|     x1_clip = torch.clamp(x1, 0, max_x) | ||||
|     y0_clip = torch.clamp(y0, 0, max_y) | ||||
|     y1_clip = torch.clamp(y1, 0, max_y) | ||||
|     dim2 = W | ||||
|     dim1 = W * H | ||||
|  | ||||
|     base = torch.arange(0, B, dtype=torch.int64, device=x.device) * dim1 | ||||
|     base = torch.reshape(base, [B, 1]).repeat([1, N]) | ||||
|  | ||||
|     base_y0 = base + y0_clip * dim2 | ||||
|     base_y1 = base + y1_clip * dim2 | ||||
|  | ||||
|     idx_y0_x0 = base_y0 + x0_clip | ||||
|     idx_y0_x1 = base_y0 + x1_clip | ||||
|     idx_y1_x0 = base_y1 + x0_clip | ||||
|     idx_y1_x1 = base_y1 + x1_clip | ||||
|  | ||||
|     # use the indices to lookup pixels in the flat image | ||||
|     # im is B x C x H x W | ||||
|     # move C out to last dim | ||||
|     if len(im.shape) == 5: | ||||
|         im_flat = (im.permute(0, 3, 4, 1, 2)).reshape(B * H * W, N, C) | ||||
|         i_y0_x0 = torch.diagonal(im_flat[idx_y0_x0.long()], dim1=1, dim2=2).permute( | ||||
|             0, 2, 1 | ||||
|         ) | ||||
|         i_y0_x1 = torch.diagonal(im_flat[idx_y0_x1.long()], dim1=1, dim2=2).permute( | ||||
|             0, 2, 1 | ||||
|         ) | ||||
|         i_y1_x0 = torch.diagonal(im_flat[idx_y1_x0.long()], dim1=1, dim2=2).permute( | ||||
|             0, 2, 1 | ||||
|         ) | ||||
|         i_y1_x1 = torch.diagonal(im_flat[idx_y1_x1.long()], dim1=1, dim2=2).permute( | ||||
|             0, 2, 1 | ||||
|         ) | ||||
|     else: | ||||
|         im_flat = (im.permute(0, 2, 3, 1)).reshape(B * H * W, C) | ||||
|         i_y0_x0 = im_flat[idx_y0_x0.long()] | ||||
|         i_y0_x1 = im_flat[idx_y0_x1.long()] | ||||
|         i_y1_x0 = im_flat[idx_y1_x0.long()] | ||||
|         i_y1_x1 = im_flat[idx_y1_x1.long()] | ||||
|  | ||||
|     # Finally calculate interpolated values. | ||||
|     x0_f = x0.float() | ||||
|     x1_f = x1.float() | ||||
|     y0_f = y0.float() | ||||
|     y1_f = y1.float() | ||||
|  | ||||
|     w_y0_x0 = ((x1_f - x) * (y1_f - y)).unsqueeze(2) | ||||
|     w_y0_x1 = ((x - x0_f) * (y1_f - y)).unsqueeze(2) | ||||
|     w_y1_x0 = ((x1_f - x) * (y - y0_f)).unsqueeze(2) | ||||
|     w_y1_x1 = ((x - x0_f) * (y - y0_f)).unsqueeze(2) | ||||
|  | ||||
|     output = ( | ||||
|         w_y0_x0 * i_y0_x0 + w_y0_x1 * i_y0_x1 + w_y1_x0 * i_y1_x0 + w_y1_x1 * i_y1_x1 | ||||
|     ) | ||||
|     # output is B*N x C | ||||
|     output = output.view(B, -1, C) | ||||
|     output = output.permute(0, 2, 1) | ||||
|     # output is B x C x N | ||||
|  | ||||
|     if return_inbounds: | ||||
|         x_valid = (x > -0.5).byte() & (x < float(W_f - 0.5)).byte() | ||||
|         y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte() | ||||
|         inbounds = (x_valid & y_valid).float() | ||||
|         inbounds = inbounds.reshape( | ||||
|             B, N | ||||
|         )  # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1) | ||||
|         return output, inbounds | ||||
|  | ||||
|     return output  # B, C, N | ||||
							
								
								
									
										103
									
								
								cotracker/models/evaluation_predictor.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										103
									
								
								cotracker/models/evaluation_predictor.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,103 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
| from typing import Tuple | ||||
|  | ||||
| from cotracker.models.core.cotracker.cotracker import CoTracker, get_points_on_a_grid | ||||
|  | ||||
|  | ||||
| class EvaluationPredictor(torch.nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         cotracker_model: CoTracker, | ||||
|         interp_shape: Tuple[int, int] = (384, 512), | ||||
|         grid_size: int = 6, | ||||
|         local_grid_size: int = 6, | ||||
|         single_point: bool = True, | ||||
|         n_iters: int = 6, | ||||
|     ) -> None: | ||||
|         super(EvaluationPredictor, self).__init__() | ||||
|         self.grid_size = grid_size | ||||
|         self.local_grid_size = local_grid_size | ||||
|         self.single_point = single_point | ||||
|         self.interp_shape = interp_shape | ||||
|         self.n_iters = n_iters | ||||
|  | ||||
|         self.model = cotracker_model | ||||
|         self.model.to("cuda") | ||||
|         self.model.eval() | ||||
|  | ||||
|     def forward(self, video, queries): | ||||
|         queries = queries.clone().cuda() | ||||
|         B, T, C, H, W = video.shape | ||||
|         B, N, D = queries.shape | ||||
|  | ||||
|         assert D == 3 | ||||
|         assert B == 1 | ||||
|  | ||||
|         rgbs = video.reshape(B * T, C, H, W) | ||||
|         rgbs = F.interpolate(rgbs, tuple(self.interp_shape), mode="bilinear") | ||||
|         rgbs = rgbs.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1]).cuda() | ||||
|  | ||||
|         queries[:, :, 1] *= self.interp_shape[1] / W | ||||
|         queries[:, :, 2] *= self.interp_shape[0] / H | ||||
|  | ||||
|         if self.single_point: | ||||
|             traj_e = torch.zeros((B, T, N, 2)).cuda() | ||||
|             vis_e = torch.zeros((B, T, N)).cuda() | ||||
|             for pind in range((N)): | ||||
|                 query = queries[:, pind : pind + 1] | ||||
|  | ||||
|                 t = query[0, 0, 0].long() | ||||
|  | ||||
|                 traj_e_pind, vis_e_pind = self._process_one_point(rgbs, query) | ||||
|                 traj_e[:, t:, pind : pind + 1] = traj_e_pind[:, :, :1] | ||||
|                 vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1] | ||||
|         else: | ||||
|             if self.grid_size > 0: | ||||
|                 xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:]) | ||||
|                 xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).cuda()  # | ||||
|                 queries = torch.cat([queries, xy], dim=1)  # | ||||
|  | ||||
|             traj_e, __, vis_e, __ = self.model( | ||||
|                 rgbs=rgbs, | ||||
|                 queries=queries, | ||||
|                 iters=self.n_iters, | ||||
|             ) | ||||
|  | ||||
|         traj_e[:, :, :, 0] *= W / float(self.interp_shape[1]) | ||||
|         traj_e[:, :, :, 1] *= H / float(self.interp_shape[0]) | ||||
|         return traj_e, vis_e | ||||
|  | ||||
|     def _process_one_point(self, rgbs, query): | ||||
|         t = query[0, 0, 0].long() | ||||
|  | ||||
|         device = rgbs.device | ||||
|         if self.local_grid_size > 0: | ||||
|             xy_target = get_points_on_a_grid( | ||||
|                 self.local_grid_size, | ||||
|                 (50, 50), | ||||
|                 [query[0, 0, 2], query[0, 0, 1]], | ||||
|             ) | ||||
|  | ||||
|             xy_target = torch.cat( | ||||
|                 [torch.zeros_like(xy_target[:, :, :1]), xy_target], dim=2 | ||||
|             )  # | ||||
|             query = torch.cat([query, xy_target], dim=1).to(device)  # | ||||
|  | ||||
|         if self.grid_size > 0: | ||||
|             xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:]) | ||||
|             xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).cuda()  # | ||||
|             query = torch.cat([query, xy], dim=1).to(device)  # | ||||
|         # crop the video to start from the queried frame | ||||
|         query[0, 0, 0] = 0 | ||||
|         traj_e_pind, __, vis_e_pind, __ = self.model( | ||||
|             rgbs=rgbs[:, t:], queries=query, iters=self.n_iters | ||||
|         ) | ||||
|  | ||||
|         return traj_e_pind, vis_e_pind | ||||
							
								
								
									
										178
									
								
								cotracker/predictor.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										178
									
								
								cotracker/predictor.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,178 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
|  | ||||
| from tqdm import tqdm | ||||
| from cotracker.models.core.cotracker.cotracker import get_points_on_a_grid | ||||
| from cotracker.models.core.model_utils import smart_cat | ||||
| from cotracker.models.build_cotracker import ( | ||||
|     build_cotracker, | ||||
| ) | ||||
|  | ||||
|  | ||||
| class CoTrackerPredictor(torch.nn.Module): | ||||
|     def __init__( | ||||
|         self, checkpoint="cotracker/checkpoints/cotracker_stride_4_wind_8.pth" | ||||
|     ): | ||||
|         super().__init__() | ||||
|         self.interp_shape = (384, 512) | ||||
|         self.support_grid_size = 6 | ||||
|         model = build_cotracker(checkpoint) | ||||
|  | ||||
|         self.model = model | ||||
|         self.model.to("cuda") | ||||
|         self.model.eval() | ||||
|  | ||||
|     @torch.no_grad() | ||||
|     def forward( | ||||
|         self, | ||||
|         video,  # (1, T, 3, H, W) | ||||
|         # input prompt types: | ||||
|         # - None. Dense tracks are computed in this case. You can adjust *query_frame* to compute tracks starting from a specific frame. | ||||
|         # *backward_tracking=True* will compute tracks in both directions. | ||||
|         # - queries. Queried points of shape (1, N, 3) in format (t, x, y) for frame index and pixel coordinates. | ||||
|         # - grid_size. Grid of N*N points from the first frame. if segm_mask is provided, then computed only for the mask. | ||||
|         # You can adjust *query_frame* and *backward_tracking* for the regular grid in the same way as for dense tracks. | ||||
|         queries: torch.Tensor = None, | ||||
|         segm_mask: torch.Tensor = None,  # Segmentation mask of shape (B, 1, H, W) | ||||
|         grid_size: int = 0, | ||||
|         grid_query_frame: int = 0,  # only for dense and regular grid tracks | ||||
|         backward_tracking: bool = False, | ||||
|     ): | ||||
|  | ||||
|         if queries is None and grid_size == 0: | ||||
|             tracks, visibilities = self._compute_dense_tracks( | ||||
|                 video, | ||||
|                 grid_query_frame=grid_query_frame, | ||||
|                 backward_tracking=backward_tracking, | ||||
|             ) | ||||
|         else: | ||||
|             tracks, visibilities = self._compute_sparse_tracks( | ||||
|                 video, | ||||
|                 queries, | ||||
|                 segm_mask, | ||||
|                 grid_size, | ||||
|                 add_support_grid=(grid_size == 0 or segm_mask is not None), | ||||
|                 grid_query_frame=grid_query_frame, | ||||
|                 backward_tracking=backward_tracking, | ||||
|             ) | ||||
|  | ||||
|         return tracks, visibilities | ||||
|  | ||||
|     def _compute_dense_tracks( | ||||
|         self, video, grid_query_frame, grid_size=50, backward_tracking=False | ||||
|     ): | ||||
|         *_, H, W = video.shape | ||||
|         grid_step = W // grid_size | ||||
|         grid_width = W // grid_step | ||||
|         grid_height = H // grid_step | ||||
|         tracks = visibilities = None | ||||
|         grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to("cuda") | ||||
|         grid_pts[0, :, 0] = grid_query_frame | ||||
|         for offset in tqdm(range(grid_step * grid_step)): | ||||
|             ox = offset % grid_step | ||||
|             oy = offset // grid_step | ||||
|             grid_pts[0, :, 1] = ( | ||||
|                 torch.arange(grid_width).repeat(grid_height) * grid_step + ox | ||||
|             ) | ||||
|             grid_pts[0, :, 2] = ( | ||||
|                 torch.arange(grid_height).repeat_interleave(grid_width) * grid_step + oy | ||||
|             ) | ||||
|             tracks_step, visibilities_step = self._compute_sparse_tracks( | ||||
|                 video=video, | ||||
|                 queries=grid_pts, | ||||
|                 backward_tracking=backward_tracking, | ||||
|             ) | ||||
|             tracks = smart_cat(tracks, tracks_step, dim=2) | ||||
|             visibilities = smart_cat(visibilities, visibilities_step, dim=2) | ||||
|  | ||||
|         return tracks, visibilities | ||||
|  | ||||
|     def _compute_sparse_tracks( | ||||
|         self, | ||||
|         video, | ||||
|         queries, | ||||
|         segm_mask=None, | ||||
|         grid_size=0, | ||||
|         add_support_grid=False, | ||||
|         grid_query_frame=0, | ||||
|         backward_tracking=False, | ||||
|     ): | ||||
|         B, T, C, H, W = video.shape | ||||
|         assert B == 1 | ||||
|  | ||||
|         video = video.reshape(B * T, C, H, W) | ||||
|         video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear").cuda() | ||||
|         video = video.reshape( | ||||
|             B, T, 3, self.interp_shape[0], self.interp_shape[1] | ||||
|         ).cuda() | ||||
|  | ||||
|         if queries is not None: | ||||
|             queries = queries.clone() | ||||
|             B, N, D = queries.shape | ||||
|             assert D == 3 | ||||
|             queries[:, :, 1] *= self.interp_shape[1] / W | ||||
|             queries[:, :, 2] *= self.interp_shape[0] / H | ||||
|         elif grid_size > 0: | ||||
|             grid_pts = get_points_on_a_grid(grid_size, self.interp_shape) | ||||
|             if segm_mask is not None: | ||||
|                 segm_mask = F.interpolate( | ||||
|                     segm_mask, tuple(self.interp_shape), mode="nearest" | ||||
|                 ) | ||||
|                 point_mask = segm_mask[0, 0][ | ||||
|                     (grid_pts[0, :, 1]).round().long().cpu(), | ||||
|                     (grid_pts[0, :, 0]).round().long().cpu(), | ||||
|                 ].bool() | ||||
|                 grid_pts = grid_pts[:, point_mask] | ||||
|  | ||||
|             queries = torch.cat( | ||||
|                 [torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts], | ||||
|                 dim=2, | ||||
|             ) | ||||
|  | ||||
|         if add_support_grid: | ||||
|             grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape) | ||||
|             grid_pts = torch.cat( | ||||
|                 [torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2 | ||||
|             ) | ||||
|             queries = torch.cat([queries, grid_pts], dim=1) | ||||
|  | ||||
|         tracks, __, visibilities, __ = self.model(rgbs=video, queries=queries, iters=6) | ||||
|  | ||||
|         if backward_tracking: | ||||
|             tracks, visibilities = self._compute_backward_tracks( | ||||
|                 video, queries, tracks, visibilities | ||||
|             ) | ||||
|             if add_support_grid: | ||||
|                 queries[:, -self.support_grid_size ** 2 :, 0] = T - 1 | ||||
|         if add_support_grid: | ||||
|             tracks = tracks[:, :, : -self.support_grid_size ** 2] | ||||
|             visibilities = visibilities[:, :, : -self.support_grid_size ** 2] | ||||
|         thr = 0.9 | ||||
|         visibilities = visibilities > thr | ||||
|         tracks[:, :, :, 0] *= W / float(self.interp_shape[1]) | ||||
|         tracks[:, :, :, 1] *= H / float(self.interp_shape[0]) | ||||
|         return tracks, visibilities | ||||
|  | ||||
|     def _compute_backward_tracks(self, video, queries, tracks, visibilities): | ||||
|         inv_video = video.flip(1).clone() | ||||
|         inv_queries = queries.clone() | ||||
|         inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1 | ||||
|  | ||||
|         inv_tracks, __, inv_visibilities, __ = self.model( | ||||
|             rgbs=inv_video, queries=inv_queries, iters=6 | ||||
|         ) | ||||
|  | ||||
|         inv_tracks = inv_tracks.flip(1) | ||||
|         inv_visibilities = inv_visibilities.flip(1) | ||||
|  | ||||
|         mask = tracks == 0 | ||||
|  | ||||
|         tracks[mask] = inv_tracks[mask] | ||||
|         visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]] | ||||
|         return tracks, visibilities | ||||
							
								
								
									
										5
									
								
								cotracker/utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								cotracker/utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
							
								
								
									
										291
									
								
								cotracker/utils/visualizer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										291
									
								
								cotracker/utils/visualizer.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,291 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| import os | ||||
| import numpy as np | ||||
| import cv2 | ||||
| import torch | ||||
| import flow_vis | ||||
|  | ||||
| from matplotlib import cm | ||||
| import torch.nn.functional as F | ||||
| import torchvision.transforms as transforms | ||||
| from moviepy.editor import ImageSequenceClip | ||||
| from torch.utils.tensorboard import SummaryWriter | ||||
| import matplotlib.pyplot as plt | ||||
|  | ||||
|  | ||||
| class Visualizer: | ||||
|     def __init__( | ||||
|         self, | ||||
|         save_dir: str = "./results", | ||||
|         grayscale: bool = False, | ||||
|         pad_value: int = 0, | ||||
|         fps: int = 10, | ||||
|         mode: str = "rainbow",  # 'cool', 'optical_flow' | ||||
|         linewidth: int = 2, | ||||
|         show_first_frame: int = 10, | ||||
|         tracks_leave_trace: int = 0,  # -1 for infinite | ||||
|     ): | ||||
|         self.mode = mode | ||||
|         self.save_dir = save_dir | ||||
|         if mode == "rainbow": | ||||
|             self.color_map = cm.get_cmap("gist_rainbow") | ||||
|         elif mode == "cool": | ||||
|             self.color_map = cm.get_cmap(mode) | ||||
|         self.show_first_frame = show_first_frame | ||||
|         self.grayscale = grayscale | ||||
|         self.tracks_leave_trace = tracks_leave_trace | ||||
|         self.pad_value = pad_value | ||||
|         self.linewidth = linewidth | ||||
|         self.fps = fps | ||||
|  | ||||
|     def visualize( | ||||
|         self, | ||||
|         video: torch.Tensor,  # (B,T,C,H,W) | ||||
|         tracks: torch.Tensor,  # (B,T,N,2) | ||||
|         gt_tracks: torch.Tensor = None,  # (B,T,N,2) | ||||
|         segm_mask: torch.Tensor = None,  # (B,1,H,W) | ||||
|         filename: str = "video", | ||||
|         writer: SummaryWriter = None, | ||||
|         step: int = 0, | ||||
|         query_frame: int = 0, | ||||
|         save_video: bool = True, | ||||
|         compensate_for_camera_motion: bool = False, | ||||
|     ): | ||||
|         if compensate_for_camera_motion: | ||||
|             assert segm_mask is not None | ||||
|         if segm_mask is not None: | ||||
|             coords = tracks[0, query_frame].round().long() | ||||
|             segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long() | ||||
|  | ||||
|         video = F.pad( | ||||
|             video, | ||||
|             (self.pad_value, self.pad_value, self.pad_value, self.pad_value), | ||||
|             "constant", | ||||
|             255, | ||||
|         ) | ||||
|         tracks = tracks + self.pad_value | ||||
|  | ||||
|         if self.grayscale: | ||||
|             transform = transforms.Grayscale() | ||||
|             video = transform(video) | ||||
|             video = video.repeat(1, 1, 3, 1, 1) | ||||
|  | ||||
|         res_video = self.draw_tracks_on_video( | ||||
|             video=video, | ||||
|             tracks=tracks, | ||||
|             segm_mask=segm_mask, | ||||
|             gt_tracks=gt_tracks, | ||||
|             query_frame=query_frame, | ||||
|             compensate_for_camera_motion=compensate_for_camera_motion, | ||||
|         ) | ||||
|         if save_video: | ||||
|             self.save_video(res_video, filename=filename, writer=writer, step=step) | ||||
|         return res_video | ||||
|  | ||||
|     def save_video(self, video, filename, writer=None, step=0): | ||||
|         if writer is not None: | ||||
|             writer.add_video( | ||||
|                 f"{filename}_pred_track", | ||||
|                 video.to(torch.uint8), | ||||
|                 global_step=step, | ||||
|                 fps=self.fps, | ||||
|             ) | ||||
|         else: | ||||
|             os.makedirs(self.save_dir, exist_ok=True) | ||||
|             wide_list = list(video.unbind(1)) | ||||
|             wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list] | ||||
|             clip = ImageSequenceClip(wide_list[2:-1], fps=self.fps) | ||||
|  | ||||
|             # Write the video file | ||||
|             save_path = os.path.join(self.save_dir, f"{filename}_pred_track.mp4") | ||||
|             clip.write_videofile(save_path, codec="libx264", fps=self.fps, logger=None) | ||||
|  | ||||
|             print(f"Video saved to {save_path}") | ||||
|  | ||||
|     def draw_tracks_on_video( | ||||
|         self, | ||||
|         video: torch.Tensor, | ||||
|         tracks: torch.Tensor, | ||||
|         segm_mask: torch.Tensor = None, | ||||
|         gt_tracks=None, | ||||
|         query_frame: int = 0, | ||||
|         compensate_for_camera_motion=False, | ||||
|     ): | ||||
|         B, T, C, H, W = video.shape | ||||
|         _, _, N, D = tracks.shape | ||||
|  | ||||
|         assert D == 2 | ||||
|         assert C == 3 | ||||
|         video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy()  # S, H, W, C | ||||
|         tracks = tracks[0].long().detach().cpu().numpy()  # S, N, 2 | ||||
|         if gt_tracks is not None: | ||||
|             gt_tracks = gt_tracks[0].detach().cpu().numpy() | ||||
|  | ||||
|         res_video = [] | ||||
|  | ||||
|         # process input video | ||||
|         for rgb in video: | ||||
|             res_video.append(rgb.copy()) | ||||
|  | ||||
|         vector_colors = np.zeros((T, N, 3)) | ||||
|         if self.mode == "optical_flow": | ||||
|             vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None]) | ||||
|         elif segm_mask is None: | ||||
|             if self.mode == "rainbow": | ||||
|                 y_min, y_max = ( | ||||
|                     tracks[query_frame, :, 1].min(), | ||||
|                     tracks[query_frame, :, 1].max(), | ||||
|                 ) | ||||
|                 norm = plt.Normalize(y_min, y_max) | ||||
|                 for n in range(N): | ||||
|                     color = self.color_map(norm(tracks[query_frame, n, 1])) | ||||
|                     color = np.array(color[:3])[None] * 255 | ||||
|                     vector_colors[:, n] = np.repeat(color, T, axis=0) | ||||
|             else: | ||||
|                 # color changes with time | ||||
|                 for t in range(T): | ||||
|                     color = np.array(self.color_map(t / T)[:3])[None] * 255 | ||||
|                     vector_colors[t] = np.repeat(color, N, axis=0) | ||||
|         else: | ||||
|             if self.mode == "rainbow": | ||||
|                 vector_colors[:, segm_mask <= 0, :] = 255 | ||||
|  | ||||
|                 y_min, y_max = ( | ||||
|                     tracks[0, segm_mask > 0, 1].min(), | ||||
|                     tracks[0, segm_mask > 0, 1].max(), | ||||
|                 ) | ||||
|                 norm = plt.Normalize(y_min, y_max) | ||||
|                 for n in range(N): | ||||
|                     if segm_mask[n] > 0: | ||||
|                         color = self.color_map(norm(tracks[0, n, 1])) | ||||
|                         color = np.array(color[:3])[None] * 255 | ||||
|                         vector_colors[:, n] = np.repeat(color, T, axis=0) | ||||
|  | ||||
|             else: | ||||
|                 # color changes with segm class | ||||
|                 segm_mask = segm_mask.cpu() | ||||
|                 color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32) | ||||
|                 color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0 | ||||
|                 color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0 | ||||
|                 vector_colors = np.repeat(color[None], T, axis=0) | ||||
|  | ||||
|         #  draw tracks | ||||
|         if self.tracks_leave_trace != 0: | ||||
|             for t in range(1, T): | ||||
|                 first_ind = ( | ||||
|                     max(0, t - self.tracks_leave_trace) | ||||
|                     if self.tracks_leave_trace >= 0 | ||||
|                     else 0 | ||||
|                 ) | ||||
|                 curr_tracks = tracks[first_ind : t + 1] | ||||
|                 curr_colors = vector_colors[first_ind : t + 1] | ||||
|                 if compensate_for_camera_motion: | ||||
|                     diff = ( | ||||
|                         tracks[first_ind : t + 1, segm_mask <= 0] | ||||
|                         - tracks[t : t + 1, segm_mask <= 0] | ||||
|                     ).mean(1)[:, None] | ||||
|  | ||||
|                     curr_tracks = curr_tracks - diff | ||||
|                     curr_tracks = curr_tracks[:, segm_mask > 0] | ||||
|                     curr_colors = curr_colors[:, segm_mask > 0] | ||||
|  | ||||
|                 res_video[t] = self._draw_pred_tracks( | ||||
|                     res_video[t], | ||||
|                     curr_tracks, | ||||
|                     curr_colors, | ||||
|                 ) | ||||
|                 if gt_tracks is not None: | ||||
|                     res_video[t] = self._draw_gt_tracks( | ||||
|                         res_video[t], gt_tracks[first_ind : t + 1] | ||||
|                     ) | ||||
|  | ||||
|         #  draw points | ||||
|         for t in range(T): | ||||
|             for i in range(N): | ||||
|                 coord = (tracks[t, i, 0], tracks[t, i, 1]) | ||||
|                 if coord[0] != 0 and coord[1] != 0: | ||||
|                     if not compensate_for_camera_motion or ( | ||||
|                         compensate_for_camera_motion and segm_mask[i] > 0 | ||||
|                     ): | ||||
|                         cv2.circle( | ||||
|                             res_video[t], | ||||
|                             coord, | ||||
|                             int(self.linewidth * 2), | ||||
|                             vector_colors[t, i].tolist(), | ||||
|                             -1, | ||||
|                         ) | ||||
|  | ||||
|         #  construct the final rgb sequence | ||||
|         if self.show_first_frame > 0: | ||||
|             res_video = [res_video[0]] * self.show_first_frame + res_video[1:] | ||||
|         return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte() | ||||
|  | ||||
|     def _draw_pred_tracks( | ||||
|         self, | ||||
|         rgb: np.ndarray,  # H x W x 3 | ||||
|         tracks: np.ndarray,  # T x 2 | ||||
|         vector_colors: np.ndarray, | ||||
|         alpha: float = 0.5, | ||||
|     ): | ||||
|         T, N, _ = tracks.shape | ||||
|  | ||||
|         for s in range(T - 1): | ||||
|             vector_color = vector_colors[s] | ||||
|             original = rgb.copy() | ||||
|             alpha = (s / T) ** 2 | ||||
|             for i in range(N): | ||||
|                 coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1])) | ||||
|                 coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1])) | ||||
|                 if coord_y[0] != 0 and coord_y[1] != 0: | ||||
|                     cv2.line( | ||||
|                         rgb, | ||||
|                         coord_y, | ||||
|                         coord_x, | ||||
|                         vector_color[i].tolist(), | ||||
|                         self.linewidth, | ||||
|                         cv2.LINE_AA, | ||||
|                     ) | ||||
|             if self.tracks_leave_trace > 0: | ||||
|                 rgb = cv2.addWeighted(rgb, alpha, original, 1 - alpha, 0) | ||||
|         return rgb | ||||
|  | ||||
|     def _draw_gt_tracks( | ||||
|         self, | ||||
|         rgb: np.ndarray,  # H x W x 3, | ||||
|         gt_tracks: np.ndarray,  # T x 2 | ||||
|     ): | ||||
|         T, N, _ = gt_tracks.shape | ||||
|         color = np.array((211.0, 0.0, 0.0)) | ||||
|  | ||||
|         for t in range(T): | ||||
|             for i in range(N): | ||||
|                 gt_tracks = gt_tracks[t][i] | ||||
|                 #  draw a red cross | ||||
|                 if gt_tracks[0] > 0 and gt_tracks[1] > 0: | ||||
|                     length = self.linewidth * 3 | ||||
|                     coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length) | ||||
|                     coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length) | ||||
|                     cv2.line( | ||||
|                         rgb, | ||||
|                         coord_y, | ||||
|                         coord_x, | ||||
|                         color, | ||||
|                         self.linewidth, | ||||
|                         cv2.LINE_AA, | ||||
|                     ) | ||||
|                     coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length) | ||||
|                     coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length) | ||||
|                     cv2.line( | ||||
|                         rgb, | ||||
|                         coord_y, | ||||
|                         coord_x, | ||||
|                         color, | ||||
|                         self.linewidth, | ||||
|                         cv2.LINE_AA, | ||||
|                     ) | ||||
|         return rgb | ||||
		Reference in New Issue
	
	Block a user