#ifndef CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_
#define CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_
#include <c10/util/Optional.h>
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/conversions.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
// This is Caffe's InnerProductOp, with a name that fits its purpose better.
template <
class Context,
class Engine = DefaultEngine,
bool TransposeWeight = true>
class FullyConnectedOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit FullyConnectedOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
axis_(this->template GetSingleArgument<int32_t>("axis", 1)),
axis_w_(this->template GetSingleArgument<int32_t>("axis_w", 1)),
float16_compute_(
this->template GetSingleArgument<bool>("float16_compute", false)) {}
~FullyConnectedOp() {}
template <
typename T_X,
typename T_W,
typename T_B,
typename T_Y,
typename MATH>
bool DoRunWithType() {
const auto& X = Input(0);
const auto& W = Input(1);
const auto& b = Input(2);
CAFFE_ENFORCE(b.dim() == 1, b.dim());
// batch size
const auto canonical_axis = X.canonical_axis_index(axis_);
const auto M = X.size_to_dim(canonical_axis);
const auto K = X.size_from_dim(canonical_axis);
const auto canonical_axis_w = W.canonical_axis_index(axis_w_);
const int N = TransposeWeight ? W.size_to_dim(canonical_axis_w)
: W.size_from_dim(canonical_axis_w);
auto dimErrorString = [&]() {
return c10::str(
"Dimension mismatch: ",
"X: ",
X.sizes(),
", W: ",
W.sizes(),
", b: ",
b.sizes(),
", axis: ",
axis_,
", M: ",
M,
", N: ",
N,
", K: ",
K);
};
// Error checking
CAFFE_ENFORCE(M == X.numel() / K, dimErrorString());
CAFFE_ENFORCE(K == W.numel() / N, dimErrorString());
CAFFE_ENFORCE(N == b.dim32(0), dimErrorString());
CAFFE_ENFORCE(N == b.numel(), dimErrorString());
Y_shape_cache_ = X.sizes().vec();
// This is an invariant of canonical_axis, so we can DCHECK.
DCHECK_LE(canonical_axis + 1, Y_shape_cache_.size());
Y_shape_cache_.resize(canonical_axis + 1);
Y_shape_cache_[canonical_axis] = N;
auto* Y = Output(0, Y_shape_cache_, at::dtype<T_Y>());
CAFFE_ENFORCE(M * N == Y->numel(), dimErrorString());
if (X.numel() == 0) {
// skip the rest of the computation if X is empty
Y->template mutable_data<T_Y>();
return true;
}
// default to FLOAT as math.h does.
TensorProto::DataType math_type = TensorProto_DataType_FLOAT;
if (fp16_type<MATH>()) {
math_type = TensorProto_DataType_FLOAT16;
}
// W * x
math::Gemm<T_X, Context, Engine>(
CblasNoTrans,
TransposeWeight ? CblasTrans : CblasNoTrans,
M,
N,
K,
1,
X.template data<T_X>(),
W.template data<T_W>(),
0,
Y->template mutable_data<T_Y>(),
&context_,
math_type);
// Add bias term
if (!bias_multiplier_.has_value()) {
bias_multiplier_ =
caffe2::empty({M}, at::dtype<T_B>().device(Context::GetDeviceType()));
math::Set<T_B, Context>(
M,
convert::To<float, T_B>(1),
bias_multiplier_->template mutable_data<T_B>(),
&context_);
} else if (bias_multiplier_->numel() != M) {
bias_multiplier_->Resize(M);
math::Set<T_B, Context>(
M,
convert::To<float, T_B>(1),
bias_multiplier_->template mutable_data<T_B>(),
&context_);
}
math::Gemm<T_B, Context, Engine>(
CblasNoTrans,
CblasNoTrans,
M,
N,
1,
1,
bias_multiplier_->template data<T_B>(),
b.template data<T_B>(),
1,
Y->template mutable_data<T_Y>(),
&context_,
math_type);
return true;
}
bool RunOnDevice() override {
return DoRunWithType<
float, // X
float, // W
float, // B
float, // Y
float>(); // Math
}
protected:
size_t axis_{1};
size_t axis_w_{1};
// A local vector to cache the output shape so we don't need to recreate
// a vector object every time we run Run().
vector<int64_t> Y_shape_cache_;
c10::optional<Tensor> bias_multiplier_;
bool float16_compute_;
};
template <
class Context,
class Engine = DefaultEngine,
bool TransposeWeight = true>
class FullyConnectedGradientOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit FullyConnectedGradientOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
axis_(this->template GetSingleArgument<int32_t>("axis", 1)),
axis_w_(this->template GetSingleArgument<int32_t>("axis_w", 1)),
float16_compute_(
this->template GetSingleArgument<bool>("float16_compute", false)) {}
~FullyConnectedGradientOp() {}
template <
typename T_X,
typename T_W,
typename T_DY,
typename T_B,
typename T_DX,
typename T_DW,
typename T_DB,
typename MATH>
bool DoRunWithType() {
const auto& X = Input(0);
const auto& W = Input(1);
const auto& dY = Input(2);
// batch size
const auto canonical_axis = X.canonical_axis_index(axis_);
const int M = X.size_to_dim(canonical_axis);
const int K = X.size_from_dim(canonical_axis);
const auto canonical_axis_w = W.canonical_axis_index(axis_w_);
const int N = TransposeWeight ? W.size_to_dim(canonical_axis_w)
: W.size_from_dim(canonical_axis_w);
auto dimErrorString = [&]() {
return c10::str(
"Dimension mismatch: ",
"X: ",
X.sizes(),
", W: ",
W.sizes(),
", dY: ",
dY.sizes(),
", axis: ",
axis_,
", M: ",
M,
", N: ",
N,
", K: ",
K);
};
CAFFE_ENFORCE(M * K == X.numel(), dimErrorString());
CAFFE_ENFORCE(K * N == W.numel(), dimErrorString());
auto* dW = Output(0, W.sizes(), at::dtype<T_DW>());
auto* db = Output(1, {N}, at::dtype<T_DB>());
if (X.numel() == 0) {
// generate a zero blob for db and dW when X is empty
math::Set<T_DB, Context>(
db->numel(),
convert::To<float, T_DB>(0),
db->template mutable_data<T_DB>(),
&context_);
math::Set<T_DW, Context>(
dW->numel(),
convert::To<float, T_DW>(0),
dW->template mutable_data<T_DW>(),
&context_);
if (OutputSize() == 3) {
Output(2, X.sizes(), at::dtype<T_DX>());
}
return true;
}
// default to FLOAT as math.h does.
TensorProto::DataType math_type = TensorProto_DataType_FLOAT;
if (fp16_type<MATH>()) {
math_type = TensorProto_DataType_FLOAT16;
}
// Compute dW
math::Gemm<T_DY, Context, Engine>(
CblasTrans,
CblasNoTrans,
TransposeWeight ? N : K,
TransposeWeight ? K : N,
M,
1,
TransposeWeight ? dY.template data<T_DY>() : X.template data<T_X>(),
TransposeWeight ? X.template data<T_X>() : dY.template data<T_DY>(),
0,
dW->template mutable_data<T_DW>(),
&context_,
math_type);
if (!bias_multiplier_.has_value()) {
bias_multiplier_ =
caffe2::empty({M}, at::dtype<T_B>().device(Context::GetDeviceType()));
math::Set<T_B, Context>(
M,
convert::To<float, T_B>(1),
bias_multiplier_->template mutable_data<T_B>(),
&context_);
} else if (bias_multiplier_->numel() != M) {
bias_multiplier_->Resize(M);
math::Set<T_B, Context>(
M,
convert::To<float, T_B>(1),
bias_multiplier_->template mutable_data<T_B>(),
&context_);
}
// Compute dB
math::Gemv<T_DY, Context>(
CblasTrans,
M,
N,
1,
dY.template data<T_DY>(),
bias_multiplier_->template data<T_B>(),
0,
db->template mutable_data<T_DB>(),
&context_);
// Compute dX
if (OutputSize() == 3) {
auto* dX = Output(2, X.sizes(), at::dtype<T_DX>());
math::Gemm<T_DX, Context, Engine>(
CblasNoTrans,
TransposeWeight ? CblasNoTrans : CblasTrans,
M,
K,
N,
1,
dY.template data<T_DY>(),
W.template data<T_W>(),
0,
dX->template mutable_data<T_DX>(),
&context_,
math_type);
}
return true;
}
bool RunOnDevice() override {
return DoRunWithType<
float, // X
float, // W
float, // dY
float, // B
float, // dX
float, // dW
float, // dB
float>(); // Math
}
protected:
size_t axis_{1};
size_t axis_w_{1};
c10::optional<Tensor> bias_multiplier_;
bool float16_compute_;
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_