Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / caffe2 / operators / group_norm_op.h

#ifndef CAFFE2_OPERATORS_GROUP_NORM_OP_H_
#define CAFFE2_OPERATORS_GROUP_NORM_OP_H_

#include <array>
#include <string>
#include <vector>

#include "caffe2/core/common.h"
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"

namespace caffe2 {

template <typename T, class Context>
class GroupNormOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;

  template <class... Args>
  explicit GroupNormOp(Args&&... args)
      : Operator<Context>(std::forward<Args>(args)...),
        OP_SINGLE_ARG(int, "group", group_, 32),
        OP_SINGLE_ARG(float, "epsilon", epsilon_, 1e-5),
        order_(StringToStorageOrder(
            this->template GetSingleArgument<std::string>("order", "NCHW"))),
        OP_SINGLE_ARG(bool, OpSchema::Arg_IsTest, is_test_, true) {
    CAFFE_ENFORCE_NE(
        order_,
        StorageOrder::UNKNOWN,
        "order should be either \"NCHW\" or \"NHWC\".");
    if (!is_test_) {
      CAFFE_ENFORCE_EQ(OutputSize(), 3);
    }
  }

  bool RunOnDevice() override {
    const auto& X = Input(INPUT);
    const auto& gamma = Input(GAMMA);
    const auto& beta = Input(BETA);
    const int ndim = X.dim();
    const int N = X.dim32(0);
    const int C = order_ == StorageOrder::NCHW ? X.dim32(1) : X.dim32(ndim - 1);
    const size_t HxW = order_ == StorageOrder::NCHW
        ? X.size_from_dim(2)
        : X.size_between_dim(0, ndim - 1);
    CAFFE_ENFORCE_EQ(C % group_, 0);
    CAFFE_ENFORCE_EQ(gamma.numel(), C);
    CAFFE_ENFORCE_EQ(beta.numel(), C);
    const int G = group_;
    const int K = C / G;
    auto* Y = Output(OUTPUT, X.sizes(), at::dtype<T>());
    if (N == 0) {
      return true;
    }
    T* mu_data = nullptr;
    T* rsig_data = nullptr;
    if (OutputSize() == 3) {
      auto* mu = Output(MU, {N, G}, at::dtype<T>());
      auto* rsig = Output(INV_SIGMA, {N, G}, at::dtype<T>());
      mu_data = mu->template mutable_data<T>();
      rsig_data = rsig->template mutable_data<T>();
    } else {
      ReinitializeTensor(
          &mu_, {N, G}, at::dtype<T>().device(Context::GetDeviceType()));
      ReinitializeTensor(
          &rsig_, {N, G}, at::dtype<T>().device(Context::GetDeviceType()));
      mu_data = mu_.template mutable_data<T>();
      rsig_data = rsig_.template mutable_data<T>();
    }
    if (order_ == StorageOrder::NCHW) {
      return RunOnDeviceWithOrderNCHW(
          N,
          G,
          K,
          HxW,
          X.template data<T>(),
          gamma.template data<T>(),
          beta.template data<T>(),
          Y->template mutable_data<T>(),
          mu_data,
          rsig_data);
    } else {
      return RunOnDeviceWithOrderNHWC(
          N,
          G,
          K,
          HxW,
          X.template data<T>(),
          gamma.template data<T>(),
          beta.template data<T>(),
          Y->template mutable_data<T>(),
          mu_data,
          rsig_data);
    }
  }

 private:
  bool RunOnDeviceWithOrderNCHW(
      const int N,
      const int G,
      const int K,
      const int HxW,
      const T* X,
      const T* gamma,
      const T* beta,
      T* Y,
      T* mu,
      T* rsig) {
    const int C = G * K;
    ReinitializeTensor(
        &scale_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
    ReinitializeTensor(
        &bias_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
    T* scale_data = scale_.template mutable_data<T>();
    T* bias_data = bias_.template mutable_data<T>();
    const std::array<int, 2> X_dims = {N * G, K * HxW};
    const std::array<int, 2> Y_dims = {N * G, 1};
    math::Moments<T, Context>(
        2, X_dims.data(), Y_dims.data(), X, mu, rsig, &context_);
    math::InvStd<T, Context>(
        N * G, static_cast<T>(epsilon_), rsig, rsig, &context_);
    ComputeFusedParams(N, G, K, mu, rsig, gamma, beta, scale_data, bias_data);
    GroupNormForwardNCHW(N, C, HxW, X, scale_data, bias_data, Y);
    return true;
  }

  bool RunOnDeviceWithOrderNHWC(
      const int N,
      const int G,
      const int K,
      const int HxW,
      const T* X,
      const T* gamma,
      const T* beta,
      T* Y,
      T* mu,
      T* rsig) {
    const int C = G * K;
    ReinitializeTensor(
        &scale_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
    ReinitializeTensor(
        &bias_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
    T* scale_data = scale_.template mutable_data<T>();
    T* bias_data = bias_.template mutable_data<T>();
    const std::array<int, 4> X_dims = {N, HxW, G, K};
    const std::array<int, 4> Y_dims = {N, 1, G, 1};
    math::Moments<T, Context>(
        4, X_dims.data(), Y_dims.data(), X, mu, rsig, &context_);
    math::InvStd<T, Context>(
        N * G, static_cast<T>(epsilon_), rsig, rsig, &context_);
    ComputeFusedParams(N, G, K, mu, rsig, gamma, beta, scale_data, bias_data);
    GroupNormForwardNHWC(N, C, HxW, X, scale_data, bias_data, Y);
    return true;
  }

  void ComputeFusedParams(
      int N,
      int G,
      int K,
      const T* mu,
      const T* rsig,
      const T* gamma,
      const T* beta,
      T* scale,
      T* bias);

  void GroupNormForwardNCHW(
      const int N,
      const int C,
      const int HxW,
      const T* X,
      const T* scale,
      const T* bias,
      T* Y);

  void GroupNormForwardNHWC(
      const int N,
      const int C,
      const int HxW,
      const T* X,
      const T* scale,
      const T* bias,
      T* Y);

  const int group_;
  const float epsilon_;
  const StorageOrder order_;
  const bool is_test_;

  Tensor mu_;
  Tensor rsig_;
  Tensor scale_;
  Tensor bias_;

  // Input: X, gamma, beta
  // Output: Y, mu, inv_sig
  INPUT_TAGS(INPUT, GAMMA, BETA);
  OUTPUT_TAGS(OUTPUT, MU, INV_SIGMA);
};

template <typename T, class Context>
class GroupNormGradientOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;

  template <class... Args>
  explicit GroupNormGradientOp(Args&&... args)
      : Operator<Context>(std::forward<Args>(args)...),
        OP_SINGLE_ARG(int, "group", group_, 32),
        order_(StringToStorageOrder(
            this->template GetSingleArgument<std::string>("order", "NCHW"))) {
    CAFFE_ENFORCE_NE(
        order_,
        StorageOrder::UNKNOWN,
        "order should be either \"NCHW\" or \"NHWC\".");
  }

  bool RunOnDevice() override {
    const auto& dY = Input(OUTPUT_GRAD);
    const auto& X = Input(INPUT);
    const auto& gamma = Input(GAMMA);
    const auto& beta = Input(BETA);
    const auto& mu = Input(MU);
    const auto& rsig = Input(INV_SIGMA);
    const int ndim = X.dim();
    const int N = X.dim32(0);
    const int C = order_ == StorageOrder::NCHW ? X.dim32(1) : X.dim32(ndim - 1);
    const int HxW = X.numel() / (N * C);
    CAFFE_ENFORCE_EQ(C % group_, 0);
    CAFFE_ENFORCE_EQ(gamma.numel(), C);
    CAFFE_ENFORCE_EQ(beta.numel(), C);
    const int G = group_;
    const int K = C / G;
    auto* dX = Output(INPUT_GRAD, X.sizes(), at::dtype<T>());
    auto* dgamma = Output(GAMMA_GRAD, gamma.sizes(), at::dtype<T>());
    auto* dbeta = Output(BETA_GRAD, beta.sizes(), at::dtype<T>());
    if (order_ == StorageOrder::NCHW) {
      return RunOnDeviceWithOrderNCHW(
          N,
          G,
          K,
          HxW,
          dY.template data<T>(),
          X.template data<T>(),
          mu.template data<T>(),
          rsig.template data<T>(),
          gamma.template data<T>(),
          dX->template mutable_data<T>(),
          dgamma->template mutable_data<T>(),
          dbeta->template mutable_data<T>());
    } else {
      return RunOnDeviceWithOrderNHWC(
          N,
          G,
          K,
          HxW,
          dY.template data<T>(),
          X.template data<T>(),
          mu.template data<T>(),
          rsig.template data<T>(),
          gamma.template data<T>(),
          dX->template mutable_data<T>(),
          dgamma->template mutable_data<T>(),
          dbeta->template mutable_data<T>());
    }
  }

 protected:
  bool RunOnDeviceWithOrderNCHW(
      int N,
      int G,
      int K,
      int HxW,
      const T* dY_data,
      const T* X_data,
      const T* mu_data,
      const T* rsig_data,
      const T* gamma_data,
      T* dX_data,
      T* dgamma_data,
      T* dbeta_data);

  bool RunOnDeviceWithOrderNHWC(
      int N,
      int G,
      int K,
      int HxW,
      const T* dY_data,
      const T* X_data,
      const T* mu_data,
      const T* rsig_data,
      const T* gamma_data,
      T* dX_data,
      T* dgamma_data,
      T* dbeta_data);

  const int group_;
  const StorageOrder order_;

  Tensor ds_;
  Tensor db_;
  Tensor dY_scale_;
  Tensor X_scale_;
  Tensor bias_;
  Tensor ones_;

  // Input: dY, X, gamma, beta, mu, inv_sig
  // Output: dX, dgamma, dbeta
  INPUT_TAGS(OUTPUT_GRAD, INPUT, GAMMA, BETA, MU, INV_SIGMA);
  OUTPUT_TAGS(INPUT_GRAD, GAMMA_GRAD, BETA_GRAD);
};

} // namespace caffe2

#endif // CAFFE2_OPERATORS_GROUP_NORM_OP_H_