#ifndef CAFFE2_CORE_COMMON_CUDNN_H_
#define CAFFE2_CORE_COMMON_CUDNN_H_
#include <array>
#include <mutex>
#include "caffe2/core/common.h"
#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/types.h"
#ifndef CAFFE2_USE_CUDNN
#error("This Caffe2 install is not built with cudnn, so you should not include this file.");
#endif
#include <cudnn.h>
static_assert(
CUDNN_VERSION >= 5000,
"Caffe2 requires cudnn version 5.0 or above.");
#if CUDNN_VERSION < 6000
#pragma message "CUDNN version under 6.0 is supported at best effort."
#pragma message "We strongly encourage you to move to 6.0 and above."
#pragma message "This message is intended to annoy you enough to update."
#endif // CUDNN_VERSION < 6000
#define CUDNN_VERSION_MIN(major, minor, patch) \
(CUDNN_VERSION >= ((major) * 1000 + (minor) * 100 + (patch)))
namespace caffe2 {
namespace internal {
/**
* A helper function to obtain cudnn error strings.
*/
inline const char* cudnnGetErrorString(cudnnStatus_t status) {
switch (status) {
case CUDNN_STATUS_SUCCESS:
return "CUDNN_STATUS_SUCCESS";
case CUDNN_STATUS_NOT_INITIALIZED:
return "CUDNN_STATUS_NOT_INITIALIZED";
case CUDNN_STATUS_ALLOC_FAILED:
return "CUDNN_STATUS_ALLOC_FAILED";
case CUDNN_STATUS_BAD_PARAM:
return "CUDNN_STATUS_BAD_PARAM";
case CUDNN_STATUS_INTERNAL_ERROR:
return "CUDNN_STATUS_INTERNAL_ERROR";
case CUDNN_STATUS_INVALID_VALUE:
return "CUDNN_STATUS_INVALID_VALUE";
case CUDNN_STATUS_ARCH_MISMATCH:
return "CUDNN_STATUS_ARCH_MISMATCH";
case CUDNN_STATUS_MAPPING_ERROR:
return "CUDNN_STATUS_MAPPING_ERROR";
case CUDNN_STATUS_EXECUTION_FAILED:
return "CUDNN_STATUS_EXECUTION_FAILED";
case CUDNN_STATUS_NOT_SUPPORTED:
return "CUDNN_STATUS_NOT_SUPPORTED";
case CUDNN_STATUS_LICENSE_ERROR:
return "CUDNN_STATUS_LICENSE_ERROR";
default:
return "Unknown cudnn error number";
}
}
} // namespace internal
// A macro that wraps around a cudnn statement so we can check if the cudnn
// execution finishes or not.
#define CUDNN_ENFORCE(condition) \
do { \
cudnnStatus_t status = condition; \
CAFFE_ENFORCE_EQ( \
status, \
CUDNN_STATUS_SUCCESS, \
", Error at: ", \
__FILE__, \
":", \
__LINE__, \
": ", \
::caffe2::internal::cudnnGetErrorString(status)); \
} while (0)
#define CUDNN_CHECK(condition) \
do { \
cudnnStatus_t status = condition; \
CHECK(status == CUDNN_STATUS_SUCCESS) \
<< ::caffe2::internal::cudnnGetErrorString(status); \
} while (0)
// report the version of cuDNN Caffe2 was compiled with
inline size_t cudnnCompiledVersion() {
return CUDNN_VERSION;
}
// report the runtime version of cuDNN
inline size_t cudnnRuntimeVersion() {
return cudnnGetVersion();
}
// Check compatibility of compiled and runtime cuDNN versions
inline void CheckCuDNNVersions() {
// Version format is major*1000 + minor*100 + patch
// If compiled with version < 7, major, minor and patch must all match
// If compiled with version >= 7, then either
// runtime_version > compiled_version
// major and minor match
bool version_match = cudnnCompiledVersion() == cudnnRuntimeVersion();
bool compiled_with_7 = cudnnCompiledVersion() >= 7000;
bool backwards_compatible_7 = compiled_with_7 && cudnnRuntimeVersion() >= cudnnCompiledVersion();
bool patch_compatible = compiled_with_7 && (cudnnRuntimeVersion() / 100) == (cudnnCompiledVersion() / 100);
CAFFE_ENFORCE(version_match || backwards_compatible_7 || patch_compatible,
"cuDNN compiled (", cudnnCompiledVersion(), ") and "
"runtime (", cudnnRuntimeVersion(), ") versions mismatch");
}
/**
* cudnnTypeWrapper is a wrapper class that allows us to refer to the cudnn type
* in a template function. The class is specialized explicitly for different
* data types below.
*/
template <typename T>
class cudnnTypeWrapper;
template <>
class cudnnTypeWrapper<float> {
public:
static const cudnnDataType_t type = CUDNN_DATA_FLOAT;
typedef const float ScalingParamType;
typedef float BNParamType;
static ScalingParamType* kOne() {
static ScalingParamType v = 1.0;
return &v;
}
static const ScalingParamType* kZero() {
static ScalingParamType v = 0.0;
return &v;
}
};
#if CUDNN_VERSION_MIN(6, 0, 0)
template <>
class cudnnTypeWrapper<int> {
public:
static const cudnnDataType_t type = CUDNN_DATA_INT32;
typedef const int ScalingParamType;
typedef int BNParamType;
static ScalingParamType* kOne() {
static ScalingParamType v = 1;
return &v;
}
static const ScalingParamType* kZero() {
static ScalingParamType v = 0;
return &v;
}
};
#endif // CUDNN_VERSION_MIN(6, 0, 0)
template <>
class cudnnTypeWrapper<double> {
public:
static const cudnnDataType_t type = CUDNN_DATA_DOUBLE;
typedef const double ScalingParamType;
typedef double BNParamType;
static ScalingParamType* kOne() {
static ScalingParamType v = 1.0;
return &v;
}
static ScalingParamType* kZero() {
static ScalingParamType v = 0.0;
return &v;
}
};
template <>
class cudnnTypeWrapper<at::Half> {
public:
static const cudnnDataType_t type = CUDNN_DATA_HALF;
typedef const float ScalingParamType;
typedef float BNParamType;
static ScalingParamType* kOne() {
static ScalingParamType v = 1.0;
return &v;
}
static ScalingParamType* kZero() {
static ScalingParamType v = 0.0;
return &v;
}
};
/**
* A wrapper function to convert the Caffe storage order to cudnn storage order
* enum values.
*/
inline cudnnTensorFormat_t GetCudnnTensorFormat(const StorageOrder& order) {
switch (order) {
case StorageOrder::NHWC:
return CUDNN_TENSOR_NHWC;
case StorageOrder::NCHW:
return CUDNN_TENSOR_NCHW;
default:
LOG(FATAL) << "Unknown cudnn equivalent for order: " << order;
}
// Just to suppress compiler warnings
return CUDNN_TENSOR_NCHW;
}
/**
* cudnnTensorDescWrapper is the placeholder that wraps around a
* cudnnTensorDescriptor_t, allowing us to do descriptor change as-needed during
* runtime.
*/
class cudnnTensorDescWrapper {
public:
cudnnTensorDescWrapper() {
CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&desc_));
}
~cudnnTensorDescWrapper() noexcept {
CUDNN_CHECK(cudnnDestroyTensorDescriptor(desc_));
}
inline cudnnTensorDescriptor_t Descriptor(
const cudnnTensorFormat_t format,
const cudnnDataType_t type,
const vector<int>& dims,
bool* changed) {
if (type_ == type && format_ == format && dims_ == dims) {
// if not changed, simply return the current descriptor.
if (changed)
*changed = false;
return desc_;
}
CAFFE_ENFORCE_EQ(
dims.size(), 4U, "Currently only 4-dimensional descriptor supported.");
format_ = format;
type_ = type;
dims_ = dims;
CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
desc_,
format,
type,
dims_[0],
(format == CUDNN_TENSOR_NCHW ? dims_[1] : dims_[3]),
(format == CUDNN_TENSOR_NCHW ? dims_[2] : dims_[1]),
(format == CUDNN_TENSOR_NCHW ? dims_[3] : dims_[2])));
if (changed)
*changed = true;
return desc_;
}
template <typename T>
inline cudnnTensorDescriptor_t Descriptor(
const StorageOrder& order,
const vector<int>& dims) {
return Descriptor(
GetCudnnTensorFormat(order), cudnnTypeWrapper<T>::type, dims, nullptr);
}
private:
cudnnTensorDescriptor_t desc_;
cudnnTensorFormat_t format_;
cudnnDataType_t type_;
vector<int> dims_;
C10_DISABLE_COPY_AND_ASSIGN(cudnnTensorDescWrapper);
};
class cudnnFilterDescWrapper {
public:
cudnnFilterDescWrapper() {
CUDNN_ENFORCE(cudnnCreateFilterDescriptor(&desc_));
}
~cudnnFilterDescWrapper() noexcept {
CUDNN_CHECK(cudnnDestroyFilterDescriptor(desc_));
}
inline cudnnFilterDescriptor_t Descriptor(
const StorageOrder& order,
const cudnnDataType_t type,
const vector<int>& dims,
bool* changed) {
if (type_ == type && order_ == order && dims_ == dims) {
// if not changed, simply return the current descriptor.
if (changed)
*changed = false;
return desc_;
}
CAFFE_ENFORCE_EQ(
dims.size(), 4U, "Currently only 4-dimensional descriptor supported.");
order_ = order;
type_ = type;
dims_ = dims;
CUDNN_ENFORCE(cudnnSetFilter4dDescriptor(
desc_,
type,
GetCudnnTensorFormat(order),
dims_[0],
// TODO - confirm that this is correct for NHWC
(order == StorageOrder::NCHW ? dims_[1] : dims_[3]),
(order == StorageOrder::NCHW ? dims_[2] : dims_[1]),
(order == StorageOrder::NCHW ? dims_[3] : dims_[2])));
if (changed)
*changed = true;
return desc_;
}
template <typename T>
inline cudnnFilterDescriptor_t Descriptor(
const StorageOrder& order,
const vector<int>& dims) {
return Descriptor(order, cudnnTypeWrapper<T>::type, dims, nullptr);
}
private:
cudnnFilterDescriptor_t desc_;
StorageOrder order_;
cudnnDataType_t type_;
vector<int> dims_;
C10_DISABLE_COPY_AND_ASSIGN(cudnnFilterDescWrapper);
};
} // namespace caffe2
#endif // CAFFE2_CORE_COMMON_CUDNN_H_