Caffe
cudnn.hpp
1 #ifndef CAFFE_UTIL_CUDNN_H_
2 #define CAFFE_UTIL_CUDNN_H_
3 #ifdef USE_CUDNN
4 
5 #include <cudnn.h>
6 
7 #include "caffe/common.hpp"
8 #include "caffe/proto/caffe.pb.h"
9 
10 #define CUDNN_VERSION_MIN(major, minor, patch) \
11  (CUDNN_VERSION >= (major * 1000 + minor * 100 + patch))
12 
13 #define CUDNN_CHECK(condition) \
14  do { \
15  cudnnStatus_t status = condition; \
16  CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << " "\
17  << cudnnGetErrorString(status); \
18  } while (0)
19 
20 inline const char* cudnnGetErrorString(cudnnStatus_t status) {
21  switch (status) {
22  case CUDNN_STATUS_SUCCESS:
23  return "CUDNN_STATUS_SUCCESS";
24  case CUDNN_STATUS_NOT_INITIALIZED:
25  return "CUDNN_STATUS_NOT_INITIALIZED";
26  case CUDNN_STATUS_ALLOC_FAILED:
27  return "CUDNN_STATUS_ALLOC_FAILED";
28  case CUDNN_STATUS_BAD_PARAM:
29  return "CUDNN_STATUS_BAD_PARAM";
30  case CUDNN_STATUS_INTERNAL_ERROR:
31  return "CUDNN_STATUS_INTERNAL_ERROR";
32  case CUDNN_STATUS_INVALID_VALUE:
33  return "CUDNN_STATUS_INVALID_VALUE";
34  case CUDNN_STATUS_ARCH_MISMATCH:
35  return "CUDNN_STATUS_ARCH_MISMATCH";
36  case CUDNN_STATUS_MAPPING_ERROR:
37  return "CUDNN_STATUS_MAPPING_ERROR";
38  case CUDNN_STATUS_EXECUTION_FAILED:
39  return "CUDNN_STATUS_EXECUTION_FAILED";
40  case CUDNN_STATUS_NOT_SUPPORTED:
41  return "CUDNN_STATUS_NOT_SUPPORTED";
42  case CUDNN_STATUS_LICENSE_ERROR:
43  return "CUDNN_STATUS_LICENSE_ERROR";
44 #if CUDNN_VERSION_MIN(6, 0, 0)
45  case CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING:
46  return "CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING";
47 #endif
48  }
49  return "Unknown cudnn status";
50 }
51 
52 namespace caffe {
53 
54 namespace cudnn {
55 
56 template <typename Dtype> class dataType;
57 template<> class dataType<float> {
58  public:
59  static const cudnnDataType_t type = CUDNN_DATA_FLOAT;
60  static float oneval, zeroval;
61  static const void *one, *zero;
62 };
63 template<> class dataType<double> {
64  public:
65  static const cudnnDataType_t type = CUDNN_DATA_DOUBLE;
66  static double oneval, zeroval;
67  static const void *one, *zero;
68 };
69 
70 template <typename Dtype>
71 inline void createTensor4dDesc(cudnnTensorDescriptor_t* desc) {
72  CUDNN_CHECK(cudnnCreateTensorDescriptor(desc));
73 }
74 
75 template <typename Dtype>
76 inline void setTensor4dDesc(cudnnTensorDescriptor_t* desc,
77  int n, int c, int h, int w,
78  int stride_n, int stride_c, int stride_h, int stride_w) {
79  CUDNN_CHECK(cudnnSetTensor4dDescriptorEx(*desc, dataType<Dtype>::type,
80  n, c, h, w, stride_n, stride_c, stride_h, stride_w));
81 }
82 
83 template <typename Dtype>
84 inline void setTensor4dDesc(cudnnTensorDescriptor_t* desc,
85  int n, int c, int h, int w) {
86  const int stride_w = 1;
87  const int stride_h = w * stride_w;
88  const int stride_c = h * stride_h;
89  const int stride_n = c * stride_c;
90  setTensor4dDesc<Dtype>(desc, n, c, h, w,
91  stride_n, stride_c, stride_h, stride_w);
92 }
93 
94 template <typename Dtype>
95 inline void createFilterDesc(cudnnFilterDescriptor_t* desc,
96  int n, int c, int h, int w) {
97  CUDNN_CHECK(cudnnCreateFilterDescriptor(desc));
98 #if CUDNN_VERSION_MIN(5, 0, 0)
99  CUDNN_CHECK(cudnnSetFilter4dDescriptor(*desc, dataType<Dtype>::type,
100  CUDNN_TENSOR_NCHW, n, c, h, w));
101 #else
102  CUDNN_CHECK(cudnnSetFilter4dDescriptor_v4(*desc, dataType<Dtype>::type,
103  CUDNN_TENSOR_NCHW, n, c, h, w));
104 #endif
105 }
106 
107 template <typename Dtype>
108 inline void createConvolutionDesc(cudnnConvolutionDescriptor_t* conv) {
109  CUDNN_CHECK(cudnnCreateConvolutionDescriptor(conv));
110 }
111 
112 template <typename Dtype>
113 inline void setConvolutionDesc(cudnnConvolutionDescriptor_t* conv,
114  cudnnTensorDescriptor_t bottom, cudnnFilterDescriptor_t filter,
115  int pad_h, int pad_w, int stride_h, int stride_w) {
116 #if CUDNN_VERSION_MIN(6, 0, 0)
117  CUDNN_CHECK(cudnnSetConvolution2dDescriptor(*conv,
118  pad_h, pad_w, stride_h, stride_w, 1, 1, CUDNN_CROSS_CORRELATION,
119  dataType<Dtype>::type));
120 #else
121  CUDNN_CHECK(cudnnSetConvolution2dDescriptor(*conv,
122  pad_h, pad_w, stride_h, stride_w, 1, 1, CUDNN_CROSS_CORRELATION));
123 #endif
124 }
125 
126 template <typename Dtype>
127 inline void createPoolingDesc(cudnnPoolingDescriptor_t* pool_desc,
128  PoolingParameter_PoolMethod poolmethod, cudnnPoolingMode_t* mode,
129  int h, int w, int pad_h, int pad_w, int stride_h, int stride_w) {
130  switch (poolmethod) {
131  case PoolingParameter_PoolMethod_MAX:
132  *mode = CUDNN_POOLING_MAX;
133  break;
134  case PoolingParameter_PoolMethod_AVE:
135  *mode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
136  break;
137  default:
138  LOG(FATAL) << "Unknown pooling method.";
139  }
140  CUDNN_CHECK(cudnnCreatePoolingDescriptor(pool_desc));
141 #if CUDNN_VERSION_MIN(5, 0, 0)
142  CUDNN_CHECK(cudnnSetPooling2dDescriptor(*pool_desc, *mode,
143  CUDNN_PROPAGATE_NAN, h, w, pad_h, pad_w, stride_h, stride_w));
144 #else
145  CUDNN_CHECK(cudnnSetPooling2dDescriptor_v4(*pool_desc, *mode,
146  CUDNN_PROPAGATE_NAN, h, w, pad_h, pad_w, stride_h, stride_w));
147 #endif
148 }
149 
150 template <typename Dtype>
151 inline void createActivationDescriptor(cudnnActivationDescriptor_t* activ_desc,
152  cudnnActivationMode_t mode) {
153  CUDNN_CHECK(cudnnCreateActivationDescriptor(activ_desc));
154  CUDNN_CHECK(cudnnSetActivationDescriptor(*activ_desc, mode,
155  CUDNN_PROPAGATE_NAN, Dtype(0)));
156 }
157 
158 } // namespace cudnn
159 
160 } // namespace caffe
161 
162 #endif // USE_CUDNN
163 #endif // CAFFE_UTIL_CUDNN_H_
A layer factory that allows one to register layers. During runtime, registered layers can be called b...
Definition: blob.hpp:14