1 #ifndef CAFFE_SGD_SOLVERS_HPP_ 2 #define CAFFE_SGD_SOLVERS_HPP_ 7 #include "caffe/solver.hpp" 15 template <
typename Dtype>
18 explicit SGDSolver(
const SolverParameter& param)
20 explicit SGDSolver(
const string& param_file)
22 virtual inline const char*
type()
const {
return "SGD"; }
24 const vector<shared_ptr<Blob<Dtype> > >& history() {
return history_; }
28 Dtype GetLearningRate();
29 virtual void ApplyUpdate();
30 virtual void Normalize(
int param_id);
31 virtual void Regularize(
int param_id);
32 virtual void ComputeUpdateValue(
int param_id, Dtype rate);
33 virtual void ClipGradients();
34 virtual void SnapshotSolverState(
const string& model_filename);
35 virtual void SnapshotSolverStateToBinaryProto(
const string& model_filename);
36 virtual void SnapshotSolverStateToHDF5(
const string& model_filename);
37 virtual void RestoreSolverStateFromHDF5(
const string& state_file);
38 virtual void RestoreSolverStateFromBinaryProto(
const string& state_file);
43 vector<shared_ptr<Blob<Dtype> > > history_, update_, temp_;
48 template <
typename Dtype>
55 virtual inline const char*
type()
const {
return "Nesterov"; }
58 virtual void ComputeUpdateValue(
int param_id, Dtype rate);
63 template <
typename Dtype>
70 virtual inline const char*
type()
const {
return "AdaGrad"; }
73 virtual void ComputeUpdateValue(
int param_id, Dtype rate);
74 void constructor_sanity_check() {
75 CHECK_EQ(0, this->param_.momentum())
76 <<
"Momentum cannot be used with AdaGrad.";
83 template <
typename Dtype>
90 virtual inline const char*
type()
const {
return "RMSProp"; }
93 virtual void ComputeUpdateValue(
int param_id, Dtype rate);
94 void constructor_sanity_check() {
95 CHECK_EQ(0, this->param_.momentum())
96 <<
"Momentum cannot be used with RMSProp.";
97 CHECK_GE(this->param_.rms_decay(), 0)
98 <<
"rms_decay should lie between 0 and 1.";
99 CHECK_LT(this->param_.rms_decay(), 1)
100 <<
"rms_decay should lie between 0 and 1.";
106 template <
typename Dtype>
113 virtual inline const char*
type()
const {
return "AdaDelta"; }
116 void AdaDeltaPreSolve();
117 virtual void ComputeUpdateValue(
int param_id, Dtype rate);
130 template <
typename Dtype>
133 explicit AdamSolver(
const SolverParameter& param)
137 virtual inline const char*
type()
const {
return "Adam"; }
141 virtual void ComputeUpdateValue(
int param_id, Dtype rate);
148 #endif // CAFFE_SGD_SOLVERS_HPP_ virtual const char * type() const
Returns the solver type.
Definition: sgd_solvers.hpp:55
A layer factory that allows one to register layers. During runtime, registered layers can be called b...
Definition: blob.hpp:14
virtual const char * type() const
Returns the solver type.
Definition: sgd_solvers.hpp:22
Optimizes the parameters of a Net using stochastic gradient descent (SGD) with momentum.
Definition: sgd_solvers.hpp:16
An interface for classes that perform optimization on Nets.
Definition: solver.hpp:42
AdamSolver, an algorithm for first-order gradient-based optimization of stochastic objective function...
Definition: sgd_solvers.hpp:131
Definition: sgd_solvers.hpp:107
virtual const char * type() const
Returns the solver type.
Definition: sgd_solvers.hpp:113
Definition: sgd_solvers.hpp:49
Definition: sgd_solvers.hpp:84
virtual const char * type() const
Returns the solver type.
Definition: sgd_solvers.hpp:90
virtual const char * type() const
Returns the solver type.
Definition: sgd_solvers.hpp:70
Definition: sgd_solvers.hpp:64
virtual const char * type() const
Returns the solver type.
Definition: sgd_solvers.hpp:137