Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
mmcv / ops / csrc / pytorch / modulated_deform_conv.cpp
Size: Mime:
// Copyright (c) OpenMMLab. All rights reserved
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"
#ifdef MMCV_WITH_DIOPI
#include <diopi/diopirt.h>
#include <diopi/functions.h>
#include <diopi/functions_mmcv.h>
#include <torch/csrc/utils/pybind.h>

#include "csrc_dipu/diopirt/diopirt_impl.h"
#include "csrc_dipu/runtime/device/deviceapis.h"
#include "csrc_dipu/utils/helpfunc.hpp"

using dipu::VENDOR_TYPE;
using dipu::diopi_helper::toDiopiScalar;
using dipu::diopi_helper::toDiopiTensorHandle;
#endif

void modulated_deformable_im2col_impl(
    const Tensor data_im, const Tensor data_offset, const Tensor data_mask,
    const int batch_size, const int channels, const int height_im,
    const int width_im, const int height_col, const int width_col,
    const int kernel_h, const int kernel_w, const int pad_h, const int pad_w,
    const int stride_h, const int stride_w, const int dilation_h,
    const int dilation_w, const int deformable_group, Tensor data_col) {
  DISPATCH_DEVICE_IMPL(modulated_deformable_im2col_impl, data_im, data_offset,
                       data_mask, batch_size, channels, height_im, width_im,
                       height_col, width_col, kernel_h, kernel_w, pad_h, pad_w,
                       stride_h, stride_w, dilation_h, dilation_w,
                       deformable_group, data_col);
}

void modulated_deformable_col2im_impl(
    const Tensor data_col, const Tensor data_offset, const Tensor data_mask,
    const int batch_size, const int channels, const int height_im,
    const int width_im, const int height_col, const int width_col,
    const int kernel_h, const int kernel_w, const int pad_h, const int pad_w,
    const int stride_h, const int stride_w, const int dilation_h,
    const int dilation_w, const int deformable_group, Tensor grad_im) {
  DISPATCH_DEVICE_IMPL(modulated_deformable_col2im_impl, data_col, data_offset,
                       data_mask, batch_size, channels, height_im, width_im,
                       height_col, width_col, kernel_h, kernel_w, pad_h, pad_w,
                       stride_h, stride_w, dilation_h, dilation_w,
                       deformable_group, grad_im);
}

void modulated_deformable_col2im_coord_impl(
    const Tensor data_col, const Tensor data_im, const Tensor data_offset,
    const Tensor data_mask, const int batch_size, const int channels,
    const int height_im, const int width_im, const int height_col,
    const int width_col, const int kernel_h, const int kernel_w,
    const int pad_h, const int pad_w, const int stride_h, const int stride_w,
    const int dilation_h, const int dilation_w, const int deformable_group,
    Tensor grad_offset, Tensor grad_mask) {
  DISPATCH_DEVICE_IMPL(modulated_deformable_col2im_coord_impl, data_col,
                       data_im, data_offset, data_mask, batch_size, channels,
                       height_im, width_im, height_col, width_col, kernel_h,
                       kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h,
                       dilation_w, deformable_group, grad_offset, grad_mask);
}

void modulated_deform_conv_forward_fallthrough(
    Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset,
    Tensor mask, Tensor output, Tensor columns, int kernel_h, int kernel_w,
    const int stride_h, const int stride_w, const int pad_h, const int pad_w,
    const int dilation_h, const int dilation_w, const int group,
    const int deformable_group, const bool with_bias) {
  at::DeviceGuard guard(input.device());

  const int batch = input.size(0);
  const int channels = input.size(1);
  const int height = input.size(2);
  const int width = input.size(3);

  const int channels_out = weight.size(0);
  const int channels_kernel = weight.size(1);
  const int kernel_h_ = weight.size(2);
  const int kernel_w_ = weight.size(3);

  if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
    AT_ERROR("Input shape and kernel shape won't match: (%d x %d vs %d x %d).",
             kernel_h_, kernel_w, kernel_h_, kernel_w_);
  if (channels != channels_kernel * group)
    AT_ERROR("Input shape and kernel channels won't match: (%d vs %d).",
             channels, channels_kernel * group);

  const int height_out =
      (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
  const int width_out =
      (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;

  if (ones.ndimension() != 2 ||
      ones.size(0) * ones.size(1) < height_out * width_out) {
    // Resize plane and fill with ones...
    ones = at::ones({height_out, width_out}, input.options());
  }

  // resize output
  output = output.view({batch, channels_out, height_out, width_out}).zero_();
  // resize temporary columns
  columns =
      at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out},
                input.options());

  output = output.view({output.size(0), group, output.size(1) / group,
                        output.size(2), output.size(3)});

  for (int b = 0; b < batch; b++) {
    modulated_deformable_im2col_impl(
        input[b], offset[b], mask[b], 1, channels, height, width, height_out,
        width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
        dilation_h, dilation_w, deformable_group, columns);

    // divide into group
    weight = weight.view({group, weight.size(0) / group, weight.size(1),
                          weight.size(2), weight.size(3)});
    columns = columns.view({group, columns.size(0) / group, columns.size(1)});

    for (int g = 0; g < group; g++) {
      output[b][g] = output[b][g]
                         .flatten(1)
                         .addmm_(weight[g].flatten(1), columns[g])
                         .view_as(output[b][g]);
    }

    weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
                          weight.size(3), weight.size(4)});
    columns =
        columns.view({columns.size(0) * columns.size(1), columns.size(2)});
  }

  output = output.view({output.size(0), output.size(1) * output.size(2),
                        output.size(3), output.size(4)});

  if (with_bias) {
    output += bias.view({1, bias.size(0), 1, 1});
  }
}

void modulated_deform_conv_backward_fallthrough(
    Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset,
    Tensor mask, Tensor columns, Tensor grad_input, Tensor grad_weight,
    Tensor grad_bias, Tensor grad_offset, Tensor grad_mask, Tensor grad_output,
    int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
    int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
    const bool with_bias) {
  at::DeviceGuard guard(input.device());

  const int batch = input.size(0);
  const int channels = input.size(1);
  const int height = input.size(2);
  const int width = input.size(3);

  const int channels_kernel = weight.size(1);
  const int kernel_h_ = weight.size(2);
  const int kernel_w_ = weight.size(3);
  if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
    AT_ERROR("Input shape and kernel shape won't match: (%d x %d vs %d x %d).",
             kernel_h_, kernel_w, kernel_h_, kernel_w_);
  if (channels != channels_kernel * group)
    AT_ERROR("Input shape and kernel channels won't match: (%d vs %d).",
             channels, channels_kernel * group);

  const int height_out =
      (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
  const int width_out =
      (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;

  if (ones.ndimension() != 2 ||
      ones.size(0) * ones.size(1) < height_out * width_out) {
    // Resize plane and fill with ones...
    ones = at::ones({height_out, width_out}, input.options());
  }

  grad_input = grad_input.view({batch, channels, height, width});
  columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out},
                      input.options());

  grad_output =
      grad_output.view({grad_output.size(0), group, grad_output.size(1) / group,
                        grad_output.size(2), grad_output.size(3)});

  for (int b = 0; b < batch; b++) {
    // divide int group
    columns = columns.view({group, columns.size(0) / group, columns.size(1)});
    weight = weight.view({group, weight.size(0) / group, weight.size(1),
                          weight.size(2), weight.size(3)});

    for (int g = 0; g < group; g++) {
      columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
                        grad_output[b][g].flatten(1), 0.0f, 1.0f);
    }

    columns =
        columns.view({columns.size(0) * columns.size(1), columns.size(2)});
    weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
                          weight.size(3), weight.size(4)});

    // gradient w.r.t. input coordinate data
    modulated_deformable_col2im_coord_impl(
        columns, input[b], offset[b], mask[b], 1, channels, height, width,
        height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
        stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b],
        grad_mask[b]);
    // gradient w.r.t. input data
    modulated_deformable_col2im_impl(
        columns, offset[b], mask[b], 1, channels, height, width, height_out,
        width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
        dilation_h, dilation_w, deformable_group, grad_input[b]);

    // gradient w.r.t. weight, dWeight should accumulate across the batch and
    // group
    modulated_deformable_im2col_impl(
        input[b], offset[b], mask[b], 1, channels, height, width, height_out,
        width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
        dilation_h, dilation_w, deformable_group, columns);

    columns = columns.view({group, columns.size(0) / group, columns.size(1)});
    grad_weight = grad_weight.view({group, grad_weight.size(0) / group,
                                    grad_weight.size(1), grad_weight.size(2),
                                    grad_weight.size(3)});
    if (with_bias)
      grad_bias = grad_bias.view({group, grad_bias.size(0) / group});

    for (int g = 0; g < group; g++) {
      grad_weight[g] =
          grad_weight[g]
              .flatten(1)
              .addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1))
              .view_as(grad_weight[g]);
      if (with_bias) {
        grad_bias[g] =
            grad_bias[g]
                .view({-1, 1})
                .addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1}))
                .view(-1);
      }
    }

    columns =
        columns.view({columns.size(0) * columns.size(1), columns.size(2)});
    grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
                                    grad_weight.size(2), grad_weight.size(3),
                                    grad_weight.size(4)});
    if (with_bias)
      grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
  }
  grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1),
                                  grad_output.size(2), grad_output.size(3),
                                  grad_output.size(4)});
}

#ifdef MMCV_WITH_DIOPI
void modulated_deform_conv_forward_diopi(
    Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset,
    Tensor mask, Tensor output, Tensor columns, int kernel_h, int kernel_w,
    const int stride_h, const int stride_w, const int pad_h, const int pad_w,
    const int dilation_h, const int dilation_w, const int group,
    const int deformable_group, const bool with_bias) {
  auto input_p = toDiopiTensorHandle(input);
  diopiDevice_t device;
  diopiGetTensorDevice(input_p, &device);
  if (device == diopi_host) {
    modulated_deform_conv_forward_fallthrough(
        input, weight, bias, ones, offset, mask, output, columns, kernel_h,
        kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w,
        group, deformable_group, with_bias);
    return;
  }
  diopiContext ctx(dipu::getCurrentDIPUStream().rawstream());
  diopiContextHandle_t ch = &ctx;
  auto weight_p = toDiopiTensorHandle(weight);
  auto bias_p = toDiopiTensorHandle(bias);
  auto ones_p = toDiopiTensorHandle(ones);
  auto offset_p = toDiopiTensorHandle(offset);
  auto mask_p = toDiopiTensorHandle(mask);
  auto output_p = toDiopiTensorHandle(output);
  auto columns_p = toDiopiTensorHandle(columns);
  if (reinterpret_cast<void*>(diopiModulatedDeformConvMmcv) != nullptr) {
    if (strcmp(dipu::VendorTypeToStr(VENDOR_TYPE), "NPU") == 0) {
      pybind11::gil_scoped_release no_gil;
      auto ret = diopiModulatedDeformConvMmcv(
          ch, output_p, columns_p, ones_p, input_p, weight_p, bias_p, offset_p,
          mask_p, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w,
          dilation_h, dilation_w, group, deformable_group, with_bias);
      if (ret == diopiSuccess) return;
    } else {
      auto ret = diopiModulatedDeformConvMmcv(
          ch, output_p, columns_p, ones_p, input_p, weight_p, bias_p, offset_p,
          mask_p, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w,
          dilation_h, dilation_w, group, deformable_group, with_bias);
      if (ret == diopiSuccess) return;
    }
  }
  LOG(WARNING) << "Fallback to cpu: mmcv ext op modulated_deform_conv_forward";
  auto input_cpu = input.cpu();
  auto weight_cpu = weight.cpu();
  auto bias_cpu = bias.cpu();
  auto ones_cpu = ones.cpu();
  auto offset_cpu = offset.cpu();
  auto mask_cpu = mask.cpu();
  auto output_cpu = output.cpu();
  auto columns_cpu = columns.cpu();
  modulated_deform_conv_forward_fallthrough(
      input_cpu, weight_cpu, bias_cpu, ones_cpu, offset_cpu, mask_cpu,
      output_cpu, columns_cpu, kernel_h, kernel_w, stride_h, stride_w, pad_h,
      pad_w, dilation_h, dilation_w, group, deformable_group, with_bias);
  output.copy_(output_cpu);
  return;
}

void modulated_deform_conv_backward_diopi(
    Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset,
    Tensor mask, Tensor columns, Tensor grad_input, Tensor grad_weight,
    Tensor grad_bias, Tensor grad_offset, Tensor grad_mask, Tensor grad_output,
    int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
    int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
    const bool with_bias) {
  auto input_p = toDiopiTensorHandle(input);
  diopiDevice_t device;
  diopiGetTensorDevice(input_p, &device);
  if (device == diopi_host) {
    modulated_deform_conv_backward_fallthrough(
        input, weight, bias, ones, offset, mask, columns, grad_input,
        grad_weight, grad_bias, grad_offset, grad_mask, grad_output, kernel_h,
        kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w,
        group, deformable_group, with_bias);
    return;
  }
  diopiContext ctx(dipu::getCurrentDIPUStream().rawstream());
  diopiContextHandle_t ch = &ctx;
  auto weight_p = toDiopiTensorHandle(weight);
  auto bias_p = toDiopiTensorHandle(bias);
  auto ones_p = toDiopiTensorHandle(ones);
  auto offset_p = toDiopiTensorHandle(offset);
  auto mask_p = toDiopiTensorHandle(mask);
  auto columns_p = toDiopiTensorHandle(columns);
  auto grad_input_p = toDiopiTensorHandle(grad_input);
  auto grad_weight_p = toDiopiTensorHandle(grad_weight);
  auto grad_bias_p = toDiopiTensorHandle(grad_bias);
  auto grad_offset_p = toDiopiTensorHandle(grad_offset);
  auto grad_mask_p = toDiopiTensorHandle(grad_mask);
  auto grad_output_p = toDiopiTensorHandle(grad_output);

  if (reinterpret_cast<void*>(diopiModulatedDeformConvBackwardMmcv) !=
      nullptr) {
    if (strcmp(dipu::VendorTypeToStr(VENDOR_TYPE), "NPU") == 0) {
      pybind11::gil_scoped_release no_gil;
      auto ret = diopiModulatedDeformConvBackwardMmcv(
          ch, grad_input_p, grad_weight_p, grad_bias_p, grad_offset_p,
          grad_mask_p, input_p, weight_p, bias_p, ones_p, offset_p, mask_p,
          columns_p, grad_output_p, kernel_h, kernel_w, stride_h, stride_w,
          pad_h, pad_w, dilation_h, dilation_w, group, deformable_group,
          with_bias);
      if (ret == diopiSuccess) return;
    } else {
      auto ret = diopiModulatedDeformConvBackwardMmcv(
          ch, grad_input_p, grad_weight_p, grad_bias_p, grad_offset_p,
          grad_mask_p, input_p, weight_p, bias_p, ones_p, offset_p, mask_p,
          columns_p, grad_output_p, kernel_h, kernel_w, stride_h, stride_w,
          pad_h, pad_w, dilation_h, dilation_w, group, deformable_group,
          with_bias);
      if (ret == diopiSuccess) return;
    }
  }
  LOG(WARNING) << "Fallback to cpu: mmcv ext op modulated_deform_conv_forward";
  auto input_cpu = input.cpu();
  auto weight_cpu = weight.cpu();
  auto bias_cpu = bias.cpu();
  auto ones_cpu = ones.cpu();
  auto offset_cpu = offset.cpu();
  auto mask_cpu = mask.cpu();
  auto columns_cpu = columns.cpu();
  auto grad_input_cpu = grad_input.cpu();
  auto grad_weight_cpu = grad_weight.cpu();
  auto grad_bias_cpu = grad_bias.cpu();
  auto grad_offset_cpu = grad_offset.cpu();
  auto grad_mask_cpu = grad_mask.cpu();
  auto grad_output_cpu = grad_output.cpu();
  modulated_deform_conv_backward_fallthrough(
      input_cpu, weight_cpu, bias_cpu, ones_cpu, offset_cpu, mask_cpu,
      columns_cpu, grad_input_cpu, grad_weight_cpu, grad_bias_cpu,
      grad_offset_cpu, grad_mask_cpu, grad_output_cpu, kernel_h, kernel_w,
      stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
      deformable_group, with_bias);
  grad_input.copy_(grad_input_cpu);
  grad_weight.copy_(grad_weight_cpu);
  grad_bias.copy_(grad_bias_cpu);
  grad_offset.copy_(grad_offset_cpu);
  grad_mask.copy_(grad_mask_cpu);
  return;
}
#endif

void modulated_deform_conv_forward(
    Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset,
    Tensor mask, Tensor output, Tensor columns, int kernel_h, int kernel_w,
    const int stride_h, const int stride_w, const int pad_h, const int pad_w,
    const int dilation_h, const int dilation_w, const int group,
    const int deformable_group, const bool with_bias) {
#ifdef MMCV_WITH_DIOPI
  modulated_deform_conv_forward_diopi(
      input, weight, bias, ones, offset, mask, output, columns, kernel_h,
      kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
      deformable_group, with_bias);
#else
  modulated_deform_conv_forward_fallthrough(
      input, weight, bias, ones, offset, mask, output, columns, kernel_h,
      kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
      deformable_group, with_bias);
#endif
}

void modulated_deform_conv_backward(
    Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset,
    Tensor mask, Tensor columns, Tensor grad_input, Tensor grad_weight,
    Tensor grad_bias, Tensor grad_offset, Tensor grad_mask, Tensor grad_output,
    int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
    int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
    const bool with_bias) {
#ifdef MMCV_WITH_DIOPI
  modulated_deform_conv_backward_diopi(
      input, weight, bias, ones, offset, mask, columns, grad_input, grad_weight,
      grad_bias, grad_offset, grad_mask, grad_output, kernel_h, kernel_w,
      stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
      deformable_group, with_bias);
#else
  modulated_deform_conv_backward_fallthrough(
      input, weight, bias, ones, offset, mask, columns, grad_input, grad_weight,
      grad_bias, grad_offset, grad_mask, grad_output, kernel_h, kernel_w,
      stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
      deformable_group, with_bias);
#endif
}