Merge branch 'main' of github.com:JunkyByte/co-tracker
This commit is contained in:
		
							
								
								
									
										37
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										37
									
								
								README.md
									
									
									
									
									
								
							| @@ -1,6 +1,6 @@ | ||||
| # CoTracker: It is Better to Track Together | ||||
|  | ||||
| **[Meta AI Research, FAIR](https://ai.facebook.com/research/)**; **[University of Oxford, VGG](https://www.robots.ox.ac.uk/~vgg/)** | ||||
| **[Meta AI Research, GenAI](https://ai.facebook.com/research/)**; **[University of Oxford, VGG](https://www.robots.ox.ac.uk/~vgg/)** | ||||
|  | ||||
| [Nikita Karaev](https://nikitakaraevv.github.io/), [Ignacio Rocco](https://www.irocco.info/), [Benjamin Graham](https://ai.facebook.com/people/benjamin-graham/), [Natalia Neverova](https://nneverova.github.io/), [Andrea Vedaldi](https://www.robots.ox.ac.uk/~vedaldi/), [Christian Rupprecht](https://chrirupp.github.io/) | ||||
|  | ||||
| @@ -15,7 +15,7 @@ | ||||
| **CoTracker** is a fast transformer-based model that can track any point in a video. It brings to tracking some of the benefits of Optical Flow. | ||||
|   | ||||
| CoTracker can track: | ||||
| - **Every pixel** within a video | ||||
| - **Every pixel** in a video | ||||
| - Points sampled on a regular grid on any video frame  | ||||
| - Manually selected points | ||||
|  | ||||
| @@ -26,7 +26,21 @@ Try these tracking modes for yourself with our [Colab demo](https://colab.resear | ||||
| ## Installation Instructions | ||||
| Ensure you have both PyTorch and TorchVision installed on your system. Follow the instructions [here](https://pytorch.org/get-started/locally/) for the installation. We strongly recommend installing both PyTorch and TorchVision with CUDA support. | ||||
|  | ||||
| ## Steps to Install CoTracker and its dependencies: | ||||
| ### Pretrained models via PyTorch Hub | ||||
| The easiest way to use CoTracker is to load a pretrained model from torch.hub: | ||||
| ``` | ||||
| pip install einops timm tqdm | ||||
| ``` | ||||
| ``` | ||||
| import torch | ||||
| import timm | ||||
| import einops | ||||
| import tqdm | ||||
|  | ||||
| cotracker = torch.hub.load("facebookresearch/co-tracker", "cotracker_w8") | ||||
| ``` | ||||
| Another option is to install it from this gihub repo. That's the best way if you need to run our demo or evaluate / train CoTracker: | ||||
| ### Steps to Install CoTracker and its dependencies: | ||||
| ``` | ||||
| git clone https://github.com/facebookresearch/co-tracker | ||||
| cd co-tracker | ||||
| @@ -35,7 +49,7 @@ pip install opencv-python einops timm matplotlib moviepy flow_vis | ||||
| ``` | ||||
|  | ||||
|  | ||||
| ## Model Weights Download: | ||||
| ### Download Model Weights: | ||||
| ``` | ||||
| mkdir checkpoints | ||||
| cd checkpoints | ||||
| @@ -60,24 +74,26 @@ To reproduce the results presented in the paper, download the following datasets | ||||
|  | ||||
| And install the necessary dependencies: | ||||
| ``` | ||||
| pip install hydra-core==1.1.0 mediapy tensorboard  | ||||
| pip install hydra-core==1.1.0 mediapy  | ||||
| ``` | ||||
| Then, execute the following command to evaluate on BADJA: | ||||
| ``` | ||||
| python ./cotracker/evaluation/evaluate.py --config-name eval_badja exp_dir=./eval_outputs dataset_root=your/badja/path | ||||
| ``` | ||||
| By default, evaluation will be slow since it is done for one target point at a time, which ensures robustness and fairness, as described in the paper. | ||||
|  | ||||
| ## Training | ||||
| To train the CoTracker as described in our paper, you first need to generate annotations for [Google Kubric](https://github.com/google-research/kubric) MOVI-f dataset.  Instructions for annotation generation can be found [here](https://github.com/deepmind/tapnet). | ||||
|  | ||||
| Once you have the annotated dataset, you need to make sure you followed the steps for evaluation setup and install the training dependencies: | ||||
| ``` | ||||
| pip install pytorch_lightning==1.6.0 | ||||
| pip install pytorch_lightning==1.6.0 tensorboard | ||||
| ``` | ||||
|  launch training on Kubric. Our model was trained using 32 GPUs, and you can adjust the parameters to best suit your hardware setup. | ||||
| Now you can launch training on Kubric. Our model was trained for 50000 iterations on 32 GPUs (4 nodes with 8 GPUs). | ||||
| Modify *dataset_root* and *ckpt_path* accordingly before running this command: | ||||
| ``` | ||||
| python train.py --batch_size 1 --num_workers 28 \ | ||||
| --num_steps 50000 --ckpt_path ./ --model_name cotracker \ | ||||
| --num_steps 50000 --ckpt_path ./ --dataset_root ./datasets --model_name cotracker \ | ||||
| --save_freq 200 --sequence_len 24 --eval_datasets tapvid_davis_first badja \ | ||||
| --traj_per_sample 256 --sliding_window_len 8 --updateformer_space_depth 6 --updateformer_time_depth 6 \ | ||||
| --save_every_n_epoch 10 --evaluate_every_n_epoch 10 --model_stride 4 | ||||
| @@ -86,13 +102,16 @@ python train.py --batch_size 1 --num_workers 28 \ | ||||
| ## License | ||||
| The majority of CoTracker is licensed under CC-BY-NC, however portions of the project are available under separate license terms: Particle Video Revisited is licensed under the MIT license, TAP-Vid is licensed under the Apache 2.0 license. | ||||
|  | ||||
| ## Acknowledgments | ||||
| We would like to thank [PIPs](https://github.com/aharley/pips) and [TAP-Vid](https://github.com/deepmind/tapnet) for publicly releasing their code and data. We also want to thank [Luke Melas-Kyriazi](https://lukemelas.github.io/) for proofreading the paper, [Jianyuan Wang](https://jytime.github.io/), [Roman Shapovalov](https://shapovalov.ro/) and [Adam W. Harley](https://adamharley.com/) for the insightful discussions. | ||||
|  | ||||
| ## Citing CoTracker | ||||
| If you find our repository useful, please consider giving it a star ⭐ and citing our paper in your work: | ||||
| ``` | ||||
| @article{karaev2023cotracker, | ||||
|   title={CoTracker: It is Better to Track Together}, | ||||
|   author={Nikita Karaev and Ignacio Rocco and Benjamin Graham and Natalia Neverova and Andrea Vedaldi and Christian Rupprecht}, | ||||
|   journal={arxiv}, | ||||
|   journal={arXiv:2307.07635}, | ||||
|   year={2023} | ||||
| } | ||||
| ``` | ||||
| @@ -185,7 +185,11 @@ class Evaluator: | ||||
|                 if not all(gotit): | ||||
|                     print("batch is None") | ||||
|                     continue | ||||
|             dataclass_to_cuda_(sample) | ||||
|             if torch.cuda.is_available(): | ||||
|                 dataclass_to_cuda_(sample) | ||||
|                 device = torch.device("cuda") | ||||
|             else: | ||||
|                 device = torch.device("cpu") | ||||
|  | ||||
|             if ( | ||||
|                 not train_mode | ||||
| @@ -205,7 +209,7 @@ class Evaluator: | ||||
|                         queries[:, :, 1], | ||||
|                     ], | ||||
|                     dim=2, | ||||
|                 ) | ||||
|                 ).to(device) | ||||
|             else: | ||||
|                 queries = torch.cat( | ||||
|                     [ | ||||
| @@ -213,7 +217,7 @@ class Evaluator: | ||||
|                         sample.trajectory[:, 0], | ||||
|                     ], | ||||
|                     dim=2, | ||||
|                 ) | ||||
|                 ).to(device) | ||||
|  | ||||
|             pred_tracks = model(sample.video, queries) | ||||
|             if "strided" in dataset_name: | ||||
|   | ||||
| @@ -102,6 +102,8 @@ def run_eval(cfg: DefaultConfig): | ||||
|         single_point=cfg.single_point, | ||||
|         n_iters=cfg.n_iters, | ||||
|     ) | ||||
|     if torch.cuda.is_available(): | ||||
|         predictor.model = predictor.model.cuda() | ||||
|  | ||||
|     # Setting the random seeds | ||||
|     torch.manual_seed(cfg.seed) | ||||
|   | ||||
| @@ -12,6 +12,8 @@ from cotracker.models.core.cotracker.cotracker import CoTracker | ||||
| def build_cotracker( | ||||
|     checkpoint: str, | ||||
| ): | ||||
|     if checkpoint is None: | ||||
|         return build_cotracker_stride_4_wind_8() | ||||
|     model_name = checkpoint.split("/")[-1].split(".")[0] | ||||
|     if model_name == "cotracker_stride_4_wind_8": | ||||
|         return build_cotracker_stride_4_wind_8(checkpoint=checkpoint) | ||||
|   | ||||
| @@ -25,11 +25,11 @@ from cotracker.models.core.embeddings import ( | ||||
| torch.manual_seed(0) | ||||
|  | ||||
|  | ||||
| def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0), device='cuda'): | ||||
| def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0), device="cuda"): | ||||
|     if grid_size == 1: | ||||
|         return torch.tensor([interp_shape[1] / 2, interp_shape[0] / 2])[ | ||||
|         return torch.tensor([interp_shape[1] / 2, interp_shape[0] / 2], device=device)[ | ||||
|             None, None | ||||
|         ].to(device) | ||||
|         ] | ||||
|  | ||||
|     grid_y, grid_x = meshgrid2d( | ||||
|         1, grid_size, grid_size, stack=False, norm=False, device=device | ||||
|   | ||||
| @@ -29,11 +29,10 @@ class EvaluationPredictor(torch.nn.Module): | ||||
|         self.n_iters = n_iters | ||||
|  | ||||
|         self.model = cotracker_model | ||||
|         self.model.to("cuda") | ||||
|         self.model.eval() | ||||
|  | ||||
|     def forward(self, video, queries): | ||||
|         queries = queries.clone().cuda() | ||||
|         queries = queries.clone() | ||||
|         B, T, C, H, W = video.shape | ||||
|         B, N, D = queries.shape | ||||
|  | ||||
| @@ -42,14 +41,16 @@ class EvaluationPredictor(torch.nn.Module): | ||||
|  | ||||
|         rgbs = video.reshape(B * T, C, H, W) | ||||
|         rgbs = F.interpolate(rgbs, tuple(self.interp_shape), mode="bilinear") | ||||
|         rgbs = rgbs.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1]).cuda() | ||||
|         rgbs = rgbs.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1]) | ||||
|  | ||||
|         device = rgbs.device | ||||
|  | ||||
|         queries[:, :, 1] *= self.interp_shape[1] / W | ||||
|         queries[:, :, 2] *= self.interp_shape[0] / H | ||||
|  | ||||
|         if self.single_point: | ||||
|             traj_e = torch.zeros((B, T, N, 2)).cuda() | ||||
|             vis_e = torch.zeros((B, T, N)).cuda() | ||||
|             traj_e = torch.zeros((B, T, N, 2), device=device) | ||||
|             vis_e = torch.zeros((B, T, N), device=device) | ||||
|             for pind in range((N)): | ||||
|                 query = queries[:, pind : pind + 1] | ||||
|  | ||||
| @@ -60,8 +61,10 @@ class EvaluationPredictor(torch.nn.Module): | ||||
|                 vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1] | ||||
|         else: | ||||
|             if self.grid_size > 0: | ||||
|                 xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:]) | ||||
|                 xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).cuda()  # | ||||
|                 xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:], device=device) | ||||
|                 xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to( | ||||
|                     device | ||||
|                 )  # | ||||
|                 queries = torch.cat([queries, xy], dim=1)  # | ||||
|  | ||||
|             traj_e, __, vis_e, __ = self.model( | ||||
| @@ -91,8 +94,8 @@ class EvaluationPredictor(torch.nn.Module): | ||||
|             query = torch.cat([query, xy_target], dim=1).to(device)  # | ||||
|  | ||||
|         if self.grid_size > 0: | ||||
|             xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:]) | ||||
|             xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).cuda()  # | ||||
|             xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:], device=device) | ||||
|             xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device)  # | ||||
|             query = torch.cat([query, xy], dim=1).to(device)  # | ||||
|         # crop the video to start from the queried frame | ||||
|         query[0, 0, 0] = 0 | ||||
|   | ||||
| @@ -25,8 +25,6 @@ class CoTrackerPredictor(torch.nn.Module): | ||||
|         model = build_cotracker(checkpoint) | ||||
|  | ||||
|         self.model = model | ||||
|         self.device = device or 'cuda' | ||||
|         self.model.to(self.device) | ||||
|         self.model.eval() | ||||
|  | ||||
|     @torch.no_grad() | ||||
| @@ -73,7 +71,7 @@ class CoTrackerPredictor(torch.nn.Module): | ||||
|         grid_width = W // grid_step | ||||
|         grid_height = H // grid_step | ||||
|         tracks = visibilities = None | ||||
|         grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(self.device) | ||||
|         grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(video.device) | ||||
|         grid_pts[0, :, 0] = grid_query_frame | ||||
|         for offset in tqdm(range(grid_step * grid_step)): | ||||
|             ox = offset % grid_step | ||||
| @@ -108,10 +106,8 @@ class CoTrackerPredictor(torch.nn.Module): | ||||
|         assert B == 1 | ||||
|  | ||||
|         video = video.reshape(B * T, C, H, W) | ||||
|         video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear").to(self.device) | ||||
|         video = video.reshape( | ||||
|             B, T, 3, self.interp_shape[0], self.interp_shape[1] | ||||
|         ).to(self.device) | ||||
|         video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear") | ||||
|         video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1]) | ||||
|  | ||||
|         if queries is not None: | ||||
|             queries = queries.clone() | ||||
| @@ -120,7 +116,7 @@ class CoTrackerPredictor(torch.nn.Module): | ||||
|             queries[:, :, 1] *= self.interp_shape[1] / W | ||||
|             queries[:, :, 2] *= self.interp_shape[0] / H | ||||
|         elif grid_size > 0: | ||||
|             grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=self.device) | ||||
|             grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=video.device) | ||||
|             if segm_mask is not None: | ||||
|                 segm_mask = F.interpolate( | ||||
|                     segm_mask, tuple(self.interp_shape), mode="nearest" | ||||
|   | ||||
| @@ -14,7 +14,6 @@ from matplotlib import cm | ||||
| import torch.nn.functional as F | ||||
| import torchvision.transforms as transforms | ||||
| from moviepy.editor import ImageSequenceClip | ||||
| from torch.utils.tensorboard import SummaryWriter | ||||
| import matplotlib.pyplot as plt | ||||
|  | ||||
|  | ||||
| @@ -67,7 +66,7 @@ class Visualizer: | ||||
|         gt_tracks: torch.Tensor = None,  # (B,T,N,2) | ||||
|         segm_mask: torch.Tensor = None,  # (B,1,H,W) | ||||
|         filename: str = "video", | ||||
|         writer: SummaryWriter = None, | ||||
|         writer=None,  # tensorboard Summary Writer, used for visualization during training | ||||
|         step: int = 0, | ||||
|         query_frame: int = 0, | ||||
|         save_video: bool = True, | ||||
|   | ||||
							
								
								
									
										12
									
								
								demo.py
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								demo.py
									
									
									
									
									
								
							| @@ -32,11 +32,6 @@ if __name__ == "__main__": | ||||
|         default="./checkpoints/cotracker_stride_4_wind_8.pth", | ||||
|         help="cotracker model", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--device", | ||||
|         default="cuda", | ||||
|         help="Device to use for inference", | ||||
|     ) | ||||
|     parser.add_argument("--grid_size", type=int, default=0, help="Regular grid size") | ||||
|     parser.add_argument( | ||||
|         "--grid_query_frame", | ||||
| @@ -59,7 +54,12 @@ if __name__ == "__main__": | ||||
|     segm_mask = np.array(Image.open(os.path.join(args.mask_path))) | ||||
|     segm_mask = torch.from_numpy(segm_mask)[None, None] | ||||
|  | ||||
|     model = CoTrackerPredictor(checkpoint=args.checkpoint, device=args.device) | ||||
|     model = CoTrackerPredictor(checkpoint=args.checkpoint) | ||||
|     if torch.cuda.is_available(): | ||||
|         model = model.cuda() | ||||
|         video = video.cuda() | ||||
|     else: | ||||
|         print("CUDA is not available!") | ||||
|  | ||||
|     pred_tracks, pred_visibility = model( | ||||
|         video, | ||||
|   | ||||
							
								
								
									
										32
									
								
								hubconf.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								hubconf.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,32 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| import torch | ||||
|  | ||||
| dependencies = ["torch", "einops", "timm", "tqdm"] | ||||
|  | ||||
| _COTRACKER_URL = ( | ||||
|     "https://dl.fbaipublicfiles.com/cotracker/cotracker_stride_4_wind_8.pth" | ||||
| ) | ||||
|  | ||||
|  | ||||
| def _make_cotracker_predictor(*, pretrained: bool = True, **kwargs): | ||||
|     from cotracker.predictor import CoTrackerPredictor | ||||
|  | ||||
|     predictor = CoTrackerPredictor(checkpoint=None) | ||||
|     if pretrained: | ||||
|         state_dict = torch.hub.load_state_dict_from_url( | ||||
|             _COTRACKER_URL, map_location="cpu" | ||||
|         ) | ||||
|         predictor.model.load_state_dict(state_dict) | ||||
|     return predictor | ||||
|  | ||||
|  | ||||
| def cotracker_w8(*, pretrained: bool = True, **kwargs): | ||||
|     """ | ||||
|     CoTracker model with stride 4 and window length 8. (The main model from the paper) | ||||
|     """ | ||||
|     return _make_cotracker_predictor(pretrained=pretrained, **kwargs) | ||||
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										111
									
								
								train.py
									
									
									
									
									
								
							
							
						
						
									
										111
									
								
								train.py
									
									
									
									
									
								
							| @@ -36,21 +36,6 @@ from cotracker.datasets.utils import collate_fn, collate_fn_train, dataclass_to_ | ||||
| from cotracker.models.core.cotracker.losses import sequence_loss, balanced_ce_loss | ||||
|  | ||||
|  | ||||
| # define the handler function | ||||
| # for training on a slurm cluster | ||||
| def sig_handler(signum, frame): | ||||
|     print("caught signal", signum) | ||||
|     print(socket.gethostname(), "USR1 signal caught.") | ||||
|     # do other stuff to cleanup here | ||||
|     print("requeuing job " + os.environ["SLURM_JOB_ID"]) | ||||
|     os.system("scontrol requeue " + os.environ["SLURM_JOB_ID"]) | ||||
|     sys.exit(-1) | ||||
|  | ||||
|  | ||||
| def term_handler(signum, frame): | ||||
|     print("bypassing sigterm", flush=True) | ||||
|  | ||||
|  | ||||
| def fetch_optimizer(args, model): | ||||
|     """Create the optimizer and learning rate scheduler""" | ||||
|     optimizer = optim.AdamW( | ||||
| @@ -153,6 +138,8 @@ def run_test_eval(evaluator, model, dataloaders, writer, step): | ||||
|             single_point=False, | ||||
|             n_iters=6, | ||||
|         ) | ||||
|         if torch.cuda.is_available(): | ||||
|             predictor.model = predictor.model.cuda() | ||||
|  | ||||
|         metrics = evaluator.evaluate_sequence( | ||||
|             model=predictor, | ||||
| @@ -302,9 +289,7 @@ class Lite(LightningLite): | ||||
|             eval_dataloaders.append(("fastcapture", eval_dataloader_fastcapture)) | ||||
|  | ||||
|         if "tapvid_davis_first" in args.eval_datasets: | ||||
|             data_root = os.path.join( | ||||
|                 args.dataset_root, "/tapvid_davis/tapvid_davis.pkl" | ||||
|             ) | ||||
|             data_root = os.path.join(args.dataset_root, "tapvid_davis/tapvid_davis.pkl") | ||||
|             eval_dataset = TapVidDataset(dataset_type="davis", data_root=data_root) | ||||
|             eval_dataloader_tapvid_davis = torch.utils.data.DataLoader( | ||||
|                 eval_dataset, | ||||
| @@ -551,17 +536,15 @@ class Lite(LightningLite): | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     signal.signal(signal.SIGUSR1, sig_handler) | ||||
|     signal.signal(signal.SIGTERM, term_handler) | ||||
|     parser = argparse.ArgumentParser() | ||||
|     parser.add_argument("--model_name", default="cotracker", help="model name") | ||||
|     parser.add_argument("--restore_ckpt", help="restore checkpoint") | ||||
|     parser.add_argument("--ckpt_path", help="restore checkpoint") | ||||
|     parser.add_argument("--restore_ckpt", help="path to restore a checkpoint") | ||||
|     parser.add_argument("--ckpt_path", help="path to save checkpoints") | ||||
|     parser.add_argument( | ||||
|         "--batch_size", type=int, default=4, help="batch size used during training." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--num_workers", type=int, default=6, help="left right consistency loss" | ||||
|         "--num_workers", type=int, default=6, help="number of dataloader workers" | ||||
|     ) | ||||
|  | ||||
|     parser.add_argument( | ||||
| @@ -578,20 +561,34 @@ if __name__ == "__main__": | ||||
|         "--evaluate_every_n_epoch", | ||||
|         type=int, | ||||
|         default=1, | ||||
|         help="number of flow-field updates during validation forward pass", | ||||
|         help="evaluate during training after every n epochs, after every epoch by default", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--save_every_n_epoch", | ||||
|         type=int, | ||||
|         default=1, | ||||
|         help="number of flow-field updates during validation forward pass", | ||||
|         help="save checkpoints during training after every n epochs, after every epoch by default", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--validate_at_start", action="store_true", help="use mixed precision" | ||||
|         "--validate_at_start", | ||||
|         action="store_true", | ||||
|         help="whether to run evaluation before training starts", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--save_freq", | ||||
|         type=int, | ||||
|         default=100, | ||||
|         help="frequency of trajectory visualization during training", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--traj_per_sample", | ||||
|         type=int, | ||||
|         default=768, | ||||
|         help="the number of trajectories to sample for training", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--dataset_root", type=str, help="path lo all the datasets (train and eval)" | ||||
|     ) | ||||
|     parser.add_argument("--save_freq", type=int, default=100, help="save_freq") | ||||
|     parser.add_argument("--traj_per_sample", type=int, default=768, help="save_freq") | ||||
|     parser.add_argument("--dataset_root", type=str, help="path lo all the datasets") | ||||
|  | ||||
|     parser.add_argument( | ||||
|         "--train_iters", | ||||
| @@ -605,49 +602,75 @@ if __name__ == "__main__": | ||||
|     parser.add_argument( | ||||
|         "--eval_datasets", | ||||
|         nargs="+", | ||||
|         default=["things", "badja", "fastcapture"], | ||||
|         help="eval datasets.", | ||||
|         default=["things", "badja"], | ||||
|         help="what datasets to use for evaluation", | ||||
|     ) | ||||
|  | ||||
|     parser.add_argument( | ||||
|         "--remove_space_attn", action="store_true", help="use mixed precision" | ||||
|         "--remove_space_attn", | ||||
|         action="store_true", | ||||
|         help="remove space attention from CoTracker", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--dont_use_augs", action="store_true", help="use mixed precision" | ||||
|         "--dont_use_augs", | ||||
|         action="store_true", | ||||
|         help="don't apply augmentations during training", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--sample_vis_1st_frame", action="store_true", help="use mixed precision" | ||||
|         "--sample_vis_1st_frame", | ||||
|         action="store_true", | ||||
|         help="only sample trajectories with points visible on the first frame", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--sliding_window_len", type=int, default=8, help="use mixed precision" | ||||
|         "--sliding_window_len", | ||||
|         type=int, | ||||
|         default=8, | ||||
|         help="length of the CoTracker sliding window", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--updateformer_hidden_size", type=int, default=384, help="use mixed precision" | ||||
|         "--updateformer_hidden_size", | ||||
|         type=int, | ||||
|         default=384, | ||||
|         help="hidden dimension of the CoTracker transformer model", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--updateformer_num_heads", type=int, default=8, help="use mixed precision" | ||||
|         "--updateformer_num_heads", | ||||
|         type=int, | ||||
|         default=8, | ||||
|         help="number of heads of the CoTracker transformer model", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--updateformer_space_depth", type=int, default=12, help="use mixed precision" | ||||
|         "--updateformer_space_depth", | ||||
|         type=int, | ||||
|         default=12, | ||||
|         help="number of group attention layers in the CoTracker transformer model", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--updateformer_time_depth", type=int, default=12, help="use mixed precision" | ||||
|         "--updateformer_time_depth", | ||||
|         type=int, | ||||
|         default=12, | ||||
|         help="number of time attention layers in the CoTracker transformer model", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--model_stride", type=int, default=8, help="use mixed precision" | ||||
|         "--model_stride", | ||||
|         type=int, | ||||
|         default=8, | ||||
|         help="stride of the CoTracker feature network", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--crop_size", | ||||
|         type=int, | ||||
|         nargs="+", | ||||
|         default=[384, 512], | ||||
|         help="use mixed precision", | ||||
|         help="crop videos to this resolution during training", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--eval_max_seq_len", type=int, default=1000, help="use mixed precision" | ||||
|         "--eval_max_seq_len", | ||||
|         type=int, | ||||
|         default=1000, | ||||
|         help="maximum length of evaluation videos", | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     logging.basicConfig( | ||||
|         level=logging.INFO, | ||||
|         format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", | ||||
| @@ -661,5 +684,5 @@ if __name__ == "__main__": | ||||
|         devices="auto", | ||||
|         accelerator="gpu", | ||||
|         precision=32, | ||||
|         num_nodes=4, | ||||
|         # num_nodes=4, | ||||
|     ).run(args) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user