Allows MPS inference. Fix visualization args

This commit is contained in:
JunkyByte
2023-07-25 16:28:48 +02:00
parent 03f3c41e07
commit d6df5d248f
2 changed files with 7 additions and 7 deletions

View File

@@ -17,7 +17,7 @@ from cotracker.models.build_cotracker import (
class CoTrackerPredictor(torch.nn.Module):
def __init__(
self, checkpoint="cotracker/checkpoints/cotracker_stride_4_wind_8.pth", device=None
self, checkpoint="cotracker/checkpoints/cotracker_stride_4_wind_8.pth"
):
super().__init__()
self.interp_shape = (384, 512)