Repository URL to install this package:
|
Version:
2.2.0 ▾
|
// Copyright (c) OpenMMLab. All rights reserved
#ifndef MODULATED_DEFORM_CONV_PYTORCH_H
#define MODULATED_DEFORM_CONV_PYTORCH_H
#include <torch/extension.h>
using namespace at;
void modulated_deform_conv_forward(
Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset,
Tensor mask, Tensor output, Tensor columns, int kernel_h, int kernel_w,
const int stride_h, const int stride_w, const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w, const int group,
const int deformable_group, const bool with_bias);
void modulated_deform_conv_backward(
Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset,
Tensor mask, Tensor columns, Tensor grad_input, Tensor grad_weight,
Tensor grad_bias, Tensor grad_offset, Tensor grad_mask, Tensor grad_output,
int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
const bool with_bias);
#endif // MODULATED_DEFORM_CONV_PYTORCH_H