Merge pull request #14 from JunkyByte/main
minor fixes / mps default device when available / occlusion visualization
This commit is contained in:
		| @@ -133,7 +133,7 @@ class CoTrackerPredictor(torch.nn.Module): | |||||||
|             ) |             ) | ||||||
|  |  | ||||||
|         if add_support_grid: |         if add_support_grid: | ||||||
|             grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape) |             grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape, device=video.device) | ||||||
|             grid_pts = torch.cat( |             grid_pts = torch.cat( | ||||||
|                 [torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2 |                 [torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2 | ||||||
|             ) |             ) | ||||||
|   | |||||||
| @@ -62,6 +62,7 @@ class Visualizer: | |||||||
|         self, |         self, | ||||||
|         video: torch.Tensor,  # (B,T,C,H,W) |         video: torch.Tensor,  # (B,T,C,H,W) | ||||||
|         tracks: torch.Tensor,  # (B,T,N,2) |         tracks: torch.Tensor,  # (B,T,N,2) | ||||||
|  |         visibility: torch.Tensor = None,  # (B, T, N, 1) bool | ||||||
|         gt_tracks: torch.Tensor = None,  # (B,T,N,2) |         gt_tracks: torch.Tensor = None,  # (B,T,N,2) | ||||||
|         segm_mask: torch.Tensor = None,  # (B,1,H,W) |         segm_mask: torch.Tensor = None,  # (B,1,H,W) | ||||||
|         filename: str = "video", |         filename: str = "video", | ||||||
| @@ -93,6 +94,7 @@ class Visualizer: | |||||||
|         res_video = self.draw_tracks_on_video( |         res_video = self.draw_tracks_on_video( | ||||||
|             video=video, |             video=video, | ||||||
|             tracks=tracks, |             tracks=tracks, | ||||||
|  |             visibility=visibility, | ||||||
|             segm_mask=segm_mask, |             segm_mask=segm_mask, | ||||||
|             gt_tracks=gt_tracks, |             gt_tracks=gt_tracks, | ||||||
|             query_frame=query_frame, |             query_frame=query_frame, | ||||||
| @@ -126,6 +128,7 @@ class Visualizer: | |||||||
|         self, |         self, | ||||||
|         video: torch.Tensor, |         video: torch.Tensor, | ||||||
|         tracks: torch.Tensor, |         tracks: torch.Tensor, | ||||||
|  |         visibility: torch.Tensor = None, | ||||||
|         segm_mask: torch.Tensor = None, |         segm_mask: torch.Tensor = None, | ||||||
|         gt_tracks=None, |         gt_tracks=None, | ||||||
|         query_frame: int = 0, |         query_frame: int = 0, | ||||||
| @@ -227,11 +230,13 @@ class Visualizer: | |||||||
|                     if not compensate_for_camera_motion or ( |                     if not compensate_for_camera_motion or ( | ||||||
|                         compensate_for_camera_motion and segm_mask[i] > 0 |                         compensate_for_camera_motion and segm_mask[i] > 0 | ||||||
|                     ): |                     ): | ||||||
|  |  | ||||||
|                         cv2.circle( |                         cv2.circle( | ||||||
|                             res_video[t], |                             res_video[t], | ||||||
|                             coord, |                             coord, | ||||||
|                             int(self.linewidth * 2), |                             int(self.linewidth * 2), | ||||||
|                             vector_colors[t, i].tolist(), |                             vector_colors[t, i].tolist(), | ||||||
|  |                             thickness=-1 if visibility[0, t, i] else 2 | ||||||
|                             -1, |                             -1, | ||||||
|                         ) |                         ) | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										12
									
								
								demo.py
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								demo.py
									
									
									
									
									
								
							| @@ -14,6 +14,9 @@ from PIL import Image | |||||||
| from cotracker.utils.visualizer import Visualizer, read_video_from_path | from cotracker.utils.visualizer import Visualizer, read_video_from_path | ||||||
| from cotracker.predictor import CoTrackerPredictor | from cotracker.predictor import CoTrackerPredictor | ||||||
|  |  | ||||||
|  | DEFAULT_DEVICE = ('cuda' if torch.cuda.is_available() else | ||||||
|  |                   'mps' if torch.backends.mps.is_available() else | ||||||
|  |                   'cpu') | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     parser = argparse.ArgumentParser() |     parser = argparse.ArgumentParser() | ||||||
| @@ -55,11 +58,8 @@ if __name__ == "__main__": | |||||||
|     segm_mask = torch.from_numpy(segm_mask)[None, None] |     segm_mask = torch.from_numpy(segm_mask)[None, None] | ||||||
|  |  | ||||||
|     model = CoTrackerPredictor(checkpoint=args.checkpoint) |     model = CoTrackerPredictor(checkpoint=args.checkpoint) | ||||||
|     if torch.cuda.is_available(): |     model = model.to(DEFAULT_DEVICE) | ||||||
|         model = model.cuda() |     video = video.to(DEFAULT_DEVICE) | ||||||
|         video = video.cuda() |  | ||||||
|     else: |  | ||||||
|         print("CUDA is not available!") |  | ||||||
|  |  | ||||||
|     pred_tracks, pred_visibility = model( |     pred_tracks, pred_visibility = model( | ||||||
|         video, |         video, | ||||||
| @@ -73,4 +73,4 @@ if __name__ == "__main__": | |||||||
|     # save a video with predicted tracks |     # save a video with predicted tracks | ||||||
|     seq_name = args.video_path.split("/")[-1] |     seq_name = args.video_path.split("/")[-1] | ||||||
|     vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3) |     vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3) | ||||||
|     vis.visualize(video, pred_tracks, query_frame=args.grid_query_frame) |     vis.visualize(video, pred_tracks, pred_visibility, query_frame=args.grid_query_frame) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user