demo fixes
This commit is contained in:
		
										
											Binary file not shown.
										
									
								
							| Before Width: | Height: | Size: 14 KiB After Width: | Height: | Size: 14 KiB | 
| @@ -55,7 +55,9 @@ class CoTrackerPredictor(torch.nn.Module): | |||||||
|  |  | ||||||
|         return tracks, visibilities |         return tracks, visibilities | ||||||
|  |  | ||||||
|     def _compute_dense_tracks(self, video, grid_query_frame, grid_size=30, backward_tracking=False): |     def _compute_dense_tracks( | ||||||
|  |         self, video, grid_query_frame, grid_size=150, backward_tracking=False | ||||||
|  |     ): | ||||||
|         *_, H, W = video.shape |         *_, H, W = video.shape | ||||||
|         grid_step = W // grid_size |         grid_step = W // grid_size | ||||||
|         grid_width = W // grid_step |         grid_width = W // grid_step | ||||||
| @@ -172,8 +174,9 @@ class CoTrackerPredictor(torch.nn.Module): | |||||||
|  |  | ||||||
|         inv_tracks = inv_tracks.flip(1) |         inv_tracks = inv_tracks.flip(1) | ||||||
|         inv_visibilities = inv_visibilities.flip(1) |         inv_visibilities = inv_visibilities.flip(1) | ||||||
|  |         arange = torch.arange(video.shape[1], device=queries.device)[None, :, None] | ||||||
|  |  | ||||||
|         mask = tracks == 0 |         mask = (arange < queries[None, :, :, 0]).unsqueeze(-1).repeat(1, 1, 1, 2) | ||||||
|  |  | ||||||
|         tracks[mask] = inv_tracks[mask] |         tracks[mask] = inv_tracks[mask] | ||||||
|         visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]] |         visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]] | ||||||
|   | |||||||
| @@ -226,7 +226,7 @@ class Visualizer: | |||||||
|  |  | ||||||
|         #  draw tracks |         #  draw tracks | ||||||
|         if self.tracks_leave_trace != 0: |         if self.tracks_leave_trace != 0: | ||||||
|             for t in range(1, T): |             for t in range(query_frame + 1, T): | ||||||
|                 first_ind = ( |                 first_ind = ( | ||||||
|                     max(0, t - self.tracks_leave_trace) if self.tracks_leave_trace >= 0 else 0 |                     max(0, t - self.tracks_leave_trace) if self.tracks_leave_trace >= 0 else 0 | ||||||
|                 ) |                 ) | ||||||
| @@ -251,7 +251,7 @@ class Visualizer: | |||||||
|                     res_video[t] = self._draw_gt_tracks(res_video[t], gt_tracks[first_ind : t + 1]) |                     res_video[t] = self._draw_gt_tracks(res_video[t], gt_tracks[first_ind : t + 1]) | ||||||
|  |  | ||||||
|         #  draw points |         #  draw points | ||||||
|         for t in range(T): |         for t in range(query_frame, T): | ||||||
|             img = Image.fromarray(np.uint8(res_video[t])) |             img = Image.fromarray(np.uint8(res_video[t])) | ||||||
|             for i in range(N): |             for i in range(N): | ||||||
|                 coord = (tracks[t, i, 0], tracks[t, i, 1]) |                 coord = (tracks[t, i, 0], tracks[t, i, 1]) | ||||||
|   | |||||||
							
								
								
									
										9
									
								
								demo.py
									
									
									
									
									
								
							
							
						
						
									
										9
									
								
								demo.py
									
									
									
									
									
								
							| @@ -72,7 +72,7 @@ if __name__ == "__main__": | |||||||
|         model = torch.hub.load("facebookresearch/co-tracker", "cotracker2") |         model = torch.hub.load("facebookresearch/co-tracker", "cotracker2") | ||||||
|     model = model.to(DEFAULT_DEVICE) |     model = model.to(DEFAULT_DEVICE) | ||||||
|     video = video.to(DEFAULT_DEVICE) |     video = video.to(DEFAULT_DEVICE) | ||||||
|  |     # video = video[:, :20] | ||||||
|     pred_tracks, pred_visibility = model( |     pred_tracks, pred_visibility = model( | ||||||
|         video, |         video, | ||||||
|         grid_size=args.grid_size, |         grid_size=args.grid_size, | ||||||
| @@ -85,4 +85,9 @@ if __name__ == "__main__": | |||||||
|     # save a video with predicted tracks |     # save a video with predicted tracks | ||||||
|     seq_name = args.video_path.split("/")[-1] |     seq_name = args.video_path.split("/")[-1] | ||||||
|     vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3) |     vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3) | ||||||
|     vis.visualize(video, pred_tracks, pred_visibility, query_frame=args.grid_query_frame) |     vis.visualize( | ||||||
|  |         video, | ||||||
|  |         pred_tracks, | ||||||
|  |         pred_visibility, | ||||||
|  |         query_frame=0 if args.backward_tracking else args.grid_query_frame, | ||||||
|  |     ) | ||||||
|   | |||||||
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
		Reference in New Issue
	
	Block a user