#ifndef CAFFE2_OPERATORS_ACTIVATION_OPS_CUDNN_H_
#define CAFFE2_OPERATORS_ACTIVATION_OPS_CUDNN_H_
#include "caffe2/core/context_gpu.h"
#include "caffe2/core/cudnn_wrappers.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/tensor.h"
#include "caffe2/core/types.h"
namespace caffe2 {
class CuDNNActivationOpBase : public Operator<CUDAContext> {
public:
USE_OPERATOR_FUNCTIONS(CUDAContext);
template <class... Args>
explicit CuDNNActivationOpBase(Args&&... args)
: Operator<CUDAContext>(std::forward<Args>(args)...),
cudnn_wrapper_(&context_) {
CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&data_desc_));
CUDNN_ENFORCE(cudnnCreateActivationDescriptor(&act_desc_));
}
virtual ~CuDNNActivationOpBase() {
CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(data_desc_));
CUDNN_ENFORCE(cudnnDestroyActivationDescriptor(act_desc_));
}
protected:
void SetTensorDescriptor(
const cudnnDataType_t data_type,
const int data_size) {
if (data_size != input_size_) {
// Since the best performance is obtained when the tensor is HW-packed, we
// put X.size() to W.
input_size_ = data_size;
CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
data_desc_,
GetCudnnTensorFormat(StorageOrder::NCHW),
data_type,
1,
1,
1,
input_size_));
}
}
CuDNNWrapper cudnn_wrapper_;
cudnnTensorDescriptor_t data_desc_;
cudnnActivationDescriptor_t act_desc_;
int input_size_ = 0;
};
template <cudnnActivationMode_t kCuDNNActivationMode>
class CuDNNActivationOp final : public CuDNNActivationOpBase {
public:
USE_OPERATOR_FUNCTIONS(CUDAContext);
template <class... Args>
explicit CuDNNActivationOp(Args&&... args)
: CuDNNActivationOpBase(std::forward<Args>(args)...) {
CUDNN_ENFORCE(cudnnSetActivationDescriptor(
act_desc_, kCuDNNActivationMode, CUDNN_PROPAGATE_NAN, 0.0));
}
bool RunOnDevice() override {
return DispatchHelper<TensorTypes<float, at::Half>>::call(this, Input(0));
}
template <typename T>
bool DoRunWithType() {
const auto& X = Input(0);
auto* Y = Output(0, X.sizes(), at::dtype<T>());
if (X.numel() == 0) {
Y->template mutable_data<T>();
return true;
}
this->SetTensorDescriptor(cudnnTypeWrapper<T>::type, X.numel());
CUDNN_ENFORCE(cudnnActivationForward(
this->cudnn_wrapper_.inline_cudnn_handle(),
this->act_desc_,
cudnnTypeWrapper<T>::kOne(),
this->data_desc_,
X.template data<T>(),
cudnnTypeWrapper<T>::kZero(),
this->data_desc_,
Y->template mutable_data<T>()));
return true;
}
};
template <cudnnActivationMode_t kCuDNNActivationMode>
class CuDNNActivationGradientOp final : public CuDNNActivationOpBase {
public:
USE_OPERATOR_FUNCTIONS(CUDAContext);
template <class... Args>
explicit CuDNNActivationGradientOp(Args&&... args)
: CuDNNActivationOpBase(std::forward<Args>(args)...) {
CUDNN_ENFORCE(cudnnSetActivationDescriptor(
act_desc_, kCuDNNActivationMode, CUDNN_PROPAGATE_NAN, 0.0));
}
bool RunOnDevice() override {
return DispatchHelper<TensorTypes<float, at::Half>>::call(this, Input(0));
}
template <typename T>
bool DoRunWithType() {
const auto& Y = Input(0);
const auto& dY = Input(1);
auto* dX = Output(0, Y.sizes(), at::dtype<T>());
if (Y.numel() == 0) {
dX->template mutable_data<T>();
return true;
}
this->SetTensorDescriptor(cudnnTypeWrapper<T>::type, Y.numel());
CUDNN_ENFORCE(cudnnActivationBackward(
this->cudnn_wrapper_.inline_cudnn_handle(),
this->act_desc_,
cudnnTypeWrapper<T>::kOne(),
this->data_desc_,
Y.template data<T>(),
this->data_desc_,
dY.template data<T>(),
this->data_desc_,
Y.template data<T>(), // Use Y_data as placeholder here.
cudnnTypeWrapper<T>::kZero(),
this->data_desc_,
dX->template mutable_data<T>()));
return true;
}
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_ACTIVATION_OPS_CUDNN_H_