add cpu-only mode
This commit is contained in:
@@ -185,7 +185,11 @@ class Evaluator:
|
||||
if not all(gotit):
|
||||
print("batch is None")
|
||||
continue
|
||||
dataclass_to_cuda_(sample)
|
||||
if torch.cuda.is_available():
|
||||
dataclass_to_cuda_(sample)
|
||||
device = torch.device("cuda")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
if (
|
||||
not train_mode
|
||||
@@ -205,7 +209,7 @@ class Evaluator:
|
||||
queries[:, :, 1],
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
).to(device)
|
||||
else:
|
||||
queries = torch.cat(
|
||||
[
|
||||
@@ -213,7 +217,7 @@ class Evaluator:
|
||||
sample.trajectory[:, 0],
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
).to(device)
|
||||
|
||||
pred_tracks = model(sample.video, queries)
|
||||
if "strided" in dataset_name:
|
||||
|
@@ -102,6 +102,8 @@ def run_eval(cfg: DefaultConfig):
|
||||
single_point=cfg.single_point,
|
||||
n_iters=cfg.n_iters,
|
||||
)
|
||||
if torch.cuda.is_available():
|
||||
predictor.model = predictor.model.cuda()
|
||||
|
||||
# Setting the random seeds
|
||||
torch.manual_seed(cfg.seed)
|
||||
|
Reference in New Issue
Block a user