#ifndef CAFFE2_OPERATORS_IM2COL_OP_H_
#define CAFFE2_OPERATORS_IM2COL_OP_H_
#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
template <typename T, class Context>
class Im2ColOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit Im2ColOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
pad_(this->template GetSingleArgument<int>("pad", 0)),
kernel_h_(this->template GetSingleArgument<int>(
"kernel_h",
this->template GetSingleArgument<int>("kernel", 0))),
kernel_w_(this->template GetSingleArgument<int>(
"kernel_w",
this->template GetSingleArgument<int>("kernel", 0))),
dilation_h_(this->template GetSingleArgument<int>(
"dilation_h",
this->template GetSingleArgument<int>("dilation", 1))),
dilation_w_(this->template GetSingleArgument<int>(
"dilation_w",
this->template GetSingleArgument<int>("dilation", 1))),
stride_h_(this->template GetSingleArgument<int>(
"stride_h",
this->template GetSingleArgument<int>("stride", 1))),
stride_w_(this->template GetSingleArgument<int>(
"stride_w",
this->template GetSingleArgument<int>("stride", 1))),
order_(StringToStorageOrder(
this->template GetSingleArgument<string>("order", "NCHW"))) {
CAFFE_ENFORCE(kernel_h_ > 0);
CAFFE_ENFORCE(kernel_w_ > 0);
CAFFE_ENFORCE(dilation_h_ > 0);
CAFFE_ENFORCE(dilation_w_ > 0);
CAFFE_ENFORCE(stride_h_ > 0);
CAFFE_ENFORCE(stride_w_ > 0);
CAFFE_ENFORCE(pad_ >= 0);
}
bool RunOnDevice() override {
auto& X = Input(0);
CAFFE_ENFORCE(4 == X.dim());
int N = 0, C = 0, H = 0, W = 0;
switch (order_) {
case StorageOrder::NCHW:
N = X.dim32(0);
C = X.dim32(1);
H = X.dim32(2);
W = X.dim32(3);
break;
case StorageOrder::NHWC:
N = X.dim32(0);
H = X.dim32(1);
W = X.dim32(2);
C = X.dim32(3);
break;
default:
CAFFE_THROW("Unknown storage order: ", order_);
}
const int dkernel_h = dilation_h_ * (kernel_h_ - 1) + 1;
const int dkernel_w = dilation_w_ * (kernel_w_ - 1) + 1;
CAFFE_ENFORCE(H >= dkernel_h);
CAFFE_ENFORCE(W >= dkernel_w);
const int out_h = (H + 2 * pad_ - dkernel_h) / stride_h_ + 1;
const int out_w = (W + 2 * pad_ - dkernel_w) / stride_w_ + 1;
switch (order_) {
case StorageOrder::NCHW: {
auto* Y = Output(
0,
std::vector<int64_t>{N, C * kernel_h_ * kernel_w_, out_h, out_w},
at::dtype<T>());
const size_t dx = X.numel() / N;
const size_t dy = Y->numel() / N;
for (int n = 0; n < N; ++n) {
const auto* xdata = X.template data<T>() + (n * dx);
auto* ydata = Y->template mutable_data<T>() + (n * dy);
math::Im2Col<T, Context, StorageOrder::NCHW>(
C,
H,
W,
kernel_h_,
kernel_w_,
dilation_h_,
dilation_w_,
pad_,
pad_,
pad_,
pad_,
stride_h_,
stride_w_,
xdata,
ydata,
&context_);
}
}; break;
case StorageOrder::NHWC: {
auto* Y = Output(
0,
std::vector<int64_t>{N, out_h, out_w, kernel_h_ * kernel_w_ * C},
at::dtype<T>());
const size_t dx = X.numel() / N;
const size_t dy = Y->numel() / N;
for (int n = 0; n < N; ++n) {
const auto* xdata = X.template data<T>() + (n * dx);
auto* ydata = Y->template mutable_data<T>() + (n * dy);
math::Im2Col<T, Context, StorageOrder::NHWC>(
C,
H,
W,
kernel_h_,
kernel_w_,
dilation_h_,
dilation_w_,
pad_,
pad_,
pad_,
pad_,
stride_h_,
stride_w_,
xdata,
ydata,
&context_);
}
}; break;
default:
CAFFE_THROW("Unknown storage order: ", order_);
}
return true;
}
private:
int pad_;
int kernel_h_;
int kernel_w_;
int dilation_h_;
int dilation_w_;
int stride_h_;
int stride_w_;
StorageOrder order_;
};
template <typename T, class Context>
class Col2ImOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit Col2ImOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
pad_(this->template GetSingleArgument<int>("pad", 0)),
kernel_h_(this->template GetSingleArgument<int>(
"kernel_h",
this->template GetSingleArgument<int>("kernel", 0))),
kernel_w_(this->template GetSingleArgument<int>(
"kernel_w",
this->template GetSingleArgument<int>("kernel", 0))),
dilation_h_(this->template GetSingleArgument<int>(
"dilation_h",
this->template GetSingleArgument<int>("dilation", 1))),
dilation_w_(this->template GetSingleArgument<int>(
"dilation_w",
this->template GetSingleArgument<int>("dilation", 1))),
stride_h_(this->template GetSingleArgument<int>(
"stride_h",
this->template GetSingleArgument<int>("stride", 1))),
stride_w_(this->template GetSingleArgument<int>(
"stride_w",
this->template GetSingleArgument<int>("stride", 1))),
order_(StringToStorageOrder(
this->template GetSingleArgument<string>("order", "NCHW"))) {
CAFFE_ENFORCE(kernel_h_ > 0);
CAFFE_ENFORCE(kernel_w_ > 0);
CAFFE_ENFORCE(dilation_h_ > 0);
CAFFE_ENFORCE(dilation_w_ > 0);
CAFFE_ENFORCE(stride_h_ > 0);
CAFFE_ENFORCE(stride_w_ > 0);
CAFFE_ENFORCE(pad_ >= 0);
}
bool RunOnDevice() override {
auto& X = Input(0);
auto& Z = Input(1);
auto* Y = Output(0, Z.sizes(), at::dtype<T>());
CAFFE_ENFORCE(4 == Y->dim());
int N = 0, C = 0, H = 0, W = 0;
switch (order_) {
case StorageOrder::NCHW:
N = Y->dim32(0);
C = Y->dim32(1);
H = Y->dim32(2);
W = Y->dim32(3);
break;
case StorageOrder::NHWC:
N = Y->dim32(0);
H = Y->dim32(1);
W = Y->dim32(2);
C = Y->dim32(3);
break;
default:
CAFFE_THROW("Unknown storage order: ", order_);
}
const int dkernel_h = dilation_h_ * (kernel_h_ - 1) + 1;
const int dkernel_w = dilation_w_ * (kernel_w_ - 1) + 1;
CAFFE_ENFORCE(H >= dkernel_h);
CAFFE_ENFORCE(W >= dkernel_w);
const int out_h = (H + 2 * pad_ - dkernel_h) / stride_h_ + 1;
const int out_w = (W + 2 * pad_ - dkernel_w) / stride_w_ + 1;
CAFFE_ENFORCE(X.numel() == N * kernel_h_ * kernel_w_ * C * out_h * out_w);
const size_t dx = X.numel() / N;
const size_t dy = Y->numel() / N;
// could template-specialize this, but it's test code...
switch (order_) {
case StorageOrder::NCHW: {
for (int n = 0; n < N; ++n) {
const auto* xdata = X.template data<T>() + (n * dx);
auto* ydata = Y->template mutable_data<T>() + (n * dy);
math::Col2Im<T, Context, StorageOrder::NCHW>(
C,
H,
W,
kernel_h_,
kernel_w_,
dilation_h_,
dilation_w_,
pad_,
pad_,
pad_,
pad_,
stride_h_,
stride_w_,
xdata,
ydata,
&context_);
}
}; break;
case StorageOrder::NHWC: {
for (int n = 0; n < N; ++n) {
const auto* xdata = X.template data<T>() + (n * dx);
auto* ydata = Y->template mutable_data<T>() + (n * dy);
math::Col2Im<T, Context, StorageOrder::NHWC>(
C,
H,
W,
kernel_h_,
kernel_w_,
dilation_h_,
dilation_w_,
pad_,
pad_,
pad_,
pad_,
stride_h_,
stride_w_,
xdata,
ydata,
&context_);
}
}; break;
default:
CAFFE_THROW("Unknown storage order: ", order_);
}
return true;
}
private:
int pad_;
int kernel_h_;
int kernel_w_;
int dilation_h_;
int dilation_w_;
int stride_h_;
int stride_w_;
StorageOrder order_;
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_IM2COL_OP_H_