Repository URL to install this package:
|
Version:
2.2.0 ▾
|
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef VOXELIZATION_CUDA_KERNEL_CUH
#define VOXELIZATION_CUDA_KERNEL_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
typedef enum { SUM = 0, MEAN = 1, MAX = 2 } reduce_t;
template <typename T, typename T_int>
__global__ void dynamic_voxelize_kernel(
const T* points, T_int* coors, const float voxel_x, const float voxel_y,
const float voxel_z, const float coors_x_min, const float coors_y_min,
const float coors_z_min, const float coors_x_max, const float coors_y_max,
const float coors_z_max, const int grid_x, const int grid_y,
const int grid_z, const int num_points, const int num_features,
const int NDim) {
// const int index = blockIdx.x * threadsPerBlock + threadIdx.x;
CUDA_1D_KERNEL_LOOP(index, num_points) {
// To save some computation
auto points_offset = points + index * num_features;
auto coors_offset = coors + index * NDim;
int c_x = floorf((points_offset[0] - coors_x_min) / voxel_x);
if (c_x < 0 || c_x >= grid_x) {
coors_offset[0] = -1;
continue;
}
int c_y = floorf((points_offset[1] - coors_y_min) / voxel_y);
if (c_y < 0 || c_y >= grid_y) {
coors_offset[0] = -1;
coors_offset[1] = -1;
continue;
}
int c_z = floorf((points_offset[2] - coors_z_min) / voxel_z);
if (c_z < 0 || c_z >= grid_z) {
coors_offset[0] = -1;
coors_offset[1] = -1;
coors_offset[2] = -1;
} else {
coors_offset[0] = c_z;
coors_offset[1] = c_y;
coors_offset[2] = c_x;
}
}
}
template <typename T, typename T_int>
__global__ void assign_point_to_voxel(const int nthreads, const T* points,
T_int* point_to_voxelidx,
T_int* coor_to_voxelidx, T* voxels,
const int max_points,
const int num_features,
const int num_points, const int NDim) {
CUDA_1D_KERNEL_LOOP(thread_idx, nthreads) {
// const int index = blockIdx.x * threadsPerBlock + threadIdx.x;
int index = thread_idx / num_features;
int num = point_to_voxelidx[index];
int voxelidx = coor_to_voxelidx[index];
if (num > -1 && voxelidx > -1) {
auto voxels_offset =
voxels + voxelidx * max_points * num_features + num * num_features;
int k = thread_idx % num_features;
voxels_offset[k] = points[thread_idx];
}
}
}
template <typename T, typename T_int>
__global__ void assign_voxel_coors(const int nthreads, T_int* coor,
T_int* point_to_voxelidx,
T_int* coor_to_voxelidx, T_int* voxel_coors,
const int num_points, const int NDim) {
CUDA_1D_KERNEL_LOOP(thread_idx, nthreads) {
// const int index = blockIdx.x * threadsPerBlock + threadIdx.x;
// if (index >= num_points) return;
int index = thread_idx / NDim;
int num = point_to_voxelidx[index];
int voxelidx = coor_to_voxelidx[index];
if (num == 0 && voxelidx > -1) {
auto coors_offset = voxel_coors + voxelidx * NDim;
int k = thread_idx % NDim;
coors_offset[k] = coor[thread_idx];
}
}
}
template <typename T_int>
__global__ void point_to_voxelidx_kernel(const T_int* coor,
T_int* point_to_voxelidx,
T_int* point_to_pointidx,
const int max_points,
const int max_voxels,
const int num_points, const int NDim) {
CUDA_1D_KERNEL_LOOP(index, num_points) {
auto coor_offset = coor + index * NDim;
// skip invalid points
if (coor_offset[0] == -1) continue;
int num = 0;
int coor_x = coor_offset[0];
int coor_y = coor_offset[1];
int coor_z = coor_offset[2];
// only calculate the coors before this coor[index]
for (int i = 0; i < index; ++i) {
auto prev_coor = coor + i * NDim;
if (prev_coor[0] == -1) continue;
// Find all previous points that have the same coors
// if find the same coor, record it
if ((prev_coor[0] == coor_x) && (prev_coor[1] == coor_y) &&
(prev_coor[2] == coor_z)) {
num++;
if (num == 1) {
// point to the same coor that first show up
point_to_pointidx[index] = i;
} else if (num >= max_points) {
// out of boundary
break;
}
}
}
if (num == 0) {
point_to_pointidx[index] = index;
}
if (num < max_points) {
point_to_voxelidx[index] = num;
}
}
}
template <typename T_int>
__global__ void determin_voxel_num(
// const T_int* coor,
T_int* num_points_per_voxel, T_int* point_to_voxelidx,
T_int* point_to_pointidx, T_int* coor_to_voxelidx, T_int* voxel_num,
const int max_points, const int max_voxels, const int num_points) {
// only calculate the coors before this coor[index]
for (int i = 0; i < num_points; ++i) {
int point_pos_in_voxel = point_to_voxelidx[i];
// record voxel
if (point_pos_in_voxel == -1) {
// out of max_points or invalid point
continue;
} else if (point_pos_in_voxel == 0) {
// record new voxel
int voxelidx = voxel_num[0];
if (voxel_num[0] >= max_voxels) continue;
voxel_num[0] += 1;
coor_to_voxelidx[i] = voxelidx;
num_points_per_voxel[voxelidx] = 1;
} else {
int point_idx = point_to_pointidx[i];
int voxelidx = coor_to_voxelidx[point_idx];
if (voxelidx != -1) {
coor_to_voxelidx[i] = voxelidx;
num_points_per_voxel[voxelidx] += 1;
}
}
}
}
__global__ void nondeterministic_get_assign_pos(
const int nthreads, const int32_t* coors_map, int32_t* pts_id,
int32_t* coors_count, int32_t* reduce_count, int32_t* coors_order) {
CUDA_1D_KERNEL_LOOP(thread_idx, nthreads) {
int coors_idx = coors_map[thread_idx];
if (coors_idx > -1) {
int32_t coors_pts_pos = atomicAdd(&reduce_count[coors_idx], 1);
pts_id[thread_idx] = coors_pts_pos;
if (coors_pts_pos == 0) {
coors_order[coors_idx] = atomicAdd(coors_count, 1);
}
}
}
}
template <typename T>
__global__ void nondeterministic_assign_point_voxel(
const int nthreads, const T* points, const int32_t* coors_map,
const int32_t* pts_id, const int32_t* coors_in, const int32_t* reduce_count,
const int32_t* coors_order, T* voxels, int32_t* coors, int32_t* pts_count,
const int max_voxels, const int max_points, const int num_features,
const int NDim) {
CUDA_1D_KERNEL_LOOP(thread_idx, nthreads) {
int coors_idx = coors_map[thread_idx];
int coors_pts_pos = pts_id[thread_idx];
if (coors_idx > -1 && coors_pts_pos < max_points) {
int coors_pos = coors_order[coors_idx];
if (coors_pos < max_voxels) {
auto voxels_offset =
voxels + (coors_pos * max_points + coors_pts_pos) * num_features;
auto points_offset = points + thread_idx * num_features;
for (int k = 0; k < num_features; k++) {
voxels_offset[k] = points_offset[k];
}
if (coors_pts_pos == 0) {
pts_count[coors_pos] = min(reduce_count[coors_idx], max_points);
auto coors_offset = coors + coors_pos * NDim;
auto coors_in_offset = coors_in + coors_idx * NDim;
for (int k = 0; k < NDim; k++) {
coors_offset[k] = coors_in_offset[k];
}
}
}
}
}
}
#endif // VOXELIZATION_CUDA_KERNEL_CUH