#ifndef CAFFE2_OPERATOR_GLU_OP_H_
#define CAFFE2_OPERATOR_GLU_OP_H_
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
namespace caffe2 {
template <typename T, class Context>
class GluOp final : public Operator<Context> {
public:
template <class... Args>
explicit GluOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
dim_(this->template GetSingleArgument<int>("dim", -1)) {}
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() {
auto& X = Input(0);
vector<int64_t> Yshape;
Yshape.insert(Yshape.end(), X.sizes().begin(), X.sizes().end());
const int split_index = dim_ == -1 ? Yshape.size() - 1 : dim_;
CAFFE_ENFORCE(
Yshape[split_index] % 2 == 0,
"Split dimension ",
Yshape[split_index],
" should be divided by two");
const int split_dim_size = Yshape[split_index] / 2;
const int M = X.size_to_dim(split_index);
const int N = X.size_from_dim(split_index + 1);
Yshape[split_index] = split_dim_size;
auto* Y = Output(0, Yshape, at::dtype<T>());
ComputeGlu(
M,
split_dim_size,
N,
X.template data<T>(),
Y->template mutable_data<T>());
return true;
}
protected:
void ComputeGlu(
const int M,
const int split_dim_size,
const int N,
const T* X,
T* output);
private:
const int dim_;
};
} // namespace caffe2
#endif // CAFFE2_OPERATOR_GLU_OP_H_