#ifndef CAFFE2_OPERATORS_LSTM_UNIT_OP_H_
#define CAFFE2_OPERATORS_LSTM_UNIT_OP_H_
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/perfkernels/lstm_unit_cpu.h"
#include "caffe2/utils/conversions.h"
namespace caffe2 {
namespace detail {
template <typename T, typename Context>
inline void LSTMUnit(
const int N,
const int D,
const int t,
const T* H_prev,
const T* C_prev,
const T* X,
const int32_t* seqLengths,
const bool drop_states,
T* C,
T* H,
const float forget_bias,
Context* /*context*/) {
LstmUnitCpu<T>(
N, D, t, H_prev, C_prev, X, seqLengths, drop_states, C, H, forget_bias);
}
template <typename T, typename Context>
inline void LSTMUnitGradient(
int N,
int D,
int t,
const T* C_prev,
const T* X,
const int32_t* seqLengths,
const T* C,
const T* H,
const T* C_diff,
const T* H_diff,
bool drop_states,
T* H_prev_diff,
T* C_prev_diff,
T* X_diff,
const float forget_bias,
Context* /*context*/) {
LstmUnitGradientCpu<T>(
N,
D,
t,
C_prev,
X,
seqLengths,
C,
H,
C_diff,
H_diff,
drop_states,
H_prev_diff,
C_prev_diff,
X_diff,
forget_bias);
}
} // namespace detail
template <typename Context>
class LSTMUnitOp : public Operator<Context> {
public:
explicit LSTMUnitOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
forget_bias_(static_cast<float>(
this->template GetSingleArgument<float>("forget_bias", 0.0))),
sequence_lengths_(
this->template GetSingleArgument<bool>("sequence_lengths", true)),
drop_states_(
this->template GetSingleArgument<bool>("drop_states", false)) {}
USE_OPERATOR_CONTEXT_FUNCTIONS;
using Operator<Context>::Operator;
template <typename T>
bool DoRunWithType() {
// handle potentially-missing sequence lengths input
const size_t TIMESTEP = SEQ_LENGTHS + (sequence_lengths_ ? 1 : 0);
// Extract N
const auto N = Input(CELL_T_M_1).size(1);
// Gates: 1xNxG
const auto G = Input(GATES).size(2);
const auto D = Input(CELL_T_M_1).size(2);
CAFFE_ENFORCE_EQ(4 * D, G);
const auto* H_prev = Input(HIDDEN_T_M_1).template data<T>();
const auto* C_prev = Input(CELL_T_M_1).template data<T>();
const auto* X = Input(GATES).template data<T>();
const int32_t* seqLengths = nullptr;
if (sequence_lengths_) {
CAFFE_ENFORCE_EQ(Input(SEQ_LENGTHS).numel(), N);
seqLengths = Input(SEQ_LENGTHS).template data<int32_t>();
}
const auto t = static_cast<OperatorBase*>(this)
->Input<Tensor>(TIMESTEP, CPU)
.template data<int32_t>()[0];
Output(CELL_T)->ResizeLike(Input(CELL_T_M_1));
auto* C = Output(CELL_T)->template mutable_data<T>();
Output(HIDDEN_T)->ResizeLike(Input(CELL_T_M_1));
auto* H = Output(HIDDEN_T)->template mutable_data<T>();
detail::LSTMUnit<T, Context>(
N,
D,
t,
H_prev,
C_prev,
X,
seqLengths,
drop_states_,
C,
H,
forget_bias_,
&context_);
return true;
}
bool RunOnDevice() override {
return DoRunWithType<float>();
}
protected:
INPUT_TAGS(HIDDEN_T_M_1, CELL_T_M_1, GATES, SEQ_LENGTHS);
// additional input tags are determined dynamically based on whether
// sequence_lengths is present.
OUTPUT_TAGS(HIDDEN_T, CELL_T);
float forget_bias_;
bool sequence_lengths_;
private:
bool drop_states_;
};
template <typename Context>
class LSTMUnitGradientOp : public Operator<Context> {
public:
template <class... Args>
explicit LSTMUnitGradientOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
forget_bias_(static_cast<float>(
this->template GetSingleArgument<float>("forget_bias", 0.0))),
sequence_lengths_(
this->template GetSingleArgument<bool>("sequence_lengths", true)),
drop_states_(
this->template GetSingleArgument<bool>("drop_states", false)) {}
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <typename T>
bool DoRunWithType() {
// handle potentially-missing sequence lengths input
const size_t inputOffset = SEQ_LENGTHS + (sequence_lengths_ ? 1 : 0);
const size_t TIMESTEP = inputOffset;
const size_t HIDDEN_T = inputOffset + 1;
const size_t CELL_T = inputOffset + 2;
const size_t HIDDEN_T_GRAD = inputOffset + 3;
const size_t CELL_T_GRAD = inputOffset + 4;
// Extract N
const auto N = Input(CELL_T_M_1).size(1);
// Gates: 1xNxG
const auto G = Input(GATES).size(2);
const auto D = Input(CELL_T_M_1).size(2);
CAFFE_ENFORCE_EQ(4 * D, G);
const auto* C_prev = Input(CELL_T_M_1).template data<T>();
const auto* X = Input(GATES).template data<T>();
const auto t = static_cast<OperatorBase*>(this)
->Input<Tensor>(TIMESTEP, CPU)
.template data<int32_t>()[0];
const auto* C = Input(CELL_T).template data<T>();
const auto* H = Input(HIDDEN_T).template data<T>();
const auto* C_diff = Input(CELL_T_GRAD).template data<T>();
const auto* H_diff = Input(HIDDEN_T_GRAD).template data<T>();
const int32_t* seqLengths = nullptr;
if (sequence_lengths_) {
CAFFE_ENFORCE_EQ(Input(SEQ_LENGTHS).numel(), N);
seqLengths = Input(SEQ_LENGTHS).template data<int32_t>();
}
Output(HIDDEN_T_M_1_GRAD)->ResizeLike(Input(HIDDEN_T_M_1));
auto* H_prev_diff = Output(HIDDEN_T_M_1_GRAD)->template mutable_data<T>();
Output(CELL_T_M_1_GRAD)->ResizeLike(Input(CELL_T_M_1));
auto* C_prev_diff = Output(CELL_T_M_1_GRAD)->template mutable_data<T>();
Output(GATES_GRAD)->ResizeLike(Input(GATES));
auto* X_diff = Output(GATES_GRAD)->template mutable_data<T>();
detail::LSTMUnitGradient<T, Context>(
N,
D,
t,
C_prev,
X,
seqLengths,
C,
H,
C_diff,
H_diff,
drop_states_,
H_prev_diff,
C_prev_diff,
X_diff,
forget_bias_,
&context_);
return true;
}
bool RunOnDevice() override {
return DoRunWithType<float>();
}
protected:
INPUT_TAGS(HIDDEN_T_M_1, CELL_T_M_1, GATES, SEQ_LENGTHS);
// additional input tags are determined dynamically based on whether
// sequence_lengths is present.
OUTPUT_TAGS(HIDDEN_T_M_1_GRAD, CELL_T_M_1_GRAD, GATES_GRAD);
float forget_bias_;
bool sequence_lengths_;
private:
bool drop_states_;
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_LSTM_UNIT_OP_H_