Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / caffe2 / operators / distance_op.h

#ifndef CAFFE2_OPERATORS_DISTANCE_OP_H_
#define CAFFE2_OPERATORS_DISTANCE_OP_H_

#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"

namespace caffe2 {

template <typename T, class Context>
class SquaredL2DistanceOp : public Operator<Context> {
 public:
  template <class... Args>
  explicit SquaredL2DistanceOp(Args&&... args)
      : Operator<Context>(std::forward<Args>(args)...) {}
  USE_OPERATOR_CONTEXT_FUNCTIONS;

  bool RunOnDevice() override;

 protected:
  // Input: X, Y; Output: Distance
};

template <typename T, class Context>
class SquaredL2DistanceGradientOp final : public Operator<Context> {
 public:
  template <class... Args>
  explicit SquaredL2DistanceGradientOp(Args&&... args)
      : Operator<Context>(std::forward<Args>(args)...) {}
  USE_OPERATOR_CONTEXT_FUNCTIONS;

  bool RunOnDevice() override {
    auto& X = Input(0);
    auto& Y = Input(1);
    auto& dDistance = Input(2);

    int N = X.dim() > 0 ? X.dim32(0) : 1;
    int D = N > 0 ? X.numel() / N : 0;
    CAFFE_ENFORCE(X.dim() == Y.dim());
    for (int i = 0; i < X.dim(); ++i) {
      CAFFE_ENFORCE(X.dim32(i) == Y.dim32(i));
    }
    CAFFE_ENFORCE(dDistance.dim() == 1);
    CAFFE_ENFORCE(dDistance.dim32(0) == N);
    auto* dX = Output(0, X.sizes(), at::dtype<T>());
    auto* dY = Output(1, Y.sizes(), at::dtype<T>());
    math::Sub<T, Context>(
        X.numel(),
        X.template data<T>(),
        Y.template data<T>(),
        dX->template mutable_data<T>(),
        &context_);
    for (int i = 0; i < N; ++i) {
      math::Scale<T, T, Context>(
          D,
          dDistance.template data<T>() + i,
          dX->template data<T>() + i * D,
          dX->template mutable_data<T>() + i * D,
          &context_);
    }
    // The gradient of the other side is basically the negative.
    math::Scale<T, T, Context>(
        X.numel(),
        -1,
        dX->template data<T>(),
        dY->template mutable_data<T>(),
        &context_);
    return true;
  }

 protected:
  // Input: X, Y, dDistance; Output: dX, dY
};

template <typename T, class Context>
class L1DistanceOp : public Operator<Context> {
 public:
  template <class... Args>
  explicit L1DistanceOp(Args&&... args)
      : Operator<Context>(std::forward<Args>(args)...) {}
  USE_OPERATOR_CONTEXT_FUNCTIONS;

  bool RunOnDevice() override;

 protected:
  // Input: X, Y; Output: Distance
};

template <typename T, class Context>
class L1DistanceGradientOp : public Operator<Context> {
 public:
  template <class... Args>
  explicit L1DistanceGradientOp(Args&&... args)
      : Operator<Context>(std::forward<Args>(args)...) {}
  USE_OPERATOR_CONTEXT_FUNCTIONS;

  bool RunOnDevice() override;

 protected:
  // Input: X, Y, dDistance; Output: dX, dY
};

template <typename T, class Context>
class DotProductOp : public Operator<Context> {
 public:
  template <class... Args>
  explicit DotProductOp(Args&&... args)
      : Operator<Context>(std::forward<Args>(args)...) {}
  USE_OPERATOR_CONTEXT_FUNCTIONS;

  bool RunOnDevice() override;

 protected:
  INPUT_TAGS(X_IN, Y_IN);
  OUTPUT_TAGS(DOT_OUT);
};

template <typename T, class Context>
class DotProductGradientOp final : public Operator<Context> {
 public:
  template <class... Args>
  explicit DotProductGradientOp(Args&&... args)
      : Operator<Context>(std::forward<Args>(args)...) {}
  USE_OPERATOR_CONTEXT_FUNCTIONS;

  bool RunOnDevice() override;

 protected:
  INPUT_TAGS(X_IN, Y_IN, DER_DOT_IN);
  OUTPUT_TAGS(DER_X_OUT, DER_Y_OUT);
};

template <typename T, class Context>
class DotProductWithPaddingOp : public Operator<Context> {
 public:
  template <class... Args>
  explicit DotProductWithPaddingOp(Args&&... args)
      : Operator<Context>(std::forward<Args>(args)...),
        pad_value_(this->template GetSingleArgument<float>("pad_value", 0.0)),
        replicate_(this->template GetSingleArgument<bool>("replicate", false)) {
  }
  USE_OPERATOR_CONTEXT_FUNCTIONS;

  bool RunOnDevice() override;

 protected:
  float pad_value_;
  bool replicate_;
  INPUT_TAGS(X_IN, Y_IN);
  OUTPUT_TAGS(DOT_OUT);
};

template <typename T, class Context>
class CosineSimilarityOp : public Operator<Context> {
 public:
  template <class... Args>
  explicit CosineSimilarityOp(Args&&... args)
      : Operator<Context>(std::forward<Args>(args)...) {}
  USE_OPERATOR_CONTEXT_FUNCTIONS;

  bool RunOnDevice() override;

 protected:
  INPUT_TAGS(X_IN, Y_IN);
  OUTPUT_TAGS(COS_OUT);

 private:
  Tensor aux_;
};

template <typename T, class Context>
class CosineSimilarityGradientOp final : public Operator<Context> {
 public:
  template <class... Args>
  explicit CosineSimilarityGradientOp(Args&&... args)
      : Operator<Context>(std::forward<Args>(args)...) {}
  USE_OPERATOR_CONTEXT_FUNCTIONS;

  bool RunOnDevice() override;

 protected:
  INPUT_TAGS(X_IN, Y_IN, DER_COS_IN);
  OUTPUT_TAGS(DER_X_OUT, DER_Y_OUT);

 private:
  Tensor aux_;
};

template <typename T, class Context>
class DotProductWithPaddingGradientOp final : public Operator<Context> {
 public:
  template <class... Args>
  explicit DotProductWithPaddingGradientOp(Args&&... args)
      : Operator<Context>(std::forward<Args>(args)...),
        pad_value_(this->template GetSingleArgument<float>("pad_value", 0.0)),
        replicate_(this->template GetSingleArgument<bool>("replicate", false)) {
  }
  USE_OPERATOR_CONTEXT_FUNCTIONS;

  bool RunOnDevice() override {
    auto& X = Input(X_IN);
    auto& Y = Input(Y_IN);
    auto& dDot = Input(DER_DOT_IN);

    int N, D, DX, DY, restD;
    if (X.numel() > 0) {
      N = X.dim() > 0 ? X.dim32(0) : 1;
      DX = X.numel() / N;
      DY = Y.numel() / N;
    } else {
      N = 0;
      DX = 0;
      DY = 0;
    }
    CAFFE_ENFORCE(!replicate_ || DX % DY == 0 || DY % DX == 0);
    D = std::min(DX, DY);
    restD = std::max(DX, DY) - D;
    CAFFE_ENFORCE_EQ(X.dim(), Y.dim());
    CAFFE_ENFORCE_EQ(X.dim32(0), Y.dim32(0));
    CAFFE_ENFORCE_EQ(dDot.dim(), 1);
    CAFFE_ENFORCE_EQ(dDot.dim32(0), N);
    auto* dX = Output(DER_X_OUT, X.sizes(), at::dtype<T>());
    auto* dY = Output(DER_Y_OUT, Y.sizes(), at::dtype<T>());

    const auto* X_data = X.template data<T>();
    const auto* Y_data = Y.template data<T>();
    const auto* dDot_data = dDot.template data<T>();
    auto* dX_data = dX->template mutable_data<T>();
    auto* dY_data = dY->template mutable_data<T>();
    for (int i = 0; i < N; ++i) { // TODO: multithreading
      auto offsetX = i * DX;
      auto offsetY = i * DY;
      if (replicate_) {
        // L_ for longer vector and S_ for shorter vector
        const T *L_data, *S_data;
        T *dL_data, *dS_data;
        int DL, DS;
        if (DX > DY) {
          L_data = X_data + offsetX;
          S_data = Y_data + offsetY;
          dL_data = dX_data + offsetX;
          dS_data = dY_data + offsetY;
          DL = DX;
          DS = DY;
        } else {
          L_data = Y_data + offsetY;
          S_data = X_data + offsetX;
          dL_data = dY_data + offsetY;
          dS_data = dX_data + offsetX;
          DL = DY;
          DS = DX;
        }

        // TODO: get rid of temp memory use
        std::vector<T> tmp_data(DS);
        math::Set<T, Context>(DS, 0.0, dS_data, &context_);
        for (int j = 0; j < DL / DS; j++) {
          math::Scale<T, T, Context>(
              DS, dDot_data[i], S_data, dL_data + j * DS, &context_);
          math::Scale<T, T, Context>(
              DS, dDot_data[i], L_data + j * DS, tmp_data.data(), &context_);
          math::Axpy<float, T, Context>(
              DS, 1.0, tmp_data.data(), dS_data, &context_);
        }
      } else {
        math::Scale<T, T, Context>(
            D, dDot_data[i], X_data + offsetX, dY_data + offsetY, &context_);
        math::Scale<T, T, Context>(
            D, dDot_data[i], Y_data + offsetY, dX_data + offsetX, &context_);
      }

      if (!replicate_ && DX != DY) {
        T* rest_data;
        if (DX > DY) {
          rest_data = dX_data + offsetX + D;
        } else {
          rest_data = dY_data + offsetY + D;
        }
        auto pad_gradient = dDot_data[i] * pad_value_;
        math::Set<T, Context>(restD, pad_gradient, rest_data, &context_);
      }
    }

    return true;
  }

 protected:
  float pad_value_;
  bool replicate_;
  INPUT_TAGS(X_IN, Y_IN, DER_DOT_IN);
  OUTPUT_TAGS(DER_X_OUT, DER_Y_OUT);
};

} // namespace caffe2

#endif // CAFFE2_OPERATORS_DISTANCE_OP_H_