1 #ifndef STAN_SERVICES_OPTIMIZE_BFGS_HPP
2 #define STAN_SERVICES_OPTIMIZE_BFGS_HPP
3 
4 #include <stan/callbacks/interrupt.hpp>
5 #include <stan/callbacks/logger.hpp>
6 #include <stan/callbacks/writer.hpp>
7 #include <stan/io/var_context.hpp>
8 #include <stan/services/error_codes.hpp>
9 #include <stan/optimization/bfgs.hpp>
10 #include <stan/services/util/initialize.hpp>
11 #include <stan/services/util/create_rng.hpp>
12 #include <fstream>
13 #include <iostream>
14 #include <iomanip>
15 #include <string>
16 #include <vector>
17 
18 namespace stan {
19 namespace services {
20 namespace optimize {
21 
22 /**
23  * Runs the BFGS algorithm for a model.
24  *
25  * @tparam Model A model implementation
26  * @param[in] model Input model to test (with data already instantiated)
27  * @param[in] init var context for initialization
28  * @param[in] random_seed random seed for the random number generator
29  * @param[in] chain chain id to advance the pseudo random number generator
30  * @param[in] init_radius radius to initialize
31  * @param[in] init_alpha line search step size for first iteration
32  * @param[in] tol_obj convergence tolerance on absolute changes in
33  *   objective function value
34  * @param[in] tol_rel_obj convergence tolerance on relative changes
35  *   in objective function value
36  * @param[in] tol_grad convergence tolerance on the norm of the gradient
37  * @param[in] tol_rel_grad convergence tolerance on the relative norm of
38  *   the gradient
39  * @param[in] tol_param convergence tolerance on changes in parameter
40  *   value
41  * @param[in] num_iterations maximum number of iterations
42  * @param[in] save_iterations indicates whether all the iterations should
43  *   be saved to the parameter_writer
44  * @param[in] refresh how often to write output to logger
45  * @param[in,out] interrupt callback to be called every iteration
46  * @param[in,out] logger Logger for messages
47  * @param[in,out] init_writer Writer callback for unconstrained inits
48  * @param[in,out] parameter_writer output for parameter values
49  * @return error_codes::OK if successful
50  */
51 template <class Model>
bfgs(Model & model,const stan::io::var_context & init,unsigned int random_seed,unsigned int chain,double init_radius,double init_alpha,double tol_obj,double tol_rel_obj,double tol_grad,double tol_rel_grad,double tol_param,int num_iterations,bool save_iterations,int refresh,callbacks::interrupt & interrupt,callbacks::logger & logger,callbacks::writer & init_writer,callbacks::writer & parameter_writer)52 int bfgs(Model& model, const stan::io::var_context& init,
53          unsigned int random_seed, unsigned int chain, double init_radius,
54          double init_alpha, double tol_obj, double tol_rel_obj, double tol_grad,
55          double tol_rel_grad, double tol_param, int num_iterations,
56          bool save_iterations, int refresh, callbacks::interrupt& interrupt,
57          callbacks::logger& logger, callbacks::writer& init_writer,
58          callbacks::writer& parameter_writer) {
59   boost::ecuyer1988 rng = util::create_rng(random_seed, chain);
60 
61   std::vector<int> disc_vector;
62   std::vector<double> cont_vector = util::initialize<false>(
63       model, init, rng, init_radius, false, logger, init_writer);
64 
65   std::stringstream bfgs_ss;
66   typedef stan::optimization::BFGSLineSearch<
67       Model, stan::optimization::BFGSUpdate_HInv<> >
68       Optimizer;
69   Optimizer bfgs(model, cont_vector, disc_vector, &bfgs_ss);
70   bfgs._ls_opts.alpha0 = init_alpha;
71   bfgs._conv_opts.tolAbsF = tol_obj;
72   bfgs._conv_opts.tolRelF = tol_rel_obj;
73   bfgs._conv_opts.tolAbsGrad = tol_grad;
74   bfgs._conv_opts.tolRelGrad = tol_rel_grad;
75   bfgs._conv_opts.tolAbsX = tol_param;
76   bfgs._conv_opts.maxIts = num_iterations;
77 
78   double lp = bfgs.logp();
79 
80   std::stringstream initial_msg;
81   initial_msg << "Initial log joint probability = " << lp;
82   logger.info(initial_msg);
83 
84   std::vector<std::string> names;
85   names.push_back("lp__");
86   model.constrained_param_names(names, true, true);
87   parameter_writer(names);
88 
89   if (save_iterations) {
90     std::vector<double> values;
91     std::stringstream msg;
92     model.write_array(rng, cont_vector, disc_vector, values, true, true, &msg);
93     if (msg.str().length() > 0)
94       logger.info(msg);
95 
96     values.insert(values.begin(), lp);
97     parameter_writer(values);
98   }
99   int ret = 0;
100 
101   while (ret == 0) {
102     interrupt();
103     if (refresh > 0
104         && (bfgs.iter_num() == 0 || ((bfgs.iter_num() + 1) % refresh == 0)))
105       logger.info(
106           "    Iter"
107           "      log prob"
108           "        ||dx||"
109           "      ||grad||"
110           "       alpha"
111           "      alpha0"
112           "  # evals"
113           "  Notes ");
114 
115     ret = bfgs.step();
116     lp = bfgs.logp();
117     bfgs.params_r(cont_vector);
118 
119     if (refresh > 0
120         && (ret != 0 || !bfgs.note().empty() || bfgs.iter_num() == 0
121             || ((bfgs.iter_num() + 1) % refresh == 0))) {
122       std::stringstream msg;
123       msg << " " << std::setw(7) << bfgs.iter_num() << " ";
124       msg << " " << std::setw(12) << std::setprecision(6) << lp << " ";
125       msg << " " << std::setw(12) << std::setprecision(6)
126           << bfgs.prev_step_size() << " ";
127       msg << " " << std::setw(12) << std::setprecision(6)
128           << bfgs.curr_g().norm() << " ";
129       msg << " " << std::setw(10) << std::setprecision(4) << bfgs.alpha()
130           << " ";
131       msg << " " << std::setw(10) << std::setprecision(4) << bfgs.alpha0()
132           << " ";
133       msg << " " << std::setw(7) << bfgs.grad_evals() << " ";
134       msg << " " << bfgs.note() << " ";
135       logger.info(msg);
136     }
137 
138     if (bfgs_ss.str().length() > 0) {
139       logger.info(bfgs_ss);
140       bfgs_ss.str("");
141     }
142 
143     if (save_iterations) {
144       std::vector<double> values;
145       std::stringstream msg;
146       model.write_array(rng, cont_vector, disc_vector, values, true, true,
147                         &msg);
148       // This if is here to match the pre-refactor behavior
149       if (msg.str().length() > 0)
150         logger.info(msg);
151 
152       values.insert(values.begin(), lp);
153       parameter_writer(values);
154     }
155   }
156 
157   if (!save_iterations) {
158     std::vector<double> values;
159     std::stringstream msg;
160     model.write_array(rng, cont_vector, disc_vector, values, true, true, &msg);
161     if (msg.str().length() > 0)
162       logger.info(msg);
163     values.insert(values.begin(), lp);
164     parameter_writer(values);
165   }
166 
167   int return_code;
168   if (ret >= 0) {
169     logger.info("Optimization terminated normally: ");
170     return_code = error_codes::OK;
171   } else {
172     logger.info("Optimization terminated with error: ");
173     return_code = error_codes::SOFTWARE;
174   }
175   logger.info("  " + bfgs.get_code_string(ret));
176 
177   return return_code;
178 }
179 
180 }  // namespace optimize
181 }  // namespace services
182 }  // namespace stan
183 #endif
184