Allows MPS inference. Fix visualization args
This commit is contained in:
parent
03f3c41e07
commit
d6df5d248f
@ -17,7 +17,7 @@ from cotracker.models.build_cotracker import (
|
|||||||
|
|
||||||
class CoTrackerPredictor(torch.nn.Module):
|
class CoTrackerPredictor(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, checkpoint="cotracker/checkpoints/cotracker_stride_4_wind_8.pth", device=None
|
self, checkpoint="cotracker/checkpoints/cotracker_stride_4_wind_8.pth"
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.interp_shape = (384, 512)
|
self.interp_shape = (384, 512)
|
||||||
|
12
demo.py
12
demo.py
@ -14,6 +14,9 @@ from PIL import Image
|
|||||||
from cotracker.utils.visualizer import Visualizer, read_video_from_path
|
from cotracker.utils.visualizer import Visualizer, read_video_from_path
|
||||||
from cotracker.predictor import CoTrackerPredictor
|
from cotracker.predictor import CoTrackerPredictor
|
||||||
|
|
||||||
|
DEFAULT_DEVICE = ('cuda' if torch.cuda.is_available() else
|
||||||
|
'mps' if torch.backends.mps.is_available() else
|
||||||
|
'cpu')
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
@ -55,11 +58,8 @@ if __name__ == "__main__":
|
|||||||
segm_mask = torch.from_numpy(segm_mask)[None, None]
|
segm_mask = torch.from_numpy(segm_mask)[None, None]
|
||||||
|
|
||||||
model = CoTrackerPredictor(checkpoint=args.checkpoint)
|
model = CoTrackerPredictor(checkpoint=args.checkpoint)
|
||||||
if torch.cuda.is_available():
|
model = model.to(DEFAULT_DEVICE)
|
||||||
model = model.cuda()
|
video = video.to(DEFAULT_DEVICE)
|
||||||
video = video.cuda()
|
|
||||||
else:
|
|
||||||
print("CUDA is not available!")
|
|
||||||
|
|
||||||
pred_tracks, pred_visibility = model(
|
pred_tracks, pred_visibility = model(
|
||||||
video,
|
video,
|
||||||
@ -73,4 +73,4 @@ if __name__ == "__main__":
|
|||||||
# save a video with predicted tracks
|
# save a video with predicted tracks
|
||||||
seq_name = args.video_path.split("/")[-1]
|
seq_name = args.video_path.split("/")[-1]
|
||||||
vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3)
|
vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3)
|
||||||
vis.visualize(video, pred_tracks, query_frame=args.grid_query_frame)
|
vis.visualize(video, pred_tracks, pred_visibility, query_frame=args.grid_query_frame)
|
||||||
|
Loading…
Reference in New Issue
Block a user