Caffe
filler.hpp
1 // Fillers are random number generators that fills a blob using the specified
2 // algorithm. The expectation is that they are only going to be used during
3 // initialization time and will not involve any GPUs.
4 
5 #ifndef CAFFE_FILLER_HPP
6 #define CAFFE_FILLER_HPP
7 
8 #include <string>
9 
10 #include "caffe/blob.hpp"
11 #include "caffe/proto/caffe.pb.h"
12 #include "caffe/syncedmem.hpp"
13 #include "caffe/util/math_functions.hpp"
14 
15 namespace caffe {
16 
18 template <typename Dtype>
19 class Filler {
20  public:
21  explicit Filler(const FillerParameter& param) : filler_param_(param) {}
22  virtual ~Filler() {}
23  virtual void Fill(Blob<Dtype>* blob) = 0;
24  protected:
25  FillerParameter filler_param_;
26 }; // class Filler
27 
28 
30 template <typename Dtype>
31 class ConstantFiller : public Filler<Dtype> {
32  public:
33  explicit ConstantFiller(const FillerParameter& param)
34  : Filler<Dtype>(param) {}
35  virtual void Fill(Blob<Dtype>* blob) {
36  Dtype* data = blob->mutable_cpu_data();
37  const int count = blob->count();
38  const Dtype value = this->filler_param_.value();
39  CHECK(count);
40  for (int i = 0; i < count; ++i) {
41  data[i] = value;
42  }
43  CHECK_EQ(this->filler_param_.sparse(), -1)
44  << "Sparsity not supported by this Filler.";
45  }
46 };
47 
49 template <typename Dtype>
50 class UniformFiller : public Filler<Dtype> {
51  public:
52  explicit UniformFiller(const FillerParameter& param)
53  : Filler<Dtype>(param) {}
54  virtual void Fill(Blob<Dtype>* blob) {
55  CHECK(blob->count());
56  caffe_rng_uniform<Dtype>(blob->count(), Dtype(this->filler_param_.min()),
57  Dtype(this->filler_param_.max()), blob->mutable_cpu_data());
58  CHECK_EQ(this->filler_param_.sparse(), -1)
59  << "Sparsity not supported by this Filler.";
60  }
61 };
62 
64 template <typename Dtype>
65 class GaussianFiller : public Filler<Dtype> {
66  public:
67  explicit GaussianFiller(const FillerParameter& param)
68  : Filler<Dtype>(param) {}
69  virtual void Fill(Blob<Dtype>* blob) {
70  Dtype* data = blob->mutable_cpu_data();
71  CHECK(blob->count());
72  caffe_rng_gaussian<Dtype>(blob->count(), Dtype(this->filler_param_.mean()),
73  Dtype(this->filler_param_.std()), blob->mutable_cpu_data());
74  int sparse = this->filler_param_.sparse();
75  CHECK_GE(sparse, -1);
76  if (sparse >= 0) {
77  // Sparse initialization is implemented for "weight" blobs; i.e. matrices.
78  // These have num == channels == 1; width is number of inputs; height is
79  // number of outputs. The 'sparse' variable specifies the mean number
80  // of non-zero input weights for a given output.
81  CHECK_GE(blob->num_axes(), 1);
82  const int num_outputs = blob->shape(0);
83  Dtype non_zero_probability = Dtype(sparse) / Dtype(num_outputs);
84  rand_vec_.reset(new SyncedMemory(blob->count() * sizeof(int)));
85  int* mask = reinterpret_cast<int*>(rand_vec_->mutable_cpu_data());
86  caffe_rng_bernoulli(blob->count(), non_zero_probability, mask);
87  for (int i = 0; i < blob->count(); ++i) {
88  data[i] *= mask[i];
89  }
90  }
91  }
92 
93  protected:
94  shared_ptr<SyncedMemory> rand_vec_;
95 };
96 
100 template <typename Dtype>
101 class PositiveUnitballFiller : public Filler<Dtype> {
102  public:
103  explicit PositiveUnitballFiller(const FillerParameter& param)
104  : Filler<Dtype>(param) {}
105  virtual void Fill(Blob<Dtype>* blob) {
106  Dtype* data = blob->mutable_cpu_data();
107  DCHECK(blob->count());
108  caffe_rng_uniform<Dtype>(blob->count(), 0, 1, blob->mutable_cpu_data());
109  // We expect the filler to not be called very frequently, so we will
110  // just use a simple implementation
111  int dim = blob->count() / blob->num();
112  CHECK(dim);
113  for (int i = 0; i < blob->num(); ++i) {
114  Dtype sum = 0;
115  for (int j = 0; j < dim; ++j) {
116  sum += data[i * dim + j];
117  }
118  for (int j = 0; j < dim; ++j) {
119  data[i * dim + j] /= sum;
120  }
121  }
122  CHECK_EQ(this->filler_param_.sparse(), -1)
123  << "Sparsity not supported by this Filler.";
124  }
125 };
126 
143 template <typename Dtype>
144 class XavierFiller : public Filler<Dtype> {
145  public:
146  explicit XavierFiller(const FillerParameter& param)
147  : Filler<Dtype>(param) {}
148  virtual void Fill(Blob<Dtype>* blob) {
149  CHECK(blob->count());
150  int fan_in = blob->count() / blob->num();
151  int fan_out = blob->count() / blob->channels();
152  Dtype n = fan_in; // default to fan_in
153  if (this->filler_param_.variance_norm() ==
154  FillerParameter_VarianceNorm_AVERAGE) {
155  n = (fan_in + fan_out) / Dtype(2);
156  } else if (this->filler_param_.variance_norm() ==
157  FillerParameter_VarianceNorm_FAN_OUT) {
158  n = fan_out;
159  }
160  Dtype scale = sqrt(Dtype(3) / n);
161  caffe_rng_uniform<Dtype>(blob->count(), -scale, scale,
162  blob->mutable_cpu_data());
163  CHECK_EQ(this->filler_param_.sparse(), -1)
164  << "Sparsity not supported by this Filler.";
165  }
166 };
167 
185 template <typename Dtype>
186 class MSRAFiller : public Filler<Dtype> {
187  public:
188  explicit MSRAFiller(const FillerParameter& param)
189  : Filler<Dtype>(param) {}
190  virtual void Fill(Blob<Dtype>* blob) {
191  CHECK(blob->count());
192  int fan_in = blob->count() / blob->num();
193  int fan_out = blob->count() / blob->channels();
194  Dtype n = fan_in; // default to fan_in
195  if (this->filler_param_.variance_norm() ==
196  FillerParameter_VarianceNorm_AVERAGE) {
197  n = (fan_in + fan_out) / Dtype(2);
198  } else if (this->filler_param_.variance_norm() ==
199  FillerParameter_VarianceNorm_FAN_OUT) {
200  n = fan_out;
201  }
202  Dtype std = sqrt(Dtype(2) / n);
203  caffe_rng_gaussian<Dtype>(blob->count(), Dtype(0), std,
204  blob->mutable_cpu_data());
205  CHECK_EQ(this->filler_param_.sparse(), -1)
206  << "Sparsity not supported by this Filler.";
207  }
208 };
209 
243 template <typename Dtype>
244 class BilinearFiller : public Filler<Dtype> {
245  public:
246  explicit BilinearFiller(const FillerParameter& param)
247  : Filler<Dtype>(param) {}
248  virtual void Fill(Blob<Dtype>* blob) {
249  CHECK_EQ(blob->num_axes(), 4) << "Blob must be 4 dim.";
250  CHECK_EQ(blob->width(), blob->height()) << "Filter must be square";
251  Dtype* data = blob->mutable_cpu_data();
252  int f = ceil(blob->width() / 2.);
253  float c = (2 * f - 1 - f % 2) / (2. * f);
254  for (int i = 0; i < blob->count(); ++i) {
255  float x = i % blob->width();
256  float y = (i / blob->width()) % blob->height();
257  data[i] = (1 - fabs(x / f - c)) * (1 - fabs(y / f - c));
258  }
259  CHECK_EQ(this->filler_param_.sparse(), -1)
260  << "Sparsity not supported by this Filler.";
261  }
262 };
263 
270 template <typename Dtype>
271 Filler<Dtype>* GetFiller(const FillerParameter& param) {
272  const std::string& type = param.type();
273  if (type == "constant") {
274  return new ConstantFiller<Dtype>(param);
275  } else if (type == "gaussian") {
276  return new GaussianFiller<Dtype>(param);
277  } else if (type == "positive_unitball") {
278  return new PositiveUnitballFiller<Dtype>(param);
279  } else if (type == "uniform") {
280  return new UniformFiller<Dtype>(param);
281  } else if (type == "xavier") {
282  return new XavierFiller<Dtype>(param);
283  } else if (type == "msra") {
284  return new MSRAFiller<Dtype>(param);
285  } else if (type == "bilinear") {
286  return new BilinearFiller<Dtype>(param);
287  } else {
288  CHECK(false) << "Unknown filler name: " << param.type();
289  }
290  return (Filler<Dtype>*)(NULL);
291 }
292 
293 } // namespace caffe
294 
295 #endif // CAFFE_FILLER_HPP_
int channels() const
Deprecated legacy shape accessor channels: use shape(1) instead.
Definition: blob.hpp:134
Fills a Blob with constant or randomly-generated data.
Definition: filler.hpp:19
Fills a Blob with values such that .
Definition: filler.hpp:101
A layer factory that allows one to register layers. During runtime, registered layers can be called b...
Definition: blob.hpp:14
int height() const
Deprecated legacy shape accessor height: use shape(2) instead.
Definition: blob.hpp:136
Fills a Blob with coefficients for bilinear interpolation.
Definition: filler.hpp:244
Filler< Dtype > * GetFiller(const FillerParameter &param)
Get a specific filler from the specification given in FillerParameter.
Definition: filler.hpp:271
Fills a Blob with values where is set inversely proportional to number of incoming nodes...
Definition: filler.hpp:144
Fills a Blob with uniformly distributed values .
Definition: filler.hpp:50
Fills a Blob with Gaussian-distributed values .
Definition: filler.hpp:65
Manages memory allocation and synchronization between the host (CPU) and device (GPU).
Definition: syncedmem.hpp:57
Fills a Blob with values where is set inversely proportional to number of incoming nodes...
Definition: filler.hpp:186
int num() const
Deprecated legacy shape accessor num: use shape(0) instead.
Definition: blob.hpp:132
int width() const
Deprecated legacy shape accessor width: use shape(3) instead.
Definition: blob.hpp:138
Fills a Blob with constant values .
Definition: filler.hpp:31
A wrapper around SyncedMemory holders serving as the basic computational unit through which Layers...
Definition: blob.hpp:24