create tensors on device

This commit is contained in:
magehrig
2021-09-16 16:34:37 +02:00
parent 224320502d
commit e6e53c4e23
2 changed files with 5 additions and 5 deletions

View File

@@ -63,8 +63,8 @@ class RAFT(nn.Module):
def initialize_flow(self, img):
""" Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
N, C, H, W = img.shape
coords0 = coords_grid(N, H//8, W//8).to(img.device)
coords1 = coords_grid(N, H//8, W//8).to(img.device)
coords0 = coords_grid(N, H//8, W//8, device=img.device)
coords1 = coords_grid(N, H//8, W//8, device=img.device)
# optical flow computed as difference: flow = coords1 - coords0
return coords0, coords1