Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / caffe2 / operators / layer_norm_op.h

#ifndef CAFFE2_OPERATORS_LAYER_NORM_OP_H_
#define CAFFE2_OPERATORS_LAYER_NORM_OP_H_

#include <array>
#include <vector>

#include "caffe2/core/context.h"
#include "caffe2/core/export_caffe2_op_to_c10.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/types.h"
#include "caffe2/utils/math.h"

C10_DECLARE_EXPORT_CAFFE2_OP_TO_C10(LayerNorm)

namespace caffe2 {

template <class Context>
class LayerNormOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;

  template <class... Args>
  explicit LayerNormOp(Args&&... args)
      : Operator<Context>(std::forward<Args>(args)...),
        OP_SINGLE_ARG(int, "axis", axis_, 1),
        OP_SINGLE_ARG(float, "epsilon", epsilon_, 1e-5f),
        OP_SINGLE_ARG(bool, "elementwise_affine", elementwise_affine_, false) {}

  bool RunOnDevice() override {
    return DispatchHelper<TensorTypes<float, double>>::call(this, Input(0));
  }

  template <typename T>
  bool DoRunWithType() {
    const auto& X = Input(0);
    auto* Y = Output(0);
    CAFFE_ENFORCE_GE(X.dim(), 2, "LayerNorm requires input dim >= 2.");
    const int canonical_axis = X.canonical_axis_index(axis_);
    std::vector<int64_t> moments_dims(
        X.sizes().cbegin(), X.sizes().cbegin() + canonical_axis);
    moments_dims.push_back(1);
    auto* mean = Output(1, moments_dims, at::dtype<T>());
    auto* sigma = Output(2, moments_dims, at::dtype<T>());
    const int M = X.size_to_dim(canonical_axis);
    const int N = X.size_from_dim(canonical_axis);
    Y->ResizeLike(X);
    scale_.Resize(M);
    bias_.Resize(M);
    const T* X_data = X.template data<T>();
    T* Y_data = Y->template mutable_data<T>();
    T* mean_data = mean->template mutable_data<T>();
    T* sigma_data = sigma->template mutable_data<T>();
    T* scale_data = scale_.template mutable_data<T>();
    T* bias_data = bias_.template mutable_data<T>();

    if (M == 0) {
      return true;
    }

    const std::array<int, 2> X_dims = {M, N};
    const std::array<int, 2> Y_dims = {M, 1};
    math::Moments<T, Context>(
        2,
        X_dims.data(),
        Y_dims.data(),
        X_data,
        mean_data,
        sigma_data,
        &context_);
    ComputeSigmaAndFusedParams<T>(
        M, epsilon_, mean_data, sigma_data, sigma_data, scale_data, bias_data);
    const T* gamma_data = nullptr;
    const T* beta_data = nullptr;
    if (elementwise_affine_) {
      CAFFE_ENFORCE_EQ(InputSize(), 3);
      const auto& gamma = Input(1);
      const auto& beta = Input(2);
      CAFFE_ENFORCE_EQ(gamma.numel(), N);
      CAFFE_ENFORCE_EQ(beta.numel(), N);
      gamma_data = gamma.template data<T>();
      beta_data = beta.template data<T>();
    }
    LayerNormForward<T>(
        M, N, X_data, scale_data, bias_data, gamma_data, beta_data, Y_data);
    return true;
  }

 private:
  template <typename T>
  void ComputeSigmaAndFusedParams(
      const int N,
      const float eps,
      const T* mean,
      const T* var,
      T* stddev,
      T* scale,
      T* bias);

  template <typename T>
  void LayerNormForward(
      const int M,
      const int N,
      const T* X,
      const T* scale,
      const T* bias,
      const T* gamma,
      const T* beta,
      T* Y);

  const int axis_;
  const float epsilon_;
  const bool elementwise_affine_;

  Tensor scale_{Context::GetDeviceType()};
  Tensor bias_{Context::GetDeviceType()};
};

template <class Context>
class LayerNormGradientOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  template <class... Args>
  explicit LayerNormGradientOp(Args&&... args)
      : Operator<Context>(std::forward<Args>(args)...),
        OP_SINGLE_ARG(int, "axis", axis_, 1),
        OP_SINGLE_ARG(bool, "elementwise_affine", elementwise_affine_, false) {}

  bool RunOnDevice() override {
    return DispatchHelper<TensorTypes<float>>::call(this, Input(0));
  }

  template <typename T>
  bool DoRunWithType() {
    const auto& dY = Input(0);
    const auto& Y = Input(1);
    const auto& mean = Input(2);
    const auto& sigma = Input(3);
    const auto& X = Input(4);

    const int canonical_axis = X.canonical_axis_index(axis_);
    const int M = X.size_to_dim(canonical_axis);
    const int N = X.size_from_dim(canonical_axis);

    auto* dX = Output(0, X.sizes(), at::dtype<T>());
    ReinitializeTensor(
        &ds_, {M}, at::dtype<T>().device(Context::GetDeviceType()));
    ReinitializeTensor(
        &db_, {M}, at::dtype<T>().device(Context::GetDeviceType()));
    ReinitializeTensor(
        &rstd_, {M}, at::dtype<T>().device(Context::GetDeviceType()));
    ReinitializeTensor(
        &X_scale_, {M}, at::dtype<T>().device(Context::GetDeviceType()));
    ReinitializeTensor(
        &bias_, {M}, at::dtype<T>().device(Context::GetDeviceType()));
    const T* dY_data = dY.template data<T>();
    const T* X_data = X.template data<T>();
    const T* mean_data = mean.template data<T>();
    const T* sigma_data = sigma.template data<T>();
    T* dX_data = dX->template mutable_data<T>();
    T* ds_data = ds_.template mutable_data<T>();
    T* db_data = db_.template mutable_data<T>();
    T* rstd_data = rstd_.template mutable_data<T>();
    T* X_scale_data = X_scale_.template mutable_data<T>();
    T* bias_data = bias_.template mutable_data<T>();

    const T* gamma_data = nullptr;
    T* dgamma_data = nullptr;
    T* dbeta_data = nullptr;
    T* g_scale_data = nullptr;
    if (elementwise_affine_) {
      const auto& gamma = Input(5);
      auto* dgamma = Output(1, gamma.sizes(), at::dtype<T>());
      auto* dbeta = Output(2, gamma.sizes(), at::dtype<T>());
      ReinitializeTensor(
          &g_scale_, {M}, at::dtype<T>().device(Context::GetDeviceType()));
      gamma_data = gamma.template data<T>();
      dgamma_data = dgamma->template mutable_data<T>();
      dbeta_data = dbeta->template mutable_data<T>();
      g_scale_data = g_scale_.template mutable_data<T>();
    }

    if (M == 0) {
      if (N > 0 && dgamma_data != nullptr) {
        math::Set<T, Context>(N, T(0), dgamma_data, &context_);
      }
      if (N > 0 && dbeta_data != nullptr) {
        math::Set<T, Context>(N, T(0), dbeta_data, &context_);
      }
      return true;
    }

    ComputeInternalGradients<T>(
        M, N, dY_data, X_data, gamma_data, dX_data, ds_data, db_data);
    ComputeFusedParams<T>(
        M,
        N,
        mean_data,
        sigma_data,
        ds_data,
        db_data,
        rstd_data,
        X_scale_data,
        bias_data,
        g_scale_data);
    if (elementwise_affine_) {
      GammaBetaBackward<T>(
          M,
          N,
          dX_data,
          dY_data,
          rstd_data,
          g_scale_data,
          dgamma_data,
          dbeta_data);
    }
    LayerNormBackward<T>(
        M,
        N,
        dY_data,
        X_data,
        gamma_data,
        rstd_data,
        X_scale_data,
        bias_data,
        dX_data);

    return true;
  }

 private:
  template <typename T>
  void ComputeInternalGradients(
      const int M,
      const int N,
      const T* dY,
      const T* X,
      const T* gamma,
      T* dYxX,
      T* ds,
      T* db);

  template <typename T>
  void ComputeFusedParams(
      const int M,
      const int N,
      const T* mean,
      const T* sigma,
      const T* ds,
      const T* db,
      T* rstd,
      T* X_scale,
      T* bias,
      T* g_scale);

  template <typename T>
  void LayerNormBackward(
      const int M,
      const int N,
      const T* dY,
      const T* X,
      const T* gamma,
      const T* dY_scale,
      const T* X_scale,
      const T* bias,
      T* dX);

  template <typename T>
  void GammaBetaBackward(
      const int M,
      const int N,
      const T* dYxX,
      const T* dY,
      const T* rstd,
      const T* g_scale,
      T* dgamma,
      T* dbeta);

  const int axis_;
  const bool elementwise_affine_;

  Tensor ds_;
  Tensor db_;
  Tensor rstd_;
  Tensor X_scale_;
  Tensor bias_;
  Tensor g_scale_;
  Tensor ones_;
};

} // namespace caffe2

#endif // CAFFE2_OPERATORS_LAYER_NORM_OP_H_