fix online demo
This commit is contained in:
		| @@ -52,25 +52,33 @@ if __name__ == "__main__": | |||||||
|  |  | ||||||
|     window_frames = [] |     window_frames = [] | ||||||
|  |  | ||||||
|     def _process_step(window_frames, is_first_step, grid_size): |     def _process_step(window_frames, is_first_step, grid_size, grid_query_frame): | ||||||
|         video_chunk = ( |         video_chunk = ( | ||||||
|             torch.tensor(np.stack(window_frames[-model.step * 2 :]), device=DEFAULT_DEVICE) |             torch.tensor(np.stack(window_frames[-model.step * 2 :]), device=DEFAULT_DEVICE) | ||||||
|             .float() |             .float() | ||||||
|             .permute(0, 3, 1, 2)[None] |             .permute(0, 3, 1, 2)[None] | ||||||
|         )  # (1, T, 3, H, W) |         )  # (1, T, 3, H, W) | ||||||
|         return model(video_chunk, is_first_step=is_first_step, grid_size=grid_size) |         return model( | ||||||
|  |             video_chunk, | ||||||
|  |             is_first_step=is_first_step, | ||||||
|  |             grid_size=grid_size, | ||||||
|  |             grid_query_frame=grid_query_frame, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     # Iterating over video frames, processing one window at a time: |     # Iterating over video frames, processing one window at a time: | ||||||
|     is_first_step = True |     is_first_step = True | ||||||
|     for i, frame in enumerate( |     for i, frame in enumerate( | ||||||
|         iio.imiter( |         iio.imiter( | ||||||
|             "https://github.com/facebookresearch/co-tracker/blob/main/assets/apple.mp4", |             "./assets/apple.mp4", | ||||||
|             plugin="FFMPEG", |             plugin="FFMPEG", | ||||||
|         ) |         ) | ||||||
|     ): |     ): | ||||||
|         if i % model.step == 0 and i != 0: |         if i % model.step == 0 and i != 0: | ||||||
|             pred_tracks, pred_visibility = _process_step( |             pred_tracks, pred_visibility = _process_step( | ||||||
|                 window_frames, is_first_step, grid_size=args.grid_size |                 window_frames, | ||||||
|  |                 is_first_step, | ||||||
|  |                 grid_size=args.grid_size, | ||||||
|  |                 grid_query_frame=args.grid_query_frame, | ||||||
|             ) |             ) | ||||||
|             is_first_step = False |             is_first_step = False | ||||||
|         window_frames.append(frame) |         window_frames.append(frame) | ||||||
| @@ -79,6 +87,7 @@ if __name__ == "__main__": | |||||||
|         window_frames[-(i % model.step) - model.step - 1 :], |         window_frames[-(i % model.step) - model.step - 1 :], | ||||||
|         is_first_step, |         is_first_step, | ||||||
|         grid_size=args.grid_size, |         grid_size=args.grid_size, | ||||||
|  |         grid_query_frame=args.grid_query_frame, | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     print("Tracks are computed") |     print("Tracks are computed") | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user