Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ include / ATen / cuda / ScanUtils.cuh

#pragma once

#include <ATen/ceil_div.h>
#include <ATen/cuda/DeviceUtils.cuh>
#include <ATen/cuda/AsmUtils.cuh>
#include <c10/macros/Macros.h>

// Collection of in-kernel scan / prefix sum utilities

namespace at {
namespace cuda {

// Inclusive prefix sum for binary vars using intra-warp voting +
// shared memory
template <typename T, bool KillWARDependency, class BinaryFunction>
__device__ void inclusiveBinaryPrefixScan(T* smem, bool in, T* out, BinaryFunction binop) {
  // Within-warp, we use warp voting.
#if defined (USE_ROCM)
  unsigned long long int vote = WARP_BALLOT(in);
  T index = __popcll(getLaneMaskLe() & vote);
  T carry = __popcll(vote);
#else
  T vote = WARP_BALLOT(in);
  T index = __popc(getLaneMaskLe() & vote);
  T carry = __popc(vote);
#endif

  int warp = threadIdx.x / C10_WARP_SIZE;

  // Per each warp, write out a value
  if (getLaneId() == 0) {
    smem[warp] = carry;
  }

  __syncthreads();

  // Sum across warps in one thread. This appears to be faster than a
  // warp shuffle scan for CC 3.0+
  if (threadIdx.x == 0) {
    int current = 0;
    for (int i = 0; i < blockDim.x / C10_WARP_SIZE; ++i) {
      T v = smem[i];
      smem[i] = binop(smem[i], current);
      current = binop(current, v);
    }
  }

  __syncthreads();

  // load the carry from the preceding warp
  if (warp >= 1) {
    index = binop(index, smem[warp - 1]);
  }

  *out = index;

  if (KillWARDependency) {
    __syncthreads();
  }
}

// Exclusive prefix sum for binary vars using intra-warp voting +
// shared memory
template <typename T, bool KillWARDependency, class BinaryFunction>
__device__ void exclusiveBinaryPrefixScan(T* smem, bool in, T* out, T* carry, BinaryFunction binop) {
  inclusiveBinaryPrefixScan<T, false, BinaryFunction>(smem, in, out, binop);

  // Inclusive to exclusive
  *out -= (T) in;

  // The outgoing carry for all threads is the last warp's sum
  *carry = smem[at::ceil_div<int>(blockDim.x, C10_WARP_SIZE) - 1];

  if (KillWARDependency) {
    __syncthreads();
  }
}

}}  // namespace at::cuda