#ifndef CAFFE2_OPERATORS_FILLER_OP_H_
#define CAFFE2_OPERATORS_FILLER_OP_H_
#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
// FillerOp takes in either zero or one input.
//
// If the number of input is 1, the shape will be identical to that of the input
// at run time with optional additional dimensions appended at the end as
// specified by "extra_shape" argument. In that case the "shape" parameter
// should not be set.
//
// If the number of inputs is 0, the full shape must be provided via "shape"
// argument
template <class Context>
class FillerOp : public Operator<Context> {
public:
template <class... Args>
explicit FillerOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
shape_(this->template GetRepeatedArgument<int64_t>("shape")),
extra_shape_(ToVectorint64_t(
this->template GetRepeatedArgument<int>("extra_shape"))),
input_as_shape_(
this->template GetSingleArgument<bool>("input_as_shape", false)) {
if (InputSize()) {
if (shape_.size() != 0) {
CAFFE_THROW(
"Cannot set the shape argument and pass in an input at "
"the same time");
}
} else {
if (!extra_shape_.empty()) {
CAFFE_THROW("Cannot set extra_shape when there is no input");
}
if (input_as_shape_) {
CAFFE_THROW("An input must be given if input_as_shape is true");
}
if (shape_.size() == 0 &&
this->template HasSingleArgumentOfType<int>("shape")) {
CAFFE_THROW("Fill 'shape' argument was a scalar, list expected");
}
}
}
virtual ~FillerOp() {}
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override {
auto* output = Operator<Context>::Output(0);
if (InputSize()) {
auto shape = vector<int64_t>{};
if (input_as_shape_) {
if (this->InputIsTensorType(0, CPU)) {
// originally, shape input must be in CPU context
auto& input = this->template Input<Tensor>(0, CPU);
CAFFE_ENFORCE_EQ(
input.dim(),
1,
"When input_as_shape is true, the input must be a 1D tensor of "
"data type int64_t");
CAFFE_ENFORCE(input.numel() > 0);
auto* shape_data = input.template data<int64_t>();
shape.insert(shape.end(), shape_data, shape_data + input.dim32(0));
} else {
// in ONNX case, we allow shape to be in CUDA context
auto& input = Input(0);
CAFFE_ENFORCE_EQ(
input.dim(),
1,
"When input_as_shape is true, the input must be a 1D tensor of "
"data type int64_t");
CAFFE_ENFORCE(input.numel() > 0);
auto* shape_data = input.template data<int64_t>();
std::unique_ptr<int64_t[]> shape_data_copy =
std::make_unique<int64_t[]>(input.dim32(0));
context_.template CopyToCPU<int64_t>(
input.dim32(0), shape_data, shape_data_copy.get());
shape.insert(
shape.end(),
shape_data_copy.get(),
shape_data_copy.get() + input.dim32(0));
}
} else {
auto& input = Input(0);
shape.insert(shape.end(), input.sizes().begin(), input.sizes().end());
}
shape.insert(shape.end(), extra_shape_.begin(), extra_shape_.end());
output->Resize(shape);
shape_ = shape;
} else {
output->Resize(shape_);
}
return Fill(output);
}
virtual bool Fill(Tensor* output) = 0;
protected:
vector<int64_t> shape_;
vector<int64_t> extra_shape_;
bool input_as_shape_;
};
template <typename T, class Context>
class UniformFillOp final : public FillerOp<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit UniformFillOp(Args&&... args)
: FillerOp<Context>(std::forward<Args>(args)...),
min_(this->template GetSingleArgument<T>("min", 0)),
max_(this->template GetSingleArgument<T>("max", 1)) {
if (InputSize() == 3) {
CAFFE_ENFORCE(
!this->template HasSingleArgumentOfType<T>("min"),
"Cannot set both min arg and min input blob");
CAFFE_ENFORCE(
!this->template HasSingleArgumentOfType<T>("max"),
"Cannot set both max arg and max input blob");
} else {
CAFFE_ENFORCE_LT(
min_, max_, "Max value should be bigger than min value.");
}
}
bool Fill(Tensor* output) override {
T min = min_;
T max = max_;
if (InputSize() == 3) {
CAFFE_ENFORCE_EQ(1, Input(1).numel(), "min blob must be scalar");
CAFFE_ENFORCE_EQ(1, Input(2).numel(), "max blob must be scalar");
min = *Input(1).template data<T>();
max = *Input(2).template data<T>();
if (min > max) {
auto shape = output->sizes().vec();
shape[0] = 0;
output->Resize(shape);
output->template mutable_data<T>();
return true;
}
}
math::RandUniform<T, Context>(
output->numel(),
min,
max,
output->template mutable_data<T>(),
&context_);
return true;
}
private:
T min_;
T max_;
};
template <class Context>
class UniqueUniformFillOp final : public FillerOp<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit UniqueUniformFillOp(Args&&... args)
: FillerOp<Context>(std::forward<Args>(args)...) {
TensorProto_DataType dtype =
static_cast<TensorProto_DataType>(this->template GetSingleArgument<int>(
"dtype", TensorProto_DataType_INT32));
switch (dtype) {
case TensorProto_DataType_INT32:
CheckRange<int>();
body_ = &UniqueUniformFillOp::FillWithType<int>;
break;
case TensorProto_DataType_INT64:
CheckRange<int64_t>();
body_ = &UniqueUniformFillOp::FillWithType<int64_t>;
break;
case TensorProto_DataType_UNDEFINED:
CAFFE_THROW(
"UniqueUniformFill op cannot have undefined 'dtype' argument");
// break;
default:
CAFFE_THROW("Unexpected 'dtype' argument value: ", dtype);
}
}
bool Fill(Tensor* output) override {
return (this->*body_)(output);
}
private:
template <typename T>
void CheckRange() {
CAFFE_ENFORCE(this->template HasSingleArgumentOfType<T>("min"));
CAFFE_ENFORCE(this->template HasSingleArgumentOfType<T>("max"));
CAFFE_ENFORCE_LT(
this->template GetSingleArgument<T>("min", 0),
this->template GetSingleArgument<T>("max", 0),
"Max value should be bigger than min value.");
}
template <typename T>
bool FillWithType(Tensor* output) {
T min = this->template GetSingleArgument<T>("min", 0);
T max = this->template GetSingleArgument<T>("max", 0);
const T* avoid_data = nullptr;
size_t avoid_size = 0;
if (InputSize() >= 2) {
auto& avoid = Input(1);
avoid_data = avoid.template data<T>();
avoid_size = avoid.numel();
}
math::RandUniformUnique<T, Context>(
output->numel(),
min,
max,
output->template mutable_data<T>(),
avoid_size,
avoid_data,
&context_);
return true;
}
bool (UniqueUniformFillOp::*body_)(Tensor* output);
};
template <class Context>
class ConstantFillOp final : public FillerOp<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit ConstantFillOp(Args&&... args)
: FillerOp<Context>(std::forward<Args>(args)...) {
TensorProto_DataType dtype =
static_cast<TensorProto_DataType>(this->template GetSingleArgument<int>(
"dtype", TensorProto_DataType_FLOAT));
if (!OperatorBase::HasArgument("dtype") &&
OperatorBase::HasArgument("value")) {
// If 'dtype' is not provided, infer type based on the type of 'value'
// Currently, single argument contains either float, int64 or bytes
if (this->template HasSingleArgumentOfType<float>("value")) {
dtype = TensorProto_DataType_FLOAT;
} else if (this->template HasSingleArgumentOfType<int64_t>("value")) {
dtype = TensorProto_DataType_INT64;
} else {
CAFFE_THROW("Argument 'value' is of unexpected type");
}
VLOG(1) << "Argument 'dtype' is not provided. Assume the data type is "
<< "the same as that of argument 'value': " << dtype;
}
switch (dtype) {
case TensorProto_DataType_FLOAT:
body_ = &ConstantFillOp::FillWithType<float>;
break;
case TensorProto_DataType_DOUBLE:
body_ = &ConstantFillOp::FillWithType<double>;
break;
case TensorProto_DataType_BOOL:
body_ = &ConstantFillOp::FillWithType<bool>;
break;
case TensorProto_DataType_INT8:
body_ = &ConstantFillOp::FillWithType<int8_t>;
break;
case TensorProto_DataType_INT16:
body_ = &ConstantFillOp::FillWithType<int16_t>;
break;
case TensorProto_DataType_INT32:
body_ = &ConstantFillOp::FillWithType<int>;
break;
case TensorProto_DataType_INT64:
body_ = &ConstantFillOp::FillWithType<int64_t>;
break;
case TensorProto_DataType_UINT8:
body_ = &ConstantFillOp::FillWithType<uint8_t>;
break;
case TensorProto_DataType_UINT16:
body_ = &ConstantFillOp::FillWithType<uint16_t>;
break;
case TensorProto_DataType_STRING:
body_ = &ConstantFillOp::FillWithString;
break;
case TensorProto_DataType_UNDEFINED:
CAFFE_THROW("ConstantFill op cannot have undefined 'dtype' argument");
// break;
default:
CAFFE_THROW("Unexpected 'dtype' argument value: ", dtype);
}
}
bool Fill(Tensor* output) override {
return (this->*body_)(output);
}
template <typename T>
bool FillWithType(Tensor* output) {
T value = this->template GetSingleArgument<T>("value", 0);
if (InputSize() == 2) {
auto& value_vec = Input(1);
if (value_vec) {
CAFFE_ENFORCE_EQ(
value_vec.size(), 1, "value vector must have 1 element");
value = value_vec.template data<T>()[0];
}
}
auto* data = output->template mutable_data<T>();
if (output->numel()) {
math::Set<T, Context>(output->numel(), value, data, &context_);
}
return true;
}
bool FillWithString(Tensor* output) {
CAFFE_ENFORCE_LT(
InputSize(), 2, "constant fill string from tensor is not supported");
auto value = this->template GetSingleArgument<std::string>("value", "");
auto* data = output->template mutable_data<std::string>();
for (int i = 0; i < output->numel(); ++i) {
data[i] = value;
}
return true;
}
private:
bool (ConstantFillOp::*body_)(Tensor* output);
};
template <class Context>
class DiagonalFillOp final : public FillerOp<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit DiagonalFillOp(Args&&... args)
: FillerOp<Context>(std::forward<Args>(args)...) {
TensorProto_DataType dtype =
static_cast<TensorProto_DataType>(this->template GetSingleArgument<int>(
"dtype", TensorProto_DataType_FLOAT));
Loading ...