add torch.hub support
This commit is contained in:
		| @@ -12,6 +12,8 @@ from cotracker.models.core.cotracker.cotracker import CoTracker | ||||
| def build_cotracker( | ||||
|     checkpoint: str, | ||||
| ): | ||||
|     if checkpoint is None: | ||||
|         return build_cotracker_stride_4_wind_8() | ||||
|     model_name = checkpoint.split("/")[-1].split(".")[0] | ||||
|     if model_name == "cotracker_stride_4_wind_8": | ||||
|         return build_cotracker_stride_4_wind_8(checkpoint=checkpoint) | ||||
|   | ||||
| @@ -25,7 +25,6 @@ class CoTrackerPredictor(torch.nn.Module): | ||||
|         model = build_cotracker(checkpoint) | ||||
|  | ||||
|         self.model = model | ||||
|         self.model.to("cuda") | ||||
|         self.model.eval() | ||||
|  | ||||
|     @torch.no_grad() | ||||
| @@ -72,7 +71,7 @@ class CoTrackerPredictor(torch.nn.Module): | ||||
|         grid_width = W // grid_step | ||||
|         grid_height = H // grid_step | ||||
|         tracks = visibilities = None | ||||
|         grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to("cuda") | ||||
|         grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(video.device) | ||||
|         grid_pts[0, :, 0] = grid_query_frame | ||||
|         for offset in tqdm(range(grid_step * grid_step)): | ||||
|             ox = offset % grid_step | ||||
| @@ -107,10 +106,8 @@ class CoTrackerPredictor(torch.nn.Module): | ||||
|         assert B == 1 | ||||
|  | ||||
|         video = video.reshape(B * T, C, H, W) | ||||
|         video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear").cuda() | ||||
|         video = video.reshape( | ||||
|             B, T, 3, self.interp_shape[0], self.interp_shape[1] | ||||
|         ).cuda() | ||||
|         video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear") | ||||
|         video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1]) | ||||
|  | ||||
|         if queries is not None: | ||||
|             queries = queries.clone() | ||||
|   | ||||
							
								
								
									
										5
									
								
								demo.py
									
									
									
									
									
								
							
							
						
						
									
										5
									
								
								demo.py
									
									
									
									
									
								
							| @@ -55,6 +55,11 @@ if __name__ == "__main__": | ||||
|     segm_mask = torch.from_numpy(segm_mask)[None, None] | ||||
|  | ||||
|     model = CoTrackerPredictor(checkpoint=args.checkpoint) | ||||
|     if torch.cuda.is_available(): | ||||
|         model = model.cuda() | ||||
|         video = video.cuda() | ||||
|     else: | ||||
|         print("CUDA is not available!") | ||||
|  | ||||
|     pred_tracks, pred_visibility = model( | ||||
|         video, | ||||
|   | ||||
							
								
								
									
										32
									
								
								hubconf.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								hubconf.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,32 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
|  | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
| import torch | ||||
|  | ||||
| dependencies = ["torch", "einops", "timm", "tqdm"] | ||||
|  | ||||
| _COTRACKER_URL = ( | ||||
|     "https://dl.fbaipublicfiles.com/cotracker/cotracker_stride_4_wind_8.pth" | ||||
| ) | ||||
|  | ||||
|  | ||||
| def _make_cotracker_predictor(*, pretrained: bool = True, **kwargs): | ||||
|     from cotracker.predictor import CoTrackerPredictor | ||||
|  | ||||
|     predictor = CoTrackerPredictor(checkpoint=None) | ||||
|     if pretrained: | ||||
|         state_dict = torch.hub.load_state_dict_from_url( | ||||
|             _COTRACKER_URL, map_location="cpu" | ||||
|         ) | ||||
|         predictor.model.load_state_dict(state_dict) | ||||
|     return predictor | ||||
|  | ||||
|  | ||||
| def cotracker_w8(*, pretrained: bool = True, **kwargs): | ||||
|     """ | ||||
|     CoTracker model with stride 4 and window length 8. (The main model from the paper) | ||||
|     """ | ||||
|     return _make_cotracker_predictor(pretrained=pretrained, **kwargs) | ||||
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
		Reference in New Issue
	
	Block a user