Caffe
blob.hpp
1 #ifndef CAFFE_BLOB_HPP_
2 #define CAFFE_BLOB_HPP_
3 
4 #include <algorithm>
5 #include <string>
6 #include <vector>
7 
8 #include "caffe/common.hpp"
9 #include "caffe/proto/caffe.pb.h"
10 #include "caffe/syncedmem.hpp"
11 
12 const int kMaxBlobAxes = 32;
13 
14 namespace caffe {
15 
23 template <typename Dtype>
24 class Blob {
25  public:
26  Blob()
27  : data_(), diff_(), count_(0), capacity_(0) {}
28 
30  explicit Blob(const int num, const int channels, const int height,
31  const int width);
32  explicit Blob(const vector<int>& shape);
33 
35  void Reshape(const int num, const int channels, const int height,
36  const int width);
51  void Reshape(const vector<int>& shape);
52  void Reshape(const BlobShape& shape);
53  void ReshapeLike(const Blob& other);
54  inline string shape_string() const {
55  ostringstream stream;
56  for (int i = 0; i < shape_.size(); ++i) {
57  stream << shape_[i] << " ";
58  }
59  stream << "(" << count_ << ")";
60  return stream.str();
61  }
62  inline const vector<int>& shape() const { return shape_; }
71  inline int shape(int index) const {
72  return shape_[CanonicalAxisIndex(index)];
73  }
74  inline int num_axes() const { return shape_.size(); }
75  inline int count() const { return count_; }
76 
85  inline int count(int start_axis, int end_axis) const {
86  CHECK_LE(start_axis, end_axis);
87  CHECK_GE(start_axis, 0);
88  CHECK_GE(end_axis, 0);
89  CHECK_LE(start_axis, num_axes());
90  CHECK_LE(end_axis, num_axes());
91  int count = 1;
92  for (int i = start_axis; i < end_axis; ++i) {
93  count *= shape(i);
94  }
95  return count;
96  }
103  inline int count(int start_axis) const {
104  return count(start_axis, num_axes());
105  }
106 
118  inline int CanonicalAxisIndex(int axis_index) const {
119  CHECK_GE(axis_index, -num_axes())
120  << "axis " << axis_index << " out of range for " << num_axes()
121  << "-D Blob with shape " << shape_string();
122  CHECK_LT(axis_index, num_axes())
123  << "axis " << axis_index << " out of range for " << num_axes()
124  << "-D Blob with shape " << shape_string();
125  if (axis_index < 0) {
126  return axis_index + num_axes();
127  }
128  return axis_index;
129  }
130 
132  inline int num() const { return LegacyShape(0); }
134  inline int channels() const { return LegacyShape(1); }
136  inline int height() const { return LegacyShape(2); }
138  inline int width() const { return LegacyShape(3); }
139  inline int LegacyShape(int index) const {
140  CHECK_LE(num_axes(), 4)
141  << "Cannot use legacy accessors on Blobs with > 4 axes.";
142  CHECK_LT(index, 4);
143  CHECK_GE(index, -4);
144  if (index >= num_axes() || index < -num_axes()) {
145  // Axis is out of range, but still in [0, 3] (or [-4, -1] for reverse
146  // indexing) -- this special case simulates the one-padding used to fill
147  // extraneous axes of legacy blobs.
148  return 1;
149  }
150  return shape(index);
151  }
152 
153  inline int offset(const int n, const int c = 0, const int h = 0,
154  const int w = 0) const {
155  CHECK_GE(n, 0);
156  CHECK_LE(n, num());
157  CHECK_GE(channels(), 0);
158  CHECK_LE(c, channels());
159  CHECK_GE(height(), 0);
160  CHECK_LE(h, height());
161  CHECK_GE(width(), 0);
162  CHECK_LE(w, width());
163  return ((n * channels() + c) * height() + h) * width() + w;
164  }
165 
166  inline int offset(const vector<int>& indices) const {
167  CHECK_LE(indices.size(), num_axes());
168  int offset = 0;
169  for (int i = 0; i < num_axes(); ++i) {
170  offset *= shape(i);
171  if (indices.size() > i) {
172  CHECK_GE(indices[i], 0);
173  CHECK_LT(indices[i], shape(i));
174  offset += indices[i];
175  }
176  }
177  return offset;
178  }
188  void CopyFrom(const Blob<Dtype>& source, bool copy_diff = false,
189  bool reshape = false);
190 
191  inline Dtype data_at(const int n, const int c, const int h,
192  const int w) const {
193  return cpu_data()[offset(n, c, h, w)];
194  }
195 
196  inline Dtype diff_at(const int n, const int c, const int h,
197  const int w) const {
198  return cpu_diff()[offset(n, c, h, w)];
199  }
200 
201  inline Dtype data_at(const vector<int>& index) const {
202  return cpu_data()[offset(index)];
203  }
204 
205  inline Dtype diff_at(const vector<int>& index) const {
206  return cpu_diff()[offset(index)];
207  }
208 
209  inline const shared_ptr<SyncedMemory>& data() const {
210  CHECK(data_);
211  return data_;
212  }
213 
214  inline const shared_ptr<SyncedMemory>& diff() const {
215  CHECK(diff_);
216  return diff_;
217  }
218 
219  const Dtype* cpu_data() const;
220  void set_cpu_data(Dtype* data);
221  const int* gpu_shape() const;
222  const Dtype* gpu_data() const;
223  void set_gpu_data(Dtype* data);
224  const Dtype* cpu_diff() const;
225  const Dtype* gpu_diff() const;
226  Dtype* mutable_cpu_data();
227  Dtype* mutable_gpu_data();
228  Dtype* mutable_cpu_diff();
229  Dtype* mutable_gpu_diff();
230  void Update();
231  void FromProto(const BlobProto& proto, bool reshape = true);
232  void ToProto(BlobProto* proto, bool write_diff = false) const;
233 
235  Dtype asum_data() const;
237  Dtype asum_diff() const;
239  Dtype sumsq_data() const;
241  Dtype sumsq_diff() const;
242 
244  void scale_data(Dtype scale_factor);
246  void scale_diff(Dtype scale_factor);
247 
256  void ShareData(const Blob& other);
265  void ShareDiff(const Blob& other);
266 
267  bool ShapeEquals(const BlobProto& other);
268 
269  protected:
270  shared_ptr<SyncedMemory> data_;
271  shared_ptr<SyncedMemory> diff_;
272  shared_ptr<SyncedMemory> shape_data_;
273  vector<int> shape_;
274  int count_;
275  int capacity_;
276 
277  DISABLE_COPY_AND_ASSIGN(Blob);
278 }; // class Blob
279 
280 } // namespace caffe
281 
282 #endif // CAFFE_BLOB_HPP_
int channels() const
Deprecated legacy shape accessor channels: use shape(1) instead.
Definition: blob.hpp:134
int shape(int index) const
Returns the dimension of the index-th axis (or the negative index-th axis from the end...
Definition: blob.hpp:71
A layer factory that allows one to register layers. During runtime, registered layers can be called b...
Definition: blob.hpp:14
Dtype asum_diff() const
Compute the sum of absolute values (L1 norm) of the diff.
Definition: blob.cpp:245
void ShareDiff(const Blob &other)
Set the diff_ shared_ptr to point to the SyncedMemory holding the diff_ of Blob other – useful in La...
Definition: blob.cpp:162
void ShareData(const Blob &other)
Set the data_ shared_ptr to point to the SyncedMemory holding the data_ of Blob other – useful in La...
Definition: blob.cpp:156
Dtype asum_data() const
Compute the sum of absolute values (L1 norm) of the data.
Definition: blob.cpp:210
int height() const
Deprecated legacy shape accessor height: use shape(2) instead.
Definition: blob.hpp:136
void Reshape(const int num, const int channels, const int height, const int width)
Deprecated; use Reshape(const vector<int>& shape).
Definition: blob.cpp:12
Dtype sumsq_diff() const
Compute the sum of squares (L2 norm squared) of the diff.
Definition: blob.cpp:317
void scale_diff(Dtype scale_factor)
Scale the blob diff by a constant factor.
Definition: blob.cpp:385
int CanonicalAxisIndex(int axis_index) const
Returns the &#39;canonical&#39; version of a (usually) user-specified axis, allowing for negative indexing (e...
Definition: blob.hpp:118
void CopyFrom(const Blob< Dtype > &source, bool copy_diff=false, bool reshape=false)
Copy from a source Blob.
Definition: blob.cpp:433
Dtype sumsq_data() const
Compute the sum of squares (L2 norm squared) of the data.
Definition: blob.cpp:280
void scale_data(Dtype scale_factor)
Scale the blob data by a constant factor.
Definition: blob.cpp:352
int num() const
Deprecated legacy shape accessor num: use shape(0) instead.
Definition: blob.hpp:132
int width() const
Deprecated legacy shape accessor width: use shape(3) instead.
Definition: blob.hpp:138
int count(int start_axis) const
Compute the volume of a slice spanning from a particular first axis to the final axis.
Definition: blob.hpp:103
int count(int start_axis, int end_axis) const
Compute the volume of a slice; i.e., the product of dimensions among a range of axes.
Definition: blob.hpp:85
A wrapper around SyncedMemory holders serving as the basic computational unit through which Layers...
Definition: blob.hpp:24