Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / caffe2 / operators / elementwise_ops.h

#ifndef CAFFE2_OPERATORS_ELEMENTWISE_OPS_H_
#define CAFFE2_OPERATORS_ELEMENTWISE_OPS_H_

#include <iterator>
#include <string>
#include <tuple>
#include <vector>

#include "caffe2/core/common_omp.h"
#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/tensor.h"
#include "caffe2/operators/elementwise_ops_utils.h"
#include "caffe2/utils/eigen_utils.h"
#include "caffe2/utils/math.h"

namespace caffe2 {

using NumericTypes = TensorTypes<int32_t, int64_t, float, double>;
using IntTypes = TensorTypes<int32_t, int64_t>;
using BoolTypes = TensorTypes<bool>;
using IntBoolTypes = TensorTypes<int32_t, int64_t, bool>; // discrete types

struct SameTypeAsInput {
  template <typename T>
  using type = T;
};

template <typename R>
struct FixedType {
  template <typename T>
  using type = R;
};

template <
    typename InputTypes,
    class Context,
    class Functor,
    class OutputTypeMap = SameTypeAsInput>
class UnaryElementwiseWithArgsOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;

  template <class... Args>
  explicit UnaryElementwiseWithArgsOp(Args&&... args)
      : Operator<Context>(std::forward<Args>(args)...), functor_(*this) {}

  bool RunOnDevice() override {
    return DispatchHelper<InputTypes>::call(this, Input(0));
  }

  template <typename T>
  bool DoRunWithType() {
    const auto& X = Input(0);

    auto* Y = Output(
        0, X.sizes(), at::dtype<typename OutputTypeMap::template type<T>>());
    return functor_(
        X.numel(),
        X.template data<T>(),
        Y->template mutable_data<typename OutputTypeMap::template type<T>>(),
        &context_);
  }

 private:
  Functor functor_;
};

// UnaryFunctorWithDefaultCtor is a functor that can be used as the functor of
// an UnaryElementwiseWithArgsOp. It simply forwards the operator() call into
// another functor that doesn't accept arguments in its constructor.
template <class Functor>
struct UnaryFunctorWithDefaultCtor {
  explicit UnaryFunctorWithDefaultCtor(OperatorBase& /* op */) {}

  template <typename TIn, typename TOut, class Context>
  bool operator()(const int size, const TIn* X, TOut* Y, Context* context)
      const {
    return functor(size, X, Y, context);
  }

  Functor functor{};
};

// UnaryElementwiseOp is a wrapper around UnaryElementwiseWithArgsOp, with the
// difference that it takes a functor with default constructor, e.g. that does
// not need to take into consideration any arguments during operator creation.
template <
    typename InputTypes,
    class Context,
    class Functor,
    class OutputTypeMap = SameTypeAsInput>
using UnaryElementwiseOp = UnaryElementwiseWithArgsOp<
    InputTypes,
    Context,
    UnaryFunctorWithDefaultCtor<Functor>,
    OutputTypeMap>;

template <
    typename InputTypes,
    class Context,
    class Functor,
    class OutputTypeMap = SameTypeAsInput>
class BinaryElementwiseWithArgsOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;

  template <class... Args>
  explicit BinaryElementwiseWithArgsOp(Args&&... args)
      : Operator<Context>(std::forward<Args>(args)...),
        OP_SINGLE_ARG(bool, "broadcast", legacy_broadcast_, false),
        OP_SINGLE_ARG(int, "axis", axis_, -1),
        OP_SINGLE_ARG(string, "axis_str", axis_str_, string("")),
        OP_SINGLE_ARG(string, "order", order_, "NCHW"),
        functor_(*this) {
    if (legacy_broadcast_) {
      if (axis_ != -1) {
        // Get axis from an explicit axis argument.
        CAFFE_ENFORCE_EQ(
            axis_str_.size(),
            0U,
            "Args axis and axis_str cannot be used simultaneously.");
      } else if (axis_str_.size()) {
        // Get the axis index semantically.
        CAFFE_ENFORCE_EQ(
            axis_str_.size(), 1U, "Unsupported axis string", axis_str_);
        const size_t semantic_axis_ = order_.find(axis_str_);
        CAFFE_ENFORCE_NE(
            semantic_axis_,
            string::npos,
            "Unrecognizable axis string ",
            axis_str_,
            " from order string ",
            order_);
        axis_ = semantic_axis_;
      } else {
        CAFFE_ENFORCE(
            axis_ == -1 && axis_str_.empty(),
            "Do not specify axis or axis_str if broadcast is not enabled.");
      }
    }
  }

  bool RunOnDevice() override {
    return DispatchHelper<InputTypes>::call(this, Input(0));
  }

  template <typename T>
  bool DoRunWithType() {
    const auto& A = Input(0);
    const auto& B = Input(1);

    const T* A_data = A.template data<T>();
    const T* B_data = B.template data<T>();
    std::vector<int> A_dims;
    std::vector<int> B_dims;
    std::vector<int64_t> C_dims;

    if (legacy_broadcast_) {
      CAFFE_ENFORCE(
          !IsInputOutputAlias(1, 0),
          "In-place is allowed only with the first tensor when "
          "legacy-broadcasting");
      C_dims = A.sizes().vec();
      if (B.numel() == 1) {
        A_dims = {static_cast<int>(A.numel())};
        B_dims = {1};
      } else {
        size_t pre, n, post;
        std::tie(pre, n, post) =
            elementwise_ops_utils::ComputeLegacyBroadcastSizes(A, B, axis_);
        A_dims = {
            static_cast<int>(pre), static_cast<int>(n), static_cast<int>(post)};
        B_dims = {static_cast<int>(n), 1};
      }
    } else {
      std::copy(
          A.sizes().cbegin(), A.sizes().cend(), std::back_inserter(A_dims));
      std::copy(
          B.sizes().cbegin(), B.sizes().cend(), std::back_inserter(B_dims));
      // TODO: change the types to vector<int64_t>
      auto C_dims_int =
          elementwise_ops_utils::ComputeBinaryBroadcastForwardDims(
              A_dims, B_dims);
      std::copy(
          C_dims_int.cbegin(), C_dims_int.cend(), std::back_inserter(C_dims));
      if (IsInputOutputAlias(0, 0)) {
        CAFFE_ENFORCE_EQ(C_dims_int, A_dims);
      } else if (IsInputOutputAlias(1, 0)) {
        CAFFE_ENFORCE_EQ(C_dims_int, B_dims);
      }
    }

    auto* C = Output(
        0, C_dims, at::dtype<typename OutputTypeMap::template type<T>>());
    auto* C_data =
        C->template mutable_data<typename OutputTypeMap::template type<T>>();
    return functor_.Forward(A_dims, B_dims, A_data, B_data, C_data, &context_);
  }

 private:
  const bool legacy_broadcast_;
  int axis_;
  const std::string axis_str_;
  const std::string order_;

  Functor functor_;
};

template <
    typename InputTypes,
    class Context,
    class Functor,
    class OutputTypeMap = SameTypeAsInput,
    class GradientTypeMap = SameTypeAsInput>
class BinaryElementwiseWithArgsGradientOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;

  template <class... Args>
  explicit BinaryElementwiseWithArgsGradientOp(Args&&... args)
      : Operator<Context>(std::forward<Args>(args)...),
        OP_SINGLE_ARG(bool, "broadcast", legacy_broadcast_, false),
        OP_SINGLE_ARG(int, "axis", axis_, -1),
        OP_SINGLE_ARG(string, "axis_str", axis_str_, ""),
        OP_SINGLE_ARG(string, "order", order_, "NCHW"),
        functor_(*this) {
    if (legacy_broadcast_) {
      if (axis_ != -1) {
        // Get axis from an explicit axis argument.
        CAFFE_ENFORCE_EQ(
            axis_str_.size(),
            0U,
            "Args axis and axis_str cannot be used simultaneously.");
      } else if (axis_str_.size()) {
        // Get the axis index semantically.
        CAFFE_ENFORCE_EQ(
            axis_str_.size(), 1U, "Unsupported axis string", axis_str_);
        const size_t semantic_axis_ = order_.find(axis_str_);
        CAFFE_ENFORCE_NE(
            semantic_axis_,
            string::npos,
            "Unrecognizable axis string ",
            axis_str_,
            " from order string ",
            order_);
        axis_ = semantic_axis_;
      } else {
        CAFFE_ENFORCE(
            axis_ == -1 && axis_str_.empty(),
            "Do not specify axis or axis_str if broadcast is not enabled.");
      }
    }
  }

  bool RunOnDevice() override {
    return DispatchHelper<InputTypes>::call(this, Input(1));
  }

  template <typename T>
  bool DoRunWithType() {
    const auto& dC = Input(0);
    const auto& A = Input(1);
    const auto& B = Input(2);

    vector<int> A_dims;
    vector<int> B_dims;
    if (legacy_broadcast_) {
      if (B.numel() == 1) {
        A_dims = {static_cast<int>(A.numel())};
        B_dims = {1};
      } else {
        size_t pre, n, post;
        std::tie(pre, n, post) =
            elementwise_ops_utils::ComputeLegacyBroadcastSizes(A, B, axis_);
        A_dims = {
            static_cast<int>(pre), static_cast<int>(n), static_cast<int>(post)};
        B_dims = {static_cast<int>(n), 1};
      }
    } else {
      std::copy(
          A.sizes().cbegin(), A.sizes().cend(), std::back_inserter(A_dims));
      std::copy(
          B.sizes().cbegin(), B.sizes().cend(), std::back_inserter(B_dims));
    }
    const typename OutputTypeMap::template type<T>* C_data = nullptr;
    if (InputSize() == 4) {
      const auto& C = Input(3);
      C_data = C.template data<typename OutputTypeMap::template type<T>>();
    }
    const auto* dC_data =
        dC.template data<typename GradientTypeMap::template type<T>>();
    const T* A_data = A.template data<T>();
    const T* B_data = B.template data<T>();
    auto* dA = Output(
        0, A.sizes(), at::dtype<typename GradientTypeMap::template type<T>>());
    auto* dB = Output(
        1, B.sizes(), at::dtype<typename GradientTypeMap::template type<T>>());
    auto* dA_data =
        dA->template mutable_data<typename GradientTypeMap::template type<T>>();
    auto* dB_data =
        dB->template mutable_data<typename GradientTypeMap::template type<T>>();
    return functor_.Backward(
        A_dims,
        B_dims,
        dC_data,
        A_data,
        B_data,
        C_data,
        dA_data,
        dB_data,
        &context_);
  }

 private:
  const bool legacy_broadcast_;
  int axis_;
  const std::string axis_str_;
  const std::string order_;

  Functor functor_;
};

template <class Functor>
struct BinaryFunctorWithDefaultCtor {
  explicit BinaryFunctorWithDefaultCtor(OperatorBase& /* op */) {}

  template <typename TIn, typename TOut, class Context>
  bool Forward(
      const std::vector<int>& A_dims,
      const std::vector<int>& B_dims,
      const TIn* A_data,
      const TIn* B_data,
      TOut* C_data,
      Context* context) const {
    return functor.Forward(A_dims, B_dims, A_data, B_data, C_data, context);
  }

  template <typename TGrad, typename TIn, typename TOut, class Context>
  bool Backward(
      const std::vector<int>& A_dims,
      const std::vector<int>& B_dims,
      const TGrad* dC_data,
      const TIn* A_data,
Loading ...