1 #ifndef STAN_SERVICES_OPTIMIZE_BFGS_HPP
2 #define STAN_SERVICES_OPTIMIZE_BFGS_HPP
3
4 #include <stan/callbacks/interrupt.hpp>
5 #include <stan/callbacks/logger.hpp>
6 #include <stan/callbacks/writer.hpp>
7 #include <stan/io/var_context.hpp>
8 #include <stan/services/error_codes.hpp>
9 #include <stan/optimization/bfgs.hpp>
10 #include <stan/services/util/initialize.hpp>
11 #include <stan/services/util/create_rng.hpp>
12 #include <fstream>
13 #include <iostream>
14 #include <iomanip>
15 #include <string>
16 #include <vector>
17
18 namespace stan {
19 namespace services {
20 namespace optimize {
21
22 /**
23 * Runs the BFGS algorithm for a model.
24 *
25 * @tparam Model A model implementation
26 * @param[in] model Input model to test (with data already instantiated)
27 * @param[in] init var context for initialization
28 * @param[in] random_seed random seed for the random number generator
29 * @param[in] chain chain id to advance the pseudo random number generator
30 * @param[in] init_radius radius to initialize
31 * @param[in] init_alpha line search step size for first iteration
32 * @param[in] tol_obj convergence tolerance on absolute changes in
33 * objective function value
34 * @param[in] tol_rel_obj convergence tolerance on relative changes
35 * in objective function value
36 * @param[in] tol_grad convergence tolerance on the norm of the gradient
37 * @param[in] tol_rel_grad convergence tolerance on the relative norm of
38 * the gradient
39 * @param[in] tol_param convergence tolerance on changes in parameter
40 * value
41 * @param[in] num_iterations maximum number of iterations
42 * @param[in] save_iterations indicates whether all the iterations should
43 * be saved to the parameter_writer
44 * @param[in] refresh how often to write output to logger
45 * @param[in,out] interrupt callback to be called every iteration
46 * @param[in,out] logger Logger for messages
47 * @param[in,out] init_writer Writer callback for unconstrained inits
48 * @param[in,out] parameter_writer output for parameter values
49 * @return error_codes::OK if successful
50 */
51 template <class Model>
bfgs(Model & model,const stan::io::var_context & init,unsigned int random_seed,unsigned int chain,double init_radius,double init_alpha,double tol_obj,double tol_rel_obj,double tol_grad,double tol_rel_grad,double tol_param,int num_iterations,bool save_iterations,int refresh,callbacks::interrupt & interrupt,callbacks::logger & logger,callbacks::writer & init_writer,callbacks::writer & parameter_writer)52 int bfgs(Model& model, const stan::io::var_context& init,
53 unsigned int random_seed, unsigned int chain, double init_radius,
54 double init_alpha, double tol_obj, double tol_rel_obj, double tol_grad,
55 double tol_rel_grad, double tol_param, int num_iterations,
56 bool save_iterations, int refresh, callbacks::interrupt& interrupt,
57 callbacks::logger& logger, callbacks::writer& init_writer,
58 callbacks::writer& parameter_writer) {
59 boost::ecuyer1988 rng = util::create_rng(random_seed, chain);
60
61 std::vector<int> disc_vector;
62 std::vector<double> cont_vector = util::initialize<false>(
63 model, init, rng, init_radius, false, logger, init_writer);
64
65 std::stringstream bfgs_ss;
66 typedef stan::optimization::BFGSLineSearch<
67 Model, stan::optimization::BFGSUpdate_HInv<> >
68 Optimizer;
69 Optimizer bfgs(model, cont_vector, disc_vector, &bfgs_ss);
70 bfgs._ls_opts.alpha0 = init_alpha;
71 bfgs._conv_opts.tolAbsF = tol_obj;
72 bfgs._conv_opts.tolRelF = tol_rel_obj;
73 bfgs._conv_opts.tolAbsGrad = tol_grad;
74 bfgs._conv_opts.tolRelGrad = tol_rel_grad;
75 bfgs._conv_opts.tolAbsX = tol_param;
76 bfgs._conv_opts.maxIts = num_iterations;
77
78 double lp = bfgs.logp();
79
80 std::stringstream initial_msg;
81 initial_msg << "Initial log joint probability = " << lp;
82 logger.info(initial_msg);
83
84 std::vector<std::string> names;
85 names.push_back("lp__");
86 model.constrained_param_names(names, true, true);
87 parameter_writer(names);
88
89 if (save_iterations) {
90 std::vector<double> values;
91 std::stringstream msg;
92 model.write_array(rng, cont_vector, disc_vector, values, true, true, &msg);
93 if (msg.str().length() > 0)
94 logger.info(msg);
95
96 values.insert(values.begin(), lp);
97 parameter_writer(values);
98 }
99 int ret = 0;
100
101 while (ret == 0) {
102 interrupt();
103 if (refresh > 0
104 && (bfgs.iter_num() == 0 || ((bfgs.iter_num() + 1) % refresh == 0)))
105 logger.info(
106 " Iter"
107 " log prob"
108 " ||dx||"
109 " ||grad||"
110 " alpha"
111 " alpha0"
112 " # evals"
113 " Notes ");
114
115 ret = bfgs.step();
116 lp = bfgs.logp();
117 bfgs.params_r(cont_vector);
118
119 if (refresh > 0
120 && (ret != 0 || !bfgs.note().empty() || bfgs.iter_num() == 0
121 || ((bfgs.iter_num() + 1) % refresh == 0))) {
122 std::stringstream msg;
123 msg << " " << std::setw(7) << bfgs.iter_num() << " ";
124 msg << " " << std::setw(12) << std::setprecision(6) << lp << " ";
125 msg << " " << std::setw(12) << std::setprecision(6)
126 << bfgs.prev_step_size() << " ";
127 msg << " " << std::setw(12) << std::setprecision(6)
128 << bfgs.curr_g().norm() << " ";
129 msg << " " << std::setw(10) << std::setprecision(4) << bfgs.alpha()
130 << " ";
131 msg << " " << std::setw(10) << std::setprecision(4) << bfgs.alpha0()
132 << " ";
133 msg << " " << std::setw(7) << bfgs.grad_evals() << " ";
134 msg << " " << bfgs.note() << " ";
135 logger.info(msg);
136 }
137
138 if (bfgs_ss.str().length() > 0) {
139 logger.info(bfgs_ss);
140 bfgs_ss.str("");
141 }
142
143 if (save_iterations) {
144 std::vector<double> values;
145 std::stringstream msg;
146 model.write_array(rng, cont_vector, disc_vector, values, true, true,
147 &msg);
148 // This if is here to match the pre-refactor behavior
149 if (msg.str().length() > 0)
150 logger.info(msg);
151
152 values.insert(values.begin(), lp);
153 parameter_writer(values);
154 }
155 }
156
157 if (!save_iterations) {
158 std::vector<double> values;
159 std::stringstream msg;
160 model.write_array(rng, cont_vector, disc_vector, values, true, true, &msg);
161 if (msg.str().length() > 0)
162 logger.info(msg);
163 values.insert(values.begin(), lp);
164 parameter_writer(values);
165 }
166
167 int return_code;
168 if (ret >= 0) {
169 logger.info("Optimization terminated normally: ");
170 return_code = error_codes::OK;
171 } else {
172 logger.info("Optimization terminated with error: ");
173 return_code = error_codes::SOFTWARE;
174 }
175 logger.info(" " + bfgs.get_code_string(ret));
176
177 return return_code;
178 }
179
180 } // namespace optimize
181 } // namespace services
182 } // namespace stan
183 #endif
184