added cuda extension for efficent implementation
This commit is contained in:
57
core/corr.py
57
core/corr.py
@@ -2,6 +2,12 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from utils.utils import bilinear_sampler, coords_grid
|
||||
|
||||
try:
|
||||
import alt_cuda_corr
|
||||
except:
|
||||
# alt_cuda_corr is not compiled
|
||||
pass
|
||||
|
||||
|
||||
class CorrBlock:
|
||||
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
||||
@@ -43,7 +49,6 @@ class CorrBlock:
|
||||
out = torch.cat(out_pyramid, dim=-1)
|
||||
return out.permute(0, 3, 1, 2).contiguous().float()
|
||||
|
||||
|
||||
@staticmethod
|
||||
def corr(fmap1, fmap2):
|
||||
batch, dim, ht, wd = fmap1.shape
|
||||
@@ -54,3 +59,53 @@ class CorrBlock:
|
||||
corr = corr.view(batch, ht, wd, 1, ht, wd)
|
||||
return corr / torch.sqrt(torch.tensor(dim).float())
|
||||
|
||||
|
||||
class CorrLayer(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, fmap1, fmap2, coords, r):
|
||||
fmap1 = fmap1.contiguous()
|
||||
fmap2 = fmap2.contiguous()
|
||||
coords = coords.contiguous()
|
||||
ctx.save_for_backward(fmap1, fmap2, coords)
|
||||
ctx.r = r
|
||||
corr, = correlation_cudaz.forward(fmap1, fmap2, coords, ctx.r)
|
||||
return corr
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_corr):
|
||||
fmap1, fmap2, coords = ctx.saved_tensors
|
||||
grad_corr = grad_corr.contiguous()
|
||||
fmap1_grad, fmap2_grad, coords_grad = \
|
||||
correlation_cudaz.backward(fmap1, fmap2, coords, grad_corr, ctx.r)
|
||||
return fmap1_grad, fmap2_grad, coords_grad, None
|
||||
|
||||
|
||||
class AlternateCorrBlock:
|
||||
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
||||
self.num_levels = num_levels
|
||||
self.radius = radius
|
||||
|
||||
self.pyramid = [(fmap1, fmap2)]
|
||||
for i in range(self.num_levels):
|
||||
fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
|
||||
fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
|
||||
self.pyramid.append((fmap1, fmap2))
|
||||
|
||||
def __call__(self, coords):
|
||||
|
||||
coords = coords.permute(0, 2, 3, 1)
|
||||
B, H, W, _ = coords.shape
|
||||
|
||||
corr_list = []
|
||||
for i in range(self.num_levels):
|
||||
r = self.radius
|
||||
fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1)
|
||||
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1)
|
||||
|
||||
coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
|
||||
corr = alt_cuda_corr(fmap1_i, fmap2_i, coords_i, r)
|
||||
corr_list.append(corr.squeeze(1))
|
||||
|
||||
corr = torch.stack(corr_list, dim=1)
|
||||
corr = corr.reshape(B, -1, H, W)
|
||||
return corr / 16.0
|
||||
|
Reference in New Issue
Block a user