Caffe
sgd_solvers.hpp
1 #ifndef CAFFE_SGD_SOLVERS_HPP_
2 #define CAFFE_SGD_SOLVERS_HPP_
3 
4 #include <string>
5 #include <vector>
6 
7 #include "caffe/solver.hpp"
8 
9 namespace caffe {
10 
15 template <typename Dtype>
16 class SGDSolver : public Solver<Dtype> {
17  public:
18  explicit SGDSolver(const SolverParameter& param)
19  : Solver<Dtype>(param) { PreSolve(); }
20  explicit SGDSolver(const string& param_file)
21  : Solver<Dtype>(param_file) { PreSolve(); }
22  virtual inline const char* type() const { return "SGD"; }
23 
24  const vector<shared_ptr<Blob<Dtype> > >& history() { return history_; }
25 
26  protected:
27  void PreSolve();
28  Dtype GetLearningRate();
29  virtual void ApplyUpdate();
30  virtual void Normalize(int param_id);
31  virtual void Regularize(int param_id);
32  virtual void ComputeUpdateValue(int param_id, Dtype rate);
33  virtual void ClipGradients();
34  virtual void SnapshotSolverState(const string& model_filename);
35  virtual void SnapshotSolverStateToBinaryProto(const string& model_filename);
36  virtual void SnapshotSolverStateToHDF5(const string& model_filename);
37  virtual void RestoreSolverStateFromHDF5(const string& state_file);
38  virtual void RestoreSolverStateFromBinaryProto(const string& state_file);
39  // history maintains the historical momentum data.
40  // update maintains update related data and is not needed in snapshots.
41  // temp maintains other information that might be needed in computation
42  // of gradients/updates and is not needed in snapshots
43  vector<shared_ptr<Blob<Dtype> > > history_, update_, temp_;
44 
45  DISABLE_COPY_AND_ASSIGN(SGDSolver);
46 };
47 
48 template <typename Dtype>
49 class NesterovSolver : public SGDSolver<Dtype> {
50  public:
51  explicit NesterovSolver(const SolverParameter& param)
52  : SGDSolver<Dtype>(param) {}
53  explicit NesterovSolver(const string& param_file)
54  : SGDSolver<Dtype>(param_file) {}
55  virtual inline const char* type() const { return "Nesterov"; }
56 
57  protected:
58  virtual void ComputeUpdateValue(int param_id, Dtype rate);
59 
60  DISABLE_COPY_AND_ASSIGN(NesterovSolver);
61 };
62 
63 template <typename Dtype>
64 class AdaGradSolver : public SGDSolver<Dtype> {
65  public:
66  explicit AdaGradSolver(const SolverParameter& param)
67  : SGDSolver<Dtype>(param) { constructor_sanity_check(); }
68  explicit AdaGradSolver(const string& param_file)
69  : SGDSolver<Dtype>(param_file) { constructor_sanity_check(); }
70  virtual inline const char* type() const { return "AdaGrad"; }
71 
72  protected:
73  virtual void ComputeUpdateValue(int param_id, Dtype rate);
74  void constructor_sanity_check() {
75  CHECK_EQ(0, this->param_.momentum())
76  << "Momentum cannot be used with AdaGrad.";
77  }
78 
79  DISABLE_COPY_AND_ASSIGN(AdaGradSolver);
80 };
81 
82 
83 template <typename Dtype>
84 class RMSPropSolver : public SGDSolver<Dtype> {
85  public:
86  explicit RMSPropSolver(const SolverParameter& param)
87  : SGDSolver<Dtype>(param) { constructor_sanity_check(); }
88  explicit RMSPropSolver(const string& param_file)
89  : SGDSolver<Dtype>(param_file) { constructor_sanity_check(); }
90  virtual inline const char* type() const { return "RMSProp"; }
91 
92  protected:
93  virtual void ComputeUpdateValue(int param_id, Dtype rate);
94  void constructor_sanity_check() {
95  CHECK_EQ(0, this->param_.momentum())
96  << "Momentum cannot be used with RMSProp.";
97  CHECK_GE(this->param_.rms_decay(), 0)
98  << "rms_decay should lie between 0 and 1.";
99  CHECK_LT(this->param_.rms_decay(), 1)
100  << "rms_decay should lie between 0 and 1.";
101  }
102 
103  DISABLE_COPY_AND_ASSIGN(RMSPropSolver);
104 };
105 
106 template <typename Dtype>
107 class AdaDeltaSolver : public SGDSolver<Dtype> {
108  public:
109  explicit AdaDeltaSolver(const SolverParameter& param)
110  : SGDSolver<Dtype>(param) { AdaDeltaPreSolve(); }
111  explicit AdaDeltaSolver(const string& param_file)
112  : SGDSolver<Dtype>(param_file) { AdaDeltaPreSolve(); }
113  virtual inline const char* type() const { return "AdaDelta"; }
114 
115  protected:
116  void AdaDeltaPreSolve();
117  virtual void ComputeUpdateValue(int param_id, Dtype rate);
118 
119  DISABLE_COPY_AND_ASSIGN(AdaDeltaSolver);
120 };
121 
130 template <typename Dtype>
131 class AdamSolver : public SGDSolver<Dtype> {
132  public:
133  explicit AdamSolver(const SolverParameter& param)
134  : SGDSolver<Dtype>(param) { AdamPreSolve();}
135  explicit AdamSolver(const string& param_file)
136  : SGDSolver<Dtype>(param_file) { AdamPreSolve(); }
137  virtual inline const char* type() const { return "Adam"; }
138 
139  protected:
140  void AdamPreSolve();
141  virtual void ComputeUpdateValue(int param_id, Dtype rate);
142 
143  DISABLE_COPY_AND_ASSIGN(AdamSolver);
144 };
145 
146 } // namespace caffe
147 
148 #endif // CAFFE_SGD_SOLVERS_HPP_
virtual const char * type() const
Returns the solver type.
Definition: sgd_solvers.hpp:55
A layer factory that allows one to register layers. During runtime, registered layers can be called b...
Definition: blob.hpp:14
virtual const char * type() const
Returns the solver type.
Definition: sgd_solvers.hpp:22
Optimizes the parameters of a Net using stochastic gradient descent (SGD) with momentum.
Definition: sgd_solvers.hpp:16
An interface for classes that perform optimization on Nets.
Definition: solver.hpp:42
AdamSolver, an algorithm for first-order gradient-based optimization of stochastic objective function...
Definition: sgd_solvers.hpp:131
Definition: sgd_solvers.hpp:107
virtual const char * type() const
Returns the solver type.
Definition: sgd_solvers.hpp:113
Definition: sgd_solvers.hpp:49
Definition: sgd_solvers.hpp:84
virtual const char * type() const
Returns the solver type.
Definition: sgd_solvers.hpp:90
virtual const char * type() const
Returns the solver type.
Definition: sgd_solvers.hpp:70
Definition: sgd_solvers.hpp:64
virtual const char * type() const
Returns the solver type.
Definition: sgd_solvers.hpp:137