create tensors on device
This commit is contained in:
		| @@ -34,9 +34,9 @@ class CorrBlock: | |||||||
|         out_pyramid = [] |         out_pyramid = [] | ||||||
|         for i in range(self.num_levels): |         for i in range(self.num_levels): | ||||||
|             corr = self.corr_pyramid[i] |             corr = self.corr_pyramid[i] | ||||||
|             dx = torch.linspace(-r, r, 2*r+1) |             dx = torch.linspace(-r, r, 2*r+1, device=coords.device) | ||||||
|             dy = torch.linspace(-r, r, 2*r+1) |             dy = torch.linspace(-r, r, 2*r+1, device=coords.device) | ||||||
|             delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) |             delta = torch.stack(torch.meshgrid(dy, dx), axis=-1) | ||||||
|  |  | ||||||
|             centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i |             centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i | ||||||
|             delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) |             delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) | ||||||
|   | |||||||
| @@ -63,8 +63,8 @@ class RAFT(nn.Module): | |||||||
|     def initialize_flow(self, img): |     def initialize_flow(self, img): | ||||||
|         """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" |         """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" | ||||||
|         N, C, H, W = img.shape |         N, C, H, W = img.shape | ||||||
|         coords0 = coords_grid(N, H//8, W//8).to(img.device) |         coords0 = coords_grid(N, H//8, W//8, device=img.device) | ||||||
|         coords1 = coords_grid(N, H//8, W//8).to(img.device) |         coords1 = coords_grid(N, H//8, W//8, device=img.device) | ||||||
|  |  | ||||||
|         # optical flow computed as difference: flow = coords1 - coords0 |         # optical flow computed as difference: flow = coords1 - coords0 | ||||||
|         return coords0, coords1 |         return coords0, coords1 | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user