Initial commit
This commit is contained in:
		
							
								
								
									
										5
									
								
								cotracker/evaluation/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								cotracker/evaluation/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
							
								
								
									
										6
									
								
								cotracker/evaluation/configs/eval_badja.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								cotracker/evaluation/configs/eval_badja.yaml
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,6 @@ | ||||
| defaults: | ||||
|   - default_config_eval | ||||
| exp_dir: ./outputs/cotracker | ||||
| dataset_name: badja | ||||
|  | ||||
|     | ||||
							
								
								
									
										6
									
								
								cotracker/evaluation/configs/eval_fastcapture.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								cotracker/evaluation/configs/eval_fastcapture.yaml
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,6 @@ | ||||
| defaults: | ||||
|   - default_config_eval | ||||
| exp_dir: ./outputs/cotracker | ||||
| dataset_name: fastcapture | ||||
|  | ||||
|     | ||||
| @@ -0,0 +1,6 @@ | ||||
| defaults: | ||||
|   - default_config_eval | ||||
| exp_dir: ./outputs/cotracker | ||||
| dataset_name: tapvid_davis_first | ||||
|  | ||||
|     | ||||
| @@ -0,0 +1,6 @@ | ||||
| defaults: | ||||
|   - default_config_eval | ||||
| exp_dir: ./outputs/cotracker | ||||
| dataset_name: tapvid_davis_strided | ||||
|  | ||||
|     | ||||
| @@ -0,0 +1,6 @@ | ||||
| defaults: | ||||
|   - default_config_eval | ||||
| exp_dir: ./outputs/cotracker | ||||
| dataset_name: tapvid_kinetics_first | ||||
|  | ||||
|     | ||||
							
								
								
									
										5
									
								
								cotracker/evaluation/core/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								cotracker/evaluation/core/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
							
								
								
									
										144
									
								
								cotracker/evaluation/core/eval_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										144
									
								
								cotracker/evaluation/core/eval_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,144 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| import numpy as np | ||||
|  | ||||
| from typing import Iterable, Mapping, Tuple, Union | ||||
|  | ||||
|  | ||||
| def compute_tapvid_metrics( | ||||
|     query_points: np.ndarray, | ||||
|     gt_occluded: np.ndarray, | ||||
|     gt_tracks: np.ndarray, | ||||
|     pred_occluded: np.ndarray, | ||||
|     pred_tracks: np.ndarray, | ||||
|     query_mode: str, | ||||
| ) -> Mapping[str, np.ndarray]: | ||||
|     """Computes TAP-Vid metrics (Jaccard, Pts. Within Thresh, Occ. Acc.) | ||||
|     See the TAP-Vid paper for details on the metric computation.  All inputs are | ||||
|     given in raster coordinates.  The first three arguments should be the direct | ||||
|     outputs of the reader: the 'query_points', 'occluded', and 'target_points'. | ||||
|     The paper metrics assume these are scaled relative to 256x256 images. | ||||
|     pred_occluded and pred_tracks are your algorithm's predictions. | ||||
|     This function takes a batch of inputs, and computes metrics separately for | ||||
|     each video.  The metrics for the full benchmark are a simple mean of the | ||||
|     metrics across the full set of videos.  These numbers are between 0 and 1, | ||||
|     but the paper multiplies them by 100 to ease reading. | ||||
|     Args: | ||||
|        query_points: The query points, an in the format [t, y, x].  Its size is | ||||
|          [b, n, 3], where b is the batch size and n is the number of queries | ||||
|        gt_occluded: A boolean array of shape [b, n, t], where t is the number | ||||
|          of frames.  True indicates that the point is occluded. | ||||
|        gt_tracks: The target points, of shape [b, n, t, 2].  Each point is | ||||
|          in the format [x, y] | ||||
|        pred_occluded: A boolean array of predicted occlusions, in the same | ||||
|          format as gt_occluded. | ||||
|        pred_tracks: An array of track predictions from your algorithm, in the | ||||
|          same format as gt_tracks. | ||||
|        query_mode: Either 'first' or 'strided', depending on how queries are | ||||
|          sampled.  If 'first', we assume the prior knowledge that all points | ||||
|          before the query point are occluded, and these are removed from the | ||||
|          evaluation. | ||||
|     Returns: | ||||
|         A dict with the following keys: | ||||
|         occlusion_accuracy: Accuracy at predicting occlusion. | ||||
|         pts_within_{x} for x in [1, 2, 4, 8, 16]: Fraction of points | ||||
|           predicted to be within the given pixel threshold, ignoring occlusion | ||||
|           prediction. | ||||
|         jaccard_{x} for x in [1, 2, 4, 8, 16]: Jaccard metric for the given | ||||
|           threshold | ||||
|         average_pts_within_thresh: average across pts_within_{x} | ||||
|         average_jaccard: average across jaccard_{x} | ||||
|     """ | ||||
|  | ||||
|     metrics = {} | ||||
|  | ||||
|     # Don't evaluate the query point.  Numpy doesn't have one_hot, so we | ||||
|     # replicate it by indexing into an identity matrix. | ||||
|     one_hot_eye = np.eye(gt_tracks.shape[2]) | ||||
|     query_frame = query_points[..., 0] | ||||
|     query_frame = np.round(query_frame).astype(np.int32) | ||||
|     evaluation_points = one_hot_eye[query_frame] == 0 | ||||
|  | ||||
|     # If we're using the first point on the track as a query, don't evaluate the | ||||
|     # other points. | ||||
|     if query_mode == "first": | ||||
|         for i in range(gt_occluded.shape[0]): | ||||
|             index = np.where(gt_occluded[i] == 0)[0][0] | ||||
|             evaluation_points[i, :index] = False | ||||
|     elif query_mode != "strided": | ||||
|         raise ValueError("Unknown query mode " + query_mode) | ||||
|  | ||||
|     # Occlusion accuracy is simply how often the predicted occlusion equals the | ||||
|     # ground truth. | ||||
|     occ_acc = ( | ||||
|         np.sum( | ||||
|             np.equal(pred_occluded, gt_occluded) & evaluation_points, | ||||
|             axis=(1, 2), | ||||
|         ) | ||||
|         / np.sum(evaluation_points) | ||||
|     ) | ||||
|     metrics["occlusion_accuracy"] = occ_acc | ||||
|  | ||||
|     # Next, convert the predictions and ground truth positions into pixel | ||||
|     # coordinates. | ||||
|     visible = np.logical_not(gt_occluded) | ||||
|     pred_visible = np.logical_not(pred_occluded) | ||||
|     all_frac_within = [] | ||||
|     all_jaccard = [] | ||||
|     for thresh in [1, 2, 4, 8, 16]: | ||||
|         # True positives are points that are within the threshold and where both | ||||
|         # the prediction and the ground truth are listed as visible. | ||||
|         within_dist = ( | ||||
|             np.sum( | ||||
|                 np.square(pred_tracks - gt_tracks), | ||||
|                 axis=-1, | ||||
|             ) | ||||
|             < np.square(thresh) | ||||
|         ) | ||||
|         is_correct = np.logical_and(within_dist, visible) | ||||
|  | ||||
|         # Compute the frac_within_threshold, which is the fraction of points | ||||
|         # within the threshold among points that are visible in the ground truth, | ||||
|         # ignoring whether they're predicted to be visible. | ||||
|         count_correct = np.sum( | ||||
|             is_correct & evaluation_points, | ||||
|             axis=(1, 2), | ||||
|         ) | ||||
|         count_visible_points = np.sum(visible & evaluation_points, axis=(1, 2)) | ||||
|         frac_correct = count_correct / count_visible_points | ||||
|         metrics["pts_within_" + str(thresh)] = frac_correct | ||||
|         all_frac_within.append(frac_correct) | ||||
|  | ||||
|         true_positives = np.sum( | ||||
|             is_correct & pred_visible & evaluation_points, axis=(1, 2) | ||||
|         ) | ||||
|  | ||||
|         # The denominator of the jaccard metric is the true positives plus | ||||
|         # false positives plus false negatives.  However, note that true positives | ||||
|         # plus false negatives is simply the number of points in the ground truth | ||||
|         # which is easier to compute than trying to compute all three quantities. | ||||
|         # Thus we just add the number of points in the ground truth to the number | ||||
|         # of false positives. | ||||
|         # | ||||
|         # False positives are simply points that are predicted to be visible, | ||||
|         # but the ground truth is not visible or too far from the prediction. | ||||
|         gt_positives = np.sum(visible & evaluation_points, axis=(1, 2)) | ||||
|         false_positives = (~visible) & pred_visible | ||||
|         false_positives = false_positives | ((~within_dist) & pred_visible) | ||||
|         false_positives = np.sum(false_positives & evaluation_points, axis=(1, 2)) | ||||
|         jaccard = true_positives / (gt_positives + false_positives) | ||||
|         metrics["jaccard_" + str(thresh)] = jaccard | ||||
|         all_jaccard.append(jaccard) | ||||
|     metrics["average_jaccard"] = np.mean( | ||||
|         np.stack(all_jaccard, axis=1), | ||||
|         axis=1, | ||||
|     ) | ||||
|     metrics["average_pts_within_thresh"] = np.mean( | ||||
|         np.stack(all_frac_within, axis=1), | ||||
|         axis=1, | ||||
|     ) | ||||
|     return metrics | ||||
							
								
								
									
										252
									
								
								cotracker/evaluation/core/evaluator.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										252
									
								
								cotracker/evaluation/core/evaluator.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,252 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| from collections import defaultdict | ||||
| import os | ||||
| from typing import Optional | ||||
| import torch | ||||
| from tqdm import tqdm | ||||
| import numpy as np | ||||
|  | ||||
| from torch.utils.tensorboard import SummaryWriter | ||||
| from cotracker.datasets.utils import dataclass_to_cuda_ | ||||
| from cotracker.utils.visualizer import Visualizer | ||||
| from cotracker.models.core.model_utils import reduce_masked_mean | ||||
| from cotracker.evaluation.core.eval_utils import compute_tapvid_metrics | ||||
|  | ||||
| import logging | ||||
|  | ||||
|  | ||||
| class Evaluator: | ||||
|     """ | ||||
|     A class defining the CoTracker evaluator. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, exp_dir) -> None: | ||||
|         # Visualization | ||||
|         self.exp_dir = exp_dir | ||||
|         os.makedirs(exp_dir, exist_ok=True) | ||||
|         self.visualization_filepaths = defaultdict(lambda: defaultdict(list)) | ||||
|         self.visualize_dir = os.path.join(exp_dir, "visualisations") | ||||
|  | ||||
|     def compute_metrics(self, metrics, sample, pred_trajectory, dataset_name): | ||||
|         if isinstance(pred_trajectory, tuple): | ||||
|             pred_trajectory, pred_visibility = pred_trajectory | ||||
|         else: | ||||
|             pred_visibility = None | ||||
|         if dataset_name == "badja": | ||||
|             sample.segmentation = (sample.segmentation > 0).float() | ||||
|             *_, N, _ = sample.trajectory.shape | ||||
|             accs = [] | ||||
|             accs_3px = [] | ||||
|             for s1 in range(1, sample.video.shape[1]):  # target frame | ||||
|                 for n in range(N): | ||||
|                     vis = sample.visibility[0, s1, n] | ||||
|                     if vis > 0: | ||||
|                         coord_e = pred_trajectory[0, s1, n]  # 2 | ||||
|                         coord_g = sample.trajectory[0, s1, n]  # 2 | ||||
|                         dist = torch.sqrt(torch.sum((coord_e - coord_g) ** 2, dim=0)) | ||||
|                         area = torch.sum(sample.segmentation[0, s1]) | ||||
|                         # print_('0.2*sqrt(area)', 0.2*torch.sqrt(area)) | ||||
|                         thr = 0.2 * torch.sqrt(area) | ||||
|                         # correct = | ||||
|                         accs.append((dist < thr).float()) | ||||
|                         # print('thr',thr) | ||||
|                         accs_3px.append((dist < 3.0).float()) | ||||
|  | ||||
|             res = torch.mean(torch.stack(accs)) * 100.0 | ||||
|             res_3px = torch.mean(torch.stack(accs_3px)) * 100.0 | ||||
|             metrics[sample.seq_name[0]] = res.item() | ||||
|             metrics[sample.seq_name[0] + "_accuracy"] = res_3px.item() | ||||
|             print(metrics) | ||||
|             print( | ||||
|                 "avg", np.mean([v for k, v in metrics.items() if "accuracy" not in k]) | ||||
|             ) | ||||
|             print( | ||||
|                 "avg acc 3px", | ||||
|                 np.mean([v for k, v in metrics.items() if "accuracy" in k]), | ||||
|             ) | ||||
|         elif dataset_name == "fastcapture" or ("kubric" in dataset_name): | ||||
|             *_, N, _ = sample.trajectory.shape | ||||
|             accs = [] | ||||
|             for s1 in range(1, sample.video.shape[1]):  # target frame | ||||
|                 for n in range(N): | ||||
|                     vis = sample.visibility[0, s1, n] | ||||
|                     if vis > 0: | ||||
|                         coord_e = pred_trajectory[0, s1, n]  # 2 | ||||
|                         coord_g = sample.trajectory[0, s1, n]  # 2 | ||||
|                         dist = torch.sqrt(torch.sum((coord_e - coord_g) ** 2, dim=0)) | ||||
|                         thr = 3 | ||||
|                         correct = (dist < thr).float() | ||||
|                         accs.append(correct) | ||||
|  | ||||
|             res = torch.mean(torch.stack(accs)) * 100.0 | ||||
|             metrics[sample.seq_name[0] + "_accuracy"] = res.item() | ||||
|             print(metrics) | ||||
|             print("avg", np.mean([v for v in metrics.values()])) | ||||
|         elif "tapvid" in dataset_name: | ||||
|             B, T, N, D = sample.trajectory.shape | ||||
|             traj = sample.trajectory.clone() | ||||
|             thr = 0.9 | ||||
|  | ||||
|             if pred_visibility is None: | ||||
|                 logging.warning("visibility is NONE") | ||||
|                 pred_visibility = torch.zeros_like(sample.visibility) | ||||
|  | ||||
|             if not pred_visibility.dtype == torch.bool: | ||||
|                 pred_visibility = pred_visibility > thr | ||||
|  | ||||
|             # pred_trajectory | ||||
|             query_points = sample.query_points.clone().cpu().numpy() | ||||
|  | ||||
|             pred_visibility = pred_visibility[:, :, :N] | ||||
|             pred_trajectory = pred_trajectory[:, :, :N] | ||||
|  | ||||
|             gt_tracks = traj.permute(0, 2, 1, 3).cpu().numpy() | ||||
|             gt_occluded = ( | ||||
|                 torch.logical_not(sample.visibility.clone().permute(0, 2, 1)) | ||||
|                 .cpu() | ||||
|                 .numpy() | ||||
|             ) | ||||
|  | ||||
|             pred_occluded = ( | ||||
|                 torch.logical_not(pred_visibility.clone().permute(0, 2, 1)) | ||||
|                 .cpu() | ||||
|                 .numpy() | ||||
|             ) | ||||
|             pred_tracks = pred_trajectory.permute(0, 2, 1, 3).cpu().numpy() | ||||
|  | ||||
|             out_metrics = compute_tapvid_metrics( | ||||
|                 query_points, | ||||
|                 gt_occluded, | ||||
|                 gt_tracks, | ||||
|                 pred_occluded, | ||||
|                 pred_tracks, | ||||
|                 query_mode="strided" if "strided" in dataset_name else "first", | ||||
|             ) | ||||
|  | ||||
|             metrics[sample.seq_name[0]] = out_metrics | ||||
|             for metric_name in out_metrics.keys(): | ||||
|                 if "avg" not in metrics: | ||||
|                     metrics["avg"] = {} | ||||
|                 metrics["avg"][metric_name] = np.mean( | ||||
|                     [v[metric_name] for k, v in metrics.items() if k != "avg"] | ||||
|                 ) | ||||
|  | ||||
|             logging.info(f"Metrics: {out_metrics}") | ||||
|             logging.info(f"avg: {metrics['avg']}") | ||||
|             print("metrics", out_metrics) | ||||
|             print("avg", metrics["avg"]) | ||||
|         else: | ||||
|             rgbs = sample.video | ||||
|             trajs_g = sample.trajectory | ||||
|             valids = sample.valid | ||||
|             vis_g = sample.visibility | ||||
|  | ||||
|             B, S, C, H, W = rgbs.shape | ||||
|             assert C == 3 | ||||
|             B, S, N, D = trajs_g.shape | ||||
|  | ||||
|             assert torch.sum(valids) == B * S * N | ||||
|  | ||||
|             vis_g = (torch.sum(vis_g, dim=1, keepdim=True) >= 4).float().repeat(1, S, 1) | ||||
|  | ||||
|             ate = torch.norm(pred_trajectory - trajs_g, dim=-1)  # B, S, N | ||||
|  | ||||
|             metrics["things_all"] = reduce_masked_mean(ate, valids).item() | ||||
|             metrics["things_vis"] = reduce_masked_mean(ate, valids * vis_g).item() | ||||
|             metrics["things_occ"] = reduce_masked_mean( | ||||
|                 ate, valids * (1.0 - vis_g) | ||||
|             ).item() | ||||
|  | ||||
|     @torch.no_grad() | ||||
|     def evaluate_sequence( | ||||
|         self, | ||||
|         model, | ||||
|         test_dataloader: torch.utils.data.DataLoader, | ||||
|         dataset_name: str, | ||||
|         train_mode=False, | ||||
|         writer: Optional[SummaryWriter] = None, | ||||
|         step: Optional[int] = 0, | ||||
|     ): | ||||
|         metrics = {} | ||||
|  | ||||
|         vis = Visualizer( | ||||
|             save_dir=self.exp_dir, | ||||
|             fps=7, | ||||
|         ) | ||||
|  | ||||
|         for ind, sample in enumerate(tqdm(test_dataloader)): | ||||
|             if isinstance(sample, tuple): | ||||
|                 sample, gotit = sample | ||||
|                 if not all(gotit): | ||||
|                     print("batch is None") | ||||
|                     continue | ||||
|             dataclass_to_cuda_(sample) | ||||
|  | ||||
|             if ( | ||||
|                 not train_mode | ||||
|                 and hasattr(model, "sequence_len") | ||||
|                 and (sample.visibility[:, : model.sequence_len].sum() == 0) | ||||
|             ): | ||||
|                 print(f"skipping batch {ind}") | ||||
|                 continue | ||||
|  | ||||
|             if "tapvid" in dataset_name: | ||||
|                 queries = sample.query_points.clone().float() | ||||
|  | ||||
|                 queries = torch.stack( | ||||
|                     [ | ||||
|                         queries[:, :, 0], | ||||
|                         queries[:, :, 2], | ||||
|                         queries[:, :, 1], | ||||
|                     ], | ||||
|                     dim=2, | ||||
|                 ) | ||||
|             else: | ||||
|                 queries = torch.cat( | ||||
|                     [ | ||||
|                         torch.zeros_like(sample.trajectory[:, 0, :, :1]), | ||||
|                         sample.trajectory[:, 0], | ||||
|                     ], | ||||
|                     dim=2, | ||||
|                 ) | ||||
|  | ||||
|             pred_tracks = model(sample.video, queries) | ||||
|             if "strided" in dataset_name: | ||||
|  | ||||
|                 inv_video = sample.video.flip(1).clone() | ||||
|                 inv_queries = queries.clone() | ||||
|                 inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1 | ||||
|  | ||||
|                 pred_trj, pred_vsb = pred_tracks | ||||
|                 inv_pred_trj, inv_pred_vsb = model(inv_video, inv_queries) | ||||
|  | ||||
|                 inv_pred_trj = inv_pred_trj.flip(1) | ||||
|                 inv_pred_vsb = inv_pred_vsb.flip(1) | ||||
|  | ||||
|                 mask = pred_trj == 0 | ||||
|  | ||||
|                 pred_trj[mask] = inv_pred_trj[mask] | ||||
|                 pred_vsb[mask[:, :, :, 0]] = inv_pred_vsb[mask[:, :, :, 0]] | ||||
|  | ||||
|                 pred_tracks = pred_trj, pred_vsb | ||||
|  | ||||
|             if dataset_name == "badja" or dataset_name == "fastcapture": | ||||
|                 seq_name = sample.seq_name[0] | ||||
|             else: | ||||
|                 seq_name = str(ind) | ||||
|  | ||||
|             vis.visualize( | ||||
|                 sample.video, | ||||
|                 pred_tracks[0] if isinstance(pred_tracks, tuple) else pred_tracks, | ||||
|                 filename=dataset_name + "_" + seq_name, | ||||
|                 writer=writer, | ||||
|                 step=step, | ||||
|             ) | ||||
|  | ||||
|             self.compute_metrics(metrics, sample, pred_tracks, dataset_name) | ||||
|         return metrics | ||||
							
								
								
									
										179
									
								
								cotracker/evaluation/evaluate.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										179
									
								
								cotracker/evaluation/evaluate.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,179 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| import json | ||||
| import os | ||||
| from dataclasses import dataclass, field | ||||
|  | ||||
| import hydra | ||||
| import numpy as np | ||||
|  | ||||
| import torch | ||||
| from omegaconf import OmegaConf | ||||
|  | ||||
| from cotracker.datasets.badja_dataset import BadjaDataset | ||||
| from cotracker.datasets.fast_capture_dataset import FastCaptureDataset | ||||
| from cotracker.datasets.tap_vid_datasets import TapVidDataset | ||||
| from cotracker.datasets.utils import collate_fn | ||||
|  | ||||
| from cotracker.models.evaluation_predictor import EvaluationPredictor | ||||
|  | ||||
| from cotracker.evaluation.core.evaluator import Evaluator | ||||
| from cotracker.models.build_cotracker import ( | ||||
|     build_cotracker, | ||||
| ) | ||||
|  | ||||
|  | ||||
| @dataclass(eq=False) | ||||
| class DefaultConfig: | ||||
|     # Directory where all outputs of the experiment will be saved. | ||||
|     exp_dir: str = "./outputs" | ||||
|  | ||||
|     # Name of the dataset to be used for the evaluation. | ||||
|     dataset_name: str = "badja" | ||||
|     # The root directory of the dataset. | ||||
|     dataset_root: str = "./" | ||||
|  | ||||
|     # Path to the pre-trained model checkpoint to be used for the evaluation. | ||||
|     # The default value is the path to a specific CoTracker model checkpoint. | ||||
|     # Other available options are commented. | ||||
|     checkpoint: str = "./checkpoints/cotracker_stride_4_wind_8.pth" | ||||
|     # cotracker_stride_4_wind_12 | ||||
|     # cotracker_stride_8_wind_16 | ||||
|  | ||||
|     # EvaluationPredictor parameters | ||||
|     # The size (N) of the support grid used in the predictor. | ||||
|     # The total number of points is (N*N). | ||||
|     grid_size: int = 6 | ||||
|     # The size (N) of the local support grid. | ||||
|     local_grid_size: int = 6 | ||||
|     # A flag indicating whether to evaluate one ground truth point at a time. | ||||
|     single_point: bool = True | ||||
|     # The number of iterative updates for each sliding window. | ||||
|     n_iters: int = 6 | ||||
|  | ||||
|     seed: int = 0 | ||||
|     gpu_idx: int = 0 | ||||
|  | ||||
|     # Override hydra's working directory to current working dir, | ||||
|     # also disable storing the .hydra logs: | ||||
|     hydra: dict = field( | ||||
|         default_factory=lambda: { | ||||
|             "run": {"dir": "."}, | ||||
|             "output_subdir": None, | ||||
|         } | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def run_eval(cfg: DefaultConfig): | ||||
|     """ | ||||
|     The function evaluates CoTracker on a specified benchmark dataset based on a provided configuration. | ||||
|  | ||||
|     Args: | ||||
|         cfg (DefaultConfig): An instance of DefaultConfig class which includes: | ||||
|             - exp_dir (str): The directory path for the experiment. | ||||
|             - dataset_name (str): The name of the dataset to be used. | ||||
|             - dataset_root (str): The root directory of the dataset. | ||||
|             - checkpoint (str): The path to the CoTracker model's checkpoint. | ||||
|             - single_point (bool): A flag indicating whether to evaluate one ground truth point at a time. | ||||
|             - n_iters (int): The number of iterative updates for each sliding window. | ||||
|             - seed (int): The seed for setting the random state for reproducibility. | ||||
|             - gpu_idx (int): The index of the GPU to be used. | ||||
|     """ | ||||
|     # Creating the experiment directory if it doesn't exist | ||||
|     os.makedirs(cfg.exp_dir, exist_ok=True) | ||||
|  | ||||
|     # Saving the experiment configuration to a .yaml file in the experiment directory | ||||
|     cfg_file = os.path.join(cfg.exp_dir, "expconfig.yaml") | ||||
|     with open(cfg_file, "w") as f: | ||||
|         OmegaConf.save(config=cfg, f=f) | ||||
|  | ||||
|     evaluator = Evaluator(cfg.exp_dir) | ||||
|     cotracker_model = build_cotracker(cfg.checkpoint) | ||||
|  | ||||
|     # Creating the EvaluationPredictor object | ||||
|     predictor = EvaluationPredictor( | ||||
|         cotracker_model, | ||||
|         grid_size=cfg.grid_size, | ||||
|         local_grid_size=cfg.local_grid_size, | ||||
|         single_point=cfg.single_point, | ||||
|         n_iters=cfg.n_iters, | ||||
|     ) | ||||
|  | ||||
|     # Setting the random seeds | ||||
|     torch.manual_seed(cfg.seed) | ||||
|     np.random.seed(cfg.seed) | ||||
|  | ||||
|     # Constructing the specified dataset | ||||
|     curr_collate_fn = collate_fn | ||||
|     if cfg.dataset_name == "badja": | ||||
|         test_dataset = BadjaDataset(data_root=os.path.join(cfg.dataset_root, "BADJA")) | ||||
|     elif cfg.dataset_name == "fastcapture": | ||||
|         test_dataset = FastCaptureDataset( | ||||
|             data_root=os.path.join(cfg.dataset_root, "fastcapture"), | ||||
|             max_seq_len=100, | ||||
|             max_num_points=20, | ||||
|         ) | ||||
|     elif "tapvid" in cfg.dataset_name: | ||||
|         dataset_type = cfg.dataset_name.split("_")[1] | ||||
|         if dataset_type == "davis": | ||||
|             data_root = os.path.join(cfg.dataset_root, "/tapvid_davis/tapvid_davis.pkl") | ||||
|         elif dataset_type == "kinetics": | ||||
|             data_root = os.path.join( | ||||
|                 cfg.dataset_root, "/kinetics/kinetics-dataset/k700-2020/tapvid_kinetics" | ||||
|             ) | ||||
|         test_dataset = TapVidDataset( | ||||
|             dataset_type=dataset_type, | ||||
|             data_root=data_root, | ||||
|             queried_first=not "strided" in cfg.dataset_name, | ||||
|         ) | ||||
|  | ||||
|     # Creating the DataLoader object | ||||
|     test_dataloader = torch.utils.data.DataLoader( | ||||
|         test_dataset, | ||||
|         batch_size=1, | ||||
|         shuffle=False, | ||||
|         num_workers=14, | ||||
|         collate_fn=curr_collate_fn, | ||||
|     ) | ||||
|  | ||||
|     # Timing and conducting the evaluation | ||||
|     import time | ||||
|  | ||||
|     start = time.time() | ||||
|     evaluate_result = evaluator.evaluate_sequence( | ||||
|         predictor, | ||||
|         test_dataloader, | ||||
|         dataset_name=cfg.dataset_name, | ||||
|     ) | ||||
|     end = time.time() | ||||
|     print(end - start) | ||||
|  | ||||
|     # Saving the evaluation results to a .json file | ||||
|     if not "tapvid" in cfg.dataset_name: | ||||
|         print("evaluate_result", evaluate_result) | ||||
|     else: | ||||
|         evaluate_result = evaluate_result["avg"] | ||||
|     result_file = os.path.join(cfg.exp_dir, f"result_eval_.json") | ||||
|     evaluate_result["time"] = end - start | ||||
|     print(f"Dumping eval results to {result_file}.") | ||||
|     with open(result_file, "w") as f: | ||||
|         json.dump(evaluate_result, f) | ||||
|  | ||||
|  | ||||
| cs = hydra.core.config_store.ConfigStore.instance() | ||||
| cs.store(name="default_config_eval", node=DefaultConfig) | ||||
|  | ||||
|  | ||||
| @hydra.main(config_path="./configs/", config_name="default_config_eval") | ||||
| def evaluate(cfg: DefaultConfig) -> None: | ||||
|     os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | ||||
|     os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_idx) | ||||
|     run_eval(cfg) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     evaluate() | ||||
		Reference in New Issue
	
	Block a user