added cuda extension for efficent implementation
This commit is contained in:
parent
5b1f510d6b
commit
c86b3dc8f3
1
.gitignore
vendored
1
.gitignore
vendored
@ -5,3 +5,4 @@ datasets
|
|||||||
pytorch_env
|
pytorch_env
|
||||||
models
|
models
|
||||||
build
|
build
|
||||||
|
correlation.egg-info
|
||||||
|
@ -31,6 +31,13 @@ You can demo a trained model on a sequence of frames
|
|||||||
python demo.py --model=models/raft-things.pth --path=demo-frames
|
python demo.py --model=models/raft-things.pth --path=demo-frames
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## (Optional) Efficent Implementation
|
||||||
|
You can optionally use our alternate (efficent) implementation by compiling the provided cuda extension
|
||||||
|
```Shell
|
||||||
|
cd alt_cuda_corr && python setup.py install && cd ..
|
||||||
|
```
|
||||||
|
and running `demo.py` and `evaluate.py` with the `--alternate_corr` flag.Note, this implementation is somewhat slower than all-pairs, but uses significantly less GPU memory during the forward pass.
|
||||||
|
|
||||||
|
|
||||||
## Required Data
|
## Required Data
|
||||||
To evaluate/train RAFT, you will need to download the required datasets.
|
To evaluate/train RAFT, you will need to download the required datasets.
|
||||||
|
54
alt_cuda_corr/correlation.cpp
Normal file
54
alt_cuda_corr/correlation.cpp
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
#include <torch/extension.h>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
// CUDA forward declarations
|
||||||
|
std::vector<torch::Tensor> corr_cuda_forward(
|
||||||
|
torch::Tensor fmap1,
|
||||||
|
torch::Tensor fmap2,
|
||||||
|
torch::Tensor coords,
|
||||||
|
int radius);
|
||||||
|
|
||||||
|
std::vector<torch::Tensor> corr_cuda_backward(
|
||||||
|
torch::Tensor fmap1,
|
||||||
|
torch::Tensor fmap2,
|
||||||
|
torch::Tensor coords,
|
||||||
|
torch::Tensor corr_grad,
|
||||||
|
int radius);
|
||||||
|
|
||||||
|
// C++ interface
|
||||||
|
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||||
|
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||||
|
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||||
|
|
||||||
|
std::vector<torch::Tensor> corr_forward(
|
||||||
|
torch::Tensor fmap1,
|
||||||
|
torch::Tensor fmap2,
|
||||||
|
torch::Tensor coords,
|
||||||
|
int radius) {
|
||||||
|
CHECK_INPUT(fmap1);
|
||||||
|
CHECK_INPUT(fmap2);
|
||||||
|
CHECK_INPUT(coords);
|
||||||
|
|
||||||
|
return corr_cuda_forward(fmap1, fmap2, coords, radius);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
std::vector<torch::Tensor> corr_backward(
|
||||||
|
torch::Tensor fmap1,
|
||||||
|
torch::Tensor fmap2,
|
||||||
|
torch::Tensor coords,
|
||||||
|
torch::Tensor corr_grad,
|
||||||
|
int radius) {
|
||||||
|
CHECK_INPUT(fmap1);
|
||||||
|
CHECK_INPUT(fmap2);
|
||||||
|
CHECK_INPUT(coords);
|
||||||
|
CHECK_INPUT(corr_grad);
|
||||||
|
|
||||||
|
return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
m.def("forward", &corr_forward, "CORR forward");
|
||||||
|
m.def("backward", &corr_backward, "CORR backward");
|
||||||
|
}
|
324
alt_cuda_corr/correlation_kernel.cu
Normal file
324
alt_cuda_corr/correlation_kernel.cu
Normal file
@ -0,0 +1,324 @@
|
|||||||
|
#include <torch/extension.h>
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
|
||||||
|
#define BLOCK_H 4
|
||||||
|
#define BLOCK_W 8
|
||||||
|
#define BLOCK_HW BLOCK_H * BLOCK_W
|
||||||
|
#define CHANNEL_STRIDE 32
|
||||||
|
|
||||||
|
|
||||||
|
__forceinline__ __device__
|
||||||
|
bool within_bounds(int h, int w, int H, int W) {
|
||||||
|
return h >= 0 && h < H && w >= 0 && w < W;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
__global__ void corr_forward_kernel(
|
||||||
|
const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap1,
|
||||||
|
const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap2,
|
||||||
|
const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> coords,
|
||||||
|
torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> corr,
|
||||||
|
int r)
|
||||||
|
{
|
||||||
|
const int b = blockIdx.x;
|
||||||
|
const int h0 = blockIdx.y * blockDim.x;
|
||||||
|
const int w0 = blockIdx.z * blockDim.y;
|
||||||
|
const int tid = threadIdx.x * blockDim.y + threadIdx.y;
|
||||||
|
|
||||||
|
const int H1 = fmap1.size(1);
|
||||||
|
const int W1 = fmap1.size(2);
|
||||||
|
const int H2 = fmap2.size(1);
|
||||||
|
const int W2 = fmap2.size(2);
|
||||||
|
const int N = coords.size(1);
|
||||||
|
const int C = fmap1.size(3);
|
||||||
|
|
||||||
|
__shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1];
|
||||||
|
__shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1];
|
||||||
|
__shared__ scalar_t x2s[BLOCK_HW];
|
||||||
|
__shared__ scalar_t y2s[BLOCK_HW];
|
||||||
|
|
||||||
|
for (int c=0; c<C; c+=CHANNEL_STRIDE) {
|
||||||
|
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
|
||||||
|
int k1 = k + tid / CHANNEL_STRIDE;
|
||||||
|
int h1 = h0 + k1 / BLOCK_W;
|
||||||
|
int w1 = w0 + k1 % BLOCK_W;
|
||||||
|
int c1 = tid % CHANNEL_STRIDE;
|
||||||
|
|
||||||
|
auto fptr = fmap1[b][h1][w1];
|
||||||
|
if (within_bounds(h1, w1, H1, W1))
|
||||||
|
f1[c1][k1] = fptr[c+c1];
|
||||||
|
else
|
||||||
|
f1[c1][k1] = 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int n=0; n<N; n++) {
|
||||||
|
int h1 = h0 + threadIdx.x;
|
||||||
|
int w1 = w0 + threadIdx.y;
|
||||||
|
if (within_bounds(h1, w1, H1, W1)) {
|
||||||
|
x2s[tid] = coords[b][n][h1][w1][0];
|
||||||
|
y2s[tid] = coords[b][n][h1][w1][1];
|
||||||
|
}
|
||||||
|
|
||||||
|
scalar_t dx = x2s[tid] - floor(x2s[tid]);
|
||||||
|
scalar_t dy = y2s[tid] - floor(y2s[tid]);
|
||||||
|
|
||||||
|
int rd = 2*r + 1;
|
||||||
|
for (int iy=0; iy<rd+1; iy++) {
|
||||||
|
for (int ix=0; ix<rd+1; ix++) {
|
||||||
|
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
|
||||||
|
int k1 = k + tid / CHANNEL_STRIDE;
|
||||||
|
int h2 = static_cast<int>(floor(y2s[k1]))-r+iy;
|
||||||
|
int w2 = static_cast<int>(floor(x2s[k1]))-r+ix;
|
||||||
|
int c2 = tid % CHANNEL_STRIDE;
|
||||||
|
|
||||||
|
auto fptr = fmap2[b][h2][w2];
|
||||||
|
if (within_bounds(h2, w2, H2, W2))
|
||||||
|
f2[c2][k1] = fptr[c+c2];
|
||||||
|
else
|
||||||
|
f2[c2][k1] = 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
scalar_t s = 0.0;
|
||||||
|
for (int k=0; k<CHANNEL_STRIDE; k++)
|
||||||
|
s += f1[k][tid] * f2[k][tid];
|
||||||
|
|
||||||
|
int ix_nw = H1*W1*((iy-1) + rd*(ix-1));
|
||||||
|
int ix_ne = H1*W1*((iy-1) + rd*ix);
|
||||||
|
int ix_sw = H1*W1*(iy + rd*(ix-1));
|
||||||
|
int ix_se = H1*W1*(iy + rd*ix);
|
||||||
|
|
||||||
|
scalar_t nw = s * (dy) * (dx);
|
||||||
|
scalar_t ne = s * (dy) * (1-dx);
|
||||||
|
scalar_t sw = s * (1-dy) * (dx);
|
||||||
|
scalar_t se = s * (1-dy) * (1-dx);
|
||||||
|
|
||||||
|
scalar_t* corr_ptr = &corr[b][n][0][h1][w1];
|
||||||
|
|
||||||
|
if (iy > 0 && ix > 0 && within_bounds(h1, w1, H1, W1))
|
||||||
|
*(corr_ptr + ix_nw) += nw;
|
||||||
|
|
||||||
|
if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1))
|
||||||
|
*(corr_ptr + ix_ne) += ne;
|
||||||
|
|
||||||
|
if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1))
|
||||||
|
*(corr_ptr + ix_sw) += sw;
|
||||||
|
|
||||||
|
if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1))
|
||||||
|
*(corr_ptr + ix_se) += se;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
__global__ void corr_backward_kernel(
|
||||||
|
const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap1,
|
||||||
|
const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap2,
|
||||||
|
const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> coords,
|
||||||
|
const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> corr_grad,
|
||||||
|
torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap1_grad,
|
||||||
|
torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap2_grad,
|
||||||
|
torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> coords_grad,
|
||||||
|
int r)
|
||||||
|
{
|
||||||
|
|
||||||
|
const int b = blockIdx.x;
|
||||||
|
const int h0 = blockIdx.y * blockDim.x;
|
||||||
|
const int w0 = blockIdx.z * blockDim.y;
|
||||||
|
const int tid = threadIdx.x * blockDim.y + threadIdx.y;
|
||||||
|
|
||||||
|
const int H1 = fmap1.size(1);
|
||||||
|
const int W1 = fmap1.size(2);
|
||||||
|
const int H2 = fmap2.size(1);
|
||||||
|
const int W2 = fmap2.size(2);
|
||||||
|
const int N = coords.size(1);
|
||||||
|
const int C = fmap1.size(3);
|
||||||
|
|
||||||
|
__shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1];
|
||||||
|
__shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1];
|
||||||
|
|
||||||
|
__shared__ scalar_t f1_grad[CHANNEL_STRIDE][BLOCK_HW+1];
|
||||||
|
__shared__ scalar_t f2_grad[CHANNEL_STRIDE][BLOCK_HW+1];
|
||||||
|
|
||||||
|
__shared__ scalar_t x2s[BLOCK_HW];
|
||||||
|
__shared__ scalar_t y2s[BLOCK_HW];
|
||||||
|
|
||||||
|
for (int c=0; c<C; c+=CHANNEL_STRIDE) {
|
||||||
|
|
||||||
|
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
|
||||||
|
int k1 = k + tid / CHANNEL_STRIDE;
|
||||||
|
int h1 = h0 + k1 / BLOCK_W;
|
||||||
|
int w1 = w0 + k1 % BLOCK_W;
|
||||||
|
int c1 = tid % CHANNEL_STRIDE;
|
||||||
|
|
||||||
|
auto fptr = fmap1[b][h1][w1];
|
||||||
|
if (within_bounds(h1, w1, H1, W1))
|
||||||
|
f1[c1][k1] = fptr[c+c1];
|
||||||
|
else
|
||||||
|
f1[c1][k1] = 0.0;
|
||||||
|
|
||||||
|
f1_grad[c1][k1] = 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
int h1 = h0 + threadIdx.x;
|
||||||
|
int w1 = w0 + threadIdx.y;
|
||||||
|
|
||||||
|
for (int n=0; n<N; n++) {
|
||||||
|
x2s[tid] = coords[b][n][h1][w1][0];
|
||||||
|
y2s[tid] = coords[b][n][h1][w1][1];
|
||||||
|
|
||||||
|
scalar_t dx = x2s[tid] - floor(x2s[tid]);
|
||||||
|
scalar_t dy = y2s[tid] - floor(y2s[tid]);
|
||||||
|
|
||||||
|
int rd = 2*r + 1;
|
||||||
|
for (int iy=0; iy<rd+1; iy++) {
|
||||||
|
for (int ix=0; ix<rd+1; ix++) {
|
||||||
|
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
|
||||||
|
int k1 = k + tid / CHANNEL_STRIDE;
|
||||||
|
int h2 = static_cast<int>(floor(y2s[k1]))-r+iy;
|
||||||
|
int w2 = static_cast<int>(floor(x2s[k1]))-r+ix;
|
||||||
|
int c2 = tid % CHANNEL_STRIDE;
|
||||||
|
|
||||||
|
auto fptr = fmap2[b][h2][w2];
|
||||||
|
if (within_bounds(h2, w2, H2, W2))
|
||||||
|
f2[c2][k1] = fptr[c+c2];
|
||||||
|
else
|
||||||
|
f2[c2][k1] = 0.0;
|
||||||
|
|
||||||
|
f2_grad[c2][k1] = 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
const scalar_t* grad_ptr = &corr_grad[b][n][0][h1][w1];
|
||||||
|
scalar_t g = 0.0;
|
||||||
|
|
||||||
|
int ix_nw = H1*W1*((iy-1) + rd*(ix-1));
|
||||||
|
int ix_ne = H1*W1*((iy-1) + rd*ix);
|
||||||
|
int ix_sw = H1*W1*(iy + rd*(ix-1));
|
||||||
|
int ix_se = H1*W1*(iy + rd*ix);
|
||||||
|
|
||||||
|
if (iy > 0 && ix > 0 && within_bounds(h1, w1, H1, W1))
|
||||||
|
g += *(grad_ptr + ix_nw) * dy * dx;
|
||||||
|
|
||||||
|
if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1))
|
||||||
|
g += *(grad_ptr + ix_ne) * dy * (1-dx);
|
||||||
|
|
||||||
|
if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1))
|
||||||
|
g += *(grad_ptr + ix_sw) * (1-dy) * dx;
|
||||||
|
|
||||||
|
if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1))
|
||||||
|
g += *(grad_ptr + ix_se) * (1-dy) * (1-dx);
|
||||||
|
|
||||||
|
for (int k=0; k<CHANNEL_STRIDE; k++) {
|
||||||
|
f1_grad[k][tid] += g * f2[k][tid];
|
||||||
|
f2_grad[k][tid] += g * f1[k][tid];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
|
||||||
|
int k1 = k + tid / CHANNEL_STRIDE;
|
||||||
|
int h2 = static_cast<int>(floor(y2s[k1]))-r+iy;
|
||||||
|
int w2 = static_cast<int>(floor(x2s[k1]))-r+ix;
|
||||||
|
int c2 = tid % CHANNEL_STRIDE;
|
||||||
|
|
||||||
|
scalar_t* fptr = &fmap2_grad[b][h2][w2][0];
|
||||||
|
if (within_bounds(h2, w2, H2, W2))
|
||||||
|
atomicAdd(fptr+c+c2, f2_grad[c2][k1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
|
||||||
|
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
|
||||||
|
int k1 = k + tid / CHANNEL_STRIDE;
|
||||||
|
int h1 = h0 + k1 / BLOCK_W;
|
||||||
|
int w1 = w0 + k1 % BLOCK_W;
|
||||||
|
int c1 = tid % CHANNEL_STRIDE;
|
||||||
|
|
||||||
|
scalar_t* fptr = &fmap1_grad[b][h1][w1][0];
|
||||||
|
if (within_bounds(h1, w1, H1, W1))
|
||||||
|
fptr[c+c1] += f1_grad[c1][k1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
std::vector<torch::Tensor> corr_cuda_forward(
|
||||||
|
torch::Tensor fmap1,
|
||||||
|
torch::Tensor fmap2,
|
||||||
|
torch::Tensor coords,
|
||||||
|
int radius)
|
||||||
|
{
|
||||||
|
const auto B = coords.size(0);
|
||||||
|
const auto N = coords.size(1);
|
||||||
|
const auto H = coords.size(2);
|
||||||
|
const auto W = coords.size(3);
|
||||||
|
|
||||||
|
const auto rd = 2 * radius + 1;
|
||||||
|
auto opts = fmap1.options();
|
||||||
|
auto corr = torch::zeros({B, N, rd*rd, H, W}, opts);
|
||||||
|
|
||||||
|
const dim3 blocks(B, (H+BLOCK_H-1)/BLOCK_H, (W+BLOCK_W-1)/BLOCK_W);
|
||||||
|
const dim3 threads(BLOCK_H, BLOCK_W);
|
||||||
|
|
||||||
|
corr_forward_kernel<float><<<blocks, threads>>>(
|
||||||
|
fmap1.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
|
||||||
|
fmap2.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
|
||||||
|
coords.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
|
||||||
|
corr.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
|
||||||
|
radius);
|
||||||
|
|
||||||
|
return {corr};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<torch::Tensor> corr_cuda_backward(
|
||||||
|
torch::Tensor fmap1,
|
||||||
|
torch::Tensor fmap2,
|
||||||
|
torch::Tensor coords,
|
||||||
|
torch::Tensor corr_grad,
|
||||||
|
int radius)
|
||||||
|
{
|
||||||
|
const auto B = coords.size(0);
|
||||||
|
const auto N = coords.size(1);
|
||||||
|
|
||||||
|
const auto H1 = fmap1.size(1);
|
||||||
|
const auto W1 = fmap1.size(2);
|
||||||
|
const auto H2 = fmap2.size(1);
|
||||||
|
const auto W2 = fmap2.size(2);
|
||||||
|
const auto C = fmap1.size(3);
|
||||||
|
|
||||||
|
auto opts = fmap1.options();
|
||||||
|
auto fmap1_grad = torch::zeros({B, H1, W1, C}, opts);
|
||||||
|
auto fmap2_grad = torch::zeros({B, H2, W2, C}, opts);
|
||||||
|
auto coords_grad = torch::zeros({B, N, H1, W1, 2}, opts);
|
||||||
|
|
||||||
|
const dim3 blocks(B, (H1+BLOCK_H-1)/BLOCK_H, (W1+BLOCK_W-1)/BLOCK_W);
|
||||||
|
const dim3 threads(BLOCK_H, BLOCK_W);
|
||||||
|
|
||||||
|
|
||||||
|
corr_backward_kernel<float><<<blocks, threads>>>(
|
||||||
|
fmap1.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
|
||||||
|
fmap2.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
|
||||||
|
coords.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
|
||||||
|
corr_grad.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
|
||||||
|
fmap1_grad.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
|
||||||
|
fmap2_grad.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
|
||||||
|
coords_grad.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
|
||||||
|
radius);
|
||||||
|
|
||||||
|
return {fmap1_grad, fmap2_grad, coords_grad};
|
||||||
|
}
|
15
alt_cuda_corr/setup.py
Normal file
15
alt_cuda_corr/setup.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from setuptools import setup
|
||||||
|
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||||
|
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name='correlation',
|
||||||
|
ext_modules=[
|
||||||
|
CUDAExtension('alt_cuda_corr',
|
||||||
|
sources=['correlation.cpp', 'correlation_kernel.cu'],
|
||||||
|
extra_compile_args={'cxx': [], 'nvcc': ['-O3']}),
|
||||||
|
],
|
||||||
|
cmdclass={
|
||||||
|
'build_ext': BuildExtension
|
||||||
|
})
|
||||||
|
|
57
core/corr.py
57
core/corr.py
@ -2,6 +2,12 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from utils.utils import bilinear_sampler, coords_grid
|
from utils.utils import bilinear_sampler, coords_grid
|
||||||
|
|
||||||
|
try:
|
||||||
|
import alt_cuda_corr
|
||||||
|
except:
|
||||||
|
# alt_cuda_corr is not compiled
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class CorrBlock:
|
class CorrBlock:
|
||||||
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
||||||
@ -43,7 +49,6 @@ class CorrBlock:
|
|||||||
out = torch.cat(out_pyramid, dim=-1)
|
out = torch.cat(out_pyramid, dim=-1)
|
||||||
return out.permute(0, 3, 1, 2).contiguous().float()
|
return out.permute(0, 3, 1, 2).contiguous().float()
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def corr(fmap1, fmap2):
|
def corr(fmap1, fmap2):
|
||||||
batch, dim, ht, wd = fmap1.shape
|
batch, dim, ht, wd = fmap1.shape
|
||||||
@ -54,3 +59,53 @@ class CorrBlock:
|
|||||||
corr = corr.view(batch, ht, wd, 1, ht, wd)
|
corr = corr.view(batch, ht, wd, 1, ht, wd)
|
||||||
return corr / torch.sqrt(torch.tensor(dim).float())
|
return corr / torch.sqrt(torch.tensor(dim).float())
|
||||||
|
|
||||||
|
|
||||||
|
class CorrLayer(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, fmap1, fmap2, coords, r):
|
||||||
|
fmap1 = fmap1.contiguous()
|
||||||
|
fmap2 = fmap2.contiguous()
|
||||||
|
coords = coords.contiguous()
|
||||||
|
ctx.save_for_backward(fmap1, fmap2, coords)
|
||||||
|
ctx.r = r
|
||||||
|
corr, = correlation_cudaz.forward(fmap1, fmap2, coords, ctx.r)
|
||||||
|
return corr
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_corr):
|
||||||
|
fmap1, fmap2, coords = ctx.saved_tensors
|
||||||
|
grad_corr = grad_corr.contiguous()
|
||||||
|
fmap1_grad, fmap2_grad, coords_grad = \
|
||||||
|
correlation_cudaz.backward(fmap1, fmap2, coords, grad_corr, ctx.r)
|
||||||
|
return fmap1_grad, fmap2_grad, coords_grad, None
|
||||||
|
|
||||||
|
|
||||||
|
class AlternateCorrBlock:
|
||||||
|
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
||||||
|
self.num_levels = num_levels
|
||||||
|
self.radius = radius
|
||||||
|
|
||||||
|
self.pyramid = [(fmap1, fmap2)]
|
||||||
|
for i in range(self.num_levels):
|
||||||
|
fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
|
||||||
|
fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
|
||||||
|
self.pyramid.append((fmap1, fmap2))
|
||||||
|
|
||||||
|
def __call__(self, coords):
|
||||||
|
|
||||||
|
coords = coords.permute(0, 2, 3, 1)
|
||||||
|
B, H, W, _ = coords.shape
|
||||||
|
|
||||||
|
corr_list = []
|
||||||
|
for i in range(self.num_levels):
|
||||||
|
r = self.radius
|
||||||
|
fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1)
|
||||||
|
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1)
|
||||||
|
|
||||||
|
coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
|
||||||
|
corr = alt_cuda_corr(fmap1_i, fmap2_i, coords_i, r)
|
||||||
|
corr_list.append(corr.squeeze(1))
|
||||||
|
|
||||||
|
corr = torch.stack(corr_list, dim=1)
|
||||||
|
corr = corr.reshape(B, -1, H, W)
|
||||||
|
return corr / 16.0
|
||||||
|
@ -5,7 +5,7 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
from update import BasicUpdateBlock, SmallUpdateBlock
|
from update import BasicUpdateBlock, SmallUpdateBlock
|
||||||
from extractor import BasicEncoder, SmallEncoder
|
from extractor import BasicEncoder, SmallEncoder
|
||||||
from corr import CorrBlock
|
from corr import CorrBlock, AlternateCorrBlock
|
||||||
from utils.utils import bilinear_sampler, coords_grid, upflow8
|
from utils.utils import bilinear_sampler, coords_grid, upflow8
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -41,6 +41,9 @@ class RAFT(nn.Module):
|
|||||||
if 'dropout' not in args._get_kwargs():
|
if 'dropout' not in args._get_kwargs():
|
||||||
args.dropout = 0
|
args.dropout = 0
|
||||||
|
|
||||||
|
if 'alternate_corr' not in args._get_kwargs():
|
||||||
|
args.alternate_corr = False
|
||||||
|
|
||||||
# feature network, context network, and update block
|
# feature network, context network, and update block
|
||||||
if args.small:
|
if args.small:
|
||||||
self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
|
self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
|
||||||
@ -99,6 +102,9 @@ class RAFT(nn.Module):
|
|||||||
|
|
||||||
fmap1 = fmap1.float()
|
fmap1 = fmap1.float()
|
||||||
fmap2 = fmap2.float()
|
fmap2 = fmap2.float()
|
||||||
|
if self.args.alternate_corr:
|
||||||
|
corr_fn = CorrBlockAlternate(fmap1, fmap2, radius=self.args.corr_radius)
|
||||||
|
else:
|
||||||
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
||||||
|
|
||||||
# run the context network
|
# run the context network
|
||||||
|
@ -4,8 +4,10 @@ import math
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import torch
|
cv2.setNumThreads(0)
|
||||||
|
cv2.ocl.setUseOpenCL(False)
|
||||||
|
|
||||||
|
import torch
|
||||||
from torchvision.transforms import ColorJitter
|
from torchvision.transforms import ColorJitter
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
@ -1,3 +1,6 @@
|
|||||||
|
# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
|
||||||
|
|
||||||
|
|
||||||
# MIT License
|
# MIT License
|
||||||
#
|
#
|
||||||
# Copyright (c) 2018 Tom Runia
|
# Copyright (c) 2018 Tom Runia
|
||||||
@ -12,21 +15,20 @@
|
|||||||
# Author: Tom Runia
|
# Author: Tom Runia
|
||||||
# Date Created: 2018-08-03
|
# Date Created: 2018-08-03
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def make_colorwheel():
|
def make_colorwheel():
|
||||||
'''
|
"""
|
||||||
Generates a color wheel for optical flow visualization as presented in:
|
Generates a color wheel for optical flow visualization as presented in:
|
||||||
Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
|
Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
|
||||||
URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
|
URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
|
||||||
According to the C++ source code of Daniel Scharstein
|
|
||||||
According to the Matlab source code of Deqing Sun
|
Code follows the original C++ source code of Daniel Scharstein.
|
||||||
'''
|
Code follows the the Matlab source code of Deqing Sun.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Color wheel
|
||||||
|
"""
|
||||||
|
|
||||||
RY = 15
|
RY = 15
|
||||||
YG = 6
|
YG = 6
|
||||||
@ -65,211 +67,66 @@ def make_colorwheel():
|
|||||||
return colorwheel
|
return colorwheel
|
||||||
|
|
||||||
|
|
||||||
def flow_compute_color(u, v, convert_to_bgr=False):
|
def flow_uv_to_colors(u, v, convert_to_bgr=False):
|
||||||
'''
|
"""
|
||||||
Applies the flow color wheel to (possibly clipped) flow components u and v.
|
Applies the flow color wheel to (possibly clipped) flow components u and v.
|
||||||
|
|
||||||
According to the C++ source code of Daniel Scharstein
|
According to the C++ source code of Daniel Scharstein
|
||||||
According to the Matlab source code of Deqing Sun
|
According to the Matlab source code of Deqing Sun
|
||||||
:param u: np.ndarray, input horizontal flow
|
|
||||||
:param v: np.ndarray, input vertical flow
|
|
||||||
:param convert_to_bgr: bool, whether to change ordering and output BGR instead of RGB
|
|
||||||
:return:
|
|
||||||
'''
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
u (np.ndarray): Input horizontal flow of shape [H,W]
|
||||||
|
v (np.ndarray): Input vertical flow of shape [H,W]
|
||||||
|
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Flow visualization image of shape [H,W,3]
|
||||||
|
"""
|
||||||
flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
|
flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
|
||||||
|
|
||||||
colorwheel = make_colorwheel() # shape [55x3]
|
colorwheel = make_colorwheel() # shape [55x3]
|
||||||
ncols = colorwheel.shape[0]
|
ncols = colorwheel.shape[0]
|
||||||
|
|
||||||
rad = np.sqrt(np.square(u) + np.square(v))
|
rad = np.sqrt(np.square(u) + np.square(v))
|
||||||
a = np.arctan2(-v, -u)/np.pi
|
a = np.arctan2(-v, -u)/np.pi
|
||||||
|
fk = (a+1) / 2*(ncols-1)
|
||||||
fk = (a+1) / 2*(ncols-1) + 1
|
|
||||||
k0 = np.floor(fk).astype(np.int32)
|
k0 = np.floor(fk).astype(np.int32)
|
||||||
k1 = k0 + 1
|
k1 = k0 + 1
|
||||||
k1[k1 == ncols] = 1
|
k1[k1 == ncols] = 0
|
||||||
f = fk - k0
|
f = fk - k0
|
||||||
|
|
||||||
for i in range(colorwheel.shape[1]):
|
for i in range(colorwheel.shape[1]):
|
||||||
|
|
||||||
tmp = colorwheel[:,i]
|
tmp = colorwheel[:,i]
|
||||||
col0 = tmp[k0] / 255.0
|
col0 = tmp[k0] / 255.0
|
||||||
col1 = tmp[k1] / 255.0
|
col1 = tmp[k1] / 255.0
|
||||||
col = (1-f)*col0 + f*col1
|
col = (1-f)*col0 + f*col1
|
||||||
|
|
||||||
idx = (rad <= 1)
|
idx = (rad <= 1)
|
||||||
col[idx] = 1 - rad[idx] * (1-col[idx])
|
col[idx] = 1 - rad[idx] * (1-col[idx])
|
||||||
col[~idx] = col[~idx] * 0.75 # out of range?
|
col[~idx] = col[~idx] * 0.75 # out of range
|
||||||
|
|
||||||
# Note the 2-i => BGR instead of RGB
|
# Note the 2-i => BGR instead of RGB
|
||||||
ch_idx = 2-i if convert_to_bgr else i
|
ch_idx = 2-i if convert_to_bgr else i
|
||||||
flow_image[:,:,ch_idx] = np.floor(255 * col)
|
flow_image[:,:,ch_idx] = np.floor(255 * col)
|
||||||
|
|
||||||
return flow_image
|
return flow_image
|
||||||
|
|
||||||
|
|
||||||
def flow_to_color(flow_uv, clip_flow=None, convert_to_bgr=False):
|
def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
|
||||||
'''
|
"""
|
||||||
Expects a two dimensional flow image of shape [H,W,2]
|
Expects a two dimensional flow image of shape.
|
||||||
According to the C++ source code of Daniel Scharstein
|
|
||||||
According to the Matlab source code of Deqing Sun
|
|
||||||
:param flow_uv: np.ndarray of shape [H,W,2]
|
|
||||||
:param clip_flow: float, maximum clipping value for flow
|
|
||||||
:return:
|
|
||||||
'''
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
|
||||||
|
clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
|
||||||
|
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Flow visualization image of shape [H,W,3]
|
||||||
|
"""
|
||||||
assert flow_uv.ndim == 3, 'input flow must have three dimensions'
|
assert flow_uv.ndim == 3, 'input flow must have three dimensions'
|
||||||
assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
|
assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
|
||||||
|
|
||||||
if clip_flow is not None:
|
if clip_flow is not None:
|
||||||
flow_uv = np.clip(flow_uv, 0, clip_flow)
|
flow_uv = np.clip(flow_uv, 0, clip_flow)
|
||||||
|
|
||||||
u = flow_uv[:,:,0]
|
u = flow_uv[:,:,0]
|
||||||
v = flow_uv[:,:,1]
|
v = flow_uv[:,:,1]
|
||||||
|
|
||||||
rad = np.sqrt(np.square(u) + np.square(v))
|
rad = np.sqrt(np.square(u) + np.square(v))
|
||||||
rad_max = np.max(rad)
|
rad_max = np.max(rad)
|
||||||
|
|
||||||
epsilon = 1e-5
|
epsilon = 1e-5
|
||||||
u = u / (rad_max + epsilon)
|
u = u / (rad_max + epsilon)
|
||||||
v = v / (rad_max + epsilon)
|
v = v / (rad_max + epsilon)
|
||||||
|
return flow_uv_to_colors(u, v, convert_to_bgr)
|
||||||
return flow_compute_color(u, v, convert_to_bgr)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
UNKNOWN_FLOW_THRESH = 1e7
|
|
||||||
SMALLFLOW = 0.0
|
|
||||||
LARGEFLOW = 1e8
|
|
||||||
|
|
||||||
def make_color_wheel():
|
|
||||||
"""
|
|
||||||
Generate color wheel according Middlebury color code
|
|
||||||
:return: Color wheel
|
|
||||||
"""
|
|
||||||
RY = 15
|
|
||||||
YG = 6
|
|
||||||
GC = 4
|
|
||||||
CB = 11
|
|
||||||
BM = 13
|
|
||||||
MR = 6
|
|
||||||
|
|
||||||
ncols = RY + YG + GC + CB + BM + MR
|
|
||||||
|
|
||||||
colorwheel = np.zeros([ncols, 3])
|
|
||||||
|
|
||||||
col = 0
|
|
||||||
|
|
||||||
# RY
|
|
||||||
colorwheel[0:RY, 0] = 255
|
|
||||||
colorwheel[0:RY, 1] = np.transpose(np.floor(255*np.arange(0, RY) / RY))
|
|
||||||
col += RY
|
|
||||||
|
|
||||||
# YG
|
|
||||||
colorwheel[col:col+YG, 0] = 255 - np.transpose(np.floor(255*np.arange(0, YG) / YG))
|
|
||||||
colorwheel[col:col+YG, 1] = 255
|
|
||||||
col += YG
|
|
||||||
|
|
||||||
# GC
|
|
||||||
colorwheel[col:col+GC, 1] = 255
|
|
||||||
colorwheel[col:col+GC, 2] = np.transpose(np.floor(255*np.arange(0, GC) / GC))
|
|
||||||
col += GC
|
|
||||||
|
|
||||||
# CB
|
|
||||||
colorwheel[col:col+CB, 1] = 255 - np.transpose(np.floor(255*np.arange(0, CB) / CB))
|
|
||||||
colorwheel[col:col+CB, 2] = 255
|
|
||||||
col += CB
|
|
||||||
|
|
||||||
# BM
|
|
||||||
colorwheel[col:col+BM, 2] = 255
|
|
||||||
colorwheel[col:col+BM, 0] = np.transpose(np.floor(255*np.arange(0, BM) / BM))
|
|
||||||
col += + BM
|
|
||||||
|
|
||||||
# MR
|
|
||||||
colorwheel[col:col+MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))
|
|
||||||
colorwheel[col:col+MR, 0] = 255
|
|
||||||
|
|
||||||
return colorwheel
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def compute_color(u, v):
|
|
||||||
"""
|
|
||||||
compute optical flow color map
|
|
||||||
:param u: optical flow horizontal map
|
|
||||||
:param v: optical flow vertical map
|
|
||||||
:return: optical flow in color code
|
|
||||||
"""
|
|
||||||
[h, w] = u.shape
|
|
||||||
img = np.zeros([h, w, 3])
|
|
||||||
nanIdx = np.isnan(u) | np.isnan(v)
|
|
||||||
u[nanIdx] = 0
|
|
||||||
v[nanIdx] = 0
|
|
||||||
|
|
||||||
colorwheel = make_color_wheel()
|
|
||||||
ncols = np.size(colorwheel, 0)
|
|
||||||
|
|
||||||
rad = np.sqrt(u**2+v**2)
|
|
||||||
|
|
||||||
a = np.arctan2(-v, -u) / np.pi
|
|
||||||
|
|
||||||
fk = (a+1) / 2 * (ncols - 1) + 1
|
|
||||||
|
|
||||||
k0 = np.floor(fk).astype(int)
|
|
||||||
|
|
||||||
k1 = k0 + 1
|
|
||||||
k1[k1 == ncols+1] = 1
|
|
||||||
f = fk - k0
|
|
||||||
|
|
||||||
for i in range(0, np.size(colorwheel,1)):
|
|
||||||
tmp = colorwheel[:, i]
|
|
||||||
col0 = tmp[k0-1] / 255
|
|
||||||
col1 = tmp[k1-1] / 255
|
|
||||||
col = (1-f) * col0 + f * col1
|
|
||||||
|
|
||||||
idx = rad <= 1
|
|
||||||
col[idx] = 1-rad[idx]*(1-col[idx])
|
|
||||||
notidx = np.logical_not(idx)
|
|
||||||
|
|
||||||
col[notidx] *= 0.75
|
|
||||||
img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx)))
|
|
||||||
|
|
||||||
return img
|
|
||||||
|
|
||||||
# from https://github.com/gengshan-y/VCN
|
|
||||||
def flow_to_image(flow):
|
|
||||||
"""
|
|
||||||
Convert flow into middlebury color code image
|
|
||||||
:param flow: optical flow map
|
|
||||||
:return: optical flow image in middlebury color
|
|
||||||
"""
|
|
||||||
u = flow[:, :, 0]
|
|
||||||
v = flow[:, :, 1]
|
|
||||||
|
|
||||||
maxu = -999.
|
|
||||||
maxv = -999.
|
|
||||||
minu = 999.
|
|
||||||
minv = 999.
|
|
||||||
|
|
||||||
idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH)
|
|
||||||
u[idxUnknow] = 0
|
|
||||||
v[idxUnknow] = 0
|
|
||||||
|
|
||||||
maxu = max(maxu, np.max(u))
|
|
||||||
minu = min(minu, np.min(u))
|
|
||||||
|
|
||||||
maxv = max(maxv, np.max(v))
|
|
||||||
minv = min(minv, np.min(v))
|
|
||||||
|
|
||||||
rad = np.sqrt(u ** 2 + v ** 2)
|
|
||||||
maxrad = max(-1, np.max(rad))
|
|
||||||
|
|
||||||
u = u/(maxrad + np.finfo(float).eps)
|
|
||||||
v = v/(maxrad + np.finfo(float).eps)
|
|
||||||
|
|
||||||
img = compute_color(u, v)
|
|
||||||
|
|
||||||
idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2)
|
|
||||||
img[idx] = 0
|
|
||||||
|
|
||||||
return np.uint8(img)
|
|
@ -2,7 +2,10 @@ import numpy as np
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from os.path import *
|
from os.path import *
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
|
cv2.setNumThreads(0)
|
||||||
|
cv2.ocl.setUseOpenCL(False)
|
||||||
|
|
||||||
TAG_CHAR = np.array([202021.25], np.float32)
|
TAG_CHAR = np.array([202021.25], np.float32)
|
||||||
|
|
||||||
|
@ -6,10 +6,13 @@ from scipy import interpolate
|
|||||||
|
|
||||||
class InputPadder:
|
class InputPadder:
|
||||||
""" Pads images such that dimensions are divisible by 8 """
|
""" Pads images such that dimensions are divisible by 8 """
|
||||||
def __init__(self, dims):
|
def __init__(self, dims, mode='sintel'):
|
||||||
self.ht, self.wd = dims[-2:]
|
self.ht, self.wd = dims[-2:]
|
||||||
pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
|
pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
|
||||||
pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
|
pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
|
||||||
|
if mode == 'sintel':
|
||||||
|
self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
|
||||||
|
else:
|
||||||
self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
|
self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
|
||||||
|
|
||||||
def pad(self, *inputs):
|
def pad(self, *inputs):
|
||||||
@ -42,10 +45,10 @@ def forward_interpolate(flow):
|
|||||||
dy = dy[valid]
|
dy = dy[valid]
|
||||||
|
|
||||||
flow_x = interpolate.griddata(
|
flow_x = interpolate.griddata(
|
||||||
(x1, y1), dx, (x0, y0), method='cubic', fill_value=0)
|
(x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
|
||||||
|
|
||||||
flow_y = interpolate.griddata(
|
flow_y = interpolate.griddata(
|
||||||
(x1, y1), dy, (x0, y0), method='cubic', fill_value=0)
|
(x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
|
||||||
|
|
||||||
flow = np.stack([flow_x, flow_y], axis=0)
|
flow = np.stack([flow_x, flow_y], axis=0)
|
||||||
return torch.from_numpy(flow).float()
|
return torch.from_numpy(flow).float()
|
||||||
@ -68,7 +71,6 @@ def bilinear_sampler(img, coords, mode='bilinear', mask=False):
|
|||||||
return img
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def coords_grid(batch, ht, wd):
|
def coords_grid(batch, ht, wd):
|
||||||
coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
|
coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
|
||||||
coords = torch.stack(coords[::-1], dim=0).float()
|
coords = torch.stack(coords[::-1], dim=0).float()
|
||||||
|
1
demo.py
1
demo.py
@ -74,6 +74,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--path', help="dataset for evaluation")
|
parser.add_argument('--path', help="dataset for evaluation")
|
||||||
parser.add_argument('--small', action='store_true', help='use small model')
|
parser.add_argument('--small', action='store_true', help='use small model')
|
||||||
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
|
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
|
||||||
|
parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
demo(args)
|
demo(args)
|
||||||
|
@ -61,7 +61,7 @@ def create_kitti_submission(model, iters=24, output_path='kitti_submission'):
|
|||||||
|
|
||||||
for test_id in range(len(test_dataset)):
|
for test_id in range(len(test_dataset)):
|
||||||
image1, image2, (frame_id, ) = test_dataset[test_id]
|
image1, image2, (frame_id, ) = test_dataset[test_id]
|
||||||
padder = InputPadder(image1.shape)
|
padder = InputPadder(image1.shape, mode='kitti')
|
||||||
image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
|
image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
|
||||||
|
|
||||||
_, flow_pr = model(image1, image2, iters=iters, test_mode=True)
|
_, flow_pr = model(image1, image2, iters=iters, test_mode=True)
|
||||||
@ -139,7 +139,7 @@ def validate_kitti(model, iters=24):
|
|||||||
image1 = image1[None].cuda()
|
image1 = image1[None].cuda()
|
||||||
image2 = image2[None].cuda()
|
image2 = image2[None].cuda()
|
||||||
|
|
||||||
padder = InputPadder(image1.shape)
|
padder = InputPadder(image1.shape, mode='kitti')
|
||||||
image1, image2 = padder.pad(image1, image2)
|
image1, image2 = padder.pad(image1, image2)
|
||||||
|
|
||||||
flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True)
|
flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True)
|
||||||
@ -172,6 +172,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--dataset', help="dataset for evaluation")
|
parser.add_argument('--dataset', help="dataset for evaluation")
|
||||||
parser.add_argument('--small', action='store_true', help='use small model')
|
parser.add_argument('--small', action='store_true', help='use small model')
|
||||||
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
|
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
|
||||||
|
parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
model = torch.nn.DataParallel(RAFT(args))
|
model = torch.nn.DataParallel(RAFT(args))
|
||||||
|
Loading…
Reference in New Issue
Block a user