#ifndef CAFFE2_OPERATORS_ELEMENTWISE_OPS_H_
#define CAFFE2_OPERATORS_ELEMENTWISE_OPS_H_
#include <iterator>
#include <string>
#include <tuple>
#include <vector>
#include "caffe2/core/common_omp.h"
#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/tensor.h"
#include "caffe2/operators/elementwise_ops_utils.h"
#include "caffe2/utils/eigen_utils.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
using NumericTypes = TensorTypes<int32_t, int64_t, float, double>;
using IntTypes = TensorTypes<int32_t, int64_t>;
using BoolTypes = TensorTypes<bool>;
using IntBoolTypes = TensorTypes<int32_t, int64_t, bool>; // discrete types
struct SameTypeAsInput {
template <typename T>
using type = T;
};
template <typename R>
struct FixedType {
template <typename T>
using type = R;
};
template <
typename InputTypes,
class Context,
class Functor,
class OutputTypeMap = SameTypeAsInput>
class UnaryElementwiseWithArgsOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit UnaryElementwiseWithArgsOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...), functor_(*this) {}
bool RunOnDevice() override {
return DispatchHelper<InputTypes>::call(this, Input(0));
}
template <typename T>
bool DoRunWithType() {
const auto& X = Input(0);
auto* Y = Output(
0, X.sizes(), at::dtype<typename OutputTypeMap::template type<T>>());
return functor_(
X.numel(),
X.template data<T>(),
Y->template mutable_data<typename OutputTypeMap::template type<T>>(),
&context_);
}
private:
Functor functor_;
};
// UnaryFunctorWithDefaultCtor is a functor that can be used as the functor of
// an UnaryElementwiseWithArgsOp. It simply forwards the operator() call into
// another functor that doesn't accept arguments in its constructor.
template <class Functor>
struct UnaryFunctorWithDefaultCtor {
explicit UnaryFunctorWithDefaultCtor(OperatorBase& /* op */) {}
template <typename TIn, typename TOut, class Context>
bool operator()(const int size, const TIn* X, TOut* Y, Context* context)
const {
return functor(size, X, Y, context);
}
Functor functor{};
};
// UnaryElementwiseOp is a wrapper around UnaryElementwiseWithArgsOp, with the
// difference that it takes a functor with default constructor, e.g. that does
// not need to take into consideration any arguments during operator creation.
template <
typename InputTypes,
class Context,
class Functor,
class OutputTypeMap = SameTypeAsInput>
using UnaryElementwiseOp = UnaryElementwiseWithArgsOp<
InputTypes,
Context,
UnaryFunctorWithDefaultCtor<Functor>,
OutputTypeMap>;
template <
typename InputTypes,
class Context,
class Functor,
class OutputTypeMap = SameTypeAsInput>
class BinaryElementwiseWithArgsOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit BinaryElementwiseWithArgsOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
OP_SINGLE_ARG(bool, "broadcast", legacy_broadcast_, false),
OP_SINGLE_ARG(int, "axis", axis_, -1),
OP_SINGLE_ARG(string, "axis_str", axis_str_, string("")),
OP_SINGLE_ARG(string, "order", order_, "NCHW"),
functor_(*this) {
if (legacy_broadcast_) {
if (axis_ != -1) {
// Get axis from an explicit axis argument.
CAFFE_ENFORCE_EQ(
axis_str_.size(),
0U,
"Args axis and axis_str cannot be used simultaneously.");
} else if (axis_str_.size()) {
// Get the axis index semantically.
CAFFE_ENFORCE_EQ(
axis_str_.size(), 1U, "Unsupported axis string", axis_str_);
const size_t semantic_axis_ = order_.find(axis_str_);
CAFFE_ENFORCE_NE(
semantic_axis_,
string::npos,
"Unrecognizable axis string ",
axis_str_,
" from order string ",
order_);
axis_ = semantic_axis_;
} else {
CAFFE_ENFORCE(
axis_ == -1 && axis_str_.empty(),
"Do not specify axis or axis_str if broadcast is not enabled.");
}
}
}
bool RunOnDevice() override {
return DispatchHelper<InputTypes>::call(this, Input(0));
}
template <typename T>
bool DoRunWithType() {
const auto& A = Input(0);
const auto& B = Input(1);
const T* A_data = A.template data<T>();
const T* B_data = B.template data<T>();
std::vector<int> A_dims;
std::vector<int> B_dims;
std::vector<int64_t> C_dims;
if (legacy_broadcast_) {
CAFFE_ENFORCE(
!IsInputOutputAlias(1, 0),
"In-place is allowed only with the first tensor when "
"legacy-broadcasting");
C_dims = A.sizes().vec();
if (B.numel() == 1) {
A_dims = {static_cast<int>(A.numel())};
B_dims = {1};
} else {
size_t pre, n, post;
std::tie(pre, n, post) =
elementwise_ops_utils::ComputeLegacyBroadcastSizes(A, B, axis_);
A_dims = {
static_cast<int>(pre), static_cast<int>(n), static_cast<int>(post)};
B_dims = {static_cast<int>(n), 1};
}
} else {
std::copy(
A.sizes().cbegin(), A.sizes().cend(), std::back_inserter(A_dims));
std::copy(
B.sizes().cbegin(), B.sizes().cend(), std::back_inserter(B_dims));
// TODO: change the types to vector<int64_t>
auto C_dims_int =
elementwise_ops_utils::ComputeBinaryBroadcastForwardDims(
A_dims, B_dims);
std::copy(
C_dims_int.cbegin(), C_dims_int.cend(), std::back_inserter(C_dims));
if (IsInputOutputAlias(0, 0)) {
CAFFE_ENFORCE_EQ(C_dims_int, A_dims);
} else if (IsInputOutputAlias(1, 0)) {
CAFFE_ENFORCE_EQ(C_dims_int, B_dims);
}
}
auto* C = Output(
0, C_dims, at::dtype<typename OutputTypeMap::template type<T>>());
auto* C_data =
C->template mutable_data<typename OutputTypeMap::template type<T>>();
return functor_.Forward(A_dims, B_dims, A_data, B_data, C_data, &context_);
}
private:
const bool legacy_broadcast_;
int axis_;
const std::string axis_str_;
const std::string order_;
Functor functor_;
};
template <
typename InputTypes,
class Context,
class Functor,
class OutputTypeMap = SameTypeAsInput,
class GradientTypeMap = SameTypeAsInput>
class BinaryElementwiseWithArgsGradientOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit BinaryElementwiseWithArgsGradientOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
OP_SINGLE_ARG(bool, "broadcast", legacy_broadcast_, false),
OP_SINGLE_ARG(int, "axis", axis_, -1),
OP_SINGLE_ARG(string, "axis_str", axis_str_, ""),
OP_SINGLE_ARG(string, "order", order_, "NCHW"),
functor_(*this) {
if (legacy_broadcast_) {
if (axis_ != -1) {
// Get axis from an explicit axis argument.
CAFFE_ENFORCE_EQ(
axis_str_.size(),
0U,
"Args axis and axis_str cannot be used simultaneously.");
} else if (axis_str_.size()) {
// Get the axis index semantically.
CAFFE_ENFORCE_EQ(
axis_str_.size(), 1U, "Unsupported axis string", axis_str_);
const size_t semantic_axis_ = order_.find(axis_str_);
CAFFE_ENFORCE_NE(
semantic_axis_,
string::npos,
"Unrecognizable axis string ",
axis_str_,
" from order string ",
order_);
axis_ = semantic_axis_;
} else {
CAFFE_ENFORCE(
axis_ == -1 && axis_str_.empty(),
"Do not specify axis or axis_str if broadcast is not enabled.");
}
}
}
bool RunOnDevice() override {
return DispatchHelper<InputTypes>::call(this, Input(1));
}
template <typename T>
bool DoRunWithType() {
const auto& dC = Input(0);
const auto& A = Input(1);
const auto& B = Input(2);
vector<int> A_dims;
vector<int> B_dims;
if (legacy_broadcast_) {
if (B.numel() == 1) {
A_dims = {static_cast<int>(A.numel())};
B_dims = {1};
} else {
size_t pre, n, post;
std::tie(pre, n, post) =
elementwise_ops_utils::ComputeLegacyBroadcastSizes(A, B, axis_);
A_dims = {
static_cast<int>(pre), static_cast<int>(n), static_cast<int>(post)};
B_dims = {static_cast<int>(n), 1};
}
} else {
std::copy(
A.sizes().cbegin(), A.sizes().cend(), std::back_inserter(A_dims));
std::copy(
B.sizes().cbegin(), B.sizes().cend(), std::back_inserter(B_dims));
}
const typename OutputTypeMap::template type<T>* C_data = nullptr;
if (InputSize() == 4) {
const auto& C = Input(3);
C_data = C.template data<typename OutputTypeMap::template type<T>>();
}
const auto* dC_data =
dC.template data<typename GradientTypeMap::template type<T>>();
const T* A_data = A.template data<T>();
const T* B_data = B.template data<T>();
auto* dA = Output(
0, A.sizes(), at::dtype<typename GradientTypeMap::template type<T>>());
auto* dB = Output(
1, B.sizes(), at::dtype<typename GradientTypeMap::template type<T>>());
auto* dA_data =
dA->template mutable_data<typename GradientTypeMap::template type<T>>();
auto* dB_data =
dB->template mutable_data<typename GradientTypeMap::template type<T>>();
return functor_.Backward(
A_dims,
B_dims,
dC_data,
A_data,
B_data,
C_data,
dA_data,
dB_data,
&context_);
}
private:
const bool legacy_broadcast_;
int axis_;
const std::string axis_str_;
const std::string order_;
Functor functor_;
};
template <class Functor>
struct BinaryFunctorWithDefaultCtor {
explicit BinaryFunctorWithDefaultCtor(OperatorBase& /* op */) {}
template <typename TIn, typename TOut, class Context>
bool Forward(
const std::vector<int>& A_dims,
const std::vector<int>& B_dims,
const TIn* A_data,
const TIn* B_data,
TOut* C_data,
Context* context) const {
return functor.Forward(A_dims, B_dims, A_data, B_data, C_data, context);
}
template <typename TGrad, typename TIn, typename TOut, class Context>
bool Backward(
const std::vector<int>& A_dims,
const std::vector<int>& B_dims,
const TGrad* dC_data,
const TIn* A_data,
Loading ...