add time measure codes and update resolution
This commit is contained in:
		
							
								
								
									
										45
									
								
								demo1.py
									
									
									
									
									
								
							
							
						
						
									
										45
									
								
								demo1.py
									
									
									
									
									
								
							| @@ -3,14 +3,25 @@ import torch | |||||||
|  |  | ||||||
| from base64 import b64encode | from base64 import b64encode | ||||||
| from cotracker.utils.visualizer import Visualizer, read_video_from_path | from cotracker.utils.visualizer import Visualizer, read_video_from_path | ||||||
|  | import numpy as np | ||||||
|  | from PIL import Image | ||||||
|  | import time | ||||||
|  |  | ||||||
|  | device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu') | ||||||
|  |  | ||||||
|  | start_time = time.time() | ||||||
|  | print(f'Using device: {device}') | ||||||
|  | print(f'start loading video') | ||||||
| video = read_video_from_path('./assets/F1_shorts.mp4') | video = read_video_from_path('./assets/F1_shorts.mp4') | ||||||
|  | print(f'video shape: {video.shape}') | ||||||
|  | # video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float().to(device) | ||||||
| video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float() | video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float() | ||||||
|  | end_time = time.time() | ||||||
|  | print(f'video shape after permute: {video.shape}') | ||||||
|  | print("Load video Time taken: {:.2f} seconds".format(end_time - start_time)) | ||||||
|  |  | ||||||
| from cotracker.predictor import CoTrackerPredictor | from cotracker.predictor import CoTrackerPredictor | ||||||
|  |  | ||||||
| device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu') |  | ||||||
|  |  | ||||||
| model = CoTrackerPredictor( | model = CoTrackerPredictor( | ||||||
|     checkpoint=os.path.join( |     checkpoint=os.path.join( | ||||||
| @@ -27,25 +38,45 @@ grid_query_frame=20 | |||||||
|  |  | ||||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||||
| # video_interp = F.interpolate(video[0], [200, 360], mode="bilinear")[None].to(device) | # video_interp = F.interpolate(video[0], [200, 360], mode="bilinear")[None].to(device) | ||||||
|  | interp_size = (720, 1280) | ||||||
|  | video_interp = F.interpolate(video[0], [interp_size[0], interp_size[1]], mode="bilinear")[None].to(device) | ||||||
|  | print(f'video_interp shape: {video_interp.shape}') | ||||||
|  |  | ||||||
| import time |  | ||||||
| start_time = time.time() | start_time = time.time() | ||||||
| # pred_tracks, pred_visibility = model(video_interp,  | # pred_tracks, pred_visibility = model(video_interp,  | ||||||
| pred_tracks, pred_visibility = model(video,  | input_mask='./assets/F1_mask.png' | ||||||
|                                      grid_query_frame=grid_query_frame, backward_tracking=True) | segm_mask = Image.open(input_mask) | ||||||
|  | interp_size = (interp_size[1], interp_size[0]) | ||||||
|  | segm_mask = segm_mask.resize(interp_size, Image.BILINEAR) | ||||||
|  | segm_mask = np.array(Image.open(input_mask)) | ||||||
|  | segm_mask = torch.tensor(segm_mask).to(device) | ||||||
|  | # pred_tracks, pred_visibility = model(video,  | ||||||
|  | pred_tracks, pred_visibility = model(video_interp,  | ||||||
|  |                                      grid_query_frame=grid_query_frame, backward_tracking=True, | ||||||
|  |                                      segm_mask=segm_mask ) | ||||||
| end_time = time.time()  | end_time = time.time()  | ||||||
|  |  | ||||||
| print("Time taken: {:.2f} seconds".format(end_time - start_time)) | print("Time taken: {:.2f} seconds".format(end_time - start_time)) | ||||||
|  |  | ||||||
|  | start_time = time.time() | ||||||
|  | print(f'start visualizing') | ||||||
| vis = Visualizer( | vis = Visualizer( | ||||||
|     save_dir='./videos', |     save_dir='./videos', | ||||||
|     pad_value=20, |     pad_value=20, | ||||||
|     linewidth=1, |     linewidth=1, | ||||||
|     mode='optical_flow' |     mode='optical_flow' | ||||||
| ) | ) | ||||||
|  | print(f'vis initialized') | ||||||
|  | end_time = time.time() | ||||||
|  | print("Time taken: {:.2f} seconds".format(end_time - start_time)) | ||||||
|  | start_time = time.time() | ||||||
|  | print(f'start visualize') | ||||||
| vis.visualize( | vis.visualize( | ||||||
|     # video=video_interp, |     video=video_interp, | ||||||
|     video=video, |     # video=video, | ||||||
|     tracks=pred_tracks, |     tracks=pred_tracks, | ||||||
|     visibility=pred_visibility, |     visibility=pred_visibility, | ||||||
|     filename='dense'); |     filename='dense2'); | ||||||
|  | print(f'done') | ||||||
|  | end_time = time.time() | ||||||
|  | print("Time taken: {:.2f} seconds".format(end_time - start_time)) | ||||||
		Reference in New Issue
	
	Block a user