// locally_connected_impl.h is the templated implementation of the
// locally_connected.h file.
#ifndef CAFFE2_OPERATORS_LOCALLY_CONNECTED_OP_IMPL_H_
#define CAFFE2_OPERATORS_LOCALLY_CONNECTED_OP_IMPL_H_
#include <vector>
#include "caffe2/core/context.h"
#include "caffe2/core/flags.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/operators/conv_pool_op_base.h"
#include "caffe2/operators/locally_connected_op.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
template <typename T, class Context>
bool LocallyConnectedOp<T, Context>::RunOnDeviceWithOrderNCHW() {
const auto& X = Input(INPUT);
const auto& filter = Input(FILTER);
auto* Y = Output(0);
const int image_ndim = X.dim() - 2;
CAFFE_ENFORCE_EQ(X.dim() + image_ndim, filter.dim());
lc_op_util::ShapeParams shape;
shape.N = X.dim32(0);
shape.C = X.dim32(1);
shape.M = filter.dim32(image_ndim);
CAFFE_ENFORCE(
shape.C == filter.dim32(image_ndim + 1) * group_,
"Locally Connected op: input channels does not match: "
"# of input channels ",
shape.C,
" is not equal to kernel channels * group:",
filter.dim32(image_ndim + 1),
"*",
group_);
CAFFE_ENFORCE_EQ(
shape.M % group_,
0,
"The number of output channels is not divisible by group.");
ConvPoolOpBase<Context>::SetOutputSize(X, Y, shape.M);
shape.input_image_size = GetDimsSize(X);
shape.output_image_size = GetDimsSize(*Y);
const std::vector<int> output_image_dims = GetDims(*Y);
for (int i = 0; i < image_ndim; ++i) {
CAFFE_ENFORCE_EQ(output_image_dims[i], filter.dim32(i));
}
int kernel_dims_size = 1;
for (std::size_t i = 0; i < kernel_.size(); ++i) {
CAFFE_ENFORCE_EQ(filter.dim32(i + image_ndim + 2), kernel_[i]);
kernel_dims_size *= kernel_[i];
}
shape.X_dims.assign(X.sizes().cbegin() + 1, X.sizes().cend());
shape.kernel_size = shape.C / group_ * kernel_dims_size;
lc_op_util::SetColumnBufferShape(
shape.N,
shape.kernel_size,
shape.output_image_size,
output_image_dims,
order_,
&shape.column_slice_dims,
&shape.column_dims,
&shape.column_transposed_dims,
&shape.column_axes);
lc_op_util::SetYBufferShape(
shape.N,
shape.M,
shape.output_image_size,
order_,
&shape.Y_dims,
&shape.Y_transposed_dims,
&shape.Y_axes);
const T* X_data = X.template data<T>();
const T* filter_data = filter.template data<T>();
const T* bias_data = nullptr;
if (InputSize() == 3) {
const auto& bias = Input(BIAS);
CAFFE_ENFORCE_EQ(bias.dim(), image_ndim + 1);
for (int i = 0; i < image_ndim; ++i) {
CAFFE_ENFORCE_EQ(bias.dim32(i), output_image_dims[i]);
}
CAFFE_ENFORCE_EQ(bias.dim32(image_ndim), shape.M);
bias_data = bias.template data<T>();
ConvPoolOpBase<Context>::template SetBiasMultiplier<T>(
shape.N, &bias_multiplier_);
}
T* Y_data = Y->template mutable_data<T>();
RunOnDeviceWithOrderNCHWImpl(
shape,
X_data,
filter_data,
bias_data,
Y_data,
&column_buffer_,
&column_transposed_buffer_,
&Y_transposed_buffer_);
return true;
}
template <typename T, class Context>
bool LocallyConnectedOp<T, Context>::RunOnDeviceWithOrderNHWC() {
const auto& X = Input(INPUT);
const auto& filter = Input(FILTER);
auto* Y = Output(0);
CAFFE_ENFORCE_EQ(
kernel_.size(),
2,
"Only 2d locally connected op is supported for NHWC storage type.");
const int image_ndim = X.dim() - 2;
CAFFE_ENFORCE_EQ(X.dim() + image_ndim, filter.dim());
lc_op_util::ShapeParams shape;
shape.N = X.dim32(0);
shape.C = X.dim32(3);
shape.X_dims = {X.dim32(1), X.dim32(2), X.dim32(3)};
shape.M = filter.dim32(image_ndim);
CAFFE_ENFORCE_EQ(filter.dim32(image_ndim + 1), kernel_h());
CAFFE_ENFORCE_EQ(filter.dim32(image_ndim + 2), kernel_w());
CAFFE_ENFORCE_EQ(filter.dim32(image_ndim + 3), shape.C);
ConvPoolOpBase<Context>::SetOutputSize(X, Y, shape.M);
shape.input_image_size = GetDimsSize(X);
shape.output_image_size = GetDimsSize(*Y);
const std::vector<int> output_image_dims = GetDims(*Y);
for (int i = 0; i < image_ndim; ++i) {
CAFFE_ENFORCE_EQ(output_image_dims[i], filter.dim32(i));
}
shape.kernel_size = kernel_h() * kernel_w() * shape.C;
lc_op_util::SetColumnBufferShape(
shape.N,
shape.kernel_size,
shape.output_image_size,
output_image_dims,
order_,
&shape.column_slice_dims,
&shape.column_dims,
&shape.column_transposed_dims,
&shape.column_axes);
lc_op_util::SetYBufferShape(
shape.N,
shape.M,
shape.output_image_size,
order_,
&shape.Y_dims,
&shape.Y_transposed_dims,
&shape.Y_axes);
const T* X_data = X.template data<T>();
const T* filter_data = filter.template data<T>();
const T* bias_data = nullptr;
if (InputSize() == 3) {
const auto& bias = Input(BIAS);
CAFFE_ENFORCE_EQ(bias.dim(), image_ndim + 1);
for (int i = 0; i < image_ndim; ++i) {
CAFFE_ENFORCE_EQ(bias.dim32(i), output_image_dims[i]);
}
CAFFE_ENFORCE_EQ(bias.dim32(image_ndim), shape.M);
bias_data = bias.template data<T>();
ConvPoolOpBase<Context>::template SetBiasMultiplier<T>(
shape.N, &bias_multiplier_);
}
T* Y_data = Y->template mutable_data<T>();
RunOnDeviceWithOrderNHWCImpl(
shape,
X_data,
filter_data,
bias_data,
Y_data,
&column_buffer_,
&column_transposed_buffer_,
&Y_transposed_buffer_);
return true;
}
template <typename T, class Context>
void LocallyConnectedOp<T, Context>::RunOnDeviceWithOrderNCHWImpl(
const lc_op_util::ShapeParams& shape,
const T* X_data,
const T* filter_data,
const T* bias_data,
T* Y_data,
Tensor* column_buffer,
Tensor* column_transposed_buffer,
Tensor* Y_transposed_buffer) {
const int input_stride = shape.C / group_ * shape.input_image_size;
const int column_stride = shape.kernel_size * shape.output_image_size;
column_buffer->Resize(shape.column_dims);
column_transposed_buffer->Resize(shape.column_transposed_dims);
Y_transposed_buffer->Resize(shape.Y_transposed_dims);
T* column_buffer_data = column_buffer->template mutable_data<T>();
T* Y_transposed_buffer_data = Y_transposed_buffer->template mutable_data<T>();
for (int image_id = 0; image_id < shape.N; ++image_id) {
for (int group_id = 0; group_id < group_; ++group_id) {
if (kernel_.size() == 2) {
math::Im2Col<T, Context, StorageOrder::NCHW>(
shape.C / group_,
shape.X_dims[1],
shape.X_dims[2],
kernel_h(),
kernel_w(),
dilation_h(),
dilation_w(),
pad_t(),
pad_l(),
pad_b(),
pad_r(),
stride_h(),
stride_w(),
X_data + group_id * input_stride,
column_buffer_data + group_id * column_stride,
&context_);
} else {
math::Im2ColNd<T, Context, StorageOrder::NCHW>(
kernel_.size(),
shape.C * shape.input_image_size,
column_stride,
shape.X_dims.data(),
shape.column_slice_dims.data(),
kernel_.data(),
stride_.data(),
dilation_.data(),
pads_.data(),
X_data + group_id * input_stride,
column_buffer_data + group_id * column_stride,
&context_);
}
}
X_data += input_stride * group_;
column_buffer_data += column_stride * group_;
}
math::Transpose(
shape.column_dims.size(),
shape.column_dims.data(),
shape.column_axes.data(),
column_buffer->template data<T>(),
column_transposed_buffer->template mutable_data<T>(),
&context_);
math::GemmStridedBatched(
CblasNoTrans,
CblasNoTrans,
shape.output_image_size * group_,
shape.M / group_,
shape.N,
shape.kernel_size,
1.0f,
filter_data,
shape.M / group_ * shape.kernel_size,
column_transposed_buffer->template data<T>(),
shape.kernel_size * shape.N,
0.0f,
Y_transposed_buffer_data,
shape.M / group_ * shape.N,
&context_);
if (bias_data != nullptr) {
math::Gemm<T, Context>(
CblasNoTrans,
CblasNoTrans,
shape.output_image_size * shape.M,
shape.N,
1,
1.0,
bias_data,
bias_multiplier_.template data<T>(),
1.0,
Y_transposed_buffer_data,
&context_);
}
math::Transpose(
shape.Y_transposed_dims.size(),
shape.Y_transposed_dims.data(),
shape.Y_axes.data(),
Y_transposed_buffer_data,
Y_data,
&context_);
}
template <typename T, class Context>
void LocallyConnectedOp<T, Context>::RunOnDeviceWithOrderNHWCImpl(
const lc_op_util::ShapeParams& shape,
const T* X_data,
const T* filter_data,
const T* bias_data,
T* Y_data,
Tensor* column_buffer,
Tensor* column_transposed_buffer,
Tensor* Y_transposed_buffer) {
const int input_stride = shape.C * shape.input_image_size;
const int column_stride = shape.kernel_size * shape.output_image_size;
column_buffer->Resize(shape.column_dims);
column_transposed_buffer->Resize(shape.column_transposed_dims);
Y_transposed_buffer->Resize(shape.Y_transposed_dims);
T* column_buffer_data = column_buffer->template mutable_data<T>();
T* Y_transposed_buffer_data = Y_transposed_buffer->template mutable_data<T>();
for (int image_id = 0; image_id < shape.N; ++image_id) {
math::Im2Col<T, Context, StorageOrder::NHWC>(
shape.C,
shape.X_dims[0],
shape.X_dims[1],
kernel_h(),
kernel_w(),
dilation_h(),
dilation_w(),
pad_t(),
pad_l(),
pad_b(),
pad_r(),
stride_h(),
stride_w(),
X_data + image_id * input_stride,
column_buffer_data + image_id * column_stride,
&context_);
}
math::Transpose(
shape.column_dims.size(),
shape.column_dims.data(),
shape.column_axes.data(),
column_buffer->template data<T>(),
column_transposed_buffer->template mutable_data<T>(),
&context_);
math::GemmStridedBatched(
CblasNoTrans,
CblasTrans,
shape.output_image_size,
shape.N,
shape.M,
shape.kernel_size,
1.0f,
column_transposed_buffer->template data<T>(),
shape.N * shape.kernel_size,
filter_data,
shape.kernel_size * shape.M,
0.0f,
Y_transposed_buffer_data,
shape.N * shape.M,
Loading ...