Allows MPS inference. Fix visualization args
This commit is contained in:
@@ -17,7 +17,7 @@ from cotracker.models.build_cotracker import (
|
||||
|
||||
class CoTrackerPredictor(torch.nn.Module):
|
||||
def __init__(
|
||||
self, checkpoint="cotracker/checkpoints/cotracker_stride_4_wind_8.pth", device=None
|
||||
self, checkpoint="cotracker/checkpoints/cotracker_stride_4_wind_8.pth"
|
||||
):
|
||||
super().__init__()
|
||||
self.interp_shape = (384, 512)
|
||||
|
Reference in New Issue
Block a user