added cuda extension for efficent implementation
This commit is contained in:
parent
5b1f510d6b
commit
c86b3dc8f3
1
.gitignore
vendored
1
.gitignore
vendored
@ -5,3 +5,4 @@ datasets
|
||||
pytorch_env
|
||||
models
|
||||
build
|
||||
correlation.egg-info
|
||||
|
@ -31,6 +31,13 @@ You can demo a trained model on a sequence of frames
|
||||
python demo.py --model=models/raft-things.pth --path=demo-frames
|
||||
```
|
||||
|
||||
## (Optional) Efficent Implementation
|
||||
You can optionally use our alternate (efficent) implementation by compiling the provided cuda extension
|
||||
```Shell
|
||||
cd alt_cuda_corr && python setup.py install && cd ..
|
||||
```
|
||||
and running `demo.py` and `evaluate.py` with the `--alternate_corr` flag.Note, this implementation is somewhat slower than all-pairs, but uses significantly less GPU memory during the forward pass.
|
||||
|
||||
|
||||
## Required Data
|
||||
To evaluate/train RAFT, you will need to download the required datasets.
|
||||
|
54
alt_cuda_corr/correlation.cpp
Normal file
54
alt_cuda_corr/correlation.cpp
Normal file
@ -0,0 +1,54 @@
|
||||
#include <torch/extension.h>
|
||||
#include <vector>
|
||||
|
||||
// CUDA forward declarations
|
||||
std::vector<torch::Tensor> corr_cuda_forward(
|
||||
torch::Tensor fmap1,
|
||||
torch::Tensor fmap2,
|
||||
torch::Tensor coords,
|
||||
int radius);
|
||||
|
||||
std::vector<torch::Tensor> corr_cuda_backward(
|
||||
torch::Tensor fmap1,
|
||||
torch::Tensor fmap2,
|
||||
torch::Tensor coords,
|
||||
torch::Tensor corr_grad,
|
||||
int radius);
|
||||
|
||||
// C++ interface
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
std::vector<torch::Tensor> corr_forward(
|
||||
torch::Tensor fmap1,
|
||||
torch::Tensor fmap2,
|
||||
torch::Tensor coords,
|
||||
int radius) {
|
||||
CHECK_INPUT(fmap1);
|
||||
CHECK_INPUT(fmap2);
|
||||
CHECK_INPUT(coords);
|
||||
|
||||
return corr_cuda_forward(fmap1, fmap2, coords, radius);
|
||||
}
|
||||
|
||||
|
||||
std::vector<torch::Tensor> corr_backward(
|
||||
torch::Tensor fmap1,
|
||||
torch::Tensor fmap2,
|
||||
torch::Tensor coords,
|
||||
torch::Tensor corr_grad,
|
||||
int radius) {
|
||||
CHECK_INPUT(fmap1);
|
||||
CHECK_INPUT(fmap2);
|
||||
CHECK_INPUT(coords);
|
||||
CHECK_INPUT(corr_grad);
|
||||
|
||||
return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius);
|
||||
}
|
||||
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward", &corr_forward, "CORR forward");
|
||||
m.def("backward", &corr_backward, "CORR backward");
|
||||
}
|
324
alt_cuda_corr/correlation_kernel.cu
Normal file
324
alt_cuda_corr/correlation_kernel.cu
Normal file
@ -0,0 +1,324 @@
|
||||
#include <torch/extension.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <vector>
|
||||
|
||||
|
||||
#define BLOCK_H 4
|
||||
#define BLOCK_W 8
|
||||
#define BLOCK_HW BLOCK_H * BLOCK_W
|
||||
#define CHANNEL_STRIDE 32
|
||||
|
||||
|
||||
__forceinline__ __device__
|
||||
bool within_bounds(int h, int w, int H, int W) {
|
||||
return h >= 0 && h < H && w >= 0 && w < W;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void corr_forward_kernel(
|
||||
const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap1,
|
||||
const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap2,
|
||||
const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> coords,
|
||||
torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> corr,
|
||||
int r)
|
||||
{
|
||||
const int b = blockIdx.x;
|
||||
const int h0 = blockIdx.y * blockDim.x;
|
||||
const int w0 = blockIdx.z * blockDim.y;
|
||||
const int tid = threadIdx.x * blockDim.y + threadIdx.y;
|
||||
|
||||
const int H1 = fmap1.size(1);
|
||||
const int W1 = fmap1.size(2);
|
||||
const int H2 = fmap2.size(1);
|
||||
const int W2 = fmap2.size(2);
|
||||
const int N = coords.size(1);
|
||||
const int C = fmap1.size(3);
|
||||
|
||||
__shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1];
|
||||
__shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1];
|
||||
__shared__ scalar_t x2s[BLOCK_HW];
|
||||
__shared__ scalar_t y2s[BLOCK_HW];
|
||||
|
||||
for (int c=0; c<C; c+=CHANNEL_STRIDE) {
|
||||
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
|
||||
int k1 = k + tid / CHANNEL_STRIDE;
|
||||
int h1 = h0 + k1 / BLOCK_W;
|
||||
int w1 = w0 + k1 % BLOCK_W;
|
||||
int c1 = tid % CHANNEL_STRIDE;
|
||||
|
||||
auto fptr = fmap1[b][h1][w1];
|
||||
if (within_bounds(h1, w1, H1, W1))
|
||||
f1[c1][k1] = fptr[c+c1];
|
||||
else
|
||||
f1[c1][k1] = 0.0;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int n=0; n<N; n++) {
|
||||
int h1 = h0 + threadIdx.x;
|
||||
int w1 = w0 + threadIdx.y;
|
||||
if (within_bounds(h1, w1, H1, W1)) {
|
||||
x2s[tid] = coords[b][n][h1][w1][0];
|
||||
y2s[tid] = coords[b][n][h1][w1][1];
|
||||
}
|
||||
|
||||
scalar_t dx = x2s[tid] - floor(x2s[tid]);
|
||||
scalar_t dy = y2s[tid] - floor(y2s[tid]);
|
||||
|
||||
int rd = 2*r + 1;
|
||||
for (int iy=0; iy<rd+1; iy++) {
|
||||
for (int ix=0; ix<rd+1; ix++) {
|
||||
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
|
||||
int k1 = k + tid / CHANNEL_STRIDE;
|
||||
int h2 = static_cast<int>(floor(y2s[k1]))-r+iy;
|
||||
int w2 = static_cast<int>(floor(x2s[k1]))-r+ix;
|
||||
int c2 = tid % CHANNEL_STRIDE;
|
||||
|
||||
auto fptr = fmap2[b][h2][w2];
|
||||
if (within_bounds(h2, w2, H2, W2))
|
||||
f2[c2][k1] = fptr[c+c2];
|
||||
else
|
||||
f2[c2][k1] = 0.0;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
scalar_t s = 0.0;
|
||||
for (int k=0; k<CHANNEL_STRIDE; k++)
|
||||
s += f1[k][tid] * f2[k][tid];
|
||||
|
||||
int ix_nw = H1*W1*((iy-1) + rd*(ix-1));
|
||||
int ix_ne = H1*W1*((iy-1) + rd*ix);
|
||||
int ix_sw = H1*W1*(iy + rd*(ix-1));
|
||||
int ix_se = H1*W1*(iy + rd*ix);
|
||||
|
||||
scalar_t nw = s * (dy) * (dx);
|
||||
scalar_t ne = s * (dy) * (1-dx);
|
||||
scalar_t sw = s * (1-dy) * (dx);
|
||||
scalar_t se = s * (1-dy) * (1-dx);
|
||||
|
||||
scalar_t* corr_ptr = &corr[b][n][0][h1][w1];
|
||||
|
||||
if (iy > 0 && ix > 0 && within_bounds(h1, w1, H1, W1))
|
||||
*(corr_ptr + ix_nw) += nw;
|
||||
|
||||
if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1))
|
||||
*(corr_ptr + ix_ne) += ne;
|
||||
|
||||
if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1))
|
||||
*(corr_ptr + ix_sw) += sw;
|
||||
|
||||
if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1))
|
||||
*(corr_ptr + ix_se) += se;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void corr_backward_kernel(
|
||||
const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap1,
|
||||
const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap2,
|
||||
const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> coords,
|
||||
const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> corr_grad,
|
||||
torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap1_grad,
|
||||
torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap2_grad,
|
||||
torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> coords_grad,
|
||||
int r)
|
||||
{
|
||||
|
||||
const int b = blockIdx.x;
|
||||
const int h0 = blockIdx.y * blockDim.x;
|
||||
const int w0 = blockIdx.z * blockDim.y;
|
||||
const int tid = threadIdx.x * blockDim.y + threadIdx.y;
|
||||
|
||||
const int H1 = fmap1.size(1);
|
||||
const int W1 = fmap1.size(2);
|
||||
const int H2 = fmap2.size(1);
|
||||
const int W2 = fmap2.size(2);
|
||||
const int N = coords.size(1);
|
||||
const int C = fmap1.size(3);
|
||||
|
||||
__shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1];
|
||||
__shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1];
|
||||
|
||||
__shared__ scalar_t f1_grad[CHANNEL_STRIDE][BLOCK_HW+1];
|
||||
__shared__ scalar_t f2_grad[CHANNEL_STRIDE][BLOCK_HW+1];
|
||||
|
||||
__shared__ scalar_t x2s[BLOCK_HW];
|
||||
__shared__ scalar_t y2s[BLOCK_HW];
|
||||
|
||||
for (int c=0; c<C; c+=CHANNEL_STRIDE) {
|
||||
|
||||
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
|
||||
int k1 = k + tid / CHANNEL_STRIDE;
|
||||
int h1 = h0 + k1 / BLOCK_W;
|
||||
int w1 = w0 + k1 % BLOCK_W;
|
||||
int c1 = tid % CHANNEL_STRIDE;
|
||||
|
||||
auto fptr = fmap1[b][h1][w1];
|
||||
if (within_bounds(h1, w1, H1, W1))
|
||||
f1[c1][k1] = fptr[c+c1];
|
||||
else
|
||||
f1[c1][k1] = 0.0;
|
||||
|
||||
f1_grad[c1][k1] = 0.0;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
int h1 = h0 + threadIdx.x;
|
||||
int w1 = w0 + threadIdx.y;
|
||||
|
||||
for (int n=0; n<N; n++) {
|
||||
x2s[tid] = coords[b][n][h1][w1][0];
|
||||
y2s[tid] = coords[b][n][h1][w1][1];
|
||||
|
||||
scalar_t dx = x2s[tid] - floor(x2s[tid]);
|
||||
scalar_t dy = y2s[tid] - floor(y2s[tid]);
|
||||
|
||||
int rd = 2*r + 1;
|
||||
for (int iy=0; iy<rd+1; iy++) {
|
||||
for (int ix=0; ix<rd+1; ix++) {
|
||||
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
|
||||
int k1 = k + tid / CHANNEL_STRIDE;
|
||||
int h2 = static_cast<int>(floor(y2s[k1]))-r+iy;
|
||||
int w2 = static_cast<int>(floor(x2s[k1]))-r+ix;
|
||||
int c2 = tid % CHANNEL_STRIDE;
|
||||
|
||||
auto fptr = fmap2[b][h2][w2];
|
||||
if (within_bounds(h2, w2, H2, W2))
|
||||
f2[c2][k1] = fptr[c+c2];
|
||||
else
|
||||
f2[c2][k1] = 0.0;
|
||||
|
||||
f2_grad[c2][k1] = 0.0;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
const scalar_t* grad_ptr = &corr_grad[b][n][0][h1][w1];
|
||||
scalar_t g = 0.0;
|
||||
|
||||
int ix_nw = H1*W1*((iy-1) + rd*(ix-1));
|
||||
int ix_ne = H1*W1*((iy-1) + rd*ix);
|
||||
int ix_sw = H1*W1*(iy + rd*(ix-1));
|
||||
int ix_se = H1*W1*(iy + rd*ix);
|
||||
|
||||
if (iy > 0 && ix > 0 && within_bounds(h1, w1, H1, W1))
|
||||
g += *(grad_ptr + ix_nw) * dy * dx;
|
||||
|
||||
if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1))
|
||||
g += *(grad_ptr + ix_ne) * dy * (1-dx);
|
||||
|
||||
if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1))
|
||||
g += *(grad_ptr + ix_sw) * (1-dy) * dx;
|
||||
|
||||
if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1))
|
||||
g += *(grad_ptr + ix_se) * (1-dy) * (1-dx);
|
||||
|
||||
for (int k=0; k<CHANNEL_STRIDE; k++) {
|
||||
f1_grad[k][tid] += g * f2[k][tid];
|
||||
f2_grad[k][tid] += g * f1[k][tid];
|
||||
}
|
||||
|
||||
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
|
||||
int k1 = k + tid / CHANNEL_STRIDE;
|
||||
int h2 = static_cast<int>(floor(y2s[k1]))-r+iy;
|
||||
int w2 = static_cast<int>(floor(x2s[k1]))-r+ix;
|
||||
int c2 = tid % CHANNEL_STRIDE;
|
||||
|
||||
scalar_t* fptr = &fmap2_grad[b][h2][w2][0];
|
||||
if (within_bounds(h2, w2, H2, W2))
|
||||
atomicAdd(fptr+c+c2, f2_grad[c2][k1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
|
||||
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
|
||||
int k1 = k + tid / CHANNEL_STRIDE;
|
||||
int h1 = h0 + k1 / BLOCK_W;
|
||||
int w1 = w0 + k1 % BLOCK_W;
|
||||
int c1 = tid % CHANNEL_STRIDE;
|
||||
|
||||
scalar_t* fptr = &fmap1_grad[b][h1][w1][0];
|
||||
if (within_bounds(h1, w1, H1, W1))
|
||||
fptr[c+c1] += f1_grad[c1][k1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
std::vector<torch::Tensor> corr_cuda_forward(
|
||||
torch::Tensor fmap1,
|
||||
torch::Tensor fmap2,
|
||||
torch::Tensor coords,
|
||||
int radius)
|
||||
{
|
||||
const auto B = coords.size(0);
|
||||
const auto N = coords.size(1);
|
||||
const auto H = coords.size(2);
|
||||
const auto W = coords.size(3);
|
||||
|
||||
const auto rd = 2 * radius + 1;
|
||||
auto opts = fmap1.options();
|
||||
auto corr = torch::zeros({B, N, rd*rd, H, W}, opts);
|
||||
|
||||
const dim3 blocks(B, (H+BLOCK_H-1)/BLOCK_H, (W+BLOCK_W-1)/BLOCK_W);
|
||||
const dim3 threads(BLOCK_H, BLOCK_W);
|
||||
|
||||
corr_forward_kernel<float><<<blocks, threads>>>(
|
||||
fmap1.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
|
||||
fmap2.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
|
||||
coords.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
|
||||
corr.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
|
||||
radius);
|
||||
|
||||
return {corr};
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> corr_cuda_backward(
|
||||
torch::Tensor fmap1,
|
||||
torch::Tensor fmap2,
|
||||
torch::Tensor coords,
|
||||
torch::Tensor corr_grad,
|
||||
int radius)
|
||||
{
|
||||
const auto B = coords.size(0);
|
||||
const auto N = coords.size(1);
|
||||
|
||||
const auto H1 = fmap1.size(1);
|
||||
const auto W1 = fmap1.size(2);
|
||||
const auto H2 = fmap2.size(1);
|
||||
const auto W2 = fmap2.size(2);
|
||||
const auto C = fmap1.size(3);
|
||||
|
||||
auto opts = fmap1.options();
|
||||
auto fmap1_grad = torch::zeros({B, H1, W1, C}, opts);
|
||||
auto fmap2_grad = torch::zeros({B, H2, W2, C}, opts);
|
||||
auto coords_grad = torch::zeros({B, N, H1, W1, 2}, opts);
|
||||
|
||||
const dim3 blocks(B, (H1+BLOCK_H-1)/BLOCK_H, (W1+BLOCK_W-1)/BLOCK_W);
|
||||
const dim3 threads(BLOCK_H, BLOCK_W);
|
||||
|
||||
|
||||
corr_backward_kernel<float><<<blocks, threads>>>(
|
||||
fmap1.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
|
||||
fmap2.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
|
||||
coords.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
|
||||
corr_grad.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
|
||||
fmap1_grad.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
|
||||
fmap2_grad.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
|
||||
coords_grad.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
|
||||
radius);
|
||||
|
||||
return {fmap1_grad, fmap2_grad, coords_grad};
|
||||
}
|
15
alt_cuda_corr/setup.py
Normal file
15
alt_cuda_corr/setup.py
Normal file
@ -0,0 +1,15 @@
|
||||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
|
||||
setup(
|
||||
name='correlation',
|
||||
ext_modules=[
|
||||
CUDAExtension('alt_cuda_corr',
|
||||
sources=['correlation.cpp', 'correlation_kernel.cu'],
|
||||
extra_compile_args={'cxx': [], 'nvcc': ['-O3']}),
|
||||
],
|
||||
cmdclass={
|
||||
'build_ext': BuildExtension
|
||||
})
|
||||
|
57
core/corr.py
57
core/corr.py
@ -2,6 +2,12 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from utils.utils import bilinear_sampler, coords_grid
|
||||
|
||||
try:
|
||||
import alt_cuda_corr
|
||||
except:
|
||||
# alt_cuda_corr is not compiled
|
||||
pass
|
||||
|
||||
|
||||
class CorrBlock:
|
||||
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
||||
@ -43,7 +49,6 @@ class CorrBlock:
|
||||
out = torch.cat(out_pyramid, dim=-1)
|
||||
return out.permute(0, 3, 1, 2).contiguous().float()
|
||||
|
||||
|
||||
@staticmethod
|
||||
def corr(fmap1, fmap2):
|
||||
batch, dim, ht, wd = fmap1.shape
|
||||
@ -54,3 +59,53 @@ class CorrBlock:
|
||||
corr = corr.view(batch, ht, wd, 1, ht, wd)
|
||||
return corr / torch.sqrt(torch.tensor(dim).float())
|
||||
|
||||
|
||||
class CorrLayer(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, fmap1, fmap2, coords, r):
|
||||
fmap1 = fmap1.contiguous()
|
||||
fmap2 = fmap2.contiguous()
|
||||
coords = coords.contiguous()
|
||||
ctx.save_for_backward(fmap1, fmap2, coords)
|
||||
ctx.r = r
|
||||
corr, = correlation_cudaz.forward(fmap1, fmap2, coords, ctx.r)
|
||||
return corr
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_corr):
|
||||
fmap1, fmap2, coords = ctx.saved_tensors
|
||||
grad_corr = grad_corr.contiguous()
|
||||
fmap1_grad, fmap2_grad, coords_grad = \
|
||||
correlation_cudaz.backward(fmap1, fmap2, coords, grad_corr, ctx.r)
|
||||
return fmap1_grad, fmap2_grad, coords_grad, None
|
||||
|
||||
|
||||
class AlternateCorrBlock:
|
||||
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
||||
self.num_levels = num_levels
|
||||
self.radius = radius
|
||||
|
||||
self.pyramid = [(fmap1, fmap2)]
|
||||
for i in range(self.num_levels):
|
||||
fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
|
||||
fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
|
||||
self.pyramid.append((fmap1, fmap2))
|
||||
|
||||
def __call__(self, coords):
|
||||
|
||||
coords = coords.permute(0, 2, 3, 1)
|
||||
B, H, W, _ = coords.shape
|
||||
|
||||
corr_list = []
|
||||
for i in range(self.num_levels):
|
||||
r = self.radius
|
||||
fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1)
|
||||
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1)
|
||||
|
||||
coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
|
||||
corr = alt_cuda_corr(fmap1_i, fmap2_i, coords_i, r)
|
||||
corr_list.append(corr.squeeze(1))
|
||||
|
||||
corr = torch.stack(corr_list, dim=1)
|
||||
corr = corr.reshape(B, -1, H, W)
|
||||
return corr / 16.0
|
||||
|
10
core/raft.py
10
core/raft.py
@ -5,7 +5,7 @@ import torch.nn.functional as F
|
||||
|
||||
from update import BasicUpdateBlock, SmallUpdateBlock
|
||||
from extractor import BasicEncoder, SmallEncoder
|
||||
from corr import CorrBlock
|
||||
from corr import CorrBlock, AlternateCorrBlock
|
||||
from utils.utils import bilinear_sampler, coords_grid, upflow8
|
||||
|
||||
try:
|
||||
@ -41,6 +41,9 @@ class RAFT(nn.Module):
|
||||
if 'dropout' not in args._get_kwargs():
|
||||
args.dropout = 0
|
||||
|
||||
if 'alternate_corr' not in args._get_kwargs():
|
||||
args.alternate_corr = False
|
||||
|
||||
# feature network, context network, and update block
|
||||
if args.small:
|
||||
self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
|
||||
@ -99,7 +102,10 @@ class RAFT(nn.Module):
|
||||
|
||||
fmap1 = fmap1.float()
|
||||
fmap2 = fmap2.float()
|
||||
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
||||
if self.args.alternate_corr:
|
||||
corr_fn = CorrBlockAlternate(fmap1, fmap2, radius=self.args.corr_radius)
|
||||
else:
|
||||
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
||||
|
||||
# run the context network
|
||||
with autocast(enabled=self.args.mixed_precision):
|
||||
|
@ -4,8 +4,10 @@ import math
|
||||
from PIL import Image
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
cv2.setNumThreads(0)
|
||||
cv2.ocl.setUseOpenCL(False)
|
||||
|
||||
import torch
|
||||
from torchvision.transforms import ColorJitter
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
@ -1,3 +1,6 @@
|
||||
# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
|
||||
|
||||
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2018 Tom Runia
|
||||
@ -12,21 +15,20 @@
|
||||
# Author: Tom Runia
|
||||
# Date Created: 2018-08-03
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def make_colorwheel():
|
||||
'''
|
||||
"""
|
||||
Generates a color wheel for optical flow visualization as presented in:
|
||||
Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
|
||||
URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
|
||||
According to the C++ source code of Daniel Scharstein
|
||||
According to the Matlab source code of Deqing Sun
|
||||
'''
|
||||
|
||||
Code follows the original C++ source code of Daniel Scharstein.
|
||||
Code follows the the Matlab source code of Deqing Sun.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Color wheel
|
||||
"""
|
||||
|
||||
RY = 15
|
||||
YG = 6
|
||||
@ -65,211 +67,66 @@ def make_colorwheel():
|
||||
return colorwheel
|
||||
|
||||
|
||||
def flow_compute_color(u, v, convert_to_bgr=False):
|
||||
'''
|
||||
def flow_uv_to_colors(u, v, convert_to_bgr=False):
|
||||
"""
|
||||
Applies the flow color wheel to (possibly clipped) flow components u and v.
|
||||
|
||||
According to the C++ source code of Daniel Scharstein
|
||||
According to the Matlab source code of Deqing Sun
|
||||
:param u: np.ndarray, input horizontal flow
|
||||
:param v: np.ndarray, input vertical flow
|
||||
:param convert_to_bgr: bool, whether to change ordering and output BGR instead of RGB
|
||||
:return:
|
||||
'''
|
||||
|
||||
Args:
|
||||
u (np.ndarray): Input horizontal flow of shape [H,W]
|
||||
v (np.ndarray): Input vertical flow of shape [H,W]
|
||||
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Flow visualization image of shape [H,W,3]
|
||||
"""
|
||||
flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
|
||||
|
||||
colorwheel = make_colorwheel() # shape [55x3]
|
||||
ncols = colorwheel.shape[0]
|
||||
|
||||
rad = np.sqrt(np.square(u) + np.square(v))
|
||||
a = np.arctan2(-v, -u)/np.pi
|
||||
|
||||
fk = (a+1) / 2*(ncols-1) + 1
|
||||
fk = (a+1) / 2*(ncols-1)
|
||||
k0 = np.floor(fk).astype(np.int32)
|
||||
k1 = k0 + 1
|
||||
k1[k1 == ncols] = 1
|
||||
k1[k1 == ncols] = 0
|
||||
f = fk - k0
|
||||
|
||||
for i in range(colorwheel.shape[1]):
|
||||
|
||||
tmp = colorwheel[:,i]
|
||||
col0 = tmp[k0] / 255.0
|
||||
col1 = tmp[k1] / 255.0
|
||||
col = (1-f)*col0 + f*col1
|
||||
|
||||
idx = (rad <= 1)
|
||||
col[idx] = 1 - rad[idx] * (1-col[idx])
|
||||
col[~idx] = col[~idx] * 0.75 # out of range?
|
||||
|
||||
col[~idx] = col[~idx] * 0.75 # out of range
|
||||
# Note the 2-i => BGR instead of RGB
|
||||
ch_idx = 2-i if convert_to_bgr else i
|
||||
flow_image[:,:,ch_idx] = np.floor(255 * col)
|
||||
|
||||
return flow_image
|
||||
|
||||
|
||||
def flow_to_color(flow_uv, clip_flow=None, convert_to_bgr=False):
|
||||
'''
|
||||
Expects a two dimensional flow image of shape [H,W,2]
|
||||
According to the C++ source code of Daniel Scharstein
|
||||
According to the Matlab source code of Deqing Sun
|
||||
:param flow_uv: np.ndarray of shape [H,W,2]
|
||||
:param clip_flow: float, maximum clipping value for flow
|
||||
:return:
|
||||
'''
|
||||
def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
|
||||
"""
|
||||
Expects a two dimensional flow image of shape.
|
||||
|
||||
Args:
|
||||
flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
|
||||
clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
|
||||
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Flow visualization image of shape [H,W,3]
|
||||
"""
|
||||
assert flow_uv.ndim == 3, 'input flow must have three dimensions'
|
||||
assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
|
||||
|
||||
if clip_flow is not None:
|
||||
flow_uv = np.clip(flow_uv, 0, clip_flow)
|
||||
|
||||
u = flow_uv[:,:,0]
|
||||
v = flow_uv[:,:,1]
|
||||
|
||||
rad = np.sqrt(np.square(u) + np.square(v))
|
||||
rad_max = np.max(rad)
|
||||
|
||||
epsilon = 1e-5
|
||||
u = u / (rad_max + epsilon)
|
||||
v = v / (rad_max + epsilon)
|
||||
|
||||
return flow_compute_color(u, v, convert_to_bgr)
|
||||
|
||||
|
||||
|
||||
UNKNOWN_FLOW_THRESH = 1e7
|
||||
SMALLFLOW = 0.0
|
||||
LARGEFLOW = 1e8
|
||||
|
||||
def make_color_wheel():
|
||||
"""
|
||||
Generate color wheel according Middlebury color code
|
||||
:return: Color wheel
|
||||
"""
|
||||
RY = 15
|
||||
YG = 6
|
||||
GC = 4
|
||||
CB = 11
|
||||
BM = 13
|
||||
MR = 6
|
||||
|
||||
ncols = RY + YG + GC + CB + BM + MR
|
||||
|
||||
colorwheel = np.zeros([ncols, 3])
|
||||
|
||||
col = 0
|
||||
|
||||
# RY
|
||||
colorwheel[0:RY, 0] = 255
|
||||
colorwheel[0:RY, 1] = np.transpose(np.floor(255*np.arange(0, RY) / RY))
|
||||
col += RY
|
||||
|
||||
# YG
|
||||
colorwheel[col:col+YG, 0] = 255 - np.transpose(np.floor(255*np.arange(0, YG) / YG))
|
||||
colorwheel[col:col+YG, 1] = 255
|
||||
col += YG
|
||||
|
||||
# GC
|
||||
colorwheel[col:col+GC, 1] = 255
|
||||
colorwheel[col:col+GC, 2] = np.transpose(np.floor(255*np.arange(0, GC) / GC))
|
||||
col += GC
|
||||
|
||||
# CB
|
||||
colorwheel[col:col+CB, 1] = 255 - np.transpose(np.floor(255*np.arange(0, CB) / CB))
|
||||
colorwheel[col:col+CB, 2] = 255
|
||||
col += CB
|
||||
|
||||
# BM
|
||||
colorwheel[col:col+BM, 2] = 255
|
||||
colorwheel[col:col+BM, 0] = np.transpose(np.floor(255*np.arange(0, BM) / BM))
|
||||
col += + BM
|
||||
|
||||
# MR
|
||||
colorwheel[col:col+MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))
|
||||
colorwheel[col:col+MR, 0] = 255
|
||||
|
||||
return colorwheel
|
||||
|
||||
|
||||
|
||||
def compute_color(u, v):
|
||||
"""
|
||||
compute optical flow color map
|
||||
:param u: optical flow horizontal map
|
||||
:param v: optical flow vertical map
|
||||
:return: optical flow in color code
|
||||
"""
|
||||
[h, w] = u.shape
|
||||
img = np.zeros([h, w, 3])
|
||||
nanIdx = np.isnan(u) | np.isnan(v)
|
||||
u[nanIdx] = 0
|
||||
v[nanIdx] = 0
|
||||
|
||||
colorwheel = make_color_wheel()
|
||||
ncols = np.size(colorwheel, 0)
|
||||
|
||||
rad = np.sqrt(u**2+v**2)
|
||||
|
||||
a = np.arctan2(-v, -u) / np.pi
|
||||
|
||||
fk = (a+1) / 2 * (ncols - 1) + 1
|
||||
|
||||
k0 = np.floor(fk).astype(int)
|
||||
|
||||
k1 = k0 + 1
|
||||
k1[k1 == ncols+1] = 1
|
||||
f = fk - k0
|
||||
|
||||
for i in range(0, np.size(colorwheel,1)):
|
||||
tmp = colorwheel[:, i]
|
||||
col0 = tmp[k0-1] / 255
|
||||
col1 = tmp[k1-1] / 255
|
||||
col = (1-f) * col0 + f * col1
|
||||
|
||||
idx = rad <= 1
|
||||
col[idx] = 1-rad[idx]*(1-col[idx])
|
||||
notidx = np.logical_not(idx)
|
||||
|
||||
col[notidx] *= 0.75
|
||||
img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx)))
|
||||
|
||||
return img
|
||||
|
||||
# from https://github.com/gengshan-y/VCN
|
||||
def flow_to_image(flow):
|
||||
"""
|
||||
Convert flow into middlebury color code image
|
||||
:param flow: optical flow map
|
||||
:return: optical flow image in middlebury color
|
||||
"""
|
||||
u = flow[:, :, 0]
|
||||
v = flow[:, :, 1]
|
||||
|
||||
maxu = -999.
|
||||
maxv = -999.
|
||||
minu = 999.
|
||||
minv = 999.
|
||||
|
||||
idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH)
|
||||
u[idxUnknow] = 0
|
||||
v[idxUnknow] = 0
|
||||
|
||||
maxu = max(maxu, np.max(u))
|
||||
minu = min(minu, np.min(u))
|
||||
|
||||
maxv = max(maxv, np.max(v))
|
||||
minv = min(minv, np.min(v))
|
||||
|
||||
rad = np.sqrt(u ** 2 + v ** 2)
|
||||
maxrad = max(-1, np.max(rad))
|
||||
|
||||
u = u/(maxrad + np.finfo(float).eps)
|
||||
v = v/(maxrad + np.finfo(float).eps)
|
||||
|
||||
img = compute_color(u, v)
|
||||
|
||||
idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2)
|
||||
img[idx] = 0
|
||||
|
||||
return np.uint8(img)
|
||||
return flow_uv_to_colors(u, v, convert_to_bgr)
|
@ -2,7 +2,10 @@ import numpy as np
|
||||
from PIL import Image
|
||||
from os.path import *
|
||||
import re
|
||||
|
||||
import cv2
|
||||
cv2.setNumThreads(0)
|
||||
cv2.ocl.setUseOpenCL(False)
|
||||
|
||||
TAG_CHAR = np.array([202021.25], np.float32)
|
||||
|
||||
|
@ -6,11 +6,14 @@ from scipy import interpolate
|
||||
|
||||
class InputPadder:
|
||||
""" Pads images such that dimensions are divisible by 8 """
|
||||
def __init__(self, dims):
|
||||
def __init__(self, dims, mode='sintel'):
|
||||
self.ht, self.wd = dims[-2:]
|
||||
pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
|
||||
pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
|
||||
self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
|
||||
if mode == 'sintel':
|
||||
self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
|
||||
else:
|
||||
self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
|
||||
|
||||
def pad(self, *inputs):
|
||||
return [F.pad(x, self._pad, mode='replicate') for x in inputs]
|
||||
@ -42,10 +45,10 @@ def forward_interpolate(flow):
|
||||
dy = dy[valid]
|
||||
|
||||
flow_x = interpolate.griddata(
|
||||
(x1, y1), dx, (x0, y0), method='cubic', fill_value=0)
|
||||
(x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
|
||||
|
||||
flow_y = interpolate.griddata(
|
||||
(x1, y1), dy, (x0, y0), method='cubic', fill_value=0)
|
||||
(x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
|
||||
|
||||
flow = np.stack([flow_x, flow_y], axis=0)
|
||||
return torch.from_numpy(flow).float()
|
||||
@ -68,7 +71,6 @@ def bilinear_sampler(img, coords, mode='bilinear', mask=False):
|
||||
return img
|
||||
|
||||
|
||||
|
||||
def coords_grid(batch, ht, wd):
|
||||
coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
|
||||
coords = torch.stack(coords[::-1], dim=0).float()
|
||||
|
1
demo.py
1
demo.py
@ -74,6 +74,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--path', help="dataset for evaluation")
|
||||
parser.add_argument('--small', action='store_true', help='use small model')
|
||||
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
|
||||
parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
|
||||
args = parser.parse_args()
|
||||
|
||||
demo(args)
|
||||
|
@ -61,7 +61,7 @@ def create_kitti_submission(model, iters=24, output_path='kitti_submission'):
|
||||
|
||||
for test_id in range(len(test_dataset)):
|
||||
image1, image2, (frame_id, ) = test_dataset[test_id]
|
||||
padder = InputPadder(image1.shape)
|
||||
padder = InputPadder(image1.shape, mode='kitti')
|
||||
image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
|
||||
|
||||
_, flow_pr = model(image1, image2, iters=iters, test_mode=True)
|
||||
@ -139,7 +139,7 @@ def validate_kitti(model, iters=24):
|
||||
image1 = image1[None].cuda()
|
||||
image2 = image2[None].cuda()
|
||||
|
||||
padder = InputPadder(image1.shape)
|
||||
padder = InputPadder(image1.shape, mode='kitti')
|
||||
image1, image2 = padder.pad(image1, image2)
|
||||
|
||||
flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True)
|
||||
@ -172,6 +172,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--dataset', help="dataset for evaluation")
|
||||
parser.add_argument('--small', action='store_true', help='use small model')
|
||||
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
|
||||
parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
|
||||
args = parser.parse_args()
|
||||
|
||||
model = torch.nn.DataParallel(RAFT(args))
|
||||
|
Loading…
Reference in New Issue
Block a user