Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ include / ATen / native / cuda / Normalization.cuh

#pragma once

#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/AccumulateType.h>
#include <ATen/ceil_div.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/DeviceUtils.cuh>
#include <ATen/native/cuda/block_reduce.cuh>
#include <ATen/native/cuda/DeviceSqrt.cuh>
#include <ATen/native/cuda/LaunchUtils.h>
#include <c10/macros/Macros.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_like.h>
#include <ATen/ops/zeros.h>
#endif

namespace at { namespace native {

// The maximum number of threads in a block
#if defined(USE_ROCM)
constexpr int MAX_BLOCK_SIZE = 256;
#else
constexpr int MAX_BLOCK_SIZE = 512;
#endif

constexpr unsigned MAX_GRID_SIZE = 65535u;

// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
static int getNumThreads(int nElem) {
#if defined(USE_ROCM)
  int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE };
#else
  int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE };
#endif
  for (int i = 0; i != 5; ++i) {
    if (nElem <= threadSizes[i]) {
      return threadSizes[i];
    }
  }
  return MAX_BLOCK_SIZE;
}

// Returns the index of the most significant 1 bit in `val`.
__device__ __forceinline__ int getMSB(int val) {
  return 31 - __clz(val);
}

template <typename scalar_t, typename accscalar_t>
struct Float2 {
  accscalar_t v1, v2;
  __device__ Float2() {}
  __device__ Float2(scalar_t v1, scalar_t v2) : v1(static_cast<accscalar_t>(v1)), v2(static_cast<accscalar_t>(v2)) {}
  __device__ Float2(int v) : v1(static_cast<accscalar_t>(v)), v2(static_cast<accscalar_t>(v)) {}
  __device__ Float2& operator+=(const Float2& a) {
    v1 += a.v1;
    v2 += a.v2;
    return *this;
  }
  __device__ friend Float2 operator+(Float2 a, const Float2& b) {
    a += b;
    return a;
  }
};

template <typename scalar_t, typename accscalar_t, typename PTA>
struct GradOp {
  __device__ GradOp(accscalar_t m, const PTA& i, const PTA& g)
    : mean(m), input(i), grad_output(g) {}
  __device__ __forceinline__ Float2<scalar_t, accscalar_t> operator()(int batch, int plane, int n) {
    accscalar_t g = grad_output[batch][plane][n];
    accscalar_t c = static_cast<accscalar_t>(input[batch][plane][n]) - mean;
    return Float2<scalar_t, accscalar_t>(g, g * c);
  }
  const accscalar_t mean;
  const PTA& input;
  const PTA& grad_output;
};

template <typename acc_t>
struct SumReduceOp {
    __device__ __forceinline__ acc_t combine(acc_t a, acc_t b) const { return a + b; }

    __device__ __forceinline__ acc_t warp_shfl_down(acc_t data, int offset) const {
        return WARP_SHFL_DOWN(data, offset);
    }
};

template <typename scalar_t, typename accscalar_t>
struct SumReduceOp<Float2<scalar_t, accscalar_t>> {
    using acc_t = Float2<scalar_t, accscalar_t>;

    __device__ __forceinline__ acc_t combine(acc_t a, acc_t b) const { return a + b; }

    __device__ __forceinline__ acc_t warp_shfl_down(acc_t data, int offset) const {
        return {WARP_SHFL_DOWN(data.v1, offset), WARP_SHFL_DOWN(data.v2, offset)};
    }
};

// Sum across (batch, x/y/z) applying Op() pointwise
// this works by first having each thread sum it's part
// of the data. Then there is a double-shuffling reduction.
// First each warp (of C10_WARP_SIZE threads) uses warpSum to reduce its
// data to the "warp leader", who writes its value into shared memory.
// Then a single warp reads the remaining (at most C10_WARP_SIZE) items
// and reduces them using another warpSum.
// The implicit assumption is that there are no more
// than C10_WARP_SIZE**2 threads.
template<typename scalar_t, typename Op, typename PTA>
__device__ scalar_t reduce(Op op, PTA tensor, int plane) {
  // first the reductions each thread does separately
  scalar_t sum = static_cast<scalar_t>(0);
  for (int batch = threadIdx.y; batch < tensor.size(0); batch += blockDim.y) {
    for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x) {
      sum += op(batch, plane, x);
    }
  }
  __shared__ scalar_t shared[C10_WARP_SIZE];
  SumReduceOp<scalar_t> reduce_op;
  sum = cuda_utils::BlockReduce<scalar_t, SumReduceOp<scalar_t>, cuda_utils::Block2D>(sum, reduce_op, 0, shared);
  if (threadIdx.x == 0 && threadIdx.y == 0) {
      shared[0] = sum;
  }
  __syncthreads();
  // Everyone picks it up, should be broadcast into the whole grad_input
  return shared[0];
}

constexpr int ELEMENTS_PER_ITER = 4; // enables concurrency within each thread to hide latency
constexpr int ELEMENTS_PER_THREAD = 16;
constexpr int OPTIMAL_TILE_W = 32;
constexpr int MAX_H_BLOCK = 128;

__host__ void flexible_launch_configs(
      const int reduction,
      const int stride,
      dim3 &block,
      dim3 &grid,
      const bool coop_flag = false) {
  int block_x = std::min(lastPow2(stride), OPTIMAL_TILE_W);
  int block_y = std::min(lastPow2(at::ceil_div(reduction , ELEMENTS_PER_THREAD)),
                         MAX_BLOCK_SIZE / block_x);
  if (block_x * block_y != MAX_BLOCK_SIZE) {
    block_x = std::min(lastPow2(stride), MAX_BLOCK_SIZE / block_y);
  }

  int grid_x = at::ceil_div(stride, block_x);
  int grid_y = std::min(at::ceil_div(reduction, block_y * ELEMENTS_PER_THREAD), MAX_H_BLOCK);
  if (coop_flag) {
    // it's not worth having a grid reduction if the reduction dimension is not big enough
    grid_y = grid_y < 8 ? 1 : grid_y;
  }

  block.x = block_x;
  block.y = block_y;
  block.z = 1;
  grid.x = grid_x;
  grid.y = grid_y;
  grid.z = 1;
}

template<typename T, typename C>
__device__ __forceinline__ void welford_merge_element(C& count,
                                                      T& mean,
                                                      T& m2n,
                                                      const C& count_new,
                                                      const T& mean_new,
                                                      const T& m2n_new) {
      T factor = T(1.0) / ::max(1, (count + count_new));
      T delta0 = mean - mean_new;
      mean = (mean_new * count_new + mean * count) * factor;
      m2n += m2n_new + delta0 * delta0 * count_new * count * factor;
      count += count_new;
}

// merge mean/m2n among threadIdx.y within block
template<typename T, typename C>
__device__ __forceinline__ void welford_merge_block_vertical(C& count,
                                                             T& mean,
                                                             T& m2n,
                                                             C* shmem_count,
                                                             T* shmem_mean,
                                                             T* shmem_m2n) {
  // write to shared memory
  auto address_base = threadIdx.x + threadIdx.y * blockDim.x;

#pragma unroll
  for (int offset = blockDim.y/2; offset > 0; offset >>= 1) {
    if (threadIdx.y < offset*2) {
      shmem_mean[address_base] = mean;
      shmem_m2n[address_base] = m2n;
      shmem_count[address_base] = count;
    }
    __syncthreads();
    if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {
      auto address = address_base + offset * blockDim.x;
      // read shared memory back to register for reduction
      auto count_new = shmem_count[address];
      auto mean_new = shmem_mean[address];
      auto m2n_new = shmem_m2n[address];

      welford_merge_element(count, mean, m2n, count_new, mean_new, m2n_new);
    }
  }
}

template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, bool train, typename index_t>
__global__ void batch_norm_transform_input_kernel(
    const GenericPackedTensorAccessor<input_scalar_t, 3, RestrictPtrTraits, index_t> input,
    GenericPackedTensorAccessor<input_scalar_t, 3, RestrictPtrTraits, index_t> output,
    const GenericPackedTensorAccessor<typename std::conditional<train, stat_accscalar_t, stat_scalar_t>::type, 1, RestrictPtrTraits, index_t> mean_,
    const GenericPackedTensorAccessor<typename std::conditional<train, stat_accscalar_t, stat_scalar_t>::type, 1, RestrictPtrTraits, index_t> var_or_invstd,
    const GenericPackedTensorAccessor<stat_scalar_t, 1, RestrictPtrTraits, index_t> weight,
    const GenericPackedTensorAccessor<stat_scalar_t, 1, RestrictPtrTraits, index_t> bias,
    stat_accscalar_t epsilon) {

  index_t plane = blockIdx.x;

  if (plane >= input.size(1)) {
    return;
  }

  stat_accscalar_t gamma = weight.size(0) > 0 ? static_cast<stat_accscalar_t>(weight[plane]) : static_cast<stat_accscalar_t>(1);
  stat_accscalar_t beta = bias.size(0) > 0 ? static_cast<stat_accscalar_t>(bias[plane]) : static_cast<stat_accscalar_t>(0);
  stat_accscalar_t mean = static_cast<stat_accscalar_t>(mean_[plane]);
  stat_accscalar_t invstd;
  if (train) {
    invstd = var_or_invstd[plane];
  } else {
    invstd = static_cast<stat_accscalar_t>(1) / device_sqrt(static_cast<stat_accscalar_t>(var_or_invstd[plane]) + epsilon);
  }

  index_t bs = input.size(0);
  index_t fs = input.size(2);

  index_t bstep  = blockDim.y * gridDim.y;
  for (index_t batch = threadIdx.y + blockIdx.y * blockDim.y; batch < bs; batch += bstep) {
    auto o = output[batch][plane];
    auto i = input[batch][plane];
    for (index_t feature = threadIdx.x; feature < fs; feature += blockDim.x) {
      o[feature] = static_cast<input_scalar_t>(gamma * (i[feature] - mean) * invstd + beta);
    }
  }
}

struct InvStd {
  template <typename T>
  __device__ __forceinline__ T operator()(T var, double epsilon) const {
    T invstd = 0;
    if (var != static_cast<T>(0) || epsilon != static_cast<T>(0)) {
      invstd = static_cast<T>(1) / device_sqrt(var + epsilon);
    }
    return invstd;
  }
};

struct Var {
  template <typename T>
  __device__ __forceinline__ T operator()(T var, double epsilon) const {
    return var;
  }
};

template <typename VarTransform, typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
__global__ void batch_norm_collect_statistics_kernel(
    const GenericPackedTensorAccessor<input_scalar_t, 3, RestrictPtrTraits, index_t> input,
    const stat_accscalar_t epsilon,
    const stat_accscalar_t momentum,
    GenericPackedTensorAccessor<stat_accscalar_t, 1, RestrictPtrTraits, index_t> save_mean,
    GenericPackedTensorAccessor<stat_accscalar_t, 1, RestrictPtrTraits, index_t> save_transformed_var) {

  __shared__ int shared_n[2 * 2 * C10_WARP_SIZE + C10_WARP_SIZE];

  int plane = blockIdx.x;
  int N = input.size(0) * input.size(2);
  int tid = threadIdx.x + threadIdx.y * blockDim.x;

  // Compute the mean and variance across (batch, x/y/z)
  // this uses the Welford (in the for loop)/parallel algorithm (to sum across the block)
  // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm
  // and the parallel algorithm on the same page.
  // We use two shuffles to reduce across the entire block.
  // https://devblogs.nvidia.com/faster-parallel-reductions-kepler/ has a description.
  stat_accscalar_t* shared_avg_var = (stat_accscalar_t*) &shared_n[C10_WARP_SIZE];

  // first the reductions each thread does separately
  stat_accscalar_t avg = 0;
  stat_accscalar_t var_n = 0;
  int n = 0;
  for (int batch = threadIdx.y; batch < input.size(0); batch += blockDim.y) {
    for (int x = threadIdx.x; x < input.size(2); x += blockDim.x) {
      stat_accscalar_t v = input[batch][plane][x];
      stat_accscalar_t d1 = v - avg;
      n++;
      avg += d1 / n;
      var_n += d1 * (v - avg);
    }
  }

  // first warpSum to get one value per thread to
  // one value per warp
  for (int i = 0; i < getMSB(C10_WARP_SIZE); ++i) {
    stat_accscalar_t o_avg = WARP_SHFL_XOR(avg, 1 << i, C10_WARP_SIZE);
    int o_n = WARP_SHFL_XOR(n, 1 << i, C10_WARP_SIZE);
    stat_accscalar_t factor = 1.0 / fmaxf(1.0, n+o_n);
    var_n += WARP_SHFL_XOR(var_n, 1 << i, C10_WARP_SIZE) + (avg - o_avg) * (avg - o_avg) * n * o_n * factor;
    avg = (n * avg + o_n * o_avg) * factor;
    n += o_n;
  }

  // this writes each warps  item into shared memory
  // there are at most C10_WARP_SIZE items left because
  // there are at most C10_WARP_SIZE**2 threads at the beginning
  __syncthreads();
  if (tid % C10_WARP_SIZE == 0) {
    shared_n[tid / C10_WARP_SIZE] = n;
    shared_avg_var[tid / C10_WARP_SIZE * 2] = avg;
    shared_avg_var[tid / C10_WARP_SIZE * 2 + 1] = var_n;
  }
  __syncthreads();
  // now have a second warpSum to reduce the intermediate values
  // from shared memory to a single number. The very first
  // thread writes it to shared memory.

  if (tid < C10_WARP_SIZE) {
    n = (tid < blockDim.x * blockDim.y / C10_WARP_SIZE ? shared_n[tid] : 0);
    avg = (tid < blockDim.x * blockDim.y  / C10_WARP_SIZE ? shared_avg_var[2 * tid] : stat_accscalar_t(0));
    var_n = (tid < blockDim.x * blockDim.y  / C10_WARP_SIZE ? shared_avg_var[2 * tid + 1] : stat_accscalar_t(0));
  }
  for (int i = 0; i < getMSB(C10_WARP_SIZE); ++i) {
    stat_accscalar_t o_avg = WARP_SHFL_XOR(avg, 1 << i, C10_WARP_SIZE);
    int o_n = WARP_SHFL_XOR(n, 1 << i, C10_WARP_SIZE);
    stat_accscalar_t factor = 1.0 / fmaxf(1.0, n+o_n);
    var_n += WARP_SHFL_XOR(var_n, 1 << i, C10_WARP_SIZE) + (avg - o_avg) * (avg - o_avg) * n * o_n * factor;
    avg = (n * avg + o_n * o_avg) * factor;
    n += o_n;
  }

  // Save the mean, variance, and moving averages
  if (tid == 0) {
    if (save_mean.data() != NULL) {
      save_mean[plane] = avg;
    }
    if (save_transformed_var.data() != NULL) {
      save_transformed_var[plane] = VarTransform{}(var_n / N, epsilon);
    }
  }

}

template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
__global__ void batch_norm_backward_kernel(
    const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> input,
    const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_output,
    GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_input,
    GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> grad_weight,
    GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> grad_bias,
    const GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> weight,
    const GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> running_mean,
    const GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> running_var,
    const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> save_mean,
    const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> save_invstd,
    bool train,
    stat_accscalar_t epsilon) {

  index_t plane = blockIdx.x;
  index_t N = grad_output.size(0) * grad_output.size(2);

  stat_accscalar_t mean, invstd;
  if (train) {
    mean = save_mean[plane];
    invstd = save_invstd[plane];
  } else {
    mean = static_cast<stat_accscalar_t>(running_mean[plane]);
    invstd = static_cast<stat_accscalar_t>(1) / device_sqrt(static_cast<stat_accscalar_t>(running_var[plane]) + epsilon);
  }

  stat_accscalar_t weight_val = weight.size(0) > 0 ? static_cast<stat_accscalar_t>(weight[plane]) : stat_accscalar_t(1);
  stat_accscalar_t norm = stat_accscalar_t(1) / N;

  // Compute two values across (batch, x/y/z) in one pass:
  // 1. Sum(grad_output)
  // 2. DotProduct(input - mean, grad_output)
  GradOp<input_scalar_t, stat_accscalar_t, GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t>> g(mean, input, grad_output);
  auto res = reduce<Float2<input_scalar_t, stat_accscalar_t>>(g, grad_output, plane);

  stat_accscalar_t grad_output_sum = res.v1;
  stat_accscalar_t dot_p = res.v2;

  stat_accscalar_t grad_mean = grad_output_sum * norm;
  stat_accscalar_t proj_scale = dot_p * norm * invstd * invstd;
  stat_accscalar_t grad_scale = invstd * weight_val;

  if (grad_input.data() != NULL) {
    for (int batch = threadIdx.y; batch < grad_output.size(0); batch += blockDim.y) {
      for (int x = threadIdx.x; x < grad_output.size(2); x += blockDim.x) {
        input_scalar_t go = grad_output[batch][plane][x];
        if (train) {
          stat_accscalar_t inp = input[batch][plane][x];
          stat_accscalar_t proj = (inp - mean) * proj_scale;
          grad_input[batch][plane][x] = static_cast<input_scalar_t>((go - proj - grad_mean) * grad_scale);
        } else {
          grad_input[batch][plane][x] = static_cast<input_scalar_t>(go * grad_scale);
        }
      }
    }
  }

  if (grad_weight.size(0) > 0) {
    if (threadIdx.x == 0) {
      grad_weight[plane] = static_cast<stat_scalar_t>(dot_p * invstd);
    }
  }

  if (grad_bias.size(0) > 0) {
    if (threadIdx.x == 0) {
      grad_bias[plane] = static_cast<stat_scalar_t>(grad_output_sum);
    }
  }
}

template <typename scalar_t, typename accscalar_t, typename index_t>
__global__ void batch_norm_reduce_statistics_kernel(
    const GenericPackedTensorAccessor<accscalar_t, 2, RestrictPtrTraits, index_t> vec_mean,
    const GenericPackedTensorAccessor<accscalar_t, 2, RestrictPtrTraits, index_t> vec_invstd,
    GenericPackedTensorAccessor<accscalar_t, 1, RestrictPtrTraits, index_t> mean,
    GenericPackedTensorAccessor<accscalar_t, 1, RestrictPtrTraits, index_t> invstd,
    GenericPackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t> running_mean,
    GenericPackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t> running_var,
    const accscalar_t epsilon,
    const accscalar_t momentum,
    const GenericPackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t> counts) {

  int feature_size = vec_mean.size(1);
  int world_size = vec_mean.size(0);

  int bid = blockIdx.x;
  int tid = threadIdx.x;

  // first the reductions each thread does separately
  for (int i = bid*blockDim.x+tid; i < feature_size; i += gridDim.x*blockDim.x) {
    accscalar_t avg = 0;
    accscalar_t var_n = 0;
    index_t n = 0;
    for (int j = 0; j < world_size; j++) {
      scalar_t count = counts[j];
      accscalar_t m = vec_mean[j][i];
      accscalar_t v = accscalar_t(1.0) / (vec_invstd[j][i]);
      v = (v * v - epsilon) * count;
      accscalar_t factor = 1.0 / (n + count);
      var_n += v + (avg - m) * (avg - m) * n * count * factor;
      avg = n * factor * avg + count * factor * m;
      n += count;
    }
    mean[i] = avg;
    invstd[i] = static_cast<accscalar_t>(1) / device_sqrt(var_n / n + epsilon);
    if (running_mean.data() != NULL) {
      running_mean[i] = static_cast<scalar_t>((1 - momentum) * running_mean[i] + momentum * avg);
    }
    accscalar_t unbiasedVar = var_n / (n - 1);
    if (running_var.data() != NULL) {
      running_var[i] = static_cast<scalar_t>((1 - momentum) * running_var[i] + momentum * unbiasedVar);
    }
  }

}

template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
__global__ void batch_norm_backward_reduce_kernel(
    const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> input,
    const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_output,
    GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> mean,
    GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> invstd,
    GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy,
    GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy_xmu,
    GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> grad_weight,
    GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> grad_bias) {

  index_t plane = blockIdx.x;

  stat_accscalar_t r_mean = mean[plane];
  stat_accscalar_t factor = invstd[plane];

  GradOp<input_scalar_t, stat_accscalar_t, GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t>> g(r_mean, input, grad_output);
  auto res = reduce<Float2<input_scalar_t, stat_accscalar_t>>(g, grad_output, plane);

  if (threadIdx.x == 0) {
    if (grad_weight.size(0) > 0) {
      grad_weight[plane] = static_cast<stat_scalar_t>(res.v2 * factor);
    }
    if (grad_bias.size(0) > 0) {
      grad_bias[plane] = static_cast<stat_scalar_t>(res.v1);
    }
    if (sum_dy.size(0) > 0) {
      sum_dy[plane] = static_cast<stat_accscalar_t>(res.v1);
    }
    if (sum_dy_xmu.size(0) > 0) {
      sum_dy_xmu[plane] = static_cast<stat_accscalar_t>(res.v2);
    }
  }
}

template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
__device__ __forceinline__ void batch_norm_backward_elemt_kernel_impl(
    const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> input,
    const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_output,
    const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> mean,
    const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> invstd,
    const GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> weight,
    const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy,
    const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy_xmu,
    GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_input,
    const stat_accscalar_t norm_fct) {
  index_t plane = blockIdx.x;

  if (plane >= input.size(1)) {
    return;
  }

  stat_accscalar_t m_c = mean[plane];
  stat_accscalar_t m_dy_c = sum_dy[plane] * norm_fct;
  stat_accscalar_t factor_1_c = invstd[plane];
  stat_accscalar_t factor_2_c = weight.size(0) > 0 ? static_cast<stat_accscalar_t>(weight[plane]) : stat_accscalar_t(1);
  factor_2_c *= factor_1_c;
  factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[plane] * norm_fct;

  index_t bs = input.size(0);
  index_t fs = input.size(2);

  index_t bstep  = blockDim.y * gridDim.y;
  for (index_t batch = threadIdx.y + blockIdx.y * blockDim.y; batch < bs; batch += bstep) {
    auto g_i = grad_input[batch][plane];
    auto g_o = grad_output[batch][plane];
    auto i = input[batch][plane];
    for (index_t feature = threadIdx.x; feature < fs; feature += blockDim.x) {
      g_i[feature] = static_cast<input_scalar_t>((g_o[feature] - m_dy_c - (i[feature] - m_c) * factor_1_c) * factor_2_c);
    }
  }
}

template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
__global__ void batch_norm_backward_elemt_kernel(
    const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> input,
    const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_output,
    const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> mean,
    const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> invstd,
    const GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> weight,
    const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy,
    const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy_xmu,
    GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_input,
    const int* __restrict__ numel, const int world_size) {
  int64_t total_numel = 0;
  for (int i = 0; i < world_size; i ++) {
    total_numel += numel[i];
  }

  const stat_accscalar_t norm_fct =
      static_cast<stat_accscalar_t>(1) / static_cast<stat_accscalar_t>(total_numel);
  batch_norm_backward_elemt_kernel_impl(
      input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, norm_fct);
}

template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
__global__ void batch_norm_backward_elemt_kernel(
    const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> input,
    const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_output,
    const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> mean,
    const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> invstd,
    const GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> weight,
    const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy,
    const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy_xmu,
    GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_input,
    const stat_accscalar_t norm_fct) {
  batch_norm_backward_elemt_kernel_impl(
      input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, norm_fct);
}

template <typename scalar_t, int64_t dim, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
static GenericPackedTensorAccessor<scalar_t, dim, PtrTraits, index_t> get_packed_accessor(
    const Tensor& t, c10::string_view var_name) {
  constexpr auto expect_type = c10::CppTypeToScalarType<scalar_t>::value;
  const auto actual_type = t.scalar_type();
  TORCH_CHECK(actual_type == expect_type, "Expected ", var_name,
              " to have type ", expect_type, " but got ", actual_type);
  return t.generic_packed_accessor<scalar_t, dim, PtrTraits, index_t>();
}

template <typename scalar_t, int64_t dim, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
static GenericPackedTensorAccessor<scalar_t, dim, PtrTraits, index_t> packed_accessor_or_dummy(
    const Tensor& t, c10::string_view var_name) {
  if (!t.defined()) {
    const std::array<index_t, dim> zeros{{0}};
    return GenericPackedTensorAccessor<scalar_t, dim, PtrTraits, index_t>(nullptr, zeros.data(), zeros.data());
  }
  return get_packed_accessor<scalar_t, dim, PtrTraits, index_t>(t, var_name);
}

template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cuda_template(const Tensor& grad_out_, const Tensor& input_, const Tensor& weight_,
                                                                     const Tensor& running_mean_, const Tensor& running_var_, const Tensor& save_mean_, const Tensor& save_invstd_,
                                                                     bool train, double epsilon, std::array<bool,3> grad_input_mask) {

  using accscalar_t = at::acc_type<stat_scalar_t, true>;
  Tensor grad_input_;
  Tensor grad_input_reshaped;
  Tensor grad_weight_;
  Tensor grad_bias_;
  auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1});
  auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes());

  if (grad_input_mask[0]) {
    grad_input_ = at::empty_like(input_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
    grad_input_reshaped = grad_input_.view(input_reshaped.sizes());
  }
  if (grad_input_mask[1]) {
    grad_weight_ = at::empty_like(weight_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
  }
  if (grad_input_mask[2]) {
    grad_bias_ = at::empty_like(weight_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
  }

  auto input = get_packed_accessor<
      input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input");
  auto grad_output = get_packed_accessor<
      input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output");
  auto grad_input = packed_accessor_or_dummy<
      input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_input_reshaped, "grad_input");
  auto weight = packed_accessor_or_dummy<
      stat_scalar_t, 1, DefaultPtrTraits, index_t>(weight_, "weight");
  auto grad_weight = packed_accessor_or_dummy<
      stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_weight_, "grad_weight");
  auto grad_bias = packed_accessor_or_dummy<
      stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_bias_, "grad_bias");
  auto running_mean = packed_accessor_or_dummy<
      stat_scalar_t, 1, DefaultPtrTraits, index_t>(running_mean_, "running_mean");
  auto running_var = packed_accessor_or_dummy<
      stat_scalar_t, 1, DefaultPtrTraits, index_t>(running_var_, "running_var");
  auto save_mean = packed_accessor_or_dummy<
      accscalar_t, 1, DefaultPtrTraits, index_t>(save_mean_, "save_mean");
  auto save_invstd = packed_accessor_or_dummy<
      accscalar_t, 1, DefaultPtrTraits, index_t>(save_invstd_, "save_invstd");

  auto stream = at::cuda::getCurrentCUDAStream();
  dim3 blocks(input.size(1));
  int tf = getNumThreads(input.size(2));
  dim3 threads(tf, std::max<int>(1, MAX_BLOCK_SIZE/tf));

  batch_norm_backward_kernel<input_scalar_t, stat_scalar_t, accscalar_t, index_t> <<<blocks, threads, 0, stream>>>
    (input, grad_output, grad_input, grad_weight, grad_bias, weight, running_mean, running_var,
     save_mean, save_invstd, train, epsilon);
  C10_CUDA_KERNEL_LAUNCH_CHECK();

  return std::make_tuple(grad_input_, grad_weight_, grad_bias_);
}

template<typename scalar_t, typename index_t, typename VarTransform>
void batch_norm_stats_cuda_template(
    const Tensor& out_mean, const Tensor& out_invstd, const Tensor& input_, double epsilon) {

  using accscalar_t = at::acc_type<scalar_t, true>;
  int64_t n_input = input_.size(1);
  Tensor dummy_mean_;
  Tensor dummy_var_;
  auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions

  resize_output(out_mean, {n_input});
  resize_output(out_invstd, {n_input});
  auto input = get_packed_accessor<
      scalar_t, 3, RestrictPtrTraits, index_t>(input_reshaped, "input");
  TORCH_INTERNAL_ASSERT(out_invstd.dim() == 1 && out_invstd.is_contiguous() &&
                        out_invstd.sizes()[0]);
  TORCH_INTERNAL_ASSERT(out_mean.dim() == 1 && out_mean.is_contiguous() &&
                        out_mean.sizes()[0]);

  auto mean = packed_accessor_or_dummy<
      accscalar_t, 1, RestrictPtrTraits, index_t>(out_mean, "out_mean");
  auto invstd = packed_accessor_or_dummy<
      accscalar_t, 1, RestrictPtrTraits, index_t>(out_invstd, "out_invstd");
  auto stream = at::cuda::getCurrentCUDAStream();

  dim3 blocks(input.size(1));
  int tf = getNumThreads(input.size(2));
  dim3 threads(tf, std::max<int>(1, MAX_BLOCK_SIZE/tf));
  batch_norm_collect_statistics_kernel<VarTransform, scalar_t, scalar_t, accscalar_t, index_t> <<<blocks, threads, 0, stream>>>
    (input, epsilon, 0.0, mean, invstd);
  C10_CUDA_KERNEL_LAUNCH_CHECK();
}

template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
void batch_norm_elemt_cuda_template(const Tensor& output_, const Tensor& input_, const Tensor& weight_,
                                    const Tensor& bias_, const Tensor& mean_, const Tensor& invstd_) {

  using stat_accscalar_t = at::acc_type<stat_scalar_t, true>;
  int64_t n_input = input_.size(1);
  auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
  auto output_reshaped = output_.view({input_.size(0), input_.size(1), -1});

  auto input = get_packed_accessor<
      input_scalar_t, 3, RestrictPtrTraits, index_t>(input_reshaped, "input");
  auto output = get_packed_accessor<
      input_scalar_t, 3, RestrictPtrTraits, index_t>(output_reshaped, "output");
  auto weight = packed_accessor_or_dummy<
    stat_scalar_t, 1, RestrictPtrTraits, index_t>(weight_, "weight");
  auto bias = packed_accessor_or_dummy<
      stat_scalar_t, 1, RestrictPtrTraits, index_t>(bias_, "bias");
  auto mean = packed_accessor_or_dummy<
      stat_accscalar_t, 1, RestrictPtrTraits, index_t>(mean_, "mean");
  auto invstd = packed_accessor_or_dummy<
      stat_accscalar_t, 1, RestrictPtrTraits, index_t>(invstd_, "invstd");
  auto stream = at::cuda::getCurrentCUDAStream();

  // NOTE: We use transform_input_kernel in training mode, which ignores epsilon
  const double dummy_epsilon = 1e-5;

  // The input_transform kernel is pointwise, but we need to balance reading parameters (save_var/mean,
  // weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks
  // and good occupancy. Quiet likely, we could go with even more blocks than 1024.
  // The various planes are independent, so we use blocks for them.
  int tf = std::max<int>(getNumThreads(input.size(2)/4),
                         std::min<int>(getNumThreads(input.size(2)), 64));
  int tb = std::max<int>(64/tf, 1);
  dim3 blocks_trans(input.size(1), std::max<int>(1, std::min<int>((256*1024)/input.size(1),
                                                                  (input.size(0)+tb-1)/tb)));
  blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE);
  dim3 threads_trans(tf, tb);
  batch_norm_transform_input_kernel<input_scalar_t, stat_scalar_t, stat_accscalar_t, true, index_t> <<<blocks_trans, threads_trans, 0, stream>>>
    (input, output, mean, invstd, weight, bias, dummy_epsilon);
  C10_CUDA_KERNEL_LAUNCH_CHECK();
}

template<typename scalar_t, typename accscalar_t, typename index_t>
std::tuple<Tensor, Tensor> batch_norm_gather_stats_cuda_template(const Tensor& mean_, const Tensor& invstd_,
                                                                 const Tensor& running_mean_, const Tensor& running_var_,
                                                                 double momentum, double epsilon, const Tensor& counts_) {

  Tensor save_mean_;
  Tensor save_invstd_;

  auto features = mean_.size(1);
  auto input_options = mean_.options();
  if (mean_.scalar_type() == at::ScalarType::Half || mean_.scalar_type() == at::ScalarType::BFloat16) {
    input_options = input_options.dtype(ScalarType::Float);
  }
  save_mean_ = at::empty({features}, input_options);
  save_invstd_ = at::empty({features}, input_options);

  auto mean = packed_accessor_or_dummy<
      accscalar_t, 2, RestrictPtrTraits, index_t>(mean_, "mean");
  auto invstd = packed_accessor_or_dummy<
      accscalar_t, 2, RestrictPtrTraits, index_t>(invstd_, "invstd");
  auto running_mean = packed_accessor_or_dummy<
      scalar_t, 1, RestrictPtrTraits, index_t>(running_mean_, "running_mean");
  auto running_var = packed_accessor_or_dummy<
      scalar_t, 1, RestrictPtrTraits, index_t>(running_var_, "running_mean");
  auto counts = packed_accessor_or_dummy<
      scalar_t, 1, RestrictPtrTraits, index_t>(counts_, "counts");

  auto save_mean = get_packed_accessor<
      accscalar_t, 1, RestrictPtrTraits, index_t>(save_mean_, "save_mean");
  auto save_invstd = get_packed_accessor<
      accscalar_t, 1, RestrictPtrTraits, index_t>(save_invstd_, "save_invstd");
  auto stream = at::cuda::getCurrentCUDAStream();

  int block = getNumThreads(features);
  int grid = std::max<int>(1, features/block);
  batch_norm_reduce_statistics_kernel<scalar_t, accscalar_t, index_t> <<<grid, block, 0, stream>>>
      (mean, invstd, save_mean, save_invstd, running_mean, running_var, epsilon, momentum, counts);
  C10_CUDA_KERNEL_LAUNCH_CHECK();

  return std::make_tuple(save_mean_, save_invstd_);
}

template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
std::tuple<Tensor, Tensor, Tensor, Tensor> batch_norm_backward_reduce_cuda_template(const Tensor& grad_out_, const Tensor& input_,
                                                                                    const Tensor& mean_, const Tensor& invstd_, const Tensor& weight_,
                                                                                    const bool input_g, const bool weight_g, const bool bias_g) {

  using stat_accscalar_t = at::acc_type<stat_scalar_t, true>;
  int64_t n_input = input_.size(1);
  Tensor sum_dy_;
  Tensor sum_dy_xmu_;
  Tensor grad_weight_;
  Tensor grad_bias_;
  auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
  auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes());

  if (input_g) {
    sum_dy_ = at::empty_like(mean_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
    sum_dy_xmu_ = at::empty_like(mean_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
  }
  if (weight_g) {
    grad_weight_ = at::empty({n_input}, weight_.options());
  }
  if (bias_g) {
    grad_bias_ = at::empty({n_input}, weight_.options());
  }

  auto input = get_packed_accessor<
      input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input");
  auto grad_output = get_packed_accessor<
      input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output");
  auto grad_weight = packed_accessor_or_dummy<
      stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_weight_, "grad_weight");
  auto grad_bias = packed_accessor_or_dummy<
      stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_bias_, "grad_bias");
  auto mean = packed_accessor_or_dummy<
      stat_accscalar_t, 1, DefaultPtrTraits, index_t>(mean_, "mean");
  auto invstd = packed_accessor_or_dummy<
      stat_accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_, "invstd");
  auto sum_dy = packed_accessor_or_dummy<
      stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_, "sum_dy");
  auto sum_dy_xmu = packed_accessor_or_dummy<
      stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_xmu_, "sum_dy_xmu");

  auto batch_size = input_reshaped.size(0);
  auto feature_size = input_reshaped.size(2);
  auto stream = at::cuda::getCurrentCUDAStream();

  int warp_size = at::cuda::warp_size();
  int block_y = std::min<int>(lastPow2(batch_size), MAX_BLOCK_SIZE/warp_size);
  // We want block_x to be at least a warp width
  int block_x = std::min<int>(std::max<int>(getNumThreads(feature_size), warp_size), MAX_BLOCK_SIZE/block_y);
  const dim3 block(block_x, block_y);
  const dim3 grid(n_input);

  batch_norm_backward_reduce_kernel<input_scalar_t, stat_scalar_t, stat_accscalar_t, index_t> <<<grid, block, 0, stream>>>
    (input, grad_output, mean, invstd, sum_dy, sum_dy_xmu, grad_weight, grad_bias);
  C10_CUDA_KERNEL_LAUNCH_CHECK();

  return std::make_tuple(sum_dy_, sum_dy_xmu_, grad_weight_, grad_bias_);
}

template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
Tensor batch_norm_backward_elemt_cuda_template(const Tensor& grad_out_, const Tensor& input_,
                                               const Tensor& mean_, const Tensor& invstd_,
                                               const Tensor& weight_, const Tensor& sum_dy_, const Tensor& sum_dy_xmu_) {

  using stat_accscalar_t = at::acc_type<stat_scalar_t, true>;
  int64_t n_input = input_.size(1);
  auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
  auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes());
  auto grad_input_reshaped = at::empty_like(input_reshaped, LEGACY_CONTIGUOUS_MEMORY_FORMAT);

  auto input = get_packed_accessor<
      input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input");
  auto grad_input = get_packed_accessor<
      input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_input_reshaped, "grad_input");
  auto grad_output = get_packed_accessor<
      input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output");
  auto mean = packed_accessor_or_dummy<
      stat_accscalar_t, 1, DefaultPtrTraits, index_t>(mean_, "mean");
  auto invstd = packed_accessor_or_dummy<
      stat_accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_, "invstd");
  auto weight = packed_accessor_or_dummy<
      stat_scalar_t, 1, DefaultPtrTraits, index_t>(weight_, "weight");
  auto sum_dy = packed_accessor_or_dummy<
      stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_, "sum_dy");
  auto sum_dy_xmu = packed_accessor_or_dummy<
      stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_xmu_, "sum_dy_xmu");

  auto stream = at::cuda::getCurrentCUDAStream();

  // The kernel is pointwise, but we need to balance reading parameters (save_var/mean,
  // weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks
  // and good occupancy. Quiet likely, we could go with even more blocks than 1024.
  // The various planes are independent, so we use blocks for them.
  int tf = std::max<int>(getNumThreads(input.size(2)/4),
                         std::min<int>(getNumThreads(input.size(2)), 64));
  int tb = std::max<int>(64/tf, 1);
  dim3 blocks_trans(input.size(1), std::max<int>(1, std::min<int>((256*1024)/input.size(1),
                                                                  (input.size(0)+tb-1)/tb)));
  blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE);
  dim3 threads_trans(tf, tb);
  auto reduction_size = input_.numel() / n_input;
  auto norm_fct = static_cast<stat_accscalar_t>(1.0 / reduction_size);
  batch_norm_backward_elemt_kernel<input_scalar_t, stat_scalar_t, stat_accscalar_t, index_t>
      <<<blocks_trans, threads_trans, 0, stream>>>
      (input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, norm_fct);
  C10_CUDA_KERNEL_LAUNCH_CHECK();

  return grad_input_reshaped.view(input_.sizes());
}

template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
Tensor batch_norm_backward_elemt_cuda_template(const Tensor& grad_out_, const Tensor& input_,
                                               const Tensor& mean_, const Tensor& invstd_,
                                               const Tensor& weight_, const Tensor& sum_dy_, const Tensor& sum_dy_xmu_, const Tensor& count) {

  using stat_accscalar_t = at::acc_type<stat_scalar_t, true>;
  int64_t n_input = input_.size(1);
  auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
  auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes());
  auto grad_input_reshaped = at::empty_like(input_reshaped, LEGACY_CONTIGUOUS_MEMORY_FORMAT);

  auto input = get_packed_accessor<
      input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input");
  auto grad_input = get_packed_accessor<
      input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_input_reshaped, "grad_input");
  auto grad_output = get_packed_accessor<
      input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output");
  auto mean = packed_accessor_or_dummy<
      stat_accscalar_t, 1, DefaultPtrTraits, index_t>(mean_, "mean");
  auto invstd = packed_accessor_or_dummy<
      stat_accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_, "invstd");
  auto weight = packed_accessor_or_dummy<
      stat_scalar_t, 1, DefaultPtrTraits, index_t>(weight_, "weight");
  auto sum_dy = packed_accessor_or_dummy<
      stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_, "sum_dy");
  auto sum_dy_xmu = packed_accessor_or_dummy<
      stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_xmu_, "sum_dy_xmu");

  auto stream = at::cuda::getCurrentCUDAStream();

  // The kernel is pointwise, but we need to balance reading parameters (save_var/mean,
  // weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks
  // and good occupancy. Quiet likely, we could go with even more blocks than 1024.
  // The various planes are independent, so we use blocks for them.
  int tf = std::max<int>(getNumThreads(input.size(2)/4),
                         std::min<int>(getNumThreads(input.size(2)), 64));
  int tb = std::max<int>(64/tf, 1);
  dim3 blocks_trans(input.size(1), std::max<int>(1, std::min<int>((256*1024)/input.size(1),
                                                                  (input.size(0)+tb-1)/tb)));
  blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE);
  dim3 threads_trans(tf, tb);
  batch_norm_backward_elemt_kernel<input_scalar_t, stat_scalar_t, stat_accscalar_t, index_t> <<<blocks_trans, threads_trans, 0, stream>>>
    (input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, count.data_ptr<int>(), count.numel());
  C10_CUDA_KERNEL_LAUNCH_CHECK();

  return grad_input_reshaped.view(input_.sizes());
}

// welford kernel for c last tensor calculating mean/biased_variance/unbiased_variance
// original apex name: welford_kernel_c_last
template
   <typename VarTransform,
    typename scalar_t,
    typename accscalar_t,
    int PARALLEL_LOADS>
__global__ void
batch_norm_collect_statistics_channels_last_kernel(
      const scalar_t* __restrict__ input,
      accscalar_t* __restrict__ out_mean,
      accscalar_t* __restrict__ out_invstd,
      volatile accscalar_t* staging_data,
      int* semaphores,
      const int reduction_size,
      const int stride,
      accscalar_t epsilon) {
  // hide latency with concurrency
  accscalar_t x_mean[PARALLEL_LOADS];
  accscalar_t m_2_n[PARALLEL_LOADS];
  int count[PARALLEL_LOADS];

#pragma unroll
  for (int i = 0; i < PARALLEL_LOADS; i++) {
    x_mean[i] = accscalar_t(0);
    m_2_n[i] = accscalar_t(0);
    count[i] = accscalar_t(0);
  }
  // tensor dimension (m,c)

  // loop along m dimension
  int inner_loop_stride = blockDim.y * gridDim.y;

  // offset along m dimension
  int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
  int c_offset = blockIdx.x * blockDim.x + threadIdx.x;

  int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
  int address_base = m_offset * stride + c_offset;
  int address_increment = inner_loop_stride * stride;

  for (int i = 0; i < loop_count; i++) {
    accscalar_t x_math[PARALLEL_LOADS];
    accscalar_t x_count_inv[PARALLEL_LOADS];
    accscalar_t is_valid[PARALLEL_LOADS];

    // load multiple data in
#pragma unroll
    for (int j = 0; j < PARALLEL_LOADS; j++) {
      if (c_offset < stride && m_offset < reduction_size) {
        x_math[j] = input[address_base];
        count[j]++;
        x_count_inv[j] = accscalar_t(1) / count[j];
        is_valid[j] = accscalar_t(1);
      } else {
        x_math[j] = accscalar_t(0);
        x_count_inv[j] = accscalar_t(0);
        is_valid[j] = accscalar_t(0);
      }
      m_offset += inner_loop_stride;
      address_base += address_increment;
    }

    // calculate mean/m2n with welford
#pragma unroll
    for (int j = 0; j < PARALLEL_LOADS; j++) {
      accscalar_t delta0 = x_math[j] - x_mean[j];
      x_mean[j] += delta0 * x_count_inv[j];
      accscalar_t delta1 = x_math[j] - x_mean[j];
      m_2_n[j] += delta0 * delta1 * is_valid[j];
    }
  }

  // thread reduction to accumulate mean/m_2_n/count between PARALLEL_LOADS
#pragma unroll
  for (int j = 1; j < PARALLEL_LOADS; j++) {
    welford_merge_element(count[0], x_mean[0], m_2_n[0], count[j], x_mean[j], m_2_n[j]);
  }

  // release x_mean / m_2_n
  auto mean_th = x_mean[0];
  auto m2_th = m_2_n[0];
  auto count_th = count[0];

  // block-wise reduction with shared memory (since reduction cannot be done within a warp)
  static __shared__ accscalar_t shmem_mean[MAX_BLOCK_SIZE];
  static __shared__ accscalar_t shmem_m2n[MAX_BLOCK_SIZE];
  static __shared__ int shmem_count[MAX_BLOCK_SIZE];

  welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n);

  if (gridDim.y > 1) {
    volatile accscalar_t* staging_mean = staging_data;
    volatile accscalar_t* staging_m2n = &staging_data[stride*gridDim.y];
    volatile int* staging_count = reinterpret_cast<volatile int*>(&staging_m2n[stride*gridDim.y]);

    address_base = c_offset + blockIdx.y * stride;
    // write data to staging_data;
    if (threadIdx.y == 0 && c_offset < stride) {
      staging_mean[address_base] = mean_th;
      staging_m2n[address_base] = m2_th;
      staging_count[address_base] = count_th;
    }

    __threadfence();
    __syncthreads(); // ensuring writes to staging_ is visible to all blocks

    __shared__ bool is_last_block_done;
    // mark block done
    if (threadIdx.x == 0 && threadIdx.y == 0) {
      int old = atomicAdd(&semaphores[blockIdx.x], 1);
      is_last_block_done = (old == (gridDim.y-1));
    }

    __syncthreads();

    // check that all data is now available in global memory
    if (is_last_block_done) {
      count_th = 0;
      mean_th = accscalar_t(0.0);
      m2_th = accscalar_t(0.0);

      for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) {
        address_base = c_offset + y * stride;
        int count_new = c_offset < stride ? staging_count[address_base] : 0;
        accscalar_t mean_new = c_offset < stride ? staging_mean[address_base] : accscalar_t(0.0);
        accscalar_t m2n_new = c_offset < stride ? staging_m2n[address_base] : accscalar_t(0.0);

        welford_merge_element(count_th, mean_th, m2_th, count_new, mean_new, m2n_new);
      }

      welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n);
      if (threadIdx.y == 0 && c_offset < stride) {
        out_mean[c_offset] = static_cast<accscalar_t>(mean_th);
        out_invstd[c_offset] = VarTransform{}(m2_th/count_th, epsilon);
      }
    }
  } else {
    if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) {
      out_mean[c_offset] = static_cast<accscalar_t>(mean_th);
      out_invstd[c_offset] = VarTransform{}(m2_th/count_th, epsilon);
    }
  }
}

// elementwise BN kernel
// original apex name: batchnorm_forward_c_last_kernel
template <
    typename scalar_t,
    typename accscalar_t,
    typename layerscalar_t,
    int PARALLEL_LOADS>
__global__ void batch_norm_transform_input_channels_last_kernel(
      const scalar_t* __restrict__ input,
      const scalar_t* __restrict__ z,
      const accscalar_t* __restrict__ mean,
      const accscalar_t* __restrict__ inv_std,
      const layerscalar_t* __restrict__ weight,
      const layerscalar_t* __restrict__ shift,
      scalar_t* __restrict__ out,
      const int reduction_size,
      const int stride,
      const bool fuse_relu) {
  // tensor dimension (m,c)
  // loop along m dimension
  int inner_loop_stride = blockDim.y * gridDim.y;

  // offset along m dimension
  int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
  int c_offset = blockIdx.x * blockDim.x + threadIdx.x;

  if (c_offset >= stride || m_offset >= reduction_size) {
    return;
  }

  auto m_c = mean[c_offset];
  auto inv_std_c = static_cast<accscalar_t>(inv_std[c_offset]);
  auto w_c = weight == nullptr ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset]);
  auto s_c = shift == nullptr ? accscalar_t(0.0) : static_cast<accscalar_t>(shift[c_offset]);

  int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
  int address_base = m_offset * stride + c_offset;
  int address_increment = inner_loop_stride * stride;

  for (int i = 0; i < loop_count; i++) {
#pragma unroll
    for (int j = 0; j < PARALLEL_LOADS; j++) {
      if (c_offset < stride && m_offset < reduction_size) {
        auto tmp = w_c * (static_cast<accscalar_t>(input[address_base]) - m_c ) * inv_std_c + s_c;
        if (z != nullptr) {
          tmp += z[address_base];
        }
        out[address_base] = (fuse_relu && tmp <= accscalar_t(0.0) ? scalar_t(0.0) : static_cast<scalar_t>(tmp));
      }
      m_offset += inner_loop_stride;
      address_base += address_increment;
    }
  }
}

template<typename T>
__device__ __forceinline__ void merge_block_vertical_backward(T& sum_dy,
    T& sum_dy_xmu,
    T* shmem_sum_dy,
    T* shmem_sum_dy_xmu) {
  // write to shared memory
  auto address_base = threadIdx.x + threadIdx.y * blockDim.x;

#pragma unroll
  for (int offset = blockDim.y/2; offset > 0; offset >>= 1) {
    if (threadIdx.y < offset*2) {
      shmem_sum_dy[address_base] = sum_dy;
      shmem_sum_dy_xmu[address_base] = sum_dy_xmu;
    }
    __syncthreads();
    if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {
      auto address = address_base + offset * blockDim.x;

      sum_dy += shmem_sum_dy[address];
      sum_dy_xmu += shmem_sum_dy_xmu[address];
    }
  }
}

// batchnorm backward kernel for c last tensor
// original apex name: reduce_bn_c_last_kernel
template <
    int PARALLEL_LOADS,
    typename scalar_t,
    typename accscalar_t,
    typename layerscalar_t>
__global__ void batch_norm_backward_reduce_channels_last_kernel(
      const scalar_t* __restrict__ input,
      const scalar_t* __restrict__ grad_output,
      const accscalar_t* __restrict__ mean,
      const accscalar_t* __restrict__ inv_std,
      accscalar_t* __restrict__ sum_dy_o,
      accscalar_t* __restrict__ sum_dy_xmu_o,
      layerscalar_t* __restrict__ grad_weight,
      layerscalar_t* __restrict__ grad_bias,
      volatile accscalar_t* staging_data,
      int* semaphores,
      const int reduction_size,
      const int stride) {

  // hide latency with concurrency
  accscalar_t sum_dy[PARALLEL_LOADS];
  accscalar_t sum_dy_xmu[PARALLEL_LOADS];

#pragma unroll
  for (int i = 0; i < PARALLEL_LOADS; i++) {
    sum_dy[i] = accscalar_t(0);
    sum_dy_xmu[i] = accscalar_t(0);
  }
  // tensor dimension (m,c)

  // loop along m dimension
  int inner_loop_stride = blockDim.y * gridDim.y;

  // offset along m dimension
  int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
  int c_offset = blockIdx.x * blockDim.x + threadIdx.x;

  if (c_offset >= stride || m_offset >= reduction_size) {
    return;
  }

  int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
  int address_base = m_offset * stride + c_offset;
  int address_increment = inner_loop_stride * stride;

  auto r_mean = mean[c_offset];
  auto factor = inv_std[c_offset];

  for (int i = 0; i < loop_count; i++) {
    accscalar_t x_input[PARALLEL_LOADS];
    accscalar_t x_grad_output[PARALLEL_LOADS];

    // load multiple data in
#pragma unroll
    for (int j = 0; j < PARALLEL_LOADS; j++) {
      if (c_offset < stride && m_offset < reduction_size) {
        x_input[j] = input[address_base];
        x_grad_output[j] = grad_output[address_base];
      } else {
        x_input[j] = accscalar_t(0);
        x_grad_output[j] = accscalar_t(0);
      }
      m_offset += inner_loop_stride;
      address_base += address_increment;
    }

    // calculate sum_dy / sum_dy_xmu
#pragma unroll
    for (int j = 0; j < PARALLEL_LOADS; j++) {
      sum_dy[j] += x_grad_output[j];
      sum_dy_xmu[j] += x_grad_output[j] * (x_input[j] - r_mean);
    }
  }

  // thread reduction to accumulate sum_dy / sum_dy_xmu between PARALLEL_LOADS
#pragma unroll
  for (int j = 1; j < PARALLEL_LOADS; j++) {
    sum_dy[0] += sum_dy[j];
    sum_dy_xmu[0] += sum_dy_xmu[j];
  }

  // release array of registers
  auto sum_dy_th = sum_dy[0];
  auto sum_dy_xmu_th = sum_dy_xmu[0];

  // block-wise reduction with shared memory (since reduction cannot be done within a warp)
  static __shared__ accscalar_t shmem_sum_dy[MAX_BLOCK_SIZE];
  static __shared__ accscalar_t shmem_sum_dy_xmu[MAX_BLOCK_SIZE];

  merge_block_vertical_backward(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu);

  if (gridDim.y > 1) {
    volatile accscalar_t* staging_sum_dy = staging_data;
    volatile accscalar_t* staging_sum_dy_xmu = &staging_data[stride*gridDim.y];

    address_base = c_offset + blockIdx.y * stride;
    // write data to staging_data;
    if (threadIdx.y == 0 && c_offset < stride) {
      staging_sum_dy[address_base] = sum_dy_th;
      staging_sum_dy_xmu[address_base] = sum_dy_xmu_th;
    }

    __threadfence();
    __syncthreads(); // ensuring writes to staging_ is visible to all blocks

    __shared__ bool is_last_block_done;
    // mark block done
    if (threadIdx.x == 0 && threadIdx.y == 0) {
      int old = atomicAdd(&semaphores[blockIdx.x], 1);
      is_last_block_done = (old == (gridDim.y-1));
    }

    __syncthreads();

    // check that all data is now available in global memory
    if (is_last_block_done) {
      sum_dy_th = accscalar_t(0.0);
      sum_dy_xmu_th = accscalar_t(0.0);

      for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) {
        address_base = c_offset + y * stride;
        sum_dy_th += (c_offset < stride ? staging_sum_dy[address_base] : accscalar_t(0.0));
        sum_dy_xmu_th += (c_offset < stride ? staging_sum_dy_xmu[address_base] : accscalar_t(0.0));
      }

      merge_block_vertical_backward(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu);
      if (threadIdx.y == 0 && c_offset < stride) {
        if (grad_bias != nullptr) {
          grad_bias[c_offset] = static_cast<layerscalar_t>(sum_dy_th);
        }
        if (grad_weight != nullptr) {
          grad_weight[c_offset] = static_cast<layerscalar_t>(sum_dy_xmu_th * factor);
        }
        //mean_dy[c_offset] = sum_dy_th / reduction_size;
        //mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;
        sum_dy_o[c_offset] = sum_dy_th;
        sum_dy_xmu_o[c_offset] = sum_dy_xmu_th;
      }
    }
  } else {
    if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) {
      if (grad_bias != nullptr) {
        grad_bias[c_offset] = static_cast<layerscalar_t>(sum_dy_th);
      }
      if (grad_weight != nullptr) {
        grad_weight[c_offset] = static_cast<layerscalar_t>(sum_dy_xmu_th * factor);
      }
      //mean_dy[c_offset] = sum_dy_th / reduction_size;
      //mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;
      sum_dy_o[c_offset] = sum_dy_th;
      sum_dy_xmu_o[c_offset] = sum_dy_xmu_th;
    }
  }
}

// elementwise BN kernel
// original apex name: batchnorm_backward_c_last_kernel
template <
    int PARALLEL_LOADS,
    typename scalar_t,
    typename accscalar_t,
    typename layerscalar_t>
__device__ __forceinline__ void batch_norm_backward_elemt_channels_last_kernel_impl(
      const scalar_t* __restrict__ grad_output,
      const scalar_t* __restrict__ input,
      const accscalar_t* __restrict__ mean,
      const accscalar_t* __restrict__ inv_std,
      const layerscalar_t* __restrict__ weight,
      const accscalar_t* __restrict__ sum_dy,
      const accscalar_t* __restrict__ sum_dy_xmu,
      scalar_t* __restrict__ grad_input,
      const accscalar_t norm_fct,
      const int reduction_size,
      const int stride) {
  // tensor dimension (m,c)
  // loop along m dimension
  int inner_loop_stride = blockDim.y * gridDim.y;

  // offset along m dimension
  int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
  int c_offset = blockIdx.x * blockDim.x + threadIdx.x;

  if (c_offset >= stride || m_offset >= reduction_size) {
    return;
  }

  auto m_c = mean[c_offset];
  auto m_dy_c = sum_dy[c_offset] * norm_fct;
  auto factor_1_c = inv_std[c_offset];
  auto factor_2_c = (weight == nullptr? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset])) * factor_1_c;
  factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[c_offset] * norm_fct;

  int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
  int address_base = m_offset * stride + c_offset;
  int address_increment = inner_loop_stride * stride;

  for (int i = 0; i < loop_count; i++) {
#pragma unroll
    for (int j = 0; j < PARALLEL_LOADS; j++) {
      if (c_offset < stride && m_offset < reduction_size) {
        grad_input[address_base] = static_cast<scalar_t>(
            (static_cast<accscalar_t>(grad_output[address_base]) - m_dy_c -
            (static_cast<accscalar_t>(input[address_base]) - m_c) * factor_1_c)
            * factor_2_c);
      }
      m_offset += inner_loop_stride;
      address_base += address_increment;
    }
  }
}

template <
    int PARALLEL_LOADS,
    typename scalar_t,
    typename accscalar_t,
    typename layerscalar_t>
__global__ void batch_norm_backward_elemt_channels_last_kernel(
      const scalar_t* __restrict__ grad_output,
      const scalar_t* __restrict__ input,
      const accscalar_t* __restrict__ mean,
      const accscalar_t* __restrict__ inv_std,
      const layerscalar_t* __restrict__ weight,
      const accscalar_t* __restrict__ sum_dy,
      const accscalar_t* __restrict__ sum_dy_xmu,
      const int* __restrict__ numel,
      scalar_t* __restrict__ grad_input,
      const int64_t world_size,
      const int reduction_size,
      const int stride) {

  int64_t total_numel = 0;
  for (int i = 0; i < world_size; i++) {
    total_numel += numel[i];
  }

  auto norm_fct = static_cast<accscalar_t>(1) / static_cast<accscalar_t>(total_numel);
  batch_norm_backward_elemt_channels_last_kernel_impl<PARALLEL_LOADS>(
      grad_output, input, mean, inv_std, weight, sum_dy, sum_dy_xmu,
      grad_input, norm_fct, reduction_size, stride);
}

template <
    int PARALLEL_LOADS,
    typename scalar_t,
    typename accscalar_t,
    typename layerscalar_t>
__global__ void batch_norm_backward_elemt_channels_last_kernel(
      const scalar_t* __restrict__ grad_output,
      const scalar_t* __restrict__ input,
      const accscalar_t* __restrict__ mean,
      const accscalar_t* __restrict__ inv_std,
      const layerscalar_t* __restrict__ weight,
      const accscalar_t* __restrict__ sum_dy,
      const accscalar_t* __restrict__ sum_dy_xmu,
      scalar_t* __restrict__ grad_input,
      const accscalar_t norm_fct,
      const int reduction_size,
      const int stride) {
  batch_norm_backward_elemt_channels_last_kernel_impl<PARALLEL_LOADS>(
      grad_output, input, mean, inv_std, weight, sum_dy, sum_dy_xmu,
      grad_input, norm_fct, reduction_size, stride);
}

template<typename scalar_t, typename VarTransform>
void batch_norm_stats_channels_last_cuda_template(
    const Tensor& out_mean, const Tensor& out_invstd, const Tensor& input, double epsilon) {
  using accscalar_t = at::acc_type<scalar_t, true>;

  const auto stride = input.sizes()[1];
  const auto reduction_size = input.numel() / stride;

  resize_output(out_mean, {stride});
  resize_output(out_invstd, {stride});
  TORCH_INTERNAL_ASSERT(out_invstd.dim() == 1 && out_invstd.is_contiguous() &&
                        out_invstd.sizes()[0]);
  TORCH_INTERNAL_ASSERT(out_mean.dim() == 1 && out_mean.is_contiguous() &&
                        out_mean.sizes()[0]);

  dim3 block;
  dim3 grid;
  flexible_launch_configs(reduction_size, stride, block, grid, true);

  at::Tensor staging_data;
  at::Tensor semaphores;
  if (grid.y > 1) {
    staging_data = at::empty({4*stride*grid.y}, out_mean.options());
    semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt));
  }

  accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data_ptr<accscalar_t>() : nullptr;
  int* semaphores_ptr = grid.y > 1 ? semaphores.data_ptr<int>() : nullptr;
  batch_norm_collect_statistics_channels_last_kernel<VarTransform, scalar_t, accscalar_t, ELEMENTS_PER_ITER>
      <<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
      input.data_ptr<scalar_t>(),
      out_mean.data_ptr<accscalar_t>(),
      out_invstd.data_ptr<accscalar_t>(),
      staging_data_ptr,
      semaphores_ptr,
      reduction_size,
      stride,
      epsilon);
  C10_CUDA_KERNEL_LAUNCH_CHECK();
}

void batch_norm_elemt_channels_last_cuda_template(
    const at::Tensor& output,
    const at::Tensor& input,
    const at::Tensor& weight,
    const at::Tensor& shift,  // bias of BN
    const at::Tensor& mean,
    const at::Tensor& inv_std,
    const at::optional<at::Tensor>& z = c10::nullopt,  // bias after BN
    const bool fuse_relu = false) {
  const auto stride = input.sizes()[1];
  const auto reduction_size = input.numel() / stride;

  dim3 block;
  dim3 grid;
  flexible_launch_configs(reduction_size, stride, block, grid);

  auto stream = at::cuda::getCurrentCUDAStream();
  const auto second_dtype = weight.defined() ? weight.scalar_type() :
      (shift.defined() ? shift.scalar_type() : input.scalar_type());

  if (input.scalar_type() != second_dtype) {
    AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_forward", [&] {
      using accscalar_t = at::acc_type<scalar_t, true>;
      batch_norm_transform_input_channels_last_kernel<scalar_t, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
          <<<grid, block, 0, stream>>>(
          input.data_ptr<scalar_t>(),
          z.has_value() ? z.value().data_ptr<scalar_t>() : nullptr,
          mean.data_ptr<accscalar_t>(),
          inv_std.data_ptr<accscalar_t>(),
          weight.defined() ? weight.data_ptr<accscalar_t>() : nullptr,
          shift.defined() ? shift.data_ptr<accscalar_t>() : nullptr,
          output.data_ptr<scalar_t>(),
          reduction_size,
          stride,
          fuse_relu);
      C10_CUDA_KERNEL_LAUNCH_CHECK();
    });
  } else {
    if (weight.defined()){
      TORCH_CHECK(input.scalar_type() == weight.scalar_type(), "batchnorm_forward: input.scalar_type() ", input.scalar_type(),
        " is not supported with weight.scalar_type() ", weight.scalar_type());
    }
    AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_forward", [&] {
      using accscalar_t = at::acc_type<scalar_t, true>;
      batch_norm_transform_input_channels_last_kernel<scalar_t, accscalar_t, scalar_t, ELEMENTS_PER_ITER>
          <<<grid, block, 0, stream>>>(
          input.data_ptr<scalar_t>(),
          z.has_value() ? z.value().data_ptr<scalar_t>() : nullptr,
          mean.data_ptr<accscalar_t>(),
          inv_std.data_ptr<accscalar_t>(),
          weight.defined() ? weight.data_ptr<scalar_t>() : nullptr,
          shift.defined() ? shift.data_ptr<scalar_t>(): nullptr,
          output.data_ptr<scalar_t>(),
          reduction_size,
          stride,
          fuse_relu);
      C10_CUDA_KERNEL_LAUNCH_CHECK();
    });
  }
}

std::tuple<Tensor, Tensor, Tensor, Tensor>
batch_norm_backward_reduce_cuda_channels_last_template(const at::Tensor& grad_output,
    const at::Tensor& input,
    const at::Tensor& mean,
    const at::Tensor& inv_std,
    const at::Tensor& weight,
    const bool input_g, const bool weight_g, const bool bias_g) {
  const auto stride = input.sizes()[1];
  const auto reduction_size = input.numel() / stride;

  at::Tensor sumn_dy = at::empty({stride}, mean.options());
  at::Tensor sum_dy_xmu = at::empty({stride}, mean.options());

  at::Tensor grad_weight;
  at::Tensor grad_bias;
  if (weight.defined()) {
    grad_weight = at::empty({stride}, weight.options());
    grad_bias = at::empty({stride}, weight.options());
  } else {
    // because I cannot return an uninitialized at::Tensor
    grad_weight = at::empty({0}, mean.options());
    grad_bias = at::empty({0}, mean.options());
  }

  dim3 block;
  dim3 grid;
  flexible_launch_configs(reduction_size, stride, block, grid, true);

  at::Tensor staging_data;
  at::Tensor semaphores;
  if (grid.y > 1) {
    staging_data = at::empty({2*stride*grid.y}, mean.options());
    semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt));
  }
  auto stream = at::cuda::getCurrentCUDAStream();

  if (weight.defined() && input.scalar_type() != weight.scalar_type()) {
    AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_reduce", [&] {
      using accscalar_t = at::acc_type<scalar_t, true>;
      accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data_ptr<accscalar_t>() : nullptr;
      int* semaphores_ptr = grid.y > 1 ? semaphores.data_ptr<int>() : nullptr;
      batch_norm_backward_reduce_channels_last_kernel<ELEMENTS_PER_ITER>
          <<<grid, block, 0, stream>>>(
          input.data_ptr<scalar_t>(),
          grad_output.data_ptr<scalar_t>(),
          mean.data_ptr<accscalar_t>(),
          inv_std.data_ptr<accscalar_t>(),
          sumn_dy.data_ptr<accscalar_t>(),
          sum_dy_xmu.data_ptr<accscalar_t>(),
          grad_weight.data_ptr<accscalar_t>(),
          grad_bias.data_ptr<accscalar_t>(),
          staging_data_ptr,
          semaphores_ptr,
          reduction_size,
          stride);
      C10_CUDA_KERNEL_LAUNCH_CHECK();
    });
  } else {
    if (weight.defined()) {
      TORCH_CHECK(input.scalar_type() == weight.scalar_type(), "batchnorm_backward_reduce: input.scalar_type() ", input.scalar_type(),
        " is not supported with weight.scalar_type() ", weight.scalar_type());
    }
    AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_reduce", [&] {
      using accscalar_t = at::acc_type<scalar_t, true>;
      accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data_ptr<accscalar_t>() : nullptr;
      int* semaphores_ptr = grid.y > 1 ? semaphores.data_ptr<int>() : nullptr;
      batch_norm_backward_reduce_channels_last_kernel<ELEMENTS_PER_ITER>
          <<<grid, block, 0, stream>>>(
          input.data_ptr<scalar_t>(),
          grad_output.data_ptr<scalar_t>(),
          mean.data_ptr<accscalar_t>(),
          inv_std.data_ptr<accscalar_t>(),
          sumn_dy.data_ptr<accscalar_t>(),
          sum_dy_xmu.data_ptr<accscalar_t>(),
          weight.defined() ? grad_weight.data_ptr<scalar_t>() : nullptr,
          weight.defined() ? grad_bias.data_ptr<scalar_t>() : nullptr,
          staging_data_ptr,
          semaphores_ptr,
          reduction_size,
          stride);
      C10_CUDA_KERNEL_LAUNCH_CHECK();
    });
  }

  return std::make_tuple(sumn_dy, sum_dy_xmu, grad_weight, grad_bias);
}

at::Tensor batch_norm_backward_elemt_channels_last_cuda_template(
    const at::Tensor& grad_output,
    const at::Tensor& input,
    const at::Tensor& mean,
    const at::Tensor& inv_std,
    const at::Tensor& weight,
    const at::Tensor& sum_dy,
    const at::Tensor& sum_dy_xmu,
    const at::Tensor& count) {
  const auto stride = input.sizes()[1];
  const auto reduction_size = input.numel() / stride;

  // Input is guarunteed to be channels-last compatible
  at::Tensor grad_input = at::empty_like(input);

  dim3 block;
  dim3 grid;
  flexible_launch_configs(reduction_size, stride, block, grid);

  auto stream = at::cuda::getCurrentCUDAStream();

  if (weight.defined() && weight.scalar_type() != input.scalar_type()) {
    AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_element", [&] {
      using accscalar_t = at::acc_type<scalar_t, true>;
      batch_norm_backward_elemt_channels_last_kernel<ELEMENTS_PER_ITER>
          <<<grid, block, 0, stream>>>(
          grad_output.data_ptr<scalar_t>(),
          input.data_ptr<scalar_t>(),
          mean.data_ptr<accscalar_t>(),
          inv_std.data_ptr<accscalar_t>(),
          weight.data_ptr<accscalar_t>(),
          sum_dy.data_ptr<accscalar_t>(),
          sum_dy_xmu.data_ptr<accscalar_t>(),
          count.data_ptr<int>(),
          grad_input.data_ptr<scalar_t>(),
          count.numel(),
          reduction_size,
          stride);
      C10_CUDA_KERNEL_LAUNCH_CHECK();
    });
  } else {
    if (weight.defined()) {
      TORCH_CHECK(input.scalar_type() == weight.scalar_type(), "batchnorm_backward_element: input.scalar_type() ", input.scalar_type(),
        " is not supported with weight.scalar_type() ", weight.scalar_type());
    }
    AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "batchnorm_backward_element", [&] {
      using accscalar_t = at::acc_type<scalar_t, true>;
      batch_norm_backward_elemt_channels_last_kernel<ELEMENTS_PER_ITER>
          <<<grid, block, 0, stream>>>(
          grad_output.data_ptr<scalar_t>(),
          input.data_ptr<scalar_t>(),
          mean.data_ptr<accscalar_t>(),
          inv_std.data_ptr<accscalar_t>(),
          weight.defined() ? weight.data_ptr<scalar_t>() : nullptr,
          sum_dy.data_ptr<accscalar_t>(),
          sum_dy_xmu.data_ptr<accscalar_t>(),
          count.data_ptr<int>(),
          grad_input.data_ptr<scalar_t>(),
          count.numel(),
          reduction_size,
          stride);
      C10_CUDA_KERNEL_LAUNCH_CHECK();
    });
  }

  return grad_input;
}

at::Tensor batch_norm_backward_elemt_channels_last_cuda_template(
    const at::Tensor& grad_output,
    const at::Tensor& input,
    const at::Tensor& mean,
    const at::Tensor& inv_std,
    const at::Tensor& weight,
    const at::Tensor& sum_dy,
    const at::Tensor& sum_dy_xmu) {
  const auto stride = input.sizes()[1];
  const auto reduction_size = input.numel() / stride;
  auto norm_fct = 1.0 / reduction_size;

  // Input is guarunteed to be channels-last compatible
  at::Tensor grad_input = at::empty_like(input);

  dim3 block;
  dim3 grid;
  flexible_launch_configs(reduction_size, stride, block, grid);

  auto stream = at::cuda::getCurrentCUDAStream();

  AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_element", [&] {
    using accscalar_t = at::acc_type<scalar_t, true>;

    if (weight.defined() && weight.scalar_type() != input.scalar_type()) {
      batch_norm_backward_elemt_channels_last_kernel<ELEMENTS_PER_ITER>
          <<<grid, block, 0, stream>>>(
          grad_output.data_ptr<scalar_t>(),
          input.data_ptr<scalar_t>(),
          mean.data_ptr<accscalar_t>(),
          inv_std.data_ptr<accscalar_t>(),
          weight.data_ptr<accscalar_t>(),
          sum_dy.data_ptr<accscalar_t>(),
          sum_dy_xmu.data_ptr<accscalar_t>(),
          grad_input.data_ptr<scalar_t>(),
          static_cast<accscalar_t>(norm_fct),
          reduction_size,
          stride);
          C10_CUDA_KERNEL_LAUNCH_CHECK();
    } else {
      batch_norm_backward_elemt_channels_last_kernel<ELEMENTS_PER_ITER>
          <<<grid, block, 0, stream>>>(
          grad_output.data_ptr<scalar_t>(),
          input.data_ptr<scalar_t>(),
          mean.data_ptr<accscalar_t>(),
          inv_std.data_ptr<accscalar_t>(),
          weight.defined() ? weight.data_ptr<scalar_t>() : nullptr,
          sum_dy.data_ptr<accscalar_t>(),
          sum_dy_xmu.data_ptr<accscalar_t>(),
          grad_input.data_ptr<scalar_t>(),
          static_cast<accscalar_t>(norm_fct),
          reduction_size,
          stride);
          C10_CUDA_KERNEL_LAUNCH_CHECK();
    }
  });

  return grad_input;
}

} } // namespace at::native