Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / caffe2 / operators / fully_connected_op.h

#ifndef CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_
#define CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_

#include <c10/util/Optional.h>
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/conversions.h"
#include "caffe2/utils/math.h"

namespace caffe2 {

// This is Caffe's InnerProductOp, with a name that fits its purpose better.
template <
    class Context,
    class Engine = DefaultEngine,
    bool TransposeWeight = true>
class FullyConnectedOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  template <class... Args>
  explicit FullyConnectedOp(Args&&... args)
      : Operator<Context>(std::forward<Args>(args)...),
        axis_(this->template GetSingleArgument<int32_t>("axis", 1)),
        axis_w_(this->template GetSingleArgument<int32_t>("axis_w", 1)),
        float16_compute_(
            this->template GetSingleArgument<bool>("float16_compute", false)) {}
  ~FullyConnectedOp() {}

  template <
      typename T_X,
      typename T_W,
      typename T_B,
      typename T_Y,
      typename MATH>
  bool DoRunWithType() {
    const auto& X = Input(0);
    const auto& W = Input(1);
    const auto& b = Input(2);

    CAFFE_ENFORCE(b.dim() == 1, b.dim());
    // batch size
    const auto canonical_axis = X.canonical_axis_index(axis_);
    const auto M = X.size_to_dim(canonical_axis);
    const auto K = X.size_from_dim(canonical_axis);
    const auto canonical_axis_w = W.canonical_axis_index(axis_w_);
    const int N = TransposeWeight ? W.size_to_dim(canonical_axis_w)
                                  : W.size_from_dim(canonical_axis_w);

    auto dimErrorString = [&]() {
      return c10::str(
          "Dimension mismatch: ",
          "X: ",
          X.sizes(),
          ", W: ",
          W.sizes(),
          ", b: ",
          b.sizes(),
          ", axis: ",
          axis_,
          ", M: ",
          M,
          ", N: ",
          N,
          ", K: ",
          K);
    };

    // Error checking
    CAFFE_ENFORCE(M == X.numel() / K, dimErrorString());
    CAFFE_ENFORCE(K == W.numel() / N, dimErrorString());
    CAFFE_ENFORCE(N == b.dim32(0), dimErrorString());
    CAFFE_ENFORCE(N == b.numel(), dimErrorString());

    Y_shape_cache_ = X.sizes().vec();
    // This is an invariant of canonical_axis, so we can DCHECK.
    DCHECK_LE(canonical_axis + 1, Y_shape_cache_.size());
    Y_shape_cache_.resize(canonical_axis + 1);
    Y_shape_cache_[canonical_axis] = N;
    auto* Y = Output(0, Y_shape_cache_, at::dtype<T_Y>());
    CAFFE_ENFORCE(M * N == Y->numel(), dimErrorString());

    if (X.numel() == 0) {
      // skip the rest of the computation if X is empty
      Y->template mutable_data<T_Y>();
      return true;
    }

    // default to FLOAT as math.h does.
    TensorProto::DataType math_type = TensorProto_DataType_FLOAT;
    if (fp16_type<MATH>()) {
      math_type = TensorProto_DataType_FLOAT16;
    }

    // W * x
    math::Gemm<T_X, Context, Engine>(
        CblasNoTrans,
        TransposeWeight ? CblasTrans : CblasNoTrans,
        M,
        N,
        K,
        1,
        X.template data<T_X>(),
        W.template data<T_W>(),
        0,
        Y->template mutable_data<T_Y>(),
        &context_,
        math_type);

    // Add bias term
    if (!bias_multiplier_.has_value()) {
      bias_multiplier_ =
          caffe2::empty({M}, at::dtype<T_B>().device(Context::GetDeviceType()));
      math::Set<T_B, Context>(
          M,
          convert::To<float, T_B>(1),
          bias_multiplier_->template mutable_data<T_B>(),
          &context_);
    } else if (bias_multiplier_->numel() != M) {
      bias_multiplier_->Resize(M);
      math::Set<T_B, Context>(
          M,
          convert::To<float, T_B>(1),
          bias_multiplier_->template mutable_data<T_B>(),
          &context_);
    }

    math::Gemm<T_B, Context, Engine>(
        CblasNoTrans,
        CblasNoTrans,
        M,
        N,
        1,
        1,
        bias_multiplier_->template data<T_B>(),
        b.template data<T_B>(),
        1,
        Y->template mutable_data<T_Y>(),
        &context_,
        math_type);

    return true;
  }

  bool RunOnDevice() override {
    return DoRunWithType<
        float, // X
        float, // W
        float, // B
        float, // Y
        float>(); // Math
  }

 protected:
  size_t axis_{1};
  size_t axis_w_{1};
  // A local vector to cache the output shape so we don't need to recreate
  // a vector object every time we run Run().
  vector<int64_t> Y_shape_cache_;
  c10::optional<Tensor> bias_multiplier_;

  bool float16_compute_;
};

template <
    class Context,
    class Engine = DefaultEngine,
    bool TransposeWeight = true>
class FullyConnectedGradientOp : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  template <class... Args>
  explicit FullyConnectedGradientOp(Args&&... args)
      : Operator<Context>(std::forward<Args>(args)...),
        axis_(this->template GetSingleArgument<int32_t>("axis", 1)),
        axis_w_(this->template GetSingleArgument<int32_t>("axis_w", 1)),
        float16_compute_(
            this->template GetSingleArgument<bool>("float16_compute", false)) {}
  ~FullyConnectedGradientOp() {}

  template <
      typename T_X,
      typename T_W,
      typename T_DY,
      typename T_B,
      typename T_DX,
      typename T_DW,
      typename T_DB,
      typename MATH>
  bool DoRunWithType() {
    const auto& X = Input(0);
    const auto& W = Input(1);
    const auto& dY = Input(2);
    // batch size
    const auto canonical_axis = X.canonical_axis_index(axis_);
    const int M = X.size_to_dim(canonical_axis);
    const int K = X.size_from_dim(canonical_axis);
    const auto canonical_axis_w = W.canonical_axis_index(axis_w_);
    const int N = TransposeWeight ? W.size_to_dim(canonical_axis_w)
                                  : W.size_from_dim(canonical_axis_w);

    auto dimErrorString = [&]() {
      return c10::str(
          "Dimension mismatch: ",
          "X: ",
          X.sizes(),
          ", W: ",
          W.sizes(),
          ", dY: ",
          dY.sizes(),
          ", axis: ",
          axis_,
          ", M: ",
          M,
          ", N: ",
          N,
          ", K: ",
          K);
    };

    CAFFE_ENFORCE(M * K == X.numel(), dimErrorString());
    CAFFE_ENFORCE(K * N == W.numel(), dimErrorString());

    auto* dW = Output(0, W.sizes(), at::dtype<T_DW>());
    auto* db = Output(1, {N}, at::dtype<T_DB>());

    if (X.numel() == 0) {
      // generate a zero blob for db and dW when X is empty
      math::Set<T_DB, Context>(
          db->numel(),
          convert::To<float, T_DB>(0),
          db->template mutable_data<T_DB>(),
          &context_);
      math::Set<T_DW, Context>(
          dW->numel(),
          convert::To<float, T_DW>(0),
          dW->template mutable_data<T_DW>(),
          &context_);

      if (OutputSize() == 3) {
        Output(2, X.sizes(), at::dtype<T_DX>());
      }

      return true;
    }

    // default to FLOAT as math.h does.
    TensorProto::DataType math_type = TensorProto_DataType_FLOAT;
    if (fp16_type<MATH>()) {
      math_type = TensorProto_DataType_FLOAT16;
    }

    // Compute dW
    math::Gemm<T_DY, Context, Engine>(
        CblasTrans,
        CblasNoTrans,
        TransposeWeight ? N : K,
        TransposeWeight ? K : N,
        M,
        1,
        TransposeWeight ? dY.template data<T_DY>() : X.template data<T_X>(),
        TransposeWeight ? X.template data<T_X>() : dY.template data<T_DY>(),
        0,
        dW->template mutable_data<T_DW>(),
        &context_,
        math_type);
    if (!bias_multiplier_.has_value()) {
      bias_multiplier_ =
          caffe2::empty({M}, at::dtype<T_B>().device(Context::GetDeviceType()));
      math::Set<T_B, Context>(
          M,
          convert::To<float, T_B>(1),
          bias_multiplier_->template mutable_data<T_B>(),
          &context_);
    } else if (bias_multiplier_->numel() != M) {
      bias_multiplier_->Resize(M);
      math::Set<T_B, Context>(
          M,
          convert::To<float, T_B>(1),
          bias_multiplier_->template mutable_data<T_B>(),
          &context_);
    }
    // Compute dB
    math::Gemv<T_DY, Context>(
        CblasTrans,
        M,
        N,
        1,
        dY.template data<T_DY>(),
        bias_multiplier_->template data<T_B>(),
        0,
        db->template mutable_data<T_DB>(),
        &context_);

    // Compute dX
    if (OutputSize() == 3) {
      auto* dX = Output(2, X.sizes(), at::dtype<T_DX>());
      math::Gemm<T_DX, Context, Engine>(
          CblasNoTrans,
          TransposeWeight ? CblasNoTrans : CblasTrans,
          M,
          K,
          N,
          1,
          dY.template data<T_DY>(),
          W.template data<T_W>(),
          0,
          dX->template mutable_data<T_DX>(),
          &context_,
          math_type);
    }
    return true;
  }

  bool RunOnDevice() override {
    return DoRunWithType<
        float, //  X
        float, //  W
        float, // dY
        float, //  B
        float, // dX
        float, // dW
        float, // dB
        float>(); // Math
  }

 protected:
  size_t axis_{1};
  size_t axis_w_{1};
  c10::optional<Tensor> bias_multiplier_;
  bool float16_compute_;
};

} // namespace caffe2

#endif // CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_