Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / caffe2 / operators / conv_transpose_unpool_op_base.h

#ifndef CAFFE2_OPERATORS_CONV_TRANSPOSE_UNPOOL_OP_BASE_H_
#define CAFFE2_OPERATORS_CONV_TRANSPOSE_UNPOOL_OP_BASE_H_

#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/operators/conv_op_shared.h"
#include "caffe2/operators/conv_pool_op_base.h"
#include "caffe2/proto/caffe2_legacy.pb.h"
#include "caffe2/utils/math.h"

C10_DECLARE_bool(caffe2_force_shared_col_buffer);

namespace caffe2 {

template <class Context>
class ConvTransposeUnpoolBase : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  explicit ConvTransposeUnpoolBase(
      const OperatorDef& operator_def,
      Workspace* ws)
      : Operator<Context>(operator_def, ws),
        legacy_pad_(
            static_cast<LegacyPadding>(this->template GetSingleArgument<int>(
                "legacy_pad",
                LegacyPadding::NOTSET))),
        kernel_(this->template GetRepeatedArgument<int>("kernels")),
        stride_(this->template GetRepeatedArgument<int>("strides")),
        pads_(this->template GetRepeatedArgument<int>("pads")),
        adj_(this->template GetRepeatedArgument<int>("adjs")),
        group_(this->template GetSingleArgument<int>("group", 1)),
        order_(StringToStorageOrder(
            this->template GetSingleArgument<string>("order", "NCHW"))),
        shared_buffer_(
            this->template GetSingleArgument<int>("shared_buffer", 0)),
        ws_(ws) {
    // For the padding, they should either be the legacy padding strategy
    // (VALID or SAME), or an explicit, non-negative value.
    if (legacy_pad_ == LegacyPadding::VALID ||
        legacy_pad_ == LegacyPadding::SAME) {
      CAFFE_ENFORCE(
          !OperatorBase::HasArgument("pads"),
          "If you use legacy padding VALID or SAME, you should not specify "
          "any specific padding values.");
    }
    // Get old arguments values.
    if (OperatorBase::HasArgument("kernel")) {
      kernel_.resize(2, this->template GetSingleArgument<int>("kernel", 0));
    } else if (
        OperatorBase::HasArgument("kernel_h") &&
        OperatorBase::HasArgument("kernel_w")) {
      kernel_.push_back(this->template GetSingleArgument<int>("kernel_h", 0));
      kernel_.push_back(this->template GetSingleArgument<int>("kernel_w", 0));
    }

    if (OperatorBase::HasArgument("stride")) {
      stride_.resize(2, this->template GetSingleArgument<int>("stride", 0));
    } else if (
        OperatorBase::HasArgument("stride_h") &&
        OperatorBase::HasArgument("stride_w")) {
      stride_.push_back(this->template GetSingleArgument<int>("stride_h", 0));
      stride_.push_back(this->template GetSingleArgument<int>("stride_w", 0));
    }

    if (OperatorBase::HasArgument("adj")) {
      adj_.resize(2, this->template GetSingleArgument<int>("adj", 0));
    } else if (
        OperatorBase::HasArgument("adj_h") &&
        OperatorBase::HasArgument("adj_w")) {
      adj_.push_back(this->template GetSingleArgument<int>("adj_h", 0));
      adj_.push_back(this->template GetSingleArgument<int>("adj_w", 0));
    }

    if (OperatorBase::HasArgument("pad")) {
      CAFFE_ENFORCE(
          legacy_pad_ != LegacyPadding::VALID &&
              legacy_pad_ != LegacyPadding::SAME,
          "If you use legacy padding VALID or SAME, you should not specify "
          "any specific padding values.");
      pads_.resize(4, this->template GetSingleArgument<int>("pad", 0));
    } else if (
        OperatorBase::HasArgument("pad_t") &&
        OperatorBase::HasArgument("pad_l") &&
        OperatorBase::HasArgument("pad_b") &&
        OperatorBase::HasArgument("pad_r")) {
      CAFFE_ENFORCE(
          legacy_pad_ != LegacyPadding::VALID &&
              legacy_pad_ != LegacyPadding::SAME,
          "If you use legacy padding VALID or SAME, you should not specify "
          "any specific padding values.");
      pads_.push_back(this->template GetSingleArgument<int>("pad_t", 0));
      pads_.push_back(this->template GetSingleArgument<int>("pad_l", 0));
      pads_.push_back(this->template GetSingleArgument<int>("pad_b", 0));
      pads_.push_back(this->template GetSingleArgument<int>("pad_r", 0));
    }

    // Fill default values.
    if (kernel_.size() == 0) {
      kernel_.assign({0, 0});
    }

    if (stride_.size() == 0) {
      stride_.resize(kernel_.size(), 1);
    }

    if (pads_.size() == 0) {
      pads_.resize(kernel_.size() * 2, 0);
    }

    if (adj_.size() == 0) {
      adj_.resize(kernel_.size(), 0);
    }

    CAFFE_ENFORCE_EQ(stride_.size(), kernel_.size());
    CAFFE_ENFORCE_EQ(adj_.size(), kernel_.size());

    if (legacy_pad_ != LegacyPadding::VALID &&
        legacy_pad_ != LegacyPadding::SAME) {
      CAFFE_ENFORCE_EQ(pads_.size(), 2 * kernel_.size());
    }

    for (int dim = 0; dim < kernel_.size(); ++dim) {
      CAFFE_ENFORCE_GT(kernel_[dim], 0);
      CAFFE_ENFORCE_GT(stride_[dim], 0);
      CAFFE_ENFORCE_GE(adj_[dim], 0);
      CAFFE_ENFORCE_LE(adj_[dim], stride_[dim]);
    }

    // Create shared buffer mutex in the constructor
    // to avoid race-condition in DAGNet.
    if (FLAGS_caffe2_force_shared_col_buffer || shared_buffer_) {
      createSharedBuffer<Context>(ws_);
    }
  }
  // Gets the output size. The output channel is manually specified.
  std::vector<int64_t> GetOutputSize(const Tensor& input, int output_channel) {
    CAFFE_ENFORCE(4 == input.dim());
    CAFFE_ENFORCE_GT(input.size_from_dim(1), 0);
    int N = input.dim32(0);
    bool channel_first = false; // initialized to suppress compiler warning.
    int H = 0, W = 0; // initialized to suppress compiler warning.
    int M = 0;
    switch (order_) {
      case StorageOrder::NHWC:
        channel_first = false;
        H = input.dim32(1);
        W = input.dim32(2);
        M = input.dim32(3);
        break;
      case StorageOrder::NCHW:
        channel_first = true;
        M = input.dim32(1);
        H = input.dim32(2);
        W = input.dim32(3);
        break;
      default:
        LOG(FATAL) << "Unknown Storage order: " << order_;
    }
    int output_height = 0, output_width = 0;
    ComputeSizeAndPad(
        H,
        stride_[0],
        kernel_[0],
        adj_[0],
        &pads_[0],
        &pads_[2],
        &output_height);
    ComputeSizeAndPad(
        W,
        stride_[1],
        kernel_[1],
        adj_[1],
        &pads_[1],
        &pads_[3],
        &output_width);
    std::vector<int64_t> sizes;
    if (channel_first) {
      sizes = {N, output_channel, output_height, output_width};
    } else {
      sizes = {N, output_height, output_width, output_channel};
    }
    VLOG(2) << "In: N " << N << " M " << M << " H " << H << " W " << W;
    VLOG(2) << "Out: output_channel " << output_channel << " H "
            << output_height << " W " << output_width;
    return sizes;
  }

  bool RunOnDevice() override {
    switch (order_) {
      case StorageOrder::NHWC:
        return RunOnDeviceWithOrderNHWC();
      case StorageOrder::NCHW:
        return RunOnDeviceWithOrderNCHW();
      default:
        LOG(FATAL) << "Unknown storage order: " << order_;
    }
    // To suppress old compiler warnings
    return true;
  }

  virtual bool RunOnDeviceWithOrderNCHW() {
    CAFFE_THROW("Not implemented");
  }

  virtual bool RunOnDeviceWithOrderNHWC() {
    CAFFE_THROW("Not implemented");
  }

  virtual ~ConvTransposeUnpoolBase() {}

 protected:
  // Accessors for 2D conv params.

  inline int pad_t() const {
    return pads_[0];
  }

  inline int pad_l() const {
    return pads_[1];
  }

  inline int pad_b() const {
    return pads_[2];
  }

  inline int pad_r() const {
    return pads_[3];
  }

  inline int kernel_h() const {
    return kernel_[0];
  }

  inline int kernel_w() const {
    return kernel_[1];
  }

  inline int stride_h() const {
    return stride_[0];
  }

  inline int stride_w() const {
    return stride_[1];
  }

  inline int adj_h() const {
    return adj_[0];
  }

  inline int adj_w() const {
    return adj_[1];
  }

  inline void ComputeSizeAndPad(
      const int in_size,
      const int stride,
      const int kernel,
      const int adj,
      int* pad_head,
      int* pad_tail,
      int* out_size) {
    switch (legacy_pad_) {
      case LegacyPadding::NOTSET:
        CAFFE_ENFORCE(*pad_head >= 0);
        CAFFE_ENFORCE(*pad_tail >= 0);
        *out_size =
            (in_size - 1) * stride + kernel + adj - *pad_head - *pad_tail;
        break;
      // We handle cases of LegacyPadding::VALID and LegacyPadding::SAME
      // the same way
      case LegacyPadding::VALID:
      case LegacyPadding::SAME:
        *pad_head = 0;
        *pad_tail = 0;
        *out_size = (in_size - 1) * stride + kernel + adj;
        break;
      case LegacyPadding::CAFFE_LEGACY_POOLING:
        LOG(FATAL) << "CAFFE_LEGACY_POOLING is no longer supported.";
        break;
    }
  }

  LegacyPadding legacy_pad_;
  int pad_;

  std::vector<int> kernel_;
  std::vector<int> stride_;
  std::vector<int> pads_;
  std::vector<int> adj_;
  int group_;
  StorageOrder order_;
  bool shared_buffer_;
  Workspace* ws_;
};

#define USE_CONV_TRANSPOSE_UNPOOL_BASE_FUNCTIONS(Context) \
  USE_OPERATOR_FUNCTIONS(Context);                        \
  using ConvTransposeUnpoolBase<Context>::kernel_;        \
  using ConvTransposeUnpoolBase<Context>::kernel_h;       \
  using ConvTransposeUnpoolBase<Context>::kernel_w;       \
  using ConvTransposeUnpoolBase<Context>::stride_;        \
  using ConvTransposeUnpoolBase<Context>::stride_h;       \
  using ConvTransposeUnpoolBase<Context>::stride_w;       \
  using ConvTransposeUnpoolBase<Context>::pads_;          \
  using ConvTransposeUnpoolBase<Context>::pad_t;          \
  using ConvTransposeUnpoolBase<Context>::pad_l;          \
  using ConvTransposeUnpoolBase<Context>::pad_b;          \
  using ConvTransposeUnpoolBase<Context>::pad_r;          \
  using ConvTransposeUnpoolBase<Context>::adj_;           \
  using ConvTransposeUnpoolBase<Context>::group_;         \
  using ConvTransposeUnpoolBase<Context>::order_;         \
  using ConvTransposeUnpoolBase<Context>::shared_buffer_; \
  using ConvTransposeUnpoolBase<Context>::ws_

} // namespace caffe2

#endif // CAFFE2_OPERATORS_CONV_TRANSPOSE_UNPOOL_OP_BASE_H_