#ifndef CAFFE2_OPERATORS_EXPAND_SQUEEZE_DIMS_OP_H_
#define CAFFE2_OPERATORS_EXPAND_SQUEEZE_DIMS_OP_H_
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
namespace caffe2 {
template <class Context>
class ExpandDimsOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit ExpandDimsOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
dims_(this->template GetRepeatedArgument<int>("dims")) {
auto originalSize = dims_.size();
CAFFE_ENFORCE(originalSize > 0, "Parameter `dims` must be provided.");
std::sort(dims_.begin(), dims_.end());
dims_.erase(std::unique(dims_.begin(), dims_.end()), dims_.end());
if (dims_.size() < originalSize) {
LOG(WARNING) << "Parameter `dims` has repeated dimensions.";
}
CAFFE_ENFORCE(dims_.front() >= 0, "Dimension ids must be non-negative.");
}
bool RunOnDevice() override {
auto& input = Input(0);
auto* output = Output(0);
output->CopyFrom(input, true /*async*/);
if (dims_.empty()) {
return true;
}
auto newDims = input.sizes().vec();
CAFFE_ENFORCE_GE(
input.sizes().size() + dims_.size(),
dims_.back() + 1,
"Input needs at least ",
(1 + dims_.back() - dims_.size()),
" dimensions given `dims`.");
for (const auto dim : dims_) {
newDims.insert(newDims.begin() + dim, 1);
}
output->Reshape(newDims);
return true;
}
private:
vector<int> dims_;
};
template <class Context>
class SqueezeOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit SqueezeOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
dims_(this->template GetRepeatedArgument<int>("dims")) {
auto originalSize = dims_.size();
CAFFE_ENFORCE(originalSize > 0, "Parameter `dims` must be provided.");
std::sort(dims_.begin(), dims_.end());
dims_.erase(std::unique(dims_.begin(), dims_.end()), dims_.end());
if (dims_.size() < originalSize) {
LOG(WARNING) << "Parameter `dims` has repeated dimensions.";
}
CAFFE_ENFORCE(dims_.front() >= 0, "Dimension ids must be non-negative.");
}
bool RunOnDevice() override {
auto& input = Input(0);
auto* output = Output(0);
output->CopyFrom(input, true /*async*/);
CAFFE_ENFORCE_GT(
input.dim(),
dims_.back(),
"Input needs at least ",
(dims_.back() + 1),
" dimensions.");
std::vector<int> newDims = ComputeDims(input.sizes(), dims_);
output->Reshape(newDims);
return true;
}
static std::vector<int> ComputeDims(
at::IntArrayRef inputDims,
std::vector<int> dims) {
size_t j = 0;
std::vector<int> newDims;
for (size_t i = 0; i < inputDims.size(); ++i) {
if (j < dims.size() && dims[j] == i) {
CAFFE_ENFORCE_EQ(
inputDims[i],
1,
"Dimension ",
i,
" of input must be 1",
" instead of ",
inputDims[i],
".");
++j;
continue;
}
newDims.push_back(inputDims.at(i));
}
return newDims;
}
private:
vector<int> dims_;
public:
C10_DISABLE_COPY_AND_ASSIGN(SqueezeOp);
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_EXPAND_SQUEEZE_DIMS_OP_H_