added cuda extension for efficent implementation

This commit is contained in:
Zach Teed 2020-08-22 18:49:24 -06:00
parent 5b1f510d6b
commit c86b3dc8f3
13 changed files with 519 additions and 191 deletions

1
.gitignore vendored
View File

@ -5,3 +5,4 @@ datasets
pytorch_env pytorch_env
models models
build build
correlation.egg-info

View File

@ -31,6 +31,13 @@ You can demo a trained model on a sequence of frames
python demo.py --model=models/raft-things.pth --path=demo-frames python demo.py --model=models/raft-things.pth --path=demo-frames
``` ```
## (Optional) Efficent Implementation
You can optionally use our alternate (efficent) implementation by compiling the provided cuda extension
```Shell
cd alt_cuda_corr && python setup.py install && cd ..
```
and running `demo.py` and `evaluate.py` with the `--alternate_corr` flag.Note, this implementation is somewhat slower than all-pairs, but uses significantly less GPU memory during the forward pass.
## Required Data ## Required Data
To evaluate/train RAFT, you will need to download the required datasets. To evaluate/train RAFT, you will need to download the required datasets.

View File

@ -0,0 +1,54 @@
#include <torch/extension.h>
#include <vector>
// CUDA forward declarations
std::vector<torch::Tensor> corr_cuda_forward(
torch::Tensor fmap1,
torch::Tensor fmap2,
torch::Tensor coords,
int radius);
std::vector<torch::Tensor> corr_cuda_backward(
torch::Tensor fmap1,
torch::Tensor fmap2,
torch::Tensor coords,
torch::Tensor corr_grad,
int radius);
// C++ interface
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> corr_forward(
torch::Tensor fmap1,
torch::Tensor fmap2,
torch::Tensor coords,
int radius) {
CHECK_INPUT(fmap1);
CHECK_INPUT(fmap2);
CHECK_INPUT(coords);
return corr_cuda_forward(fmap1, fmap2, coords, radius);
}
std::vector<torch::Tensor> corr_backward(
torch::Tensor fmap1,
torch::Tensor fmap2,
torch::Tensor coords,
torch::Tensor corr_grad,
int radius) {
CHECK_INPUT(fmap1);
CHECK_INPUT(fmap2);
CHECK_INPUT(coords);
CHECK_INPUT(corr_grad);
return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &corr_forward, "CORR forward");
m.def("backward", &corr_backward, "CORR backward");
}

View File

@ -0,0 +1,324 @@
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
#define BLOCK_H 4
#define BLOCK_W 8
#define BLOCK_HW BLOCK_H * BLOCK_W
#define CHANNEL_STRIDE 32
__forceinline__ __device__
bool within_bounds(int h, int w, int H, int W) {
return h >= 0 && h < H && w >= 0 && w < W;
}
template <typename scalar_t>
__global__ void corr_forward_kernel(
const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap1,
const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap2,
const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> coords,
torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> corr,
int r)
{
const int b = blockIdx.x;
const int h0 = blockIdx.y * blockDim.x;
const int w0 = blockIdx.z * blockDim.y;
const int tid = threadIdx.x * blockDim.y + threadIdx.y;
const int H1 = fmap1.size(1);
const int W1 = fmap1.size(2);
const int H2 = fmap2.size(1);
const int W2 = fmap2.size(2);
const int N = coords.size(1);
const int C = fmap1.size(3);
__shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1];
__shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1];
__shared__ scalar_t x2s[BLOCK_HW];
__shared__ scalar_t y2s[BLOCK_HW];
for (int c=0; c<C; c+=CHANNEL_STRIDE) {
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
int k1 = k + tid / CHANNEL_STRIDE;
int h1 = h0 + k1 / BLOCK_W;
int w1 = w0 + k1 % BLOCK_W;
int c1 = tid % CHANNEL_STRIDE;
auto fptr = fmap1[b][h1][w1];
if (within_bounds(h1, w1, H1, W1))
f1[c1][k1] = fptr[c+c1];
else
f1[c1][k1] = 0.0;
}
__syncthreads();
for (int n=0; n<N; n++) {
int h1 = h0 + threadIdx.x;
int w1 = w0 + threadIdx.y;
if (within_bounds(h1, w1, H1, W1)) {
x2s[tid] = coords[b][n][h1][w1][0];
y2s[tid] = coords[b][n][h1][w1][1];
}
scalar_t dx = x2s[tid] - floor(x2s[tid]);
scalar_t dy = y2s[tid] - floor(y2s[tid]);
int rd = 2*r + 1;
for (int iy=0; iy<rd+1; iy++) {
for (int ix=0; ix<rd+1; ix++) {
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
int k1 = k + tid / CHANNEL_STRIDE;
int h2 = static_cast<int>(floor(y2s[k1]))-r+iy;
int w2 = static_cast<int>(floor(x2s[k1]))-r+ix;
int c2 = tid % CHANNEL_STRIDE;
auto fptr = fmap2[b][h2][w2];
if (within_bounds(h2, w2, H2, W2))
f2[c2][k1] = fptr[c+c2];
else
f2[c2][k1] = 0.0;
}
__syncthreads();
scalar_t s = 0.0;
for (int k=0; k<CHANNEL_STRIDE; k++)
s += f1[k][tid] * f2[k][tid];
int ix_nw = H1*W1*((iy-1) + rd*(ix-1));
int ix_ne = H1*W1*((iy-1) + rd*ix);
int ix_sw = H1*W1*(iy + rd*(ix-1));
int ix_se = H1*W1*(iy + rd*ix);
scalar_t nw = s * (dy) * (dx);
scalar_t ne = s * (dy) * (1-dx);
scalar_t sw = s * (1-dy) * (dx);
scalar_t se = s * (1-dy) * (1-dx);
scalar_t* corr_ptr = &corr[b][n][0][h1][w1];
if (iy > 0 && ix > 0 && within_bounds(h1, w1, H1, W1))
*(corr_ptr + ix_nw) += nw;
if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1))
*(corr_ptr + ix_ne) += ne;
if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1))
*(corr_ptr + ix_sw) += sw;
if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1))
*(corr_ptr + ix_se) += se;
}
}
}
}
}
template <typename scalar_t>
__global__ void corr_backward_kernel(
const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap1,
const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap2,
const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> coords,
const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> corr_grad,
torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap1_grad,
torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap2_grad,
torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> coords_grad,
int r)
{
const int b = blockIdx.x;
const int h0 = blockIdx.y * blockDim.x;
const int w0 = blockIdx.z * blockDim.y;
const int tid = threadIdx.x * blockDim.y + threadIdx.y;
const int H1 = fmap1.size(1);
const int W1 = fmap1.size(2);
const int H2 = fmap2.size(1);
const int W2 = fmap2.size(2);
const int N = coords.size(1);
const int C = fmap1.size(3);
__shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1];
__shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1];
__shared__ scalar_t f1_grad[CHANNEL_STRIDE][BLOCK_HW+1];
__shared__ scalar_t f2_grad[CHANNEL_STRIDE][BLOCK_HW+1];
__shared__ scalar_t x2s[BLOCK_HW];
__shared__ scalar_t y2s[BLOCK_HW];
for (int c=0; c<C; c+=CHANNEL_STRIDE) {
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
int k1 = k + tid / CHANNEL_STRIDE;
int h1 = h0 + k1 / BLOCK_W;
int w1 = w0 + k1 % BLOCK_W;
int c1 = tid % CHANNEL_STRIDE;
auto fptr = fmap1[b][h1][w1];
if (within_bounds(h1, w1, H1, W1))
f1[c1][k1] = fptr[c+c1];
else
f1[c1][k1] = 0.0;
f1_grad[c1][k1] = 0.0;
}
__syncthreads();
int h1 = h0 + threadIdx.x;
int w1 = w0 + threadIdx.y;
for (int n=0; n<N; n++) {
x2s[tid] = coords[b][n][h1][w1][0];
y2s[tid] = coords[b][n][h1][w1][1];
scalar_t dx = x2s[tid] - floor(x2s[tid]);
scalar_t dy = y2s[tid] - floor(y2s[tid]);
int rd = 2*r + 1;
for (int iy=0; iy<rd+1; iy++) {
for (int ix=0; ix<rd+1; ix++) {
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
int k1 = k + tid / CHANNEL_STRIDE;
int h2 = static_cast<int>(floor(y2s[k1]))-r+iy;
int w2 = static_cast<int>(floor(x2s[k1]))-r+ix;
int c2 = tid % CHANNEL_STRIDE;
auto fptr = fmap2[b][h2][w2];
if (within_bounds(h2, w2, H2, W2))
f2[c2][k1] = fptr[c+c2];
else
f2[c2][k1] = 0.0;
f2_grad[c2][k1] = 0.0;
}
__syncthreads();
const scalar_t* grad_ptr = &corr_grad[b][n][0][h1][w1];
scalar_t g = 0.0;
int ix_nw = H1*W1*((iy-1) + rd*(ix-1));
int ix_ne = H1*W1*((iy-1) + rd*ix);
int ix_sw = H1*W1*(iy + rd*(ix-1));
int ix_se = H1*W1*(iy + rd*ix);
if (iy > 0 && ix > 0 && within_bounds(h1, w1, H1, W1))
g += *(grad_ptr + ix_nw) * dy * dx;
if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1))
g += *(grad_ptr + ix_ne) * dy * (1-dx);
if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1))
g += *(grad_ptr + ix_sw) * (1-dy) * dx;
if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1))
g += *(grad_ptr + ix_se) * (1-dy) * (1-dx);
for (int k=0; k<CHANNEL_STRIDE; k++) {
f1_grad[k][tid] += g * f2[k][tid];
f2_grad[k][tid] += g * f1[k][tid];
}
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
int k1 = k + tid / CHANNEL_STRIDE;
int h2 = static_cast<int>(floor(y2s[k1]))-r+iy;
int w2 = static_cast<int>(floor(x2s[k1]))-r+ix;
int c2 = tid % CHANNEL_STRIDE;
scalar_t* fptr = &fmap2_grad[b][h2][w2][0];
if (within_bounds(h2, w2, H2, W2))
atomicAdd(fptr+c+c2, f2_grad[c2][k1]);
}
}
}
}
__syncthreads();
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
int k1 = k + tid / CHANNEL_STRIDE;
int h1 = h0 + k1 / BLOCK_W;
int w1 = w0 + k1 % BLOCK_W;
int c1 = tid % CHANNEL_STRIDE;
scalar_t* fptr = &fmap1_grad[b][h1][w1][0];
if (within_bounds(h1, w1, H1, W1))
fptr[c+c1] += f1_grad[c1][k1];
}
}
}
std::vector<torch::Tensor> corr_cuda_forward(
torch::Tensor fmap1,
torch::Tensor fmap2,
torch::Tensor coords,
int radius)
{
const auto B = coords.size(0);
const auto N = coords.size(1);
const auto H = coords.size(2);
const auto W = coords.size(3);
const auto rd = 2 * radius + 1;
auto opts = fmap1.options();
auto corr = torch::zeros({B, N, rd*rd, H, W}, opts);
const dim3 blocks(B, (H+BLOCK_H-1)/BLOCK_H, (W+BLOCK_W-1)/BLOCK_W);
const dim3 threads(BLOCK_H, BLOCK_W);
corr_forward_kernel<float><<<blocks, threads>>>(
fmap1.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
fmap2.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
coords.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
corr.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
radius);
return {corr};
}
std::vector<torch::Tensor> corr_cuda_backward(
torch::Tensor fmap1,
torch::Tensor fmap2,
torch::Tensor coords,
torch::Tensor corr_grad,
int radius)
{
const auto B = coords.size(0);
const auto N = coords.size(1);
const auto H1 = fmap1.size(1);
const auto W1 = fmap1.size(2);
const auto H2 = fmap2.size(1);
const auto W2 = fmap2.size(2);
const auto C = fmap1.size(3);
auto opts = fmap1.options();
auto fmap1_grad = torch::zeros({B, H1, W1, C}, opts);
auto fmap2_grad = torch::zeros({B, H2, W2, C}, opts);
auto coords_grad = torch::zeros({B, N, H1, W1, 2}, opts);
const dim3 blocks(B, (H1+BLOCK_H-1)/BLOCK_H, (W1+BLOCK_W-1)/BLOCK_W);
const dim3 threads(BLOCK_H, BLOCK_W);
corr_backward_kernel<float><<<blocks, threads>>>(
fmap1.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
fmap2.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
coords.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
corr_grad.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
fmap1_grad.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
fmap2_grad.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
coords_grad.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
radius);
return {fmap1_grad, fmap2_grad, coords_grad};
}

15
alt_cuda_corr/setup.py Normal file
View File

@ -0,0 +1,15 @@
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='correlation',
ext_modules=[
CUDAExtension('alt_cuda_corr',
sources=['correlation.cpp', 'correlation_kernel.cu'],
extra_compile_args={'cxx': [], 'nvcc': ['-O3']}),
],
cmdclass={
'build_ext': BuildExtension
})

View File

@ -2,6 +2,12 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from utils.utils import bilinear_sampler, coords_grid from utils.utils import bilinear_sampler, coords_grid
try:
import alt_cuda_corr
except:
# alt_cuda_corr is not compiled
pass
class CorrBlock: class CorrBlock:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4): def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
@ -43,7 +49,6 @@ class CorrBlock:
out = torch.cat(out_pyramid, dim=-1) out = torch.cat(out_pyramid, dim=-1)
return out.permute(0, 3, 1, 2).contiguous().float() return out.permute(0, 3, 1, 2).contiguous().float()
@staticmethod @staticmethod
def corr(fmap1, fmap2): def corr(fmap1, fmap2):
batch, dim, ht, wd = fmap1.shape batch, dim, ht, wd = fmap1.shape
@ -54,3 +59,53 @@ class CorrBlock:
corr = corr.view(batch, ht, wd, 1, ht, wd) corr = corr.view(batch, ht, wd, 1, ht, wd)
return corr / torch.sqrt(torch.tensor(dim).float()) return corr / torch.sqrt(torch.tensor(dim).float())
class CorrLayer(torch.autograd.Function):
@staticmethod
def forward(ctx, fmap1, fmap2, coords, r):
fmap1 = fmap1.contiguous()
fmap2 = fmap2.contiguous()
coords = coords.contiguous()
ctx.save_for_backward(fmap1, fmap2, coords)
ctx.r = r
corr, = correlation_cudaz.forward(fmap1, fmap2, coords, ctx.r)
return corr
@staticmethod
def backward(ctx, grad_corr):
fmap1, fmap2, coords = ctx.saved_tensors
grad_corr = grad_corr.contiguous()
fmap1_grad, fmap2_grad, coords_grad = \
correlation_cudaz.backward(fmap1, fmap2, coords, grad_corr, ctx.r)
return fmap1_grad, fmap2_grad, coords_grad, None
class AlternateCorrBlock:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self.num_levels = num_levels
self.radius = radius
self.pyramid = [(fmap1, fmap2)]
for i in range(self.num_levels):
fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
self.pyramid.append((fmap1, fmap2))
def __call__(self, coords):
coords = coords.permute(0, 2, 3, 1)
B, H, W, _ = coords.shape
corr_list = []
for i in range(self.num_levels):
r = self.radius
fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1)
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1)
coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
corr = alt_cuda_corr(fmap1_i, fmap2_i, coords_i, r)
corr_list.append(corr.squeeze(1))
corr = torch.stack(corr_list, dim=1)
corr = corr.reshape(B, -1, H, W)
return corr / 16.0

View File

@ -5,7 +5,7 @@ import torch.nn.functional as F
from update import BasicUpdateBlock, SmallUpdateBlock from update import BasicUpdateBlock, SmallUpdateBlock
from extractor import BasicEncoder, SmallEncoder from extractor import BasicEncoder, SmallEncoder
from corr import CorrBlock from corr import CorrBlock, AlternateCorrBlock
from utils.utils import bilinear_sampler, coords_grid, upflow8 from utils.utils import bilinear_sampler, coords_grid, upflow8
try: try:
@ -41,6 +41,9 @@ class RAFT(nn.Module):
if 'dropout' not in args._get_kwargs(): if 'dropout' not in args._get_kwargs():
args.dropout = 0 args.dropout = 0
if 'alternate_corr' not in args._get_kwargs():
args.alternate_corr = False
# feature network, context network, and update block # feature network, context network, and update block
if args.small: if args.small:
self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout) self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
@ -99,7 +102,10 @@ class RAFT(nn.Module):
fmap1 = fmap1.float() fmap1 = fmap1.float()
fmap2 = fmap2.float() fmap2 = fmap2.float()
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) if self.args.alternate_corr:
corr_fn = CorrBlockAlternate(fmap1, fmap2, radius=self.args.corr_radius)
else:
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
# run the context network # run the context network
with autocast(enabled=self.args.mixed_precision): with autocast(enabled=self.args.mixed_precision):

View File

@ -4,8 +4,10 @@ import math
from PIL import Image from PIL import Image
import cv2 import cv2
import torch cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)
import torch
from torchvision.transforms import ColorJitter from torchvision.transforms import ColorJitter
import torch.nn.functional as F import torch.nn.functional as F

View File

@ -1,3 +1,6 @@
# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
# MIT License # MIT License
# #
# Copyright (c) 2018 Tom Runia # Copyright (c) 2018 Tom Runia
@ -12,21 +15,20 @@
# Author: Tom Runia # Author: Tom Runia
# Date Created: 2018-08-03 # Date Created: 2018-08-03
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np import numpy as np
def make_colorwheel(): def make_colorwheel():
''' """
Generates a color wheel for optical flow visualization as presented in: Generates a color wheel for optical flow visualization as presented in:
Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
According to the C++ source code of Daniel Scharstein
According to the Matlab source code of Deqing Sun Code follows the original C++ source code of Daniel Scharstein.
''' Code follows the the Matlab source code of Deqing Sun.
Returns:
np.ndarray: Color wheel
"""
RY = 15 RY = 15
YG = 6 YG = 6
@ -65,211 +67,66 @@ def make_colorwheel():
return colorwheel return colorwheel
def flow_compute_color(u, v, convert_to_bgr=False): def flow_uv_to_colors(u, v, convert_to_bgr=False):
''' """
Applies the flow color wheel to (possibly clipped) flow components u and v. Applies the flow color wheel to (possibly clipped) flow components u and v.
According to the C++ source code of Daniel Scharstein According to the C++ source code of Daniel Scharstein
According to the Matlab source code of Deqing Sun According to the Matlab source code of Deqing Sun
:param u: np.ndarray, input horizontal flow
:param v: np.ndarray, input vertical flow
:param convert_to_bgr: bool, whether to change ordering and output BGR instead of RGB
:return:
'''
Args:
u (np.ndarray): Input horizontal flow of shape [H,W]
v (np.ndarray): Input vertical flow of shape [H,W]
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
Returns:
np.ndarray: Flow visualization image of shape [H,W,3]
"""
flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
colorwheel = make_colorwheel() # shape [55x3] colorwheel = make_colorwheel() # shape [55x3]
ncols = colorwheel.shape[0] ncols = colorwheel.shape[0]
rad = np.sqrt(np.square(u) + np.square(v)) rad = np.sqrt(np.square(u) + np.square(v))
a = np.arctan2(-v, -u)/np.pi a = np.arctan2(-v, -u)/np.pi
fk = (a+1) / 2*(ncols-1)
fk = (a+1) / 2*(ncols-1) + 1
k0 = np.floor(fk).astype(np.int32) k0 = np.floor(fk).astype(np.int32)
k1 = k0 + 1 k1 = k0 + 1
k1[k1 == ncols] = 1 k1[k1 == ncols] = 0
f = fk - k0 f = fk - k0
for i in range(colorwheel.shape[1]): for i in range(colorwheel.shape[1]):
tmp = colorwheel[:,i] tmp = colorwheel[:,i]
col0 = tmp[k0] / 255.0 col0 = tmp[k0] / 255.0
col1 = tmp[k1] / 255.0 col1 = tmp[k1] / 255.0
col = (1-f)*col0 + f*col1 col = (1-f)*col0 + f*col1
idx = (rad <= 1) idx = (rad <= 1)
col[idx] = 1 - rad[idx] * (1-col[idx]) col[idx] = 1 - rad[idx] * (1-col[idx])
col[~idx] = col[~idx] * 0.75 # out of range? col[~idx] = col[~idx] * 0.75 # out of range
# Note the 2-i => BGR instead of RGB # Note the 2-i => BGR instead of RGB
ch_idx = 2-i if convert_to_bgr else i ch_idx = 2-i if convert_to_bgr else i
flow_image[:,:,ch_idx] = np.floor(255 * col) flow_image[:,:,ch_idx] = np.floor(255 * col)
return flow_image return flow_image
def flow_to_color(flow_uv, clip_flow=None, convert_to_bgr=False): def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
''' """
Expects a two dimensional flow image of shape [H,W,2] Expects a two dimensional flow image of shape.
According to the C++ source code of Daniel Scharstein
According to the Matlab source code of Deqing Sun
:param flow_uv: np.ndarray of shape [H,W,2]
:param clip_flow: float, maximum clipping value for flow
:return:
'''
Args:
flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
Returns:
np.ndarray: Flow visualization image of shape [H,W,3]
"""
assert flow_uv.ndim == 3, 'input flow must have three dimensions' assert flow_uv.ndim == 3, 'input flow must have three dimensions'
assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
if clip_flow is not None: if clip_flow is not None:
flow_uv = np.clip(flow_uv, 0, clip_flow) flow_uv = np.clip(flow_uv, 0, clip_flow)
u = flow_uv[:,:,0] u = flow_uv[:,:,0]
v = flow_uv[:,:,1] v = flow_uv[:,:,1]
rad = np.sqrt(np.square(u) + np.square(v)) rad = np.sqrt(np.square(u) + np.square(v))
rad_max = np.max(rad) rad_max = np.max(rad)
epsilon = 1e-5 epsilon = 1e-5
u = u / (rad_max + epsilon) u = u / (rad_max + epsilon)
v = v / (rad_max + epsilon) v = v / (rad_max + epsilon)
return flow_uv_to_colors(u, v, convert_to_bgr)
return flow_compute_color(u, v, convert_to_bgr)
UNKNOWN_FLOW_THRESH = 1e7
SMALLFLOW = 0.0
LARGEFLOW = 1e8
def make_color_wheel():
"""
Generate color wheel according Middlebury color code
:return: Color wheel
"""
RY = 15
YG = 6
GC = 4
CB = 11
BM = 13
MR = 6
ncols = RY + YG + GC + CB + BM + MR
colorwheel = np.zeros([ncols, 3])
col = 0
# RY
colorwheel[0:RY, 0] = 255
colorwheel[0:RY, 1] = np.transpose(np.floor(255*np.arange(0, RY) / RY))
col += RY
# YG
colorwheel[col:col+YG, 0] = 255 - np.transpose(np.floor(255*np.arange(0, YG) / YG))
colorwheel[col:col+YG, 1] = 255
col += YG
# GC
colorwheel[col:col+GC, 1] = 255
colorwheel[col:col+GC, 2] = np.transpose(np.floor(255*np.arange(0, GC) / GC))
col += GC
# CB
colorwheel[col:col+CB, 1] = 255 - np.transpose(np.floor(255*np.arange(0, CB) / CB))
colorwheel[col:col+CB, 2] = 255
col += CB
# BM
colorwheel[col:col+BM, 2] = 255
colorwheel[col:col+BM, 0] = np.transpose(np.floor(255*np.arange(0, BM) / BM))
col += + BM
# MR
colorwheel[col:col+MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))
colorwheel[col:col+MR, 0] = 255
return colorwheel
def compute_color(u, v):
"""
compute optical flow color map
:param u: optical flow horizontal map
:param v: optical flow vertical map
:return: optical flow in color code
"""
[h, w] = u.shape
img = np.zeros([h, w, 3])
nanIdx = np.isnan(u) | np.isnan(v)
u[nanIdx] = 0
v[nanIdx] = 0
colorwheel = make_color_wheel()
ncols = np.size(colorwheel, 0)
rad = np.sqrt(u**2+v**2)
a = np.arctan2(-v, -u) / np.pi
fk = (a+1) / 2 * (ncols - 1) + 1
k0 = np.floor(fk).astype(int)
k1 = k0 + 1
k1[k1 == ncols+1] = 1
f = fk - k0
for i in range(0, np.size(colorwheel,1)):
tmp = colorwheel[:, i]
col0 = tmp[k0-1] / 255
col1 = tmp[k1-1] / 255
col = (1-f) * col0 + f * col1
idx = rad <= 1
col[idx] = 1-rad[idx]*(1-col[idx])
notidx = np.logical_not(idx)
col[notidx] *= 0.75
img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx)))
return img
# from https://github.com/gengshan-y/VCN
def flow_to_image(flow):
"""
Convert flow into middlebury color code image
:param flow: optical flow map
:return: optical flow image in middlebury color
"""
u = flow[:, :, 0]
v = flow[:, :, 1]
maxu = -999.
maxv = -999.
minu = 999.
minv = 999.
idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH)
u[idxUnknow] = 0
v[idxUnknow] = 0
maxu = max(maxu, np.max(u))
minu = min(minu, np.min(u))
maxv = max(maxv, np.max(v))
minv = min(minv, np.min(v))
rad = np.sqrt(u ** 2 + v ** 2)
maxrad = max(-1, np.max(rad))
u = u/(maxrad + np.finfo(float).eps)
v = v/(maxrad + np.finfo(float).eps)
img = compute_color(u, v)
idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2)
img[idx] = 0
return np.uint8(img)

View File

@ -2,7 +2,10 @@ import numpy as np
from PIL import Image from PIL import Image
from os.path import * from os.path import *
import re import re
import cv2 import cv2
cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)
TAG_CHAR = np.array([202021.25], np.float32) TAG_CHAR = np.array([202021.25], np.float32)

View File

@ -6,11 +6,14 @@ from scipy import interpolate
class InputPadder: class InputPadder:
""" Pads images such that dimensions are divisible by 8 """ """ Pads images such that dimensions are divisible by 8 """
def __init__(self, dims): def __init__(self, dims, mode='sintel'):
self.ht, self.wd = dims[-2:] self.ht, self.wd = dims[-2:]
pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] if mode == 'sintel':
self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
else:
self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
def pad(self, *inputs): def pad(self, *inputs):
return [F.pad(x, self._pad, mode='replicate') for x in inputs] return [F.pad(x, self._pad, mode='replicate') for x in inputs]
@ -42,10 +45,10 @@ def forward_interpolate(flow):
dy = dy[valid] dy = dy[valid]
flow_x = interpolate.griddata( flow_x = interpolate.griddata(
(x1, y1), dx, (x0, y0), method='cubic', fill_value=0) (x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
flow_y = interpolate.griddata( flow_y = interpolate.griddata(
(x1, y1), dy, (x0, y0), method='cubic', fill_value=0) (x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
flow = np.stack([flow_x, flow_y], axis=0) flow = np.stack([flow_x, flow_y], axis=0)
return torch.from_numpy(flow).float() return torch.from_numpy(flow).float()
@ -68,7 +71,6 @@ def bilinear_sampler(img, coords, mode='bilinear', mask=False):
return img return img
def coords_grid(batch, ht, wd): def coords_grid(batch, ht, wd):
coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
coords = torch.stack(coords[::-1], dim=0).float() coords = torch.stack(coords[::-1], dim=0).float()

View File

@ -74,6 +74,7 @@ if __name__ == '__main__':
parser.add_argument('--path', help="dataset for evaluation") parser.add_argument('--path', help="dataset for evaluation")
parser.add_argument('--small', action='store_true', help='use small model') parser.add_argument('--small', action='store_true', help='use small model')
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
args = parser.parse_args() args = parser.parse_args()
demo(args) demo(args)

View File

@ -61,7 +61,7 @@ def create_kitti_submission(model, iters=24, output_path='kitti_submission'):
for test_id in range(len(test_dataset)): for test_id in range(len(test_dataset)):
image1, image2, (frame_id, ) = test_dataset[test_id] image1, image2, (frame_id, ) = test_dataset[test_id]
padder = InputPadder(image1.shape) padder = InputPadder(image1.shape, mode='kitti')
image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
_, flow_pr = model(image1, image2, iters=iters, test_mode=True) _, flow_pr = model(image1, image2, iters=iters, test_mode=True)
@ -139,7 +139,7 @@ def validate_kitti(model, iters=24):
image1 = image1[None].cuda() image1 = image1[None].cuda()
image2 = image2[None].cuda() image2 = image2[None].cuda()
padder = InputPadder(image1.shape) padder = InputPadder(image1.shape, mode='kitti')
image1, image2 = padder.pad(image1, image2) image1, image2 = padder.pad(image1, image2)
flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True) flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True)
@ -172,6 +172,7 @@ if __name__ == '__main__':
parser.add_argument('--dataset', help="dataset for evaluation") parser.add_argument('--dataset', help="dataset for evaluation")
parser.add_argument('--small', action='store_true', help='use small model') parser.add_argument('--small', action='store_true', help='use small model')
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
args = parser.parse_args() args = parser.parse_args()
model = torch.nn.DataParallel(RAFT(args)) model = torch.nn.DataParallel(RAFT(args))