Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / caffe2 / operators / pool_op.h

#ifndef CAFFE2_OPERATORS_POOL_OP_H_
#define CAFFE2_OPERATORS_POOL_OP_H_

#include <vector>

#include "caffe2/core/common_omp.h"
#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/operators/conv_pool_op_base.h"

namespace caffe2 {

template <typename T, class Context, class Functor>
class PoolOp final : public ConvPoolOpBase<Context> {
 public:
  USE_CONV_POOL_BASE_FUNCTIONS(Context);

  template <class... Args>
  explicit PoolOp(Args&&... args)
      : ConvPoolOpBase<Context>(std::forward<Args>(args)...), functor_(*this) {
    const int kernel_size = kernel_.size();
    for (int i = 0; i < kernel_size; ++i) {
      CAFFE_ENFORCE_EQ(
          dilation_[i], 1, "Pooling op does not support dilation right now.");
    }
    if (!global_pooling_) {
      for (int i = 0; i < kernel_size; ++i) {
        CAFFE_ENFORCE(
            pads_[i] < kernel_[i] && pads_[i + kernel_size] < kernel_[i],
            "Pad should be smaller than kernel.");
      }
    }
  }

  ~PoolOp() = default;

  bool RunOnDeviceWithOrderNCHW() override {
    const auto& X = Input(0);
    auto* Y = Output(0);
    const int N = X.dim32(0);
    const int C = X.dim32(1);
    ConvPoolOpBase<Context>::SetOutputSize(X, Y, C);
    const T* X_data = X.template data<T>();
    T* Y_data = Y->template mutable_data<T>();
    if (N == 0) {
      return true;
    }
    if (global_pooling_) {
      const int HxW = X.numel() / (N * C);
      return functor_.template GlobalPoolingForward<T, StorageOrder::NCHW>(
          N, C, HxW, X_data, Y_data, &context_);
    }
    const std::vector<int> X_HW_dims = GetDims(X);
    const std::vector<int> Y_HW_dims = GetDims(*Y);
    return functor_.template Forward<T, StorageOrder::NCHW>(
        N,
        C,
        X_HW_dims,
        Y_HW_dims,
        kernel_,
        dilation_,
        stride_,
        pads_,
        X.template data<T>(),
        Y->template mutable_data<T>(),
        &context_);
  }

  bool RunOnDeviceWithOrderNHWC() override {
    const auto& X = Input(0);
    auto* Y = Output(0);
    const int ndim = X.dim();
    const int N = X.dim32(0);
    const int C = X.dim32(ndim - 1);
    ConvPoolOpBase<Context>::SetOutputSize(X, Y, C);
    const T* X_data = X.template data<T>();
    T* Y_data = Y->template mutable_data<T>();
    if (N == 0) {
      return true;
    }
    if (global_pooling_) {
      const int HxW = X.numel() / (N * C);
      return functor_.template GlobalPoolingForward<T, StorageOrder::NHWC>(
          N, C, HxW, X_data, Y_data, &context_);
    }
    const std::vector<int> X_HW_dims = GetDims(X);
    const std::vector<int> Y_HW_dims = GetDims(*Y);
    return functor_.template Forward<T, StorageOrder::NHWC>(
        N,
        C,
        X_HW_dims,
        Y_HW_dims,
        kernel_,
        dilation_,
        stride_,
        pads_,
        X.template data<T>(),
        Y->template mutable_data<T>(),
        &context_);
  }

 private:
  const Functor functor_;
};

template <typename T, class Context, class Functor>
class PoolGradientOp final : public ConvPoolOpBase<Context> {
 public:
  USE_CONV_POOL_BASE_FUNCTIONS(Context);
  template <class... Args>
  explicit PoolGradientOp(Args&&... args)
      : ConvPoolOpBase<Context>(std::forward<Args>(args)...), functor_(*this) {}

  ~PoolGradientOp() = default;

  bool RunOnDeviceWithOrderNCHW() override {
    const auto& X = Input(0);
    const auto& Y = Input(1);
    const auto& dY = Input(2);
    auto* dX = Output(0, X.sizes(), at::dtype<T>());
    const int N = X.dim32(0);
    const int C = X.dim32(1);
    const std::vector<int> X_HW_dims = GetDims(X);
    const std::vector<int> Y_HW_dims = GetDims(Y);
    ConvPoolOpBase<Context>::ComputePads(X_HW_dims);
    const T* dY_data = dY.template data<T>();
    const T* X_data = X.template data<T>();
    const T* Y_data = Y.template data<T>();
    T* dX_data = dX->template mutable_data<T>();
    if (N == 0) {
      return true;
    }
    if (global_pooling_) {
      const int HxW = X.numel() / (N * C);
      return functor_.template GlobalPoolingBackward<T, StorageOrder::NCHW>(
          N, C, HxW, dY_data, X_data, Y_data, dX_data, &context_);
    }
    return functor_.template Backward<T, StorageOrder::NCHW>(
        N,
        C,
        X_HW_dims,
        Y_HW_dims,
        kernel_,
        dilation_,
        stride_,
        pads_,
        dY_data,
        X_data,
        Y_data,
        dX_data,
        &context_);
  }

  bool RunOnDeviceWithOrderNHWC() override {
    const auto& X = Input(0);
    const auto& Y = Input(1);
    const auto& dY = Input(2);
    auto* dX = Output(0, X.sizes(), at::dtype<T>());
    const int ndim = X.dim();
    const int N = X.dim32(0);
    const int C = X.dim32(ndim - 1);
    const std::vector<int> X_HW_dims = GetDims(X);
    const std::vector<int> Y_HW_dims = GetDims(Y);
    ConvPoolOpBase<Context>::ComputePads(X_HW_dims);
    const T* dY_data = dY.template data<T>();
    const T* X_data = X.template data<T>();
    const T* Y_data = Y.template data<T>();
    T* dX_data = dX->template mutable_data<T>();
    if (N == 0) {
      return true;
    }
    if (global_pooling_) {
      const int HxW = X.numel() / (N * C);
      return functor_.template GlobalPoolingBackward<T, StorageOrder::NHWC>(
          N, C, HxW, dY_data, X_data, Y_data, dX_data, &context_);
    }
    return functor_.template Backward<T, StorageOrder::NHWC>(
        N,
        C,
        X_HW_dims,
        Y_HW_dims,
        kernel_,
        dilation_,
        stride_,
        pads_,
        dY_data,
        X_data,
        Y_data,
        dX_data,
        &context_);
  }

 private:
  const Functor functor_;
};

template <class Context>
struct AveragePoolFunctor {
  explicit AveragePoolFunctor(const OperatorBase& op)
      : count_include_pad(
            op.template GetSingleArgument<bool>("count_include_pad", false)) {}

  template <typename T, StorageOrder kOrder>
  bool GlobalPoolingForward(
      int N,
      int C,
      int HxW,
      const T* X,
      T* Y,
      Context* context) const;

  template <typename T, StorageOrder kOrder>
  bool Forward(
      int N,
      int C,
      const std::vector<int>& X_dims,
      const std::vector<int>& Y_dims,
      const std::vector<int>& kernel,
      const std::vector<int>& dilation,
      const std::vector<int>& stride,
      const std::vector<int>& pads,
      const T* X,
      T* Y,
      Context* context) const;

  template <typename T, StorageOrder kOrder>
  bool GlobalPoolingBackward(
      int N,
      int C,
      int HxW,
      const T* dY,
      const T* X,
      const T* Y,
      T* dX,
      Context* context) const;

  template <typename T, StorageOrder kOrder>
  bool Backward(
      int N,
      int C,
      const std::vector<int>& X_dims,
      const std::vector<int>& Y_dims,
      const std::vector<int>& kernel,
      const std::vector<int>& dilation,
      const std::vector<int>& stride,
      const std::vector<int>& pads,
      const T* dY,
      const T* X,
      const T* Y,
      T* dX,
      Context* context) const;

  const bool count_include_pad;
  Tensor ones{Context::GetDeviceType()};
};

template <class Context>
struct MaxPoolFunctor {
  explicit MaxPoolFunctor(const OperatorBase& /* op */) {}

  template <typename T, StorageOrder kOrder>
  bool GlobalPoolingForward(
      int N,
      int C,
      int HxW,
      const T* X,
      T* Y,
      Context* context) const;

  template <typename T, StorageOrder kOrder>
  bool Forward(
      int N,
      int C,
      const std::vector<int>& X_dims,
      const std::vector<int>& Y_dims,
      const std::vector<int>& kernel,
      const std::vector<int>& dilation,
      const std::vector<int>& stride,
      const std::vector<int>& pads,
      const T* X,
      T* Y,
      Context* context) const;

  template <typename T, StorageOrder kOrder>
  bool GlobalPoolingBackward(
      int N,
      int C,
      int HxW,
      const T* dY,
      const T* X,
      const T* Y,
      T* dX,
      Context* context) const;

  template <typename T, StorageOrder kOrder>
  bool Backward(
      int N,
      int C,
      const std::vector<int>& X_dims,
      const std::vector<int>& Y_dims,
      const std::vector<int>& kernel,
      const std::vector<int>& dilation,
      const std::vector<int>& stride,
      const std::vector<int>& pads,
      const T* dY,
      const T* X,
      const T* Y,
      T* dX,
      Context* context) const;
};

} // namespace caffe2

#endif // CAFFE2_OPERATORS_POOL_OP_H_