Caffe
solver_factory.hpp
1 
38 #ifndef CAFFE_SOLVER_FACTORY_H_
39 #define CAFFE_SOLVER_FACTORY_H_
40 
41 #include <map>
42 #include <string>
43 #include <vector>
44 
45 #include "caffe/common.hpp"
46 #include "caffe/proto/caffe.pb.h"
47 
48 namespace caffe {
49 
50 template <typename Dtype>
51 class Solver;
52 
53 template <typename Dtype>
55  public:
56  typedef Solver<Dtype>* (*Creator)(const SolverParameter&);
57  typedef std::map<string, Creator> CreatorRegistry;
58 
59  static CreatorRegistry& Registry() {
60  static CreatorRegistry* g_registry_ = new CreatorRegistry();
61  return *g_registry_;
62  }
63 
64  // Adds a creator.
65  static void AddCreator(const string& type, Creator creator) {
66  CreatorRegistry& registry = Registry();
67  CHECK_EQ(registry.count(type), 0)
68  << "Solver type " << type << " already registered.";
69  registry[type] = creator;
70  }
71 
72  // Get a solver using a SolverParameter.
73  static Solver<Dtype>* CreateSolver(const SolverParameter& param) {
74  const string& type = param.type();
75  CreatorRegistry& registry = Registry();
76  CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type
77  << " (known types: " << SolverTypeListString() << ")";
78  return registry[type](param);
79  }
80 
81  static vector<string> SolverTypeList() {
82  CreatorRegistry& registry = Registry();
83  vector<string> solver_types;
84  for (typename CreatorRegistry::iterator iter = registry.begin();
85  iter != registry.end(); ++iter) {
86  solver_types.push_back(iter->first);
87  }
88  return solver_types;
89  }
90 
91  private:
92  // Solver registry should never be instantiated - everything is done with its
93  // static variables.
94  SolverRegistry() {}
95 
96  static string SolverTypeListString() {
97  vector<string> solver_types = SolverTypeList();
98  string solver_types_str;
99  for (vector<string>::iterator iter = solver_types.begin();
100  iter != solver_types.end(); ++iter) {
101  if (iter != solver_types.begin()) {
102  solver_types_str += ", ";
103  }
104  solver_types_str += *iter;
105  }
106  return solver_types_str;
107  }
108 };
109 
110 
111 template <typename Dtype>
113  public:
114  SolverRegisterer(const string& type,
115  Solver<Dtype>* (*creator)(const SolverParameter&)) {
116  // LOG(INFO) << "Registering solver type: " << type;
117  SolverRegistry<Dtype>::AddCreator(type, creator);
118  }
119 };
120 
121 
122 #define REGISTER_SOLVER_CREATOR(type, creator) \
123  static SolverRegisterer<float> g_creator_f_##type(#type, creator<float>); \
124  static SolverRegisterer<double> g_creator_d_##type(#type, creator<double>) \
125 
126 #define REGISTER_SOLVER_CLASS(type) \
127  template <typename Dtype> \
128  Solver<Dtype>* Creator_##type##Solver( \
129  const SolverParameter& param) \
130  { \
131  return new type##Solver<Dtype>(param); \
132  } \
133  REGISTER_SOLVER_CREATOR(type, Creator_##type##Solver)
134 
135 } // namespace caffe
136 
137 #endif // CAFFE_SOLVER_FACTORY_H_
A layer factory that allows one to register layers. During runtime, registered layers can be called b...
Definition: blob.hpp:14
An interface for classes that perform optimization on Nets.
Definition: solver.hpp:42
Definition: solver_factory.hpp:54
Definition: solver_factory.hpp:112