added cuda extension for efficent implementation
This commit is contained in:
57
core/corr.py
57
core/corr.py
@@ -2,6 +2,12 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from utils.utils import bilinear_sampler, coords_grid
|
||||
|
||||
try:
|
||||
import alt_cuda_corr
|
||||
except:
|
||||
# alt_cuda_corr is not compiled
|
||||
pass
|
||||
|
||||
|
||||
class CorrBlock:
|
||||
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
||||
@@ -43,7 +49,6 @@ class CorrBlock:
|
||||
out = torch.cat(out_pyramid, dim=-1)
|
||||
return out.permute(0, 3, 1, 2).contiguous().float()
|
||||
|
||||
|
||||
@staticmethod
|
||||
def corr(fmap1, fmap2):
|
||||
batch, dim, ht, wd = fmap1.shape
|
||||
@@ -54,3 +59,53 @@ class CorrBlock:
|
||||
corr = corr.view(batch, ht, wd, 1, ht, wd)
|
||||
return corr / torch.sqrt(torch.tensor(dim).float())
|
||||
|
||||
|
||||
class CorrLayer(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, fmap1, fmap2, coords, r):
|
||||
fmap1 = fmap1.contiguous()
|
||||
fmap2 = fmap2.contiguous()
|
||||
coords = coords.contiguous()
|
||||
ctx.save_for_backward(fmap1, fmap2, coords)
|
||||
ctx.r = r
|
||||
corr, = correlation_cudaz.forward(fmap1, fmap2, coords, ctx.r)
|
||||
return corr
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_corr):
|
||||
fmap1, fmap2, coords = ctx.saved_tensors
|
||||
grad_corr = grad_corr.contiguous()
|
||||
fmap1_grad, fmap2_grad, coords_grad = \
|
||||
correlation_cudaz.backward(fmap1, fmap2, coords, grad_corr, ctx.r)
|
||||
return fmap1_grad, fmap2_grad, coords_grad, None
|
||||
|
||||
|
||||
class AlternateCorrBlock:
|
||||
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
||||
self.num_levels = num_levels
|
||||
self.radius = radius
|
||||
|
||||
self.pyramid = [(fmap1, fmap2)]
|
||||
for i in range(self.num_levels):
|
||||
fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
|
||||
fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
|
||||
self.pyramid.append((fmap1, fmap2))
|
||||
|
||||
def __call__(self, coords):
|
||||
|
||||
coords = coords.permute(0, 2, 3, 1)
|
||||
B, H, W, _ = coords.shape
|
||||
|
||||
corr_list = []
|
||||
for i in range(self.num_levels):
|
||||
r = self.radius
|
||||
fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1)
|
||||
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1)
|
||||
|
||||
coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
|
||||
corr = alt_cuda_corr(fmap1_i, fmap2_i, coords_i, r)
|
||||
corr_list.append(corr.squeeze(1))
|
||||
|
||||
corr = torch.stack(corr_list, dim=1)
|
||||
corr = corr.reshape(B, -1, H, W)
|
||||
return corr / 16.0
|
||||
|
10
core/raft.py
10
core/raft.py
@@ -5,7 +5,7 @@ import torch.nn.functional as F
|
||||
|
||||
from update import BasicUpdateBlock, SmallUpdateBlock
|
||||
from extractor import BasicEncoder, SmallEncoder
|
||||
from corr import CorrBlock
|
||||
from corr import CorrBlock, AlternateCorrBlock
|
||||
from utils.utils import bilinear_sampler, coords_grid, upflow8
|
||||
|
||||
try:
|
||||
@@ -41,6 +41,9 @@ class RAFT(nn.Module):
|
||||
if 'dropout' not in args._get_kwargs():
|
||||
args.dropout = 0
|
||||
|
||||
if 'alternate_corr' not in args._get_kwargs():
|
||||
args.alternate_corr = False
|
||||
|
||||
# feature network, context network, and update block
|
||||
if args.small:
|
||||
self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
|
||||
@@ -99,7 +102,10 @@ class RAFT(nn.Module):
|
||||
|
||||
fmap1 = fmap1.float()
|
||||
fmap2 = fmap2.float()
|
||||
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
||||
if self.args.alternate_corr:
|
||||
corr_fn = CorrBlockAlternate(fmap1, fmap2, radius=self.args.corr_radius)
|
||||
else:
|
||||
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
||||
|
||||
# run the context network
|
||||
with autocast(enabled=self.args.mixed_precision):
|
||||
|
@@ -4,8 +4,10 @@ import math
|
||||
from PIL import Image
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
cv2.setNumThreads(0)
|
||||
cv2.ocl.setUseOpenCL(False)
|
||||
|
||||
import torch
|
||||
from torchvision.transforms import ColorJitter
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
@@ -1,3 +1,6 @@
|
||||
# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
|
||||
|
||||
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2018 Tom Runia
|
||||
@@ -12,21 +15,20 @@
|
||||
# Author: Tom Runia
|
||||
# Date Created: 2018-08-03
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def make_colorwheel():
|
||||
'''
|
||||
"""
|
||||
Generates a color wheel for optical flow visualization as presented in:
|
||||
Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
|
||||
URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
|
||||
According to the C++ source code of Daniel Scharstein
|
||||
According to the Matlab source code of Deqing Sun
|
||||
'''
|
||||
|
||||
Code follows the original C++ source code of Daniel Scharstein.
|
||||
Code follows the the Matlab source code of Deqing Sun.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Color wheel
|
||||
"""
|
||||
|
||||
RY = 15
|
||||
YG = 6
|
||||
@@ -65,211 +67,66 @@ def make_colorwheel():
|
||||
return colorwheel
|
||||
|
||||
|
||||
def flow_compute_color(u, v, convert_to_bgr=False):
|
||||
'''
|
||||
def flow_uv_to_colors(u, v, convert_to_bgr=False):
|
||||
"""
|
||||
Applies the flow color wheel to (possibly clipped) flow components u and v.
|
||||
|
||||
According to the C++ source code of Daniel Scharstein
|
||||
According to the Matlab source code of Deqing Sun
|
||||
:param u: np.ndarray, input horizontal flow
|
||||
:param v: np.ndarray, input vertical flow
|
||||
:param convert_to_bgr: bool, whether to change ordering and output BGR instead of RGB
|
||||
:return:
|
||||
'''
|
||||
|
||||
Args:
|
||||
u (np.ndarray): Input horizontal flow of shape [H,W]
|
||||
v (np.ndarray): Input vertical flow of shape [H,W]
|
||||
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Flow visualization image of shape [H,W,3]
|
||||
"""
|
||||
flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
|
||||
|
||||
colorwheel = make_colorwheel() # shape [55x3]
|
||||
ncols = colorwheel.shape[0]
|
||||
|
||||
rad = np.sqrt(np.square(u) + np.square(v))
|
||||
a = np.arctan2(-v, -u)/np.pi
|
||||
|
||||
fk = (a+1) / 2*(ncols-1) + 1
|
||||
fk = (a+1) / 2*(ncols-1)
|
||||
k0 = np.floor(fk).astype(np.int32)
|
||||
k1 = k0 + 1
|
||||
k1[k1 == ncols] = 1
|
||||
k1[k1 == ncols] = 0
|
||||
f = fk - k0
|
||||
|
||||
for i in range(colorwheel.shape[1]):
|
||||
|
||||
tmp = colorwheel[:,i]
|
||||
col0 = tmp[k0] / 255.0
|
||||
col1 = tmp[k1] / 255.0
|
||||
col = (1-f)*col0 + f*col1
|
||||
|
||||
idx = (rad <= 1)
|
||||
col[idx] = 1 - rad[idx] * (1-col[idx])
|
||||
col[~idx] = col[~idx] * 0.75 # out of range?
|
||||
|
||||
col[~idx] = col[~idx] * 0.75 # out of range
|
||||
# Note the 2-i => BGR instead of RGB
|
||||
ch_idx = 2-i if convert_to_bgr else i
|
||||
flow_image[:,:,ch_idx] = np.floor(255 * col)
|
||||
|
||||
return flow_image
|
||||
|
||||
|
||||
def flow_to_color(flow_uv, clip_flow=None, convert_to_bgr=False):
|
||||
'''
|
||||
Expects a two dimensional flow image of shape [H,W,2]
|
||||
According to the C++ source code of Daniel Scharstein
|
||||
According to the Matlab source code of Deqing Sun
|
||||
:param flow_uv: np.ndarray of shape [H,W,2]
|
||||
:param clip_flow: float, maximum clipping value for flow
|
||||
:return:
|
||||
'''
|
||||
def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
|
||||
"""
|
||||
Expects a two dimensional flow image of shape.
|
||||
|
||||
Args:
|
||||
flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
|
||||
clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
|
||||
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Flow visualization image of shape [H,W,3]
|
||||
"""
|
||||
assert flow_uv.ndim == 3, 'input flow must have three dimensions'
|
||||
assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
|
||||
|
||||
if clip_flow is not None:
|
||||
flow_uv = np.clip(flow_uv, 0, clip_flow)
|
||||
|
||||
u = flow_uv[:,:,0]
|
||||
v = flow_uv[:,:,1]
|
||||
|
||||
rad = np.sqrt(np.square(u) + np.square(v))
|
||||
rad_max = np.max(rad)
|
||||
|
||||
epsilon = 1e-5
|
||||
u = u / (rad_max + epsilon)
|
||||
v = v / (rad_max + epsilon)
|
||||
|
||||
return flow_compute_color(u, v, convert_to_bgr)
|
||||
|
||||
|
||||
|
||||
UNKNOWN_FLOW_THRESH = 1e7
|
||||
SMALLFLOW = 0.0
|
||||
LARGEFLOW = 1e8
|
||||
|
||||
def make_color_wheel():
|
||||
"""
|
||||
Generate color wheel according Middlebury color code
|
||||
:return: Color wheel
|
||||
"""
|
||||
RY = 15
|
||||
YG = 6
|
||||
GC = 4
|
||||
CB = 11
|
||||
BM = 13
|
||||
MR = 6
|
||||
|
||||
ncols = RY + YG + GC + CB + BM + MR
|
||||
|
||||
colorwheel = np.zeros([ncols, 3])
|
||||
|
||||
col = 0
|
||||
|
||||
# RY
|
||||
colorwheel[0:RY, 0] = 255
|
||||
colorwheel[0:RY, 1] = np.transpose(np.floor(255*np.arange(0, RY) / RY))
|
||||
col += RY
|
||||
|
||||
# YG
|
||||
colorwheel[col:col+YG, 0] = 255 - np.transpose(np.floor(255*np.arange(0, YG) / YG))
|
||||
colorwheel[col:col+YG, 1] = 255
|
||||
col += YG
|
||||
|
||||
# GC
|
||||
colorwheel[col:col+GC, 1] = 255
|
||||
colorwheel[col:col+GC, 2] = np.transpose(np.floor(255*np.arange(0, GC) / GC))
|
||||
col += GC
|
||||
|
||||
# CB
|
||||
colorwheel[col:col+CB, 1] = 255 - np.transpose(np.floor(255*np.arange(0, CB) / CB))
|
||||
colorwheel[col:col+CB, 2] = 255
|
||||
col += CB
|
||||
|
||||
# BM
|
||||
colorwheel[col:col+BM, 2] = 255
|
||||
colorwheel[col:col+BM, 0] = np.transpose(np.floor(255*np.arange(0, BM) / BM))
|
||||
col += + BM
|
||||
|
||||
# MR
|
||||
colorwheel[col:col+MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))
|
||||
colorwheel[col:col+MR, 0] = 255
|
||||
|
||||
return colorwheel
|
||||
|
||||
|
||||
|
||||
def compute_color(u, v):
|
||||
"""
|
||||
compute optical flow color map
|
||||
:param u: optical flow horizontal map
|
||||
:param v: optical flow vertical map
|
||||
:return: optical flow in color code
|
||||
"""
|
||||
[h, w] = u.shape
|
||||
img = np.zeros([h, w, 3])
|
||||
nanIdx = np.isnan(u) | np.isnan(v)
|
||||
u[nanIdx] = 0
|
||||
v[nanIdx] = 0
|
||||
|
||||
colorwheel = make_color_wheel()
|
||||
ncols = np.size(colorwheel, 0)
|
||||
|
||||
rad = np.sqrt(u**2+v**2)
|
||||
|
||||
a = np.arctan2(-v, -u) / np.pi
|
||||
|
||||
fk = (a+1) / 2 * (ncols - 1) + 1
|
||||
|
||||
k0 = np.floor(fk).astype(int)
|
||||
|
||||
k1 = k0 + 1
|
||||
k1[k1 == ncols+1] = 1
|
||||
f = fk - k0
|
||||
|
||||
for i in range(0, np.size(colorwheel,1)):
|
||||
tmp = colorwheel[:, i]
|
||||
col0 = tmp[k0-1] / 255
|
||||
col1 = tmp[k1-1] / 255
|
||||
col = (1-f) * col0 + f * col1
|
||||
|
||||
idx = rad <= 1
|
||||
col[idx] = 1-rad[idx]*(1-col[idx])
|
||||
notidx = np.logical_not(idx)
|
||||
|
||||
col[notidx] *= 0.75
|
||||
img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx)))
|
||||
|
||||
return img
|
||||
|
||||
# from https://github.com/gengshan-y/VCN
|
||||
def flow_to_image(flow):
|
||||
"""
|
||||
Convert flow into middlebury color code image
|
||||
:param flow: optical flow map
|
||||
:return: optical flow image in middlebury color
|
||||
"""
|
||||
u = flow[:, :, 0]
|
||||
v = flow[:, :, 1]
|
||||
|
||||
maxu = -999.
|
||||
maxv = -999.
|
||||
minu = 999.
|
||||
minv = 999.
|
||||
|
||||
idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH)
|
||||
u[idxUnknow] = 0
|
||||
v[idxUnknow] = 0
|
||||
|
||||
maxu = max(maxu, np.max(u))
|
||||
minu = min(minu, np.min(u))
|
||||
|
||||
maxv = max(maxv, np.max(v))
|
||||
minv = min(minv, np.min(v))
|
||||
|
||||
rad = np.sqrt(u ** 2 + v ** 2)
|
||||
maxrad = max(-1, np.max(rad))
|
||||
|
||||
u = u/(maxrad + np.finfo(float).eps)
|
||||
v = v/(maxrad + np.finfo(float).eps)
|
||||
|
||||
img = compute_color(u, v)
|
||||
|
||||
idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2)
|
||||
img[idx] = 0
|
||||
|
||||
return np.uint8(img)
|
||||
return flow_uv_to_colors(u, v, convert_to_bgr)
|
@@ -2,7 +2,10 @@ import numpy as np
|
||||
from PIL import Image
|
||||
from os.path import *
|
||||
import re
|
||||
|
||||
import cv2
|
||||
cv2.setNumThreads(0)
|
||||
cv2.ocl.setUseOpenCL(False)
|
||||
|
||||
TAG_CHAR = np.array([202021.25], np.float32)
|
||||
|
||||
|
@@ -6,11 +6,14 @@ from scipy import interpolate
|
||||
|
||||
class InputPadder:
|
||||
""" Pads images such that dimensions are divisible by 8 """
|
||||
def __init__(self, dims):
|
||||
def __init__(self, dims, mode='sintel'):
|
||||
self.ht, self.wd = dims[-2:]
|
||||
pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
|
||||
pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
|
||||
self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
|
||||
if mode == 'sintel':
|
||||
self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
|
||||
else:
|
||||
self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
|
||||
|
||||
def pad(self, *inputs):
|
||||
return [F.pad(x, self._pad, mode='replicate') for x in inputs]
|
||||
@@ -42,10 +45,10 @@ def forward_interpolate(flow):
|
||||
dy = dy[valid]
|
||||
|
||||
flow_x = interpolate.griddata(
|
||||
(x1, y1), dx, (x0, y0), method='cubic', fill_value=0)
|
||||
(x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
|
||||
|
||||
flow_y = interpolate.griddata(
|
||||
(x1, y1), dy, (x0, y0), method='cubic', fill_value=0)
|
||||
(x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
|
||||
|
||||
flow = np.stack([flow_x, flow_y], axis=0)
|
||||
return torch.from_numpy(flow).float()
|
||||
@@ -68,7 +71,6 @@ def bilinear_sampler(img, coords, mode='bilinear', mask=False):
|
||||
return img
|
||||
|
||||
|
||||
|
||||
def coords_grid(batch, ht, wd):
|
||||
coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
|
||||
coords = torch.stack(coords[::-1], dim=0).float()
|
||||
|
Reference in New Issue
Block a user