#ifndef CAFFE2_OPERATORS_POOL_OP_H_
#define CAFFE2_OPERATORS_POOL_OP_H_
#include <vector>
#include "caffe2/core/common_omp.h"
#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/operators/conv_pool_op_base.h"
namespace caffe2 {
template <typename T, class Context, class Functor>
class PoolOp final : public ConvPoolOpBase<Context> {
public:
USE_CONV_POOL_BASE_FUNCTIONS(Context);
template <class... Args>
explicit PoolOp(Args&&... args)
: ConvPoolOpBase<Context>(std::forward<Args>(args)...), functor_(*this) {
const int kernel_size = kernel_.size();
for (int i = 0; i < kernel_size; ++i) {
CAFFE_ENFORCE_EQ(
dilation_[i], 1, "Pooling op does not support dilation right now.");
}
if (!global_pooling_) {
for (int i = 0; i < kernel_size; ++i) {
CAFFE_ENFORCE(
pads_[i] < kernel_[i] && pads_[i + kernel_size] < kernel_[i],
"Pad should be smaller than kernel.");
}
}
}
~PoolOp() = default;
bool RunOnDeviceWithOrderNCHW() override {
const auto& X = Input(0);
auto* Y = Output(0);
const int N = X.dim32(0);
const int C = X.dim32(1);
ConvPoolOpBase<Context>::SetOutputSize(X, Y, C);
const T* X_data = X.template data<T>();
T* Y_data = Y->template mutable_data<T>();
if (N == 0) {
return true;
}
if (global_pooling_) {
const int HxW = X.numel() / (N * C);
return functor_.template GlobalPoolingForward<T, StorageOrder::NCHW>(
N, C, HxW, X_data, Y_data, &context_);
}
const std::vector<int> X_HW_dims = GetDims(X);
const std::vector<int> Y_HW_dims = GetDims(*Y);
return functor_.template Forward<T, StorageOrder::NCHW>(
N,
C,
X_HW_dims,
Y_HW_dims,
kernel_,
dilation_,
stride_,
pads_,
X.template data<T>(),
Y->template mutable_data<T>(),
&context_);
}
bool RunOnDeviceWithOrderNHWC() override {
const auto& X = Input(0);
auto* Y = Output(0);
const int ndim = X.dim();
const int N = X.dim32(0);
const int C = X.dim32(ndim - 1);
ConvPoolOpBase<Context>::SetOutputSize(X, Y, C);
const T* X_data = X.template data<T>();
T* Y_data = Y->template mutable_data<T>();
if (N == 0) {
return true;
}
if (global_pooling_) {
const int HxW = X.numel() / (N * C);
return functor_.template GlobalPoolingForward<T, StorageOrder::NHWC>(
N, C, HxW, X_data, Y_data, &context_);
}
const std::vector<int> X_HW_dims = GetDims(X);
const std::vector<int> Y_HW_dims = GetDims(*Y);
return functor_.template Forward<T, StorageOrder::NHWC>(
N,
C,
X_HW_dims,
Y_HW_dims,
kernel_,
dilation_,
stride_,
pads_,
X.template data<T>(),
Y->template mutable_data<T>(),
&context_);
}
private:
const Functor functor_;
};
template <typename T, class Context, class Functor>
class PoolGradientOp final : public ConvPoolOpBase<Context> {
public:
USE_CONV_POOL_BASE_FUNCTIONS(Context);
template <class... Args>
explicit PoolGradientOp(Args&&... args)
: ConvPoolOpBase<Context>(std::forward<Args>(args)...), functor_(*this) {}
~PoolGradientOp() = default;
bool RunOnDeviceWithOrderNCHW() override {
const auto& X = Input(0);
const auto& Y = Input(1);
const auto& dY = Input(2);
auto* dX = Output(0, X.sizes(), at::dtype<T>());
const int N = X.dim32(0);
const int C = X.dim32(1);
const std::vector<int> X_HW_dims = GetDims(X);
const std::vector<int> Y_HW_dims = GetDims(Y);
ConvPoolOpBase<Context>::ComputePads(X_HW_dims);
const T* dY_data = dY.template data<T>();
const T* X_data = X.template data<T>();
const T* Y_data = Y.template data<T>();
T* dX_data = dX->template mutable_data<T>();
if (N == 0) {
return true;
}
if (global_pooling_) {
const int HxW = X.numel() / (N * C);
return functor_.template GlobalPoolingBackward<T, StorageOrder::NCHW>(
N, C, HxW, dY_data, X_data, Y_data, dX_data, &context_);
}
return functor_.template Backward<T, StorageOrder::NCHW>(
N,
C,
X_HW_dims,
Y_HW_dims,
kernel_,
dilation_,
stride_,
pads_,
dY_data,
X_data,
Y_data,
dX_data,
&context_);
}
bool RunOnDeviceWithOrderNHWC() override {
const auto& X = Input(0);
const auto& Y = Input(1);
const auto& dY = Input(2);
auto* dX = Output(0, X.sizes(), at::dtype<T>());
const int ndim = X.dim();
const int N = X.dim32(0);
const int C = X.dim32(ndim - 1);
const std::vector<int> X_HW_dims = GetDims(X);
const std::vector<int> Y_HW_dims = GetDims(Y);
ConvPoolOpBase<Context>::ComputePads(X_HW_dims);
const T* dY_data = dY.template data<T>();
const T* X_data = X.template data<T>();
const T* Y_data = Y.template data<T>();
T* dX_data = dX->template mutable_data<T>();
if (N == 0) {
return true;
}
if (global_pooling_) {
const int HxW = X.numel() / (N * C);
return functor_.template GlobalPoolingBackward<T, StorageOrder::NHWC>(
N, C, HxW, dY_data, X_data, Y_data, dX_data, &context_);
}
return functor_.template Backward<T, StorageOrder::NHWC>(
N,
C,
X_HW_dims,
Y_HW_dims,
kernel_,
dilation_,
stride_,
pads_,
dY_data,
X_data,
Y_data,
dX_data,
&context_);
}
private:
const Functor functor_;
};
template <class Context>
struct AveragePoolFunctor {
explicit AveragePoolFunctor(const OperatorBase& op)
: count_include_pad(
op.template GetSingleArgument<bool>("count_include_pad", false)) {}
template <typename T, StorageOrder kOrder>
bool GlobalPoolingForward(
int N,
int C,
int HxW,
const T* X,
T* Y,
Context* context) const;
template <typename T, StorageOrder kOrder>
bool Forward(
int N,
int C,
const std::vector<int>& X_dims,
const std::vector<int>& Y_dims,
const std::vector<int>& kernel,
const std::vector<int>& dilation,
const std::vector<int>& stride,
const std::vector<int>& pads,
const T* X,
T* Y,
Context* context) const;
template <typename T, StorageOrder kOrder>
bool GlobalPoolingBackward(
int N,
int C,
int HxW,
const T* dY,
const T* X,
const T* Y,
T* dX,
Context* context) const;
template <typename T, StorageOrder kOrder>
bool Backward(
int N,
int C,
const std::vector<int>& X_dims,
const std::vector<int>& Y_dims,
const std::vector<int>& kernel,
const std::vector<int>& dilation,
const std::vector<int>& stride,
const std::vector<int>& pads,
const T* dY,
const T* X,
const T* Y,
T* dX,
Context* context) const;
const bool count_include_pad;
Tensor ones{Context::GetDeviceType()};
};
template <class Context>
struct MaxPoolFunctor {
explicit MaxPoolFunctor(const OperatorBase& /* op */) {}
template <typename T, StorageOrder kOrder>
bool GlobalPoolingForward(
int N,
int C,
int HxW,
const T* X,
T* Y,
Context* context) const;
template <typename T, StorageOrder kOrder>
bool Forward(
int N,
int C,
const std::vector<int>& X_dims,
const std::vector<int>& Y_dims,
const std::vector<int>& kernel,
const std::vector<int>& dilation,
const std::vector<int>& stride,
const std::vector<int>& pads,
const T* X,
T* Y,
Context* context) const;
template <typename T, StorageOrder kOrder>
bool GlobalPoolingBackward(
int N,
int C,
int HxW,
const T* dY,
const T* X,
const T* Y,
T* dX,
Context* context) const;
template <typename T, StorageOrder kOrder>
bool Backward(
int N,
int C,
const std::vector<int>& X_dims,
const std::vector<int>& Y_dims,
const std::vector<int>& kernel,
const std::vector<int>& dilation,
const std::vector<int>& stride,
const std::vector<int>& pads,
const T* dY,
const T* X,
const T* Y,
T* dX,
Context* context) const;
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_POOL_OP_H_