Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / caffe2 / operators / im2col_op.h

#ifndef CAFFE2_OPERATORS_IM2COL_OP_H_
#define CAFFE2_OPERATORS_IM2COL_OP_H_

#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"

namespace caffe2 {

template <typename T, class Context>
class Im2ColOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  template <class... Args>
  explicit Im2ColOp(Args&&... args)
      : Operator<Context>(std::forward<Args>(args)...),
        pad_(this->template GetSingleArgument<int>("pad", 0)),
        kernel_h_(this->template GetSingleArgument<int>(
            "kernel_h",
            this->template GetSingleArgument<int>("kernel", 0))),
        kernel_w_(this->template GetSingleArgument<int>(
            "kernel_w",
            this->template GetSingleArgument<int>("kernel", 0))),
        dilation_h_(this->template GetSingleArgument<int>(
            "dilation_h",
            this->template GetSingleArgument<int>("dilation", 1))),
        dilation_w_(this->template GetSingleArgument<int>(
            "dilation_w",
            this->template GetSingleArgument<int>("dilation", 1))),
        stride_h_(this->template GetSingleArgument<int>(
            "stride_h",
            this->template GetSingleArgument<int>("stride", 1))),
        stride_w_(this->template GetSingleArgument<int>(
            "stride_w",
            this->template GetSingleArgument<int>("stride", 1))),
        order_(StringToStorageOrder(
            this->template GetSingleArgument<string>("order", "NCHW"))) {
    CAFFE_ENFORCE(kernel_h_ > 0);
    CAFFE_ENFORCE(kernel_w_ > 0);
    CAFFE_ENFORCE(dilation_h_ > 0);
    CAFFE_ENFORCE(dilation_w_ > 0);
    CAFFE_ENFORCE(stride_h_ > 0);
    CAFFE_ENFORCE(stride_w_ > 0);
    CAFFE_ENFORCE(pad_ >= 0);
  }

  bool RunOnDevice() override {
    auto& X = Input(0);

    CAFFE_ENFORCE(4 == X.dim());

    int N = 0, C = 0, H = 0, W = 0;
    switch (order_) {
      case StorageOrder::NCHW:
        N = X.dim32(0);
        C = X.dim32(1);
        H = X.dim32(2);
        W = X.dim32(3);
        break;
      case StorageOrder::NHWC:
        N = X.dim32(0);
        H = X.dim32(1);
        W = X.dim32(2);
        C = X.dim32(3);
        break;
      default:
        CAFFE_THROW("Unknown storage order: ", order_);
    }

    const int dkernel_h = dilation_h_ * (kernel_h_ - 1) + 1;
    const int dkernel_w = dilation_w_ * (kernel_w_ - 1) + 1;
    CAFFE_ENFORCE(H >= dkernel_h);
    CAFFE_ENFORCE(W >= dkernel_w);
    const int out_h = (H + 2 * pad_ - dkernel_h) / stride_h_ + 1;
    const int out_w = (W + 2 * pad_ - dkernel_w) / stride_w_ + 1;

    switch (order_) {
      case StorageOrder::NCHW: {
        auto* Y = Output(
            0,
            std::vector<int64_t>{N, C * kernel_h_ * kernel_w_, out_h, out_w},
            at::dtype<T>());

        const size_t dx = X.numel() / N;
        const size_t dy = Y->numel() / N;
        for (int n = 0; n < N; ++n) {
          const auto* xdata = X.template data<T>() + (n * dx);
          auto* ydata = Y->template mutable_data<T>() + (n * dy);
          math::Im2Col<T, Context, StorageOrder::NCHW>(
              C,
              H,
              W,
              kernel_h_,
              kernel_w_,
              dilation_h_,
              dilation_w_,
              pad_,
              pad_,
              pad_,
              pad_,
              stride_h_,
              stride_w_,
              xdata,
              ydata,
              &context_);
        }
      }; break;
      case StorageOrder::NHWC: {
        auto* Y = Output(
            0,
            std::vector<int64_t>{N, out_h, out_w, kernel_h_ * kernel_w_ * C},
            at::dtype<T>());

        const size_t dx = X.numel() / N;
        const size_t dy = Y->numel() / N;
        for (int n = 0; n < N; ++n) {
          const auto* xdata = X.template data<T>() + (n * dx);
          auto* ydata = Y->template mutable_data<T>() + (n * dy);
          math::Im2Col<T, Context, StorageOrder::NHWC>(
              C,
              H,
              W,
              kernel_h_,
              kernel_w_,
              dilation_h_,
              dilation_w_,
              pad_,
              pad_,
              pad_,
              pad_,
              stride_h_,
              stride_w_,
              xdata,
              ydata,
              &context_);
        }
      }; break;
      default:
        CAFFE_THROW("Unknown storage order: ", order_);
    }

    return true;
  }

 private:
  int pad_;
  int kernel_h_;
  int kernel_w_;
  int dilation_h_;
  int dilation_w_;
  int stride_h_;
  int stride_w_;
  StorageOrder order_;
};

template <typename T, class Context>
class Col2ImOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  template <class... Args>
  explicit Col2ImOp(Args&&... args)
      : Operator<Context>(std::forward<Args>(args)...),
        pad_(this->template GetSingleArgument<int>("pad", 0)),
        kernel_h_(this->template GetSingleArgument<int>(
            "kernel_h",
            this->template GetSingleArgument<int>("kernel", 0))),
        kernel_w_(this->template GetSingleArgument<int>(
            "kernel_w",
            this->template GetSingleArgument<int>("kernel", 0))),
        dilation_h_(this->template GetSingleArgument<int>(
            "dilation_h",
            this->template GetSingleArgument<int>("dilation", 1))),
        dilation_w_(this->template GetSingleArgument<int>(
            "dilation_w",
            this->template GetSingleArgument<int>("dilation", 1))),
        stride_h_(this->template GetSingleArgument<int>(
            "stride_h",
            this->template GetSingleArgument<int>("stride", 1))),
        stride_w_(this->template GetSingleArgument<int>(
            "stride_w",
            this->template GetSingleArgument<int>("stride", 1))),
        order_(StringToStorageOrder(
            this->template GetSingleArgument<string>("order", "NCHW"))) {
    CAFFE_ENFORCE(kernel_h_ > 0);
    CAFFE_ENFORCE(kernel_w_ > 0);
    CAFFE_ENFORCE(dilation_h_ > 0);
    CAFFE_ENFORCE(dilation_w_ > 0);
    CAFFE_ENFORCE(stride_h_ > 0);
    CAFFE_ENFORCE(stride_w_ > 0);
    CAFFE_ENFORCE(pad_ >= 0);
  }

  bool RunOnDevice() override {
    auto& X = Input(0);
    auto& Z = Input(1);

    auto* Y = Output(0, Z.sizes(), at::dtype<T>());
    CAFFE_ENFORCE(4 == Y->dim());

    int N = 0, C = 0, H = 0, W = 0;
    switch (order_) {
      case StorageOrder::NCHW:
        N = Y->dim32(0);
        C = Y->dim32(1);
        H = Y->dim32(2);
        W = Y->dim32(3);
        break;
      case StorageOrder::NHWC:
        N = Y->dim32(0);
        H = Y->dim32(1);
        W = Y->dim32(2);
        C = Y->dim32(3);
        break;
      default:
        CAFFE_THROW("Unknown storage order: ", order_);
    }

    const int dkernel_h = dilation_h_ * (kernel_h_ - 1) + 1;
    const int dkernel_w = dilation_w_ * (kernel_w_ - 1) + 1;
    CAFFE_ENFORCE(H >= dkernel_h);
    CAFFE_ENFORCE(W >= dkernel_w);
    const int out_h = (H + 2 * pad_ - dkernel_h) / stride_h_ + 1;
    const int out_w = (W + 2 * pad_ - dkernel_w) / stride_w_ + 1;
    CAFFE_ENFORCE(X.numel() == N * kernel_h_ * kernel_w_ * C * out_h * out_w);

    const size_t dx = X.numel() / N;
    const size_t dy = Y->numel() / N;

    // could template-specialize this, but it's test code...
    switch (order_) {
      case StorageOrder::NCHW: {
        for (int n = 0; n < N; ++n) {
          const auto* xdata = X.template data<T>() + (n * dx);
          auto* ydata = Y->template mutable_data<T>() + (n * dy);
          math::Col2Im<T, Context, StorageOrder::NCHW>(
              C,
              H,
              W,
              kernel_h_,
              kernel_w_,
              dilation_h_,
              dilation_w_,
              pad_,
              pad_,
              pad_,
              pad_,
              stride_h_,
              stride_w_,
              xdata,
              ydata,
              &context_);
        }
      }; break;
      case StorageOrder::NHWC: {
        for (int n = 0; n < N; ++n) {
          const auto* xdata = X.template data<T>() + (n * dx);
          auto* ydata = Y->template mutable_data<T>() + (n * dy);
          math::Col2Im<T, Context, StorageOrder::NHWC>(
              C,
              H,
              W,
              kernel_h_,
              kernel_w_,
              dilation_h_,
              dilation_w_,
              pad_,
              pad_,
              pad_,
              pad_,
              stride_h_,
              stride_w_,
              xdata,
              ydata,
              &context_);
        }
      }; break;
      default:
        CAFFE_THROW("Unknown storage order: ", order_);
    }

    return true;
  }

 private:
  int pad_;
  int kernel_h_;
  int kernel_w_;
  int dilation_h_;
  int dilation_w_;
  int stride_h_;
  int stride_w_;
  StorageOrder order_;
};

} // namespace caffe2

#endif // CAFFE2_OPERATORS_IM2COL_OP_H_