Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / caffe2 / operators / instance_norm_op.h

#ifndef CAFFE2_OPERATORS_INSTANCE_NORM_OP_H_
#define CAFFE2_OPERATORS_INSTANCE_NORM_OP_H_

#include <array>

#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"

namespace caffe2 {

template <typename T, class Context>
class InstanceNormOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;

  template <class... Args>
  explicit InstanceNormOp(Args&&... args)
      : Operator<Context>(std::forward<Args>(args)...),
        OP_SINGLE_ARG(float, "epsilon", epsilon_, 1e-5),
        order_(StringToStorageOrder(
            this->template GetSingleArgument<string>("order", "NCHW"))) {
    CAFFE_ENFORCE_GE(epsilon_, 0, "Must pass a nonnegative epsilon.");
    CAFFE_ENFORCE_NE(
        order_,
        StorageOrder::UNKNOWN,
        "order should be either \"NCHW\" or \"NHWC\".");
  }

  bool RunOnDevice() {
    const auto& X = Input(INPUT);
    const auto& gamma = Input(SCALE);
    const auto& beta = Input(BIAS);
    const int ndim = X.dim();
    const int64_t N = X.dim(0);
    const int64_t C = order_ == StorageOrder::NCHW ? X.dim(1) : X.dim(ndim - 1);
    const int64_t HxW = X.numel() / (N * C);
    CAFFE_ENFORCE_EQ(gamma.numel(), C);
    CAFFE_ENFORCE_EQ(beta.numel(), C);
    auto* Y = Output(OUTPUT, X.sizes(), at::dtype<T>());
    const T* X_data = X.template data<T>();
    const T* gamma_data = gamma.template data<T>();
    const T* beta_data = beta.template data<T>();
    T* Y_data = Y->template mutable_data<T>();
    T* mean_data = nullptr;
    T* rstd_data = nullptr;
    if (OutputSize() >= 2) {
      auto* mean = Output(MEAN, {N, C}, at::dtype<T>());
      mean_data = mean->template mutable_data<T>();
    } else {
      ReinitializeTensor(
          &mean_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
      mean_data = mean_.template mutable_data<T>();
    }
    if (OutputSize() >= 3) {
      auto* rstd = Output(RSTD, {N, C}, at::dtype<T>());
      rstd_data = rstd->template mutable_data<T>();
    } else {
      ReinitializeTensor(
          &rstd_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
      rstd_data = rstd_.template mutable_data<T>();
    }
    switch (order_) {
      case StorageOrder::NCHW: {
        return RunOnDeviceWithOrderNCHW(
            N,
            C,
            HxW,
            X_data,
            gamma_data,
            beta_data,
            Y_data,
            mean_data,
            rstd_data);
      }
      case StorageOrder::NHWC: {
        return RunOnDeviceWithOrderNHWC(
            N,
            C,
            HxW,
            X_data,
            gamma_data,
            beta_data,
            Y_data,
            mean_data,
            rstd_data);
      }
      default: {
        CAFFE_THROW("Unknown storage order: ", order_);
      }
    }
  }

 private:
  bool RunOnDeviceWithOrderNCHW(
      int64_t N,
      int64_t C,
      int64_t HxW,
      const T* X,
      const T* gamma,
      const T* beta,
      T* Y,
      T* mean,
      T* rstd);

  bool RunOnDeviceWithOrderNHWC(
      int64_t N,
      int64_t C,
      int64_t HxW,
      const T* X,
      const T* gamma,
      const T* beta,
      T* Y,
      T* mean,
      T* rstd);

  const float epsilon_;
  const StorageOrder order_;

  Tensor mean_;
  Tensor rstd_;
  Tensor scale_;
  Tensor bias_;

  INPUT_TAGS(INPUT, SCALE, BIAS);
  OUTPUT_TAGS(OUTPUT, MEAN, RSTD);
};

template <typename T, class Context>
class InstanceNormGradientOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;

  template <class... Args>
  explicit InstanceNormGradientOp(Args&&... args)
      : Operator<Context>(std::forward<Args>(args)...),
        OP_SINGLE_ARG(float, "epsilon", epsilon_, 1e-5),
        order_(StringToStorageOrder(
            this->template GetSingleArgument<string>("order", "NCHW"))) {
    CAFFE_ENFORCE_GE(epsilon_, 0, "Must pass a nonnegative epsilon.");
    CAFFE_ENFORCE_NE(
        order_,
        StorageOrder::UNKNOWN,
        "order should be either \"NCHW\" or \"NHWC\".");
  }

  bool RunOnDevice() {
    const auto& X = Input(INPUT);
    const auto& gamma = Input(SCALE);
    const auto& dY = Input(OUTPUT_GRAD);
    const int ndim = X.dim();
    const int64_t N = X.dim(0);
    const int64_t C = order_ == StorageOrder::NCHW ? X.dim(1) : X.dim(ndim - 1);
    const int64_t HxW = X.numel() / (N * C);
    CAFFE_ENFORCE_EQ(gamma.numel(), C);
    const T* dY_data = dY.template data<T>();
    const T* X_data = X.template data<T>();
    const T* gamma_data = gamma.template data<T>();
    const T* mean_data = nullptr;
    const T* rstd_data = nullptr;
    CAFFE_ENFORCE_GE(InputSize(), 4);
    CAFFE_ENFORCE_LE(InputSize(), 6);
    if (InputSize() == 6) {
      const auto& mean = Input(MEAN);
      const auto& rstd = Input(RSTD);
      mean_data = mean.template data<T>();
      rstd_data = rstd.template data<T>();
    } else {
      ReinitializeTensor(
          &mean_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
      ReinitializeTensor(
          &rstd_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
      ComputeMoments(
          N,
          C,
          HxW,
          X_data,
          mean_.template mutable_data<T>(),
          rstd_.template mutable_data<T>());
      mean_data = mean_.template data<T>();
      rstd_data = rstd_.template data<T>();
    }

    auto* dX = Output(INPUT_GRAD, X.sizes(), at::dtype<T>());
    auto* dgamma = Output(SCALE_GRAD, gamma.sizes(), at::dtype<T>());
    auto* dbeta = Output(BIAS_GRAD, gamma.sizes(), at::dtype<T>());
    T* dX_data = dX->template mutable_data<T>();
    T* dgamma_data = dgamma->template mutable_data<T>();
    T* dbeta_data = dbeta->template mutable_data<T>();

    switch (order_) {
      case StorageOrder::NCHW: {
        return RunOnDeviceWithOrderNCHW(
            N,
            C,
            HxW,
            dY_data,
            X_data,
            mean_data,
            rstd_data,
            gamma_data,
            dX_data,
            dgamma_data,
            dbeta_data);
      }
      case StorageOrder::NHWC: {
        return RunOnDeviceWithOrderNHWC(
            N,
            C,
            HxW,
            dY_data,
            X_data,
            mean_data,
            rstd_data,
            gamma_data,
            dX_data,
            dgamma_data,
            dbeta_data);
      }
      default: {
        CAFFE_THROW("Unknown storage order: ", order_);
      }
    }
  }

 private:
  void ComputeMoments(
      int64_t N,
      int64_t C,
      int64_t HxW,
      const T* X,
      T* mean,
      T* rstd);

  bool RunOnDeviceWithOrderNCHW(
      int64_t N,
      int64_t C,
      int64_t HxW,
      const T* dY,
      const T* X,
      const T* mean,
      const T* rstd,
      const T* gamma,
      T* dX,
      T* dgamma,
      T* dbeta);

  bool RunOnDeviceWithOrderNHWC(
      int64_t N,
      int64_t C,
      int64_t HxW,
      const T* dY,
      const T* X,
      const T* mean,
      const T* rstd,
      const T* gamma,
      T* dX,
      T* dgamma,
      T* dbeta);

  const float epsilon_;
  const StorageOrder order_;

  Tensor mean_;
  Tensor rstd_;
  Tensor ds_;
  Tensor db_;
  Tensor c1_;
  Tensor c2_;
  Tensor c3_;
  Tensor ones_;

  INPUT_TAGS(INPUT, SCALE, BIAS, OUTPUT_GRAD, MEAN, RSTD);
  OUTPUT_TAGS(INPUT_GRAD, SCALE_GRAD, BIAS_GRAD);
};

} // namespace caffe2

#endif // CAFFE2_OPERATORS_INSTANCE_NORM_OP_H_