Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
mmcv / ops / csrc / pytorch / npu / group_points_npu.cpp
Size: Mime:
#include "pytorch_npu_helper.hpp"

using namespace NPU_NAME_SPACE;
using namespace std;

void group_points_forward_npu(int b, int c, int n, int npoints, int nsample,
                              const Tensor points, const Tensor idx,
                              Tensor out) {
  // b, c, n, and npoints do not need to be passed into gatherv2,
  // b, c, n, and npoints are calculated inside the operator
  // gatherv2 operator in ascend needs to set axis to 0, batch_dims is 0
  c10::SmallVector<int64_t, N> axis = {0};
  int64_t batch_dims = 0;

  auto index = at::arange(0, b);
  index = index.to(points.device());
  index = index.view({-1, 1, 1});
  index = at::mul(index, n);
  at::Tensor indices = at::add(index, idx);
  indices = indices.view({-1});

  at::Tensor trans_features = points.transpose(1, 2);
  at::Tensor features = trans_features.contiguous();
  features = features.view({b * n, c});

  OpCommand cmd;
  cmd.Name("GatherV2")
      .Input(features)
      .Input(indices)
      .Input(axis)
      .Output(out)
      .Attr("batch_dims", batch_dims)
      .Run();

  at::Tensor output =
      out.view({b, npoints, nsample, c}).transpose(1, 3).transpose(2, 3);
  at::Tensor res = output.contiguous();
  out.copy_(res);
}

void group_points_forward_impl(int b, int c, int n, int npoints, int nsample,
                               const Tensor points, const Tensor idx,
                               Tensor out);

REGISTER_NPU_IMPL(group_points_forward_impl, group_points_forward_npu);