54 lines
1.3 KiB
C++
54 lines
1.3 KiB
C++
|
#include <torch/extension.h>
|
||
|
#include <vector>
|
||
|
|
||
|
// CUDA forward declarations
|
||
|
std::vector<torch::Tensor> corr_cuda_forward(
|
||
|
torch::Tensor fmap1,
|
||
|
torch::Tensor fmap2,
|
||
|
torch::Tensor coords,
|
||
|
int radius);
|
||
|
|
||
|
std::vector<torch::Tensor> corr_cuda_backward(
|
||
|
torch::Tensor fmap1,
|
||
|
torch::Tensor fmap2,
|
||
|
torch::Tensor coords,
|
||
|
torch::Tensor corr_grad,
|
||
|
int radius);
|
||
|
|
||
|
// C++ interface
|
||
|
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||
|
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||
|
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||
|
|
||
|
std::vector<torch::Tensor> corr_forward(
|
||
|
torch::Tensor fmap1,
|
||
|
torch::Tensor fmap2,
|
||
|
torch::Tensor coords,
|
||
|
int radius) {
|
||
|
CHECK_INPUT(fmap1);
|
||
|
CHECK_INPUT(fmap2);
|
||
|
CHECK_INPUT(coords);
|
||
|
|
||
|
return corr_cuda_forward(fmap1, fmap2, coords, radius);
|
||
|
}
|
||
|
|
||
|
|
||
|
std::vector<torch::Tensor> corr_backward(
|
||
|
torch::Tensor fmap1,
|
||
|
torch::Tensor fmap2,
|
||
|
torch::Tensor coords,
|
||
|
torch::Tensor corr_grad,
|
||
|
int radius) {
|
||
|
CHECK_INPUT(fmap1);
|
||
|
CHECK_INPUT(fmap2);
|
||
|
CHECK_INPUT(coords);
|
||
|
CHECK_INPUT(corr_grad);
|
||
|
|
||
|
return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius);
|
||
|
}
|
||
|
|
||
|
|
||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||
|
m.def("forward", &corr_forward, "CORR forward");
|
||
|
m.def("backward", &corr_backward, "CORR backward");
|
||
|
}
|