1 #ifndef CAFFE_UTIL_CUDNN_H_ 2 #define CAFFE_UTIL_CUDNN_H_ 7 #include "caffe/common.hpp" 8 #include "caffe/proto/caffe.pb.h" 10 #define CUDNN_VERSION_MIN(major, minor, patch) \ 11 (CUDNN_VERSION >= (major * 1000 + minor * 100 + patch)) 13 #define CUDNN_CHECK(condition) \ 15 cudnnStatus_t status = condition; \ 16 CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << " "\ 17 << cudnnGetErrorString(status); \ 20 inline const char* cudnnGetErrorString(cudnnStatus_t status) {
22 case CUDNN_STATUS_SUCCESS:
23 return "CUDNN_STATUS_SUCCESS";
24 case CUDNN_STATUS_NOT_INITIALIZED:
25 return "CUDNN_STATUS_NOT_INITIALIZED";
26 case CUDNN_STATUS_ALLOC_FAILED:
27 return "CUDNN_STATUS_ALLOC_FAILED";
28 case CUDNN_STATUS_BAD_PARAM:
29 return "CUDNN_STATUS_BAD_PARAM";
30 case CUDNN_STATUS_INTERNAL_ERROR:
31 return "CUDNN_STATUS_INTERNAL_ERROR";
32 case CUDNN_STATUS_INVALID_VALUE:
33 return "CUDNN_STATUS_INVALID_VALUE";
34 case CUDNN_STATUS_ARCH_MISMATCH:
35 return "CUDNN_STATUS_ARCH_MISMATCH";
36 case CUDNN_STATUS_MAPPING_ERROR:
37 return "CUDNN_STATUS_MAPPING_ERROR";
38 case CUDNN_STATUS_EXECUTION_FAILED:
39 return "CUDNN_STATUS_EXECUTION_FAILED";
40 case CUDNN_STATUS_NOT_SUPPORTED:
41 return "CUDNN_STATUS_NOT_SUPPORTED";
42 case CUDNN_STATUS_LICENSE_ERROR:
43 return "CUDNN_STATUS_LICENSE_ERROR";
44 #if CUDNN_VERSION_MIN(6, 0, 0) 45 case CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING:
46 return "CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING";
49 return "Unknown cudnn status";
56 template <
typename Dtype>
class dataType;
57 template<>
class dataType<float> {
59 static const cudnnDataType_t type = CUDNN_DATA_FLOAT;
60 static float oneval, zeroval;
61 static const void *one, *zero;
63 template<>
class dataType<double> {
65 static const cudnnDataType_t type = CUDNN_DATA_DOUBLE;
66 static double oneval, zeroval;
67 static const void *one, *zero;
70 template <
typename Dtype>
71 inline void createTensor4dDesc(cudnnTensorDescriptor_t* desc) {
72 CUDNN_CHECK(cudnnCreateTensorDescriptor(desc));
75 template <
typename Dtype>
76 inline void setTensor4dDesc(cudnnTensorDescriptor_t* desc,
77 int n,
int c,
int h,
int w,
78 int stride_n,
int stride_c,
int stride_h,
int stride_w) {
79 CUDNN_CHECK(cudnnSetTensor4dDescriptorEx(*desc, dataType<Dtype>::type,
80 n, c, h, w, stride_n, stride_c, stride_h, stride_w));
83 template <
typename Dtype>
84 inline void setTensor4dDesc(cudnnTensorDescriptor_t* desc,
85 int n,
int c,
int h,
int w) {
86 const int stride_w = 1;
87 const int stride_h = w * stride_w;
88 const int stride_c = h * stride_h;
89 const int stride_n = c * stride_c;
90 setTensor4dDesc<Dtype>(desc, n, c, h, w,
91 stride_n, stride_c, stride_h, stride_w);
94 template <
typename Dtype>
95 inline void createFilterDesc(cudnnFilterDescriptor_t* desc,
96 int n,
int c,
int h,
int w) {
97 CUDNN_CHECK(cudnnCreateFilterDescriptor(desc));
98 #if CUDNN_VERSION_MIN(5, 0, 0) 99 CUDNN_CHECK(cudnnSetFilter4dDescriptor(*desc, dataType<Dtype>::type,
100 CUDNN_TENSOR_NCHW, n, c, h, w));
102 CUDNN_CHECK(cudnnSetFilter4dDescriptor_v4(*desc, dataType<Dtype>::type,
103 CUDNN_TENSOR_NCHW, n, c, h, w));
107 template <
typename Dtype>
108 inline void createConvolutionDesc(cudnnConvolutionDescriptor_t* conv) {
109 CUDNN_CHECK(cudnnCreateConvolutionDescriptor(conv));
112 template <
typename Dtype>
113 inline void setConvolutionDesc(cudnnConvolutionDescriptor_t* conv,
114 cudnnTensorDescriptor_t bottom, cudnnFilterDescriptor_t filter,
115 int pad_h,
int pad_w,
int stride_h,
int stride_w) {
116 #if CUDNN_VERSION_MIN(6, 0, 0) 117 CUDNN_CHECK(cudnnSetConvolution2dDescriptor(*conv,
118 pad_h, pad_w, stride_h, stride_w, 1, 1, CUDNN_CROSS_CORRELATION,
119 dataType<Dtype>::type));
121 CUDNN_CHECK(cudnnSetConvolution2dDescriptor(*conv,
122 pad_h, pad_w, stride_h, stride_w, 1, 1, CUDNN_CROSS_CORRELATION));
126 template <
typename Dtype>
127 inline void createPoolingDesc(cudnnPoolingDescriptor_t* pool_desc,
128 PoolingParameter_PoolMethod poolmethod, cudnnPoolingMode_t* mode,
129 int h,
int w,
int pad_h,
int pad_w,
int stride_h,
int stride_w) {
130 switch (poolmethod) {
131 case PoolingParameter_PoolMethod_MAX:
132 *mode = CUDNN_POOLING_MAX;
134 case PoolingParameter_PoolMethod_AVE:
135 *mode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
138 LOG(FATAL) <<
"Unknown pooling method.";
140 CUDNN_CHECK(cudnnCreatePoolingDescriptor(pool_desc));
141 #if CUDNN_VERSION_MIN(5, 0, 0) 142 CUDNN_CHECK(cudnnSetPooling2dDescriptor(*pool_desc, *mode,
143 CUDNN_PROPAGATE_NAN, h, w, pad_h, pad_w, stride_h, stride_w));
145 CUDNN_CHECK(cudnnSetPooling2dDescriptor_v4(*pool_desc, *mode,
146 CUDNN_PROPAGATE_NAN, h, w, pad_h, pad_w, stride_h, stride_w));
150 template <
typename Dtype>
151 inline void createActivationDescriptor(cudnnActivationDescriptor_t* activ_desc,
152 cudnnActivationMode_t mode) {
153 CUDNN_CHECK(cudnnCreateActivationDescriptor(activ_desc));
154 CUDNN_CHECK(cudnnSetActivationDescriptor(*activ_desc, mode,
155 CUDNN_PROPAGATE_NAN, Dtype(0)));
163 #endif // CAFFE_UTIL_CUDNN_H_ A layer factory that allows one to register layers. During runtime, registered layers can be called b...
Definition: blob.hpp:14