1 #ifndef STAN_SERVICES_SAMPLE_HMC_NUTS_DIAG_E_ADAPT_HPP
2 #define STAN_SERVICES_SAMPLE_HMC_NUTS_DIAG_E_ADAPT_HPP
3 
4 #include <stan/math/prim.hpp>
5 #include <stan/callbacks/interrupt.hpp>
6 #include <stan/callbacks/logger.hpp>
7 #include <stan/callbacks/writer.hpp>
8 #include <stan/io/var_context.hpp>
9 #include <stan/mcmc/fixed_param_sampler.hpp>
10 #include <stan/services/error_codes.hpp>
11 #include <stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp>
12 #include <stan/services/util/run_adaptive_sampler.hpp>
13 #include <stan/services/util/create_rng.hpp>
14 #include <stan/services/util/initialize.hpp>
15 #include <stan/services/util/inv_metric.hpp>
16 #include <vector>
17 
18 namespace stan {
19 namespace services {
20 namespace sample {
21 
22 /**
23  * Runs HMC with NUTS with adaptation using diagonal Euclidean metric
24  * with a pre-specified Euclidean metric.
25  *
26  * @tparam Model Model class
27  * @tparam InitContextPtr A type derived from `stan::io::var_context`
28  * @tparam InitMetricContext A type derived from `stan::io::var_context`
29  * @tparam SamplerWriter A type derived from `stan::callbacks::writer`
30  * @tparam DiagnosticWriter A type derived from `stan::callbacks::writer`
31  * @tparam InitWriter A type derived from `stan::callbacks::writer`
32  * @param[in] model Input model to test (with data already instantiated)
33  * @param[in] init var context for initialization
34  * @param[in] init_inv_metric var context exposing an initial diagonal
35               inverse Euclidean metric (must be positive definite)
36  * @param[in] random_seed random seed for the random number generator
37  * @param[in] chain chain id to advance the pseudo random number generator
38  * @param[in] init_radius radius to initialize
39  * @param[in] num_warmup Number of warmup samples
40  * @param[in] num_samples Number of samples
41  * @param[in] num_thin Number to thin the samples
42  * @param[in] save_warmup Indicates whether to save the warmup iterations
43  * @param[in] refresh Controls the output
44  * @param[in] stepsize initial stepsize for discrete evolution
45  * @param[in] stepsize_jitter uniform random jitter of stepsize
46  * @param[in] max_depth Maximum tree depth
47  * @param[in] delta adaptation target acceptance statistic
48  * @param[in] gamma adaptation regularization scale
49  * @param[in] kappa adaptation relaxation exponent
50  * @param[in] t0 adaptation iteration offset
51  * @param[in] init_buffer width of initial fast adaptation interval
52  * @param[in] term_buffer width of final fast adaptation interval
53  * @param[in] window initial width of slow adaptation interval
54  * @param[in,out] interrupt Callback for interrupts
55  * @param[in,out] logger Logger for messages
56  * @param[in,out] init_writer Writer callback for unconstrained inits
57  * @param[in,out] sample_writer Writer for draws
58  * @param[in,out] diagnostic_writer Writer for diagnostic information
59  * @return error_codes::OK if successful
60  */
61 template <typename Model>
hmc_nuts_diag_e_adapt(Model & model,const stan::io::var_context & init,const stan::io::var_context & init_inv_metric,unsigned int random_seed,unsigned int chain,double init_radius,int num_warmup,int num_samples,int num_thin,bool save_warmup,int refresh,double stepsize,double stepsize_jitter,int max_depth,double delta,double gamma,double kappa,double t0,unsigned int init_buffer,unsigned int term_buffer,unsigned int window,callbacks::interrupt & interrupt,callbacks::logger & logger,callbacks::writer & init_writer,callbacks::writer & sample_writer,callbacks::writer & diagnostic_writer)62 int hmc_nuts_diag_e_adapt(
63     Model& model, const stan::io::var_context& init,
64     const stan::io::var_context& init_inv_metric, unsigned int random_seed,
65     unsigned int chain, double init_radius, int num_warmup, int num_samples,
66     int num_thin, bool save_warmup, int refresh, double stepsize,
67     double stepsize_jitter, int max_depth, double delta, double gamma,
68     double kappa, double t0, unsigned int init_buffer, unsigned int term_buffer,
69     unsigned int window, callbacks::interrupt& interrupt,
70     callbacks::logger& logger, callbacks::writer& init_writer,
71     callbacks::writer& sample_writer, callbacks::writer& diagnostic_writer) {
72   boost::ecuyer1988 rng = util::create_rng(random_seed, chain);
73 
74   std::vector<double> cont_vector = util::initialize(
75       model, init, rng, init_radius, true, logger, init_writer);
76 
77   Eigen::VectorXd inv_metric;
78   try {
79     inv_metric = util::read_diag_inv_metric(init_inv_metric,
80                                             model.num_params_r(), logger);
81     util::validate_diag_inv_metric(inv_metric, logger);
82   } catch (const std::domain_error& e) {
83     return error_codes::CONFIG;
84   }
85 
86   stan::mcmc::adapt_diag_e_nuts<Model, boost::ecuyer1988> sampler(model, rng);
87 
88   sampler.set_metric(inv_metric);
89   sampler.set_nominal_stepsize(stepsize);
90   sampler.set_stepsize_jitter(stepsize_jitter);
91   sampler.set_max_depth(max_depth);
92 
93   sampler.get_stepsize_adaptation().set_mu(log(10 * stepsize));
94   sampler.get_stepsize_adaptation().set_delta(delta);
95   sampler.get_stepsize_adaptation().set_gamma(gamma);
96   sampler.get_stepsize_adaptation().set_kappa(kappa);
97   sampler.get_stepsize_adaptation().set_t0(t0);
98 
99   sampler.set_window_params(num_warmup, init_buffer, term_buffer, window,
100                             logger);
101 
102   util::run_adaptive_sampler(
103       sampler, model, cont_vector, num_warmup, num_samples, num_thin, refresh,
104       save_warmup, rng, interrupt, logger, sample_writer, diagnostic_writer);
105 
106   return error_codes::OK;
107 }
108 
109 /**
110  * Runs HMC with NUTS with adaptation using diagonal Euclidean metric.
111  *
112  * @tparam Model Model class
113  * @param[in] model Input model to test (with data already instantiated)
114  * @param[in] init var context for initialization
115  * @param[in] random_seed random seed for the random number generator
116  * @param[in] chain chain id to advance the pseudo random number generator
117  * @param[in] init_radius radius to initialize
118  * @param[in] num_warmup Number of warmup samples
119  * @param[in] num_samples Number of samples
120  * @param[in] num_thin Number to thin the samples
121  * @param[in] save_warmup Indicates whether to save the warmup iterations
122  * @param[in] refresh Controls the output
123  * @param[in] stepsize initial stepsize for discrete evolution
124  * @param[in] stepsize_jitter uniform random jitter of stepsize
125  * @param[in] max_depth Maximum tree depth
126  * @param[in] delta adaptation target acceptance statistic
127  * @param[in] gamma adaptation regularization scale
128  * @param[in] kappa adaptation relaxation exponent
129  * @param[in] t0 adaptation iteration offset
130  * @param[in] init_buffer width of initial fast adaptation interval
131  * @param[in] term_buffer width of final fast adaptation interval
132  * @param[in] window initial width of slow adaptation interval
133  * @param[in,out] interrupt Callback for interrupts
134  * @param[in,out] logger Logger for messages
135  * @param[in,out] init_writer Writer callback for unconstrained inits
136  * @param[in,out] sample_writer Writer for draws
137  * @param[in,out] diagnostic_writer Writer for diagnostic information
138  * @return error_codes::OK if successful
139  */
140 template <typename Model>
hmc_nuts_diag_e_adapt(Model & model,const stan::io::var_context & init,unsigned int random_seed,unsigned int chain,double init_radius,int num_warmup,int num_samples,int num_thin,bool save_warmup,int refresh,double stepsize,double stepsize_jitter,int max_depth,double delta,double gamma,double kappa,double t0,unsigned int init_buffer,unsigned int term_buffer,unsigned int window,callbacks::interrupt & interrupt,callbacks::logger & logger,callbacks::writer & init_writer,callbacks::writer & sample_writer,callbacks::writer & diagnostic_writer)141 int hmc_nuts_diag_e_adapt(
142     Model& model, const stan::io::var_context& init, unsigned int random_seed,
143     unsigned int chain, double init_radius, int num_warmup, int num_samples,
144     int num_thin, bool save_warmup, int refresh, double stepsize,
145     double stepsize_jitter, int max_depth, double delta, double gamma,
146     double kappa, double t0, unsigned int init_buffer, unsigned int term_buffer,
147     unsigned int window, callbacks::interrupt& interrupt,
148     callbacks::logger& logger, callbacks::writer& init_writer,
149     callbacks::writer& sample_writer, callbacks::writer& diagnostic_writer) {
150   stan::io::dump unit_e_metric
151       = util::create_unit_e_diag_inv_metric(model.num_params_r());
152   return hmc_nuts_diag_e_adapt(
153       model, init, unit_e_metric, random_seed, chain, init_radius, num_warmup,
154       num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter,
155       max_depth, delta, gamma, kappa, t0, init_buffer, term_buffer, window,
156       interrupt, logger, init_writer, sample_writer, diagnostic_writer);
157 }
158 
159 /**
160  * Runs multiple chains of HMC with NUTS with adaptation using diagonal
161  * Euclidean metric with a pre-specified Euclidean metric.
162  *
163  * @tparam Model Model class
164  * @tparam InitContextPtr A pointer with underlying type derived from
165  `stan::io::var_context`
166  * @tparam InitInvContextPtr A pointer with underlying type derived from
167  `stan::io::var_context`
168  * @tparam SamplerWriter A type derived from `stan::callbacks::writer`
169  * @tparam DiagnosticWriter A type derived from `stan::callbacks::writer`
170  * @tparam InitWriter A type derived from `stan::callbacks::writer`
171  * @param[in] model Input model to test (with data already instantiated)
172  * @param[in] num_chains The number of chains to run in parallel. `init`,
173  * `init_inv_metric`, `init_writer`, `sample_writer`, and `diagnostic_writer`
174  must
175  * be the same length as this value.
176  * @param[in] init An std vector of init var contexts for initialization of each
177  * chain.
178  * @param[in] init_inv_metric An std vector of var contexts exposing an initial
179  diagonal inverse Euclidean metric for each chain (must be positive definite)
180  * @param[in] random_seed random seed for the random number generator
181  * @param[in] init_chain_id first chain id. The pseudo random number generator
182  * will advance for each chain by an integer sequence from `init_chain_id` to
183  * `init_chain_id + num_chains - 1`
184  * @param[in] init_radius radius to initialize
185  * @param[in] num_warmup Number of warmup samples
186  * @param[in] num_samples Number of samples
187  * @param[in] num_thin Number to thin the samples
188  * @param[in] save_warmup Indicates whether to save the warmup iterations
189  * @param[in] refresh Controls the output
190  * @param[in] stepsize initial stepsize for discrete evolution
191  * @param[in] stepsize_jitter uniform random jitter of stepsize
192  * @param[in] max_depth Maximum tree depth
193  * @param[in] delta adaptation target acceptance statistic
194  * @param[in] gamma adaptation regularization scale
195  * @param[in] kappa adaptation relaxation exponent
196  * @param[in] t0 adaptation iteration offset
197  * @param[in] init_buffer width of initial fast adaptation interval
198  * @param[in] term_buffer width of final fast adaptation interval
199  * @param[in] window initial width of slow adaptation interval
200  * @param[in,out] interrupt Callback for interrupts
201  * @param[in,out] logger Logger for messages
202  * @param[in,out] init_writer std vector of Writer callbacks for unconstrained
203  * inits of each chain.
204  * @param[in,out] sample_writer std vector of Writers for draws of each chain.
205  * @param[in,out] diagnostic_writer std vector of Writers for diagnostic
206  * information of each chain.
207  * @return error_codes::OK if successful
208  */
209 template <class Model, typename InitContextPtr, typename InitInvContextPtr,
210           typename InitWriter, typename SampleWriter, typename DiagnosticWriter>
hmc_nuts_diag_e_adapt(Model & model,size_t num_chains,const std::vector<InitContextPtr> & init,const std::vector<InitInvContextPtr> & init_inv_metric,unsigned int random_seed,unsigned int init_chain_id,double init_radius,int num_warmup,int num_samples,int num_thin,bool save_warmup,int refresh,double stepsize,double stepsize_jitter,int max_depth,double delta,double gamma,double kappa,double t0,unsigned int init_buffer,unsigned int term_buffer,unsigned int window,callbacks::interrupt & interrupt,callbacks::logger & logger,std::vector<InitWriter> & init_writer,std::vector<SampleWriter> & sample_writer,std::vector<DiagnosticWriter> & diagnostic_writer)211 int hmc_nuts_diag_e_adapt(
212     Model& model, size_t num_chains, const std::vector<InitContextPtr>& init,
213     const std::vector<InitInvContextPtr>& init_inv_metric,
214     unsigned int random_seed, unsigned int init_chain_id, double init_radius,
215     int num_warmup, int num_samples, int num_thin, bool save_warmup,
216     int refresh, double stepsize, double stepsize_jitter, int max_depth,
217     double delta, double gamma, double kappa, double t0,
218     unsigned int init_buffer, unsigned int term_buffer, unsigned int window,
219     callbacks::interrupt& interrupt, callbacks::logger& logger,
220     std::vector<InitWriter>& init_writer,
221     std::vector<SampleWriter>& sample_writer,
222     std::vector<DiagnosticWriter>& diagnostic_writer) {
223   if (num_chains == 1) {
224     return hmc_nuts_diag_e_adapt(
225         model, *init[0], *init_inv_metric[0], random_seed, init_chain_id,
226         init_radius, num_warmup, num_samples, num_thin, save_warmup, refresh,
227         stepsize, stepsize_jitter, max_depth, delta, gamma, kappa, t0,
228         init_buffer, term_buffer, window, interrupt, logger, init_writer[0],
229         sample_writer[0], diagnostic_writer[0]);
230   }
231   using sample_t = stan::mcmc::adapt_diag_e_nuts<Model, boost::ecuyer1988>;
232   std::vector<boost::ecuyer1988> rngs;
233   rngs.reserve(num_chains);
234   std::vector<std::vector<double>> cont_vectors;
235   cont_vectors.reserve(num_chains);
236   std::vector<sample_t> samplers;
237   samplers.reserve(num_chains);
238   try {
239     for (int i = 0; i < num_chains; ++i) {
240       rngs.emplace_back(util::create_rng(random_seed, init_chain_id + i));
241       cont_vectors.emplace_back(util::initialize(
242           model, *init[i], rngs[i], init_radius, true, logger, init_writer[i]));
243       samplers.emplace_back(model, rngs[i]);
244       Eigen::VectorXd inv_metric = util::read_diag_inv_metric(
245           *init_inv_metric[i], model.num_params_r(), logger);
246       util::validate_diag_inv_metric(inv_metric, logger);
247 
248       samplers[i].set_metric(inv_metric);
249       samplers[i].set_nominal_stepsize(stepsize);
250       samplers[i].set_stepsize_jitter(stepsize_jitter);
251       samplers[i].set_max_depth(max_depth);
252 
253       samplers[i].get_stepsize_adaptation().set_mu(log(10 * stepsize));
254       samplers[i].get_stepsize_adaptation().set_delta(delta);
255       samplers[i].get_stepsize_adaptation().set_gamma(gamma);
256       samplers[i].get_stepsize_adaptation().set_kappa(kappa);
257       samplers[i].get_stepsize_adaptation().set_t0(t0);
258       samplers[i].set_window_params(num_warmup, init_buffer, term_buffer,
259                                     window, logger);
260     }
261   } catch (const std::domain_error& e) {
262     return error_codes::CONFIG;
263   }
264   tbb::parallel_for(tbb::blocked_range<size_t>(0, num_chains, 1),
265                     [num_warmup, num_samples, num_thin, refresh, save_warmup,
266                      num_chains, init_chain_id, &samplers, &model, &rngs,
267                      &interrupt, &logger, &sample_writer, &cont_vectors,
268                      &diagnostic_writer](const tbb::blocked_range<size_t>& r) {
269                       for (size_t i = r.begin(); i != r.end(); ++i) {
270                         util::run_adaptive_sampler(
271                             samplers[i], model, cont_vectors[i], num_warmup,
272                             num_samples, num_thin, refresh, save_warmup,
273                             rngs[i], interrupt, logger, sample_writer[i],
274                             diagnostic_writer[i], init_chain_id + i,
275                             num_chains);
276                       }
277                     },
278                     tbb::simple_partitioner());
279   return error_codes::OK;
280 }
281 
282 /**
283  * Runs multiple chains of HMC with NUTS with adaptation using diagonal
284  * Euclidean metric.
285  *
286  * @tparam Model Model class
287  * @tparam InitContextPtr A pointer with underlying type derived from
288  * `stan::io::var_context`
289  * @tparam SamplerWriter A type derived from `stan::callbacks::writer`
290  * @tparam DiagnosticWriter A type derived from `stan::callbacks::writer`
291  * @tparam InitWriter A type derived from `stan::callbacks::writer`
292  * @param[in] model Input model to test (with data already instantiated)
293  * @param[in] num_chains The number of chains to run in parallel. `init`,
294  * `init_writer`, `sample_writer`, and `diagnostic_writer` must be the same
295  * length as this value.
296  * @param[in] init An std vector of init var contexts for initialization of each
297  * chain.
298  * @param[in] random_seed random seed for the random number generator
299  * @param[in] init_chain_id first chain id. The pseudo random number generator
300  * will advance by for each chain by an integer sequence from `init_chain_id` to
301  * `init_chain_id+num_chains-1`
302  * @param[in] init_radius radius to initialize
303  * @param[in] num_warmup Number of warmup samples
304  * @param[in] num_samples Number of samples
305  * @param[in] num_thin Number to thin the samples
306  * @param[in] save_warmup Indicates whether to save the warmup iterations
307  * @param[in] refresh Controls the output
308  * @param[in] stepsize initial stepsize for discrete evolution
309  * @param[in] stepsize_jitter uniform random jitter of stepsize
310  * @param[in] max_depth Maximum tree depth
311  * @param[in] delta adaptation target acceptance statistic
312  * @param[in] gamma adaptation regularization scale
313  * @param[in] kappa adaptation relaxation exponent
314  * @param[in] t0 adaptation iteration offset
315  * @param[in] init_buffer width of initial fast adaptation interval
316  * @param[in] term_buffer width of final fast adaptation interval
317  * @param[in] window initial width of slow adaptation interval
318  * @param[in,out] interrupt Callback for interrupts
319  * @param[in,out] logger Logger for messages
320  * @param[in,out] init_writer std vector of Writer callbacks for unconstrained
321  * inits of each chain.
322  * @param[in,out] sample_writer std vector of Writers for draws of each chain.
323  * @param[in,out] diagnostic_writer std vector of Writers for diagnostic
324  * information of each chain.
325  * @return error_codes::OK if successful
326  */
327 template <class Model, typename InitContextPtr, typename InitWriter,
328           typename SampleWriter, typename DiagnosticWriter>
hmc_nuts_diag_e_adapt(Model & model,size_t num_chains,const std::vector<InitContextPtr> & init,unsigned int random_seed,unsigned int init_chain_id,double init_radius,int num_warmup,int num_samples,int num_thin,bool save_warmup,int refresh,double stepsize,double stepsize_jitter,int max_depth,double delta,double gamma,double kappa,double t0,unsigned int init_buffer,unsigned int term_buffer,unsigned int window,callbacks::interrupt & interrupt,callbacks::logger & logger,std::vector<InitWriter> & init_writer,std::vector<SampleWriter> & sample_writer,std::vector<DiagnosticWriter> & diagnostic_writer)329 int hmc_nuts_diag_e_adapt(
330     Model& model, size_t num_chains, const std::vector<InitContextPtr>& init,
331     unsigned int random_seed, unsigned int init_chain_id, double init_radius,
332     int num_warmup, int num_samples, int num_thin, bool save_warmup,
333     int refresh, double stepsize, double stepsize_jitter, int max_depth,
334     double delta, double gamma, double kappa, double t0,
335     unsigned int init_buffer, unsigned int term_buffer, unsigned int window,
336     callbacks::interrupt& interrupt, callbacks::logger& logger,
337     std::vector<InitWriter>& init_writer,
338     std::vector<SampleWriter>& sample_writer,
339     std::vector<DiagnosticWriter>& diagnostic_writer) {
340   if (num_chains == 1) {
341     return hmc_nuts_diag_e_adapt(
342         model, *init[0], random_seed, init_chain_id, init_radius, num_warmup,
343         num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter,
344         max_depth, delta, gamma, kappa, t0, init_buffer, term_buffer, window,
345         interrupt, logger, init_writer[0], sample_writer[0],
346         diagnostic_writer[0]);
347   }
348   std::vector<std::unique_ptr<stan::io::dump>> unit_e_metrics;
349   unit_e_metrics.reserve(num_chains);
350   for (size_t i = 0; i < num_chains; ++i) {
351     unit_e_metrics.emplace_back(std::make_unique<stan::io::dump>(
352         util::create_unit_e_diag_inv_metric(model.num_params_r())));
353   }
354   return hmc_nuts_diag_e_adapt(
355       model, num_chains, init, unit_e_metrics, random_seed, init_chain_id,
356       init_radius, num_warmup, num_samples, num_thin, save_warmup, refresh,
357       stepsize, stepsize_jitter, max_depth, delta, gamma, kappa, t0,
358       init_buffer, term_buffer, window, interrupt, logger, init_writer,
359       sample_writer, diagnostic_writer);
360 }
361 
362 }  // namespace sample
363 }  // namespace services
364 }  // namespace stan
365 #endif
366