Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ include / caffe2 / operators / activation_ops_cudnn.h

#ifndef CAFFE2_OPERATORS_ACTIVATION_OPS_CUDNN_H_
#define CAFFE2_OPERATORS_ACTIVATION_OPS_CUDNN_H_

#include "caffe2/core/context_gpu.h"
#include "caffe2/core/cudnn_wrappers.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/tensor.h"
#include "caffe2/core/types.h"

namespace caffe2 {

class CuDNNActivationOpBase : public Operator<CUDAContext> {
 public:
  USE_OPERATOR_FUNCTIONS(CUDAContext);

  template <class... Args>
  explicit CuDNNActivationOpBase(Args&&... args)
      : Operator<CUDAContext>(std::forward<Args>(args)...),
        cudnn_wrapper_(&context_) {
    CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&data_desc_));
    CUDNN_ENFORCE(cudnnCreateActivationDescriptor(&act_desc_));
  }

  virtual ~CuDNNActivationOpBase() {
    CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(data_desc_));
    CUDNN_ENFORCE(cudnnDestroyActivationDescriptor(act_desc_));
  }

 protected:
  void SetTensorDescriptor(
      const cudnnDataType_t data_type,
      const int data_size) {
    if (data_size != input_size_) {
      // Since the best performance is obtained when the tensor is HW-packed, we
      // put X.size() to W.
      input_size_ = data_size;
      CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
          data_desc_,
          GetCudnnTensorFormat(StorageOrder::NCHW),
          data_type,
          1,
          1,
          1,
          input_size_));
    }
  }

  CuDNNWrapper cudnn_wrapper_;
  cudnnTensorDescriptor_t data_desc_;
  cudnnActivationDescriptor_t act_desc_;

  int input_size_ = 0;
};

template <cudnnActivationMode_t kCuDNNActivationMode>
class CuDNNActivationOp final : public CuDNNActivationOpBase {
 public:
  USE_OPERATOR_FUNCTIONS(CUDAContext);

  template <class... Args>
  explicit CuDNNActivationOp(Args&&... args)
      : CuDNNActivationOpBase(std::forward<Args>(args)...) {
    CUDNN_ENFORCE(cudnnSetActivationDescriptor(
        act_desc_, kCuDNNActivationMode, CUDNN_PROPAGATE_NAN, 0.0));
  }

  bool RunOnDevice() override {
    return DispatchHelper<TensorTypes<float, at::Half>>::call(this, Input(0));
  }

  template <typename T>
  bool DoRunWithType() {
    const auto& X = Input(0);

    auto* Y = Output(0, X.sizes(), at::dtype<T>());
    if (X.numel() == 0) {
      Y->template mutable_data<T>();
      return true;
    }
    this->SetTensorDescriptor(cudnnTypeWrapper<T>::type, X.numel());
    CUDNN_ENFORCE(cudnnActivationForward(
        this->cudnn_wrapper_.inline_cudnn_handle(),
        this->act_desc_,
        cudnnTypeWrapper<T>::kOne(),
        this->data_desc_,
        X.template data<T>(),
        cudnnTypeWrapper<T>::kZero(),
        this->data_desc_,
        Y->template mutable_data<T>()));
    return true;
  }
};

template <cudnnActivationMode_t kCuDNNActivationMode>
class CuDNNActivationGradientOp final : public CuDNNActivationOpBase {
 public:
  USE_OPERATOR_FUNCTIONS(CUDAContext);

  template <class... Args>
  explicit CuDNNActivationGradientOp(Args&&... args)
      : CuDNNActivationOpBase(std::forward<Args>(args)...) {
    CUDNN_ENFORCE(cudnnSetActivationDescriptor(
        act_desc_, kCuDNNActivationMode, CUDNN_PROPAGATE_NAN, 0.0));
  }

  bool RunOnDevice() override {
    return DispatchHelper<TensorTypes<float, at::Half>>::call(this, Input(0));
  }

  template <typename T>
  bool DoRunWithType() {
    const auto& Y = Input(0);
    const auto& dY = Input(1);

    auto* dX = Output(0, Y.sizes(), at::dtype<T>());
    if (Y.numel() == 0) {
      dX->template mutable_data<T>();
      return true;
    }
    this->SetTensorDescriptor(cudnnTypeWrapper<T>::type, Y.numel());
    CUDNN_ENFORCE(cudnnActivationBackward(
        this->cudnn_wrapper_.inline_cudnn_handle(),
        this->act_desc_,
        cudnnTypeWrapper<T>::kOne(),
        this->data_desc_,
        Y.template data<T>(),
        this->data_desc_,
        dY.template data<T>(),
        this->data_desc_,
        Y.template data<T>(), // Use Y_data as placeholder here.
        cudnnTypeWrapper<T>::kZero(),
        this->data_desc_,
        dX->template mutable_data<T>()));
    return true;
  }
};

} // namespace caffe2

#endif // CAFFE2_OPERATORS_ACTIVATION_OPS_CUDNN_H_