#ifndef CAFFE2_OPERATORS_FLATTEN_OP_H_
#define CAFFE2_OPERATORS_FLATTEN_OP_H_
#include "caffe2/core/operator.h"
namespace caffe2 {
template <class Context>
class FlattenOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit FlattenOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
axis_(this->template GetSingleArgument<int>("axis", 1)) {}
bool RunOnDevice() override {
auto& input = Input(0);
auto* output = Output(0);
CAFFE_ENFORCE_GE(
input.dim(), axis_, "The rank of the tensor must be >= axis.");
output->Resize(input.size_to_dim(axis_), input.size_from_dim(axis_));
context_.CopyItemsSameDevice(
input.dtype(),
input.numel(),
input.raw_data(),
output->raw_mutable_data(input.dtype()));
return true;
}
private:
int axis_;
};
inline std::vector<TensorShape> TensorInferenceForFlatten(
const OperatorDef& def,
const std::vector<TensorShape>& in) {
ArgumentHelper helper(def);
const int axis = helper.GetSingleArgument<int>("axis", 1);
std::vector<TensorShape> out(1);
int64_t outer = 1;
int64_t inner = 1;
std::size_t index = 0;
for (auto d : in[0].dims()) {
if (index < axis) {
outer *= d;
} else {
inner *= d;
}
++index;
}
out[0].set_data_type(in[0].data_type());
out[0].add_dims(outer);
out[0].add_dims(inner);
return out;
}
} // namespace caffe2
#endif // CAFFE2_OPERATORS_FLATTEN_OP_H_