Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / caffe2 / operators / filler_op.h

#ifndef CAFFE2_OPERATORS_FILLER_OP_H_
#define CAFFE2_OPERATORS_FILLER_OP_H_

#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"

namespace caffe2 {

// FillerOp takes in either zero or one input.
//
// If the number of input is 1, the shape will be identical to that of the input
// at run time with optional additional dimensions appended at the end as
// specified by "extra_shape" argument. In that case the "shape" parameter
// should not be set.
//
// If the number of inputs is 0, the full shape must be provided via "shape"
// argument
template <class Context>
class FillerOp : public Operator<Context> {
 public:
  template <class... Args>
  explicit FillerOp(Args&&... args)
      : Operator<Context>(std::forward<Args>(args)...),
        shape_(this->template GetRepeatedArgument<int64_t>("shape")),
        extra_shape_(ToVectorint64_t(
            this->template GetRepeatedArgument<int>("extra_shape"))),
        input_as_shape_(
            this->template GetSingleArgument<bool>("input_as_shape", false)) {
    if (InputSize()) {
      if (shape_.size() != 0) {
        CAFFE_THROW(
            "Cannot set the shape argument and pass in an input at "
            "the same time");
      }
    } else {
      if (!extra_shape_.empty()) {
        CAFFE_THROW("Cannot set extra_shape when there is no input");
      }
      if (input_as_shape_) {
        CAFFE_THROW("An input must be given if input_as_shape is true");
      }
      if (shape_.size() == 0 &&
          this->template HasSingleArgumentOfType<int>("shape")) {
        CAFFE_THROW("Fill 'shape' argument was a scalar, list expected");
      }
    }
  }

  virtual ~FillerOp() {}
  USE_OPERATOR_CONTEXT_FUNCTIONS;

  bool RunOnDevice() override {
    auto* output = Operator<Context>::Output(0);
    if (InputSize()) {
      auto shape = vector<int64_t>{};
      if (input_as_shape_) {
        if (this->InputIsTensorType(0, CPU)) {
          // originally, shape input must be in CPU context
          auto& input = this->template Input<Tensor>(0, CPU);
          CAFFE_ENFORCE_EQ(
              input.dim(),
              1,
              "When input_as_shape is true, the input must be a 1D tensor of "
              "data type int64_t");
          CAFFE_ENFORCE(input.numel() > 0);
          auto* shape_data = input.template data<int64_t>();
          shape.insert(shape.end(), shape_data, shape_data + input.dim32(0));
        } else {
          // in ONNX case, we allow shape to be in CUDA context
          auto& input = Input(0);
          CAFFE_ENFORCE_EQ(
              input.dim(),
              1,
              "When input_as_shape is true, the input must be a 1D tensor of "
              "data type int64_t");
          CAFFE_ENFORCE(input.numel() > 0);
          auto* shape_data = input.template data<int64_t>();
          std::unique_ptr<int64_t[]> shape_data_copy =
              std::make_unique<int64_t[]>(input.dim32(0));
          context_.template CopyToCPU<int64_t>(
              input.dim32(0), shape_data, shape_data_copy.get());
          shape.insert(
              shape.end(),
              shape_data_copy.get(),
              shape_data_copy.get() + input.dim32(0));
        }
      } else {
        auto& input = Input(0);
        shape.insert(shape.end(), input.sizes().begin(), input.sizes().end());
      }
      shape.insert(shape.end(), extra_shape_.begin(), extra_shape_.end());
      output->Resize(shape);
      shape_ = shape;
    } else {
      output->Resize(shape_);
    }
    return Fill(output);
  }

  virtual bool Fill(Tensor* output) = 0;

 protected:
  vector<int64_t> shape_;
  vector<int64_t> extra_shape_;
  bool input_as_shape_;
};

template <typename T, class Context>
class UniformFillOp final : public FillerOp<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  template <class... Args>
  explicit UniformFillOp(Args&&... args)
      : FillerOp<Context>(std::forward<Args>(args)...),
        min_(this->template GetSingleArgument<T>("min", 0)),
        max_(this->template GetSingleArgument<T>("max", 1)) {
    if (InputSize() == 3) {
      CAFFE_ENFORCE(
          !this->template HasSingleArgumentOfType<T>("min"),
          "Cannot set both min arg and min input blob");
      CAFFE_ENFORCE(
          !this->template HasSingleArgumentOfType<T>("max"),
          "Cannot set both max arg and max input blob");
    } else {
      CAFFE_ENFORCE_LT(
          min_, max_, "Max value should be bigger than min value.");
    }
  }

  bool Fill(Tensor* output) override {
    T min = min_;
    T max = max_;
    if (InputSize() == 3) {
      CAFFE_ENFORCE_EQ(1, Input(1).numel(), "min blob must be scalar");
      CAFFE_ENFORCE_EQ(1, Input(2).numel(), "max blob must be scalar");
      min = *Input(1).template data<T>();
      max = *Input(2).template data<T>();
      if (min > max) {
        auto shape = output->sizes().vec();
        shape[0] = 0;
        output->Resize(shape);
        output->template mutable_data<T>();
        return true;
      }
    }
    math::RandUniform<T, Context>(
        output->numel(),
        min,
        max,
        output->template mutable_data<T>(),
        &context_);
    return true;
  }

 private:
  T min_;
  T max_;
};

template <class Context>
class UniqueUniformFillOp final : public FillerOp<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  template <class... Args>
  explicit UniqueUniformFillOp(Args&&... args)
      : FillerOp<Context>(std::forward<Args>(args)...) {
    TensorProto_DataType dtype =
        static_cast<TensorProto_DataType>(this->template GetSingleArgument<int>(
            "dtype", TensorProto_DataType_INT32));

    switch (dtype) {
      case TensorProto_DataType_INT32:
        CheckRange<int>();
        body_ = &UniqueUniformFillOp::FillWithType<int>;
        break;
      case TensorProto_DataType_INT64:
        CheckRange<int64_t>();
        body_ = &UniqueUniformFillOp::FillWithType<int64_t>;
        break;
      case TensorProto_DataType_UNDEFINED:
        CAFFE_THROW(
            "UniqueUniformFill op cannot have undefined 'dtype' argument");
      // break;
      default:
        CAFFE_THROW("Unexpected 'dtype' argument value: ", dtype);
    }
  }

  bool Fill(Tensor* output) override {
    return (this->*body_)(output);
  }

 private:
  template <typename T>
  void CheckRange() {
    CAFFE_ENFORCE(this->template HasSingleArgumentOfType<T>("min"));
    CAFFE_ENFORCE(this->template HasSingleArgumentOfType<T>("max"));
    CAFFE_ENFORCE_LT(
        this->template GetSingleArgument<T>("min", 0),
        this->template GetSingleArgument<T>("max", 0),
        "Max value should be bigger than min value.");
  }

  template <typename T>
  bool FillWithType(Tensor* output) {
    T min = this->template GetSingleArgument<T>("min", 0);
    T max = this->template GetSingleArgument<T>("max", 0);

    const T* avoid_data = nullptr;
    size_t avoid_size = 0;
    if (InputSize() >= 2) {
      auto& avoid = Input(1);
      avoid_data = avoid.template data<T>();
      avoid_size = avoid.numel();
    }
    math::RandUniformUnique<T, Context>(
        output->numel(),
        min,
        max,
        output->template mutable_data<T>(),
        avoid_size,
        avoid_data,
        &context_);
    return true;
  }

  bool (UniqueUniformFillOp::*body_)(Tensor* output);
};

template <class Context>
class ConstantFillOp final : public FillerOp<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  template <class... Args>
  explicit ConstantFillOp(Args&&... args)
      : FillerOp<Context>(std::forward<Args>(args)...) {
    TensorProto_DataType dtype =
        static_cast<TensorProto_DataType>(this->template GetSingleArgument<int>(
            "dtype", TensorProto_DataType_FLOAT));

    if (!OperatorBase::HasArgument("dtype") &&
        OperatorBase::HasArgument("value")) {
      // If 'dtype' is not provided, infer type based on the type of 'value'
      // Currently, single argument contains either float, int64 or bytes
      if (this->template HasSingleArgumentOfType<float>("value")) {
        dtype = TensorProto_DataType_FLOAT;
      } else if (this->template HasSingleArgumentOfType<int64_t>("value")) {
        dtype = TensorProto_DataType_INT64;
      } else {
        CAFFE_THROW("Argument 'value' is of unexpected type");
      }
      VLOG(1) << "Argument 'dtype' is not provided. Assume the data type is "
              << "the same as that of argument 'value': " << dtype;
    }

    switch (dtype) {
      case TensorProto_DataType_FLOAT:
        body_ = &ConstantFillOp::FillWithType<float>;
        break;
      case TensorProto_DataType_DOUBLE:
        body_ = &ConstantFillOp::FillWithType<double>;
        break;
      case TensorProto_DataType_BOOL:
        body_ = &ConstantFillOp::FillWithType<bool>;
        break;
      case TensorProto_DataType_INT8:
        body_ = &ConstantFillOp::FillWithType<int8_t>;
        break;
      case TensorProto_DataType_INT16:
        body_ = &ConstantFillOp::FillWithType<int16_t>;
        break;
      case TensorProto_DataType_INT32:
        body_ = &ConstantFillOp::FillWithType<int>;
        break;
      case TensorProto_DataType_INT64:
        body_ = &ConstantFillOp::FillWithType<int64_t>;
        break;
      case TensorProto_DataType_UINT8:
        body_ = &ConstantFillOp::FillWithType<uint8_t>;
        break;
      case TensorProto_DataType_UINT16:
        body_ = &ConstantFillOp::FillWithType<uint16_t>;
        break;
      case TensorProto_DataType_STRING:
        body_ = &ConstantFillOp::FillWithString;
        break;
      case TensorProto_DataType_UNDEFINED:
        CAFFE_THROW("ConstantFill op cannot have undefined 'dtype' argument");
      // break;
      default:
        CAFFE_THROW("Unexpected 'dtype' argument value: ", dtype);
    }
  }

  bool Fill(Tensor* output) override {
    return (this->*body_)(output);
  }

  template <typename T>
  bool FillWithType(Tensor* output) {
    T value = this->template GetSingleArgument<T>("value", 0);
    if (InputSize() == 2) {
      auto& value_vec = Input(1);
      if (value_vec) {
        CAFFE_ENFORCE_EQ(
            value_vec.size(), 1, "value vector must have 1 element");
        value = value_vec.template data<T>()[0];
      }
    }

    auto* data = output->template mutable_data<T>();
    if (output->numel()) {
      math::Set<T, Context>(output->numel(), value, data, &context_);
    }
    return true;
  }

  bool FillWithString(Tensor* output) {
    CAFFE_ENFORCE_LT(
        InputSize(), 2, "constant fill string from tensor is not supported");
    auto value = this->template GetSingleArgument<std::string>("value", "");
    auto* data = output->template mutable_data<std::string>();
    for (int i = 0; i < output->numel(); ++i) {
      data[i] = value;
    }
    return true;
  }

 private:
  bool (ConstantFillOp::*body_)(Tensor* output);
};

template <class Context>
class DiagonalFillOp final : public FillerOp<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  template <class... Args>
  explicit DiagonalFillOp(Args&&... args)
      : FillerOp<Context>(std::forward<Args>(args)...) {
    TensorProto_DataType dtype =
        static_cast<TensorProto_DataType>(this->template GetSingleArgument<int>(
            "dtype", TensorProto_DataType_FLOAT));
Loading ...