create tensors on device
This commit is contained in:
		| @@ -34,9 +34,9 @@ class CorrBlock: | ||||
|         out_pyramid = [] | ||||
|         for i in range(self.num_levels): | ||||
|             corr = self.corr_pyramid[i] | ||||
|             dx = torch.linspace(-r, r, 2*r+1) | ||||
|             dy = torch.linspace(-r, r, 2*r+1) | ||||
|             delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) | ||||
|             dx = torch.linspace(-r, r, 2*r+1, device=coords.device) | ||||
|             dy = torch.linspace(-r, r, 2*r+1, device=coords.device) | ||||
|             delta = torch.stack(torch.meshgrid(dy, dx), axis=-1) | ||||
|  | ||||
|             centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i | ||||
|             delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) | ||||
|   | ||||
| @@ -63,8 +63,8 @@ class RAFT(nn.Module): | ||||
|     def initialize_flow(self, img): | ||||
|         """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" | ||||
|         N, C, H, W = img.shape | ||||
|         coords0 = coords_grid(N, H//8, W//8).to(img.device) | ||||
|         coords1 = coords_grid(N, H//8, W//8).to(img.device) | ||||
|         coords0 = coords_grid(N, H//8, W//8, device=img.device) | ||||
|         coords1 = coords_grid(N, H//8, W//8, device=img.device) | ||||
|  | ||||
|         # optical flow computed as difference: flow = coords1 - coords0 | ||||
|         return coords0, coords1 | ||||
|   | ||||
		Reference in New Issue
	
	Block a user