1 #ifndef STAN_SERVICES_SAMPLE_HMC_NUTS_DIAG_E_ADAPT_HPP
2 #define STAN_SERVICES_SAMPLE_HMC_NUTS_DIAG_E_ADAPT_HPP
3
4 #include <stan/math/prim.hpp>
5 #include <stan/callbacks/interrupt.hpp>
6 #include <stan/callbacks/logger.hpp>
7 #include <stan/callbacks/writer.hpp>
8 #include <stan/io/var_context.hpp>
9 #include <stan/mcmc/fixed_param_sampler.hpp>
10 #include <stan/services/error_codes.hpp>
11 #include <stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp>
12 #include <stan/services/util/run_adaptive_sampler.hpp>
13 #include <stan/services/util/create_rng.hpp>
14 #include <stan/services/util/initialize.hpp>
15 #include <stan/services/util/inv_metric.hpp>
16 #include <vector>
17
18 namespace stan {
19 namespace services {
20 namespace sample {
21
22 /**
23 * Runs HMC with NUTS with adaptation using diagonal Euclidean metric
24 * with a pre-specified Euclidean metric.
25 *
26 * @tparam Model Model class
27 * @tparam InitContextPtr A type derived from `stan::io::var_context`
28 * @tparam InitMetricContext A type derived from `stan::io::var_context`
29 * @tparam SamplerWriter A type derived from `stan::callbacks::writer`
30 * @tparam DiagnosticWriter A type derived from `stan::callbacks::writer`
31 * @tparam InitWriter A type derived from `stan::callbacks::writer`
32 * @param[in] model Input model to test (with data already instantiated)
33 * @param[in] init var context for initialization
34 * @param[in] init_inv_metric var context exposing an initial diagonal
35 inverse Euclidean metric (must be positive definite)
36 * @param[in] random_seed random seed for the random number generator
37 * @param[in] chain chain id to advance the pseudo random number generator
38 * @param[in] init_radius radius to initialize
39 * @param[in] num_warmup Number of warmup samples
40 * @param[in] num_samples Number of samples
41 * @param[in] num_thin Number to thin the samples
42 * @param[in] save_warmup Indicates whether to save the warmup iterations
43 * @param[in] refresh Controls the output
44 * @param[in] stepsize initial stepsize for discrete evolution
45 * @param[in] stepsize_jitter uniform random jitter of stepsize
46 * @param[in] max_depth Maximum tree depth
47 * @param[in] delta adaptation target acceptance statistic
48 * @param[in] gamma adaptation regularization scale
49 * @param[in] kappa adaptation relaxation exponent
50 * @param[in] t0 adaptation iteration offset
51 * @param[in] init_buffer width of initial fast adaptation interval
52 * @param[in] term_buffer width of final fast adaptation interval
53 * @param[in] window initial width of slow adaptation interval
54 * @param[in,out] interrupt Callback for interrupts
55 * @param[in,out] logger Logger for messages
56 * @param[in,out] init_writer Writer callback for unconstrained inits
57 * @param[in,out] sample_writer Writer for draws
58 * @param[in,out] diagnostic_writer Writer for diagnostic information
59 * @return error_codes::OK if successful
60 */
61 template <typename Model>
hmc_nuts_diag_e_adapt(Model & model,const stan::io::var_context & init,const stan::io::var_context & init_inv_metric,unsigned int random_seed,unsigned int chain,double init_radius,int num_warmup,int num_samples,int num_thin,bool save_warmup,int refresh,double stepsize,double stepsize_jitter,int max_depth,double delta,double gamma,double kappa,double t0,unsigned int init_buffer,unsigned int term_buffer,unsigned int window,callbacks::interrupt & interrupt,callbacks::logger & logger,callbacks::writer & init_writer,callbacks::writer & sample_writer,callbacks::writer & diagnostic_writer)62 int hmc_nuts_diag_e_adapt(
63 Model& model, const stan::io::var_context& init,
64 const stan::io::var_context& init_inv_metric, unsigned int random_seed,
65 unsigned int chain, double init_radius, int num_warmup, int num_samples,
66 int num_thin, bool save_warmup, int refresh, double stepsize,
67 double stepsize_jitter, int max_depth, double delta, double gamma,
68 double kappa, double t0, unsigned int init_buffer, unsigned int term_buffer,
69 unsigned int window, callbacks::interrupt& interrupt,
70 callbacks::logger& logger, callbacks::writer& init_writer,
71 callbacks::writer& sample_writer, callbacks::writer& diagnostic_writer) {
72 boost::ecuyer1988 rng = util::create_rng(random_seed, chain);
73
74 std::vector<double> cont_vector = util::initialize(
75 model, init, rng, init_radius, true, logger, init_writer);
76
77 Eigen::VectorXd inv_metric;
78 try {
79 inv_metric = util::read_diag_inv_metric(init_inv_metric,
80 model.num_params_r(), logger);
81 util::validate_diag_inv_metric(inv_metric, logger);
82 } catch (const std::domain_error& e) {
83 return error_codes::CONFIG;
84 }
85
86 stan::mcmc::adapt_diag_e_nuts<Model, boost::ecuyer1988> sampler(model, rng);
87
88 sampler.set_metric(inv_metric);
89 sampler.set_nominal_stepsize(stepsize);
90 sampler.set_stepsize_jitter(stepsize_jitter);
91 sampler.set_max_depth(max_depth);
92
93 sampler.get_stepsize_adaptation().set_mu(log(10 * stepsize));
94 sampler.get_stepsize_adaptation().set_delta(delta);
95 sampler.get_stepsize_adaptation().set_gamma(gamma);
96 sampler.get_stepsize_adaptation().set_kappa(kappa);
97 sampler.get_stepsize_adaptation().set_t0(t0);
98
99 sampler.set_window_params(num_warmup, init_buffer, term_buffer, window,
100 logger);
101
102 util::run_adaptive_sampler(
103 sampler, model, cont_vector, num_warmup, num_samples, num_thin, refresh,
104 save_warmup, rng, interrupt, logger, sample_writer, diagnostic_writer);
105
106 return error_codes::OK;
107 }
108
109 /**
110 * Runs HMC with NUTS with adaptation using diagonal Euclidean metric.
111 *
112 * @tparam Model Model class
113 * @param[in] model Input model to test (with data already instantiated)
114 * @param[in] init var context for initialization
115 * @param[in] random_seed random seed for the random number generator
116 * @param[in] chain chain id to advance the pseudo random number generator
117 * @param[in] init_radius radius to initialize
118 * @param[in] num_warmup Number of warmup samples
119 * @param[in] num_samples Number of samples
120 * @param[in] num_thin Number to thin the samples
121 * @param[in] save_warmup Indicates whether to save the warmup iterations
122 * @param[in] refresh Controls the output
123 * @param[in] stepsize initial stepsize for discrete evolution
124 * @param[in] stepsize_jitter uniform random jitter of stepsize
125 * @param[in] max_depth Maximum tree depth
126 * @param[in] delta adaptation target acceptance statistic
127 * @param[in] gamma adaptation regularization scale
128 * @param[in] kappa adaptation relaxation exponent
129 * @param[in] t0 adaptation iteration offset
130 * @param[in] init_buffer width of initial fast adaptation interval
131 * @param[in] term_buffer width of final fast adaptation interval
132 * @param[in] window initial width of slow adaptation interval
133 * @param[in,out] interrupt Callback for interrupts
134 * @param[in,out] logger Logger for messages
135 * @param[in,out] init_writer Writer callback for unconstrained inits
136 * @param[in,out] sample_writer Writer for draws
137 * @param[in,out] diagnostic_writer Writer for diagnostic information
138 * @return error_codes::OK if successful
139 */
140 template <typename Model>
hmc_nuts_diag_e_adapt(Model & model,const stan::io::var_context & init,unsigned int random_seed,unsigned int chain,double init_radius,int num_warmup,int num_samples,int num_thin,bool save_warmup,int refresh,double stepsize,double stepsize_jitter,int max_depth,double delta,double gamma,double kappa,double t0,unsigned int init_buffer,unsigned int term_buffer,unsigned int window,callbacks::interrupt & interrupt,callbacks::logger & logger,callbacks::writer & init_writer,callbacks::writer & sample_writer,callbacks::writer & diagnostic_writer)141 int hmc_nuts_diag_e_adapt(
142 Model& model, const stan::io::var_context& init, unsigned int random_seed,
143 unsigned int chain, double init_radius, int num_warmup, int num_samples,
144 int num_thin, bool save_warmup, int refresh, double stepsize,
145 double stepsize_jitter, int max_depth, double delta, double gamma,
146 double kappa, double t0, unsigned int init_buffer, unsigned int term_buffer,
147 unsigned int window, callbacks::interrupt& interrupt,
148 callbacks::logger& logger, callbacks::writer& init_writer,
149 callbacks::writer& sample_writer, callbacks::writer& diagnostic_writer) {
150 stan::io::dump unit_e_metric
151 = util::create_unit_e_diag_inv_metric(model.num_params_r());
152 return hmc_nuts_diag_e_adapt(
153 model, init, unit_e_metric, random_seed, chain, init_radius, num_warmup,
154 num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter,
155 max_depth, delta, gamma, kappa, t0, init_buffer, term_buffer, window,
156 interrupt, logger, init_writer, sample_writer, diagnostic_writer);
157 }
158
159 /**
160 * Runs multiple chains of HMC with NUTS with adaptation using diagonal
161 * Euclidean metric with a pre-specified Euclidean metric.
162 *
163 * @tparam Model Model class
164 * @tparam InitContextPtr A pointer with underlying type derived from
165 `stan::io::var_context`
166 * @tparam InitInvContextPtr A pointer with underlying type derived from
167 `stan::io::var_context`
168 * @tparam SamplerWriter A type derived from `stan::callbacks::writer`
169 * @tparam DiagnosticWriter A type derived from `stan::callbacks::writer`
170 * @tparam InitWriter A type derived from `stan::callbacks::writer`
171 * @param[in] model Input model to test (with data already instantiated)
172 * @param[in] num_chains The number of chains to run in parallel. `init`,
173 * `init_inv_metric`, `init_writer`, `sample_writer`, and `diagnostic_writer`
174 must
175 * be the same length as this value.
176 * @param[in] init An std vector of init var contexts for initialization of each
177 * chain.
178 * @param[in] init_inv_metric An std vector of var contexts exposing an initial
179 diagonal inverse Euclidean metric for each chain (must be positive definite)
180 * @param[in] random_seed random seed for the random number generator
181 * @param[in] init_chain_id first chain id. The pseudo random number generator
182 * will advance for each chain by an integer sequence from `init_chain_id` to
183 * `init_chain_id + num_chains - 1`
184 * @param[in] init_radius radius to initialize
185 * @param[in] num_warmup Number of warmup samples
186 * @param[in] num_samples Number of samples
187 * @param[in] num_thin Number to thin the samples
188 * @param[in] save_warmup Indicates whether to save the warmup iterations
189 * @param[in] refresh Controls the output
190 * @param[in] stepsize initial stepsize for discrete evolution
191 * @param[in] stepsize_jitter uniform random jitter of stepsize
192 * @param[in] max_depth Maximum tree depth
193 * @param[in] delta adaptation target acceptance statistic
194 * @param[in] gamma adaptation regularization scale
195 * @param[in] kappa adaptation relaxation exponent
196 * @param[in] t0 adaptation iteration offset
197 * @param[in] init_buffer width of initial fast adaptation interval
198 * @param[in] term_buffer width of final fast adaptation interval
199 * @param[in] window initial width of slow adaptation interval
200 * @param[in,out] interrupt Callback for interrupts
201 * @param[in,out] logger Logger for messages
202 * @param[in,out] init_writer std vector of Writer callbacks for unconstrained
203 * inits of each chain.
204 * @param[in,out] sample_writer std vector of Writers for draws of each chain.
205 * @param[in,out] diagnostic_writer std vector of Writers for diagnostic
206 * information of each chain.
207 * @return error_codes::OK if successful
208 */
209 template <class Model, typename InitContextPtr, typename InitInvContextPtr,
210 typename InitWriter, typename SampleWriter, typename DiagnosticWriter>
hmc_nuts_diag_e_adapt(Model & model,size_t num_chains,const std::vector<InitContextPtr> & init,const std::vector<InitInvContextPtr> & init_inv_metric,unsigned int random_seed,unsigned int init_chain_id,double init_radius,int num_warmup,int num_samples,int num_thin,bool save_warmup,int refresh,double stepsize,double stepsize_jitter,int max_depth,double delta,double gamma,double kappa,double t0,unsigned int init_buffer,unsigned int term_buffer,unsigned int window,callbacks::interrupt & interrupt,callbacks::logger & logger,std::vector<InitWriter> & init_writer,std::vector<SampleWriter> & sample_writer,std::vector<DiagnosticWriter> & diagnostic_writer)211 int hmc_nuts_diag_e_adapt(
212 Model& model, size_t num_chains, const std::vector<InitContextPtr>& init,
213 const std::vector<InitInvContextPtr>& init_inv_metric,
214 unsigned int random_seed, unsigned int init_chain_id, double init_radius,
215 int num_warmup, int num_samples, int num_thin, bool save_warmup,
216 int refresh, double stepsize, double stepsize_jitter, int max_depth,
217 double delta, double gamma, double kappa, double t0,
218 unsigned int init_buffer, unsigned int term_buffer, unsigned int window,
219 callbacks::interrupt& interrupt, callbacks::logger& logger,
220 std::vector<InitWriter>& init_writer,
221 std::vector<SampleWriter>& sample_writer,
222 std::vector<DiagnosticWriter>& diagnostic_writer) {
223 if (num_chains == 1) {
224 return hmc_nuts_diag_e_adapt(
225 model, *init[0], *init_inv_metric[0], random_seed, init_chain_id,
226 init_radius, num_warmup, num_samples, num_thin, save_warmup, refresh,
227 stepsize, stepsize_jitter, max_depth, delta, gamma, kappa, t0,
228 init_buffer, term_buffer, window, interrupt, logger, init_writer[0],
229 sample_writer[0], diagnostic_writer[0]);
230 }
231 using sample_t = stan::mcmc::adapt_diag_e_nuts<Model, boost::ecuyer1988>;
232 std::vector<boost::ecuyer1988> rngs;
233 rngs.reserve(num_chains);
234 std::vector<std::vector<double>> cont_vectors;
235 cont_vectors.reserve(num_chains);
236 std::vector<sample_t> samplers;
237 samplers.reserve(num_chains);
238 try {
239 for (int i = 0; i < num_chains; ++i) {
240 rngs.emplace_back(util::create_rng(random_seed, init_chain_id + i));
241 cont_vectors.emplace_back(util::initialize(
242 model, *init[i], rngs[i], init_radius, true, logger, init_writer[i]));
243 samplers.emplace_back(model, rngs[i]);
244 Eigen::VectorXd inv_metric = util::read_diag_inv_metric(
245 *init_inv_metric[i], model.num_params_r(), logger);
246 util::validate_diag_inv_metric(inv_metric, logger);
247
248 samplers[i].set_metric(inv_metric);
249 samplers[i].set_nominal_stepsize(stepsize);
250 samplers[i].set_stepsize_jitter(stepsize_jitter);
251 samplers[i].set_max_depth(max_depth);
252
253 samplers[i].get_stepsize_adaptation().set_mu(log(10 * stepsize));
254 samplers[i].get_stepsize_adaptation().set_delta(delta);
255 samplers[i].get_stepsize_adaptation().set_gamma(gamma);
256 samplers[i].get_stepsize_adaptation().set_kappa(kappa);
257 samplers[i].get_stepsize_adaptation().set_t0(t0);
258 samplers[i].set_window_params(num_warmup, init_buffer, term_buffer,
259 window, logger);
260 }
261 } catch (const std::domain_error& e) {
262 return error_codes::CONFIG;
263 }
264 tbb::parallel_for(tbb::blocked_range<size_t>(0, num_chains, 1),
265 [num_warmup, num_samples, num_thin, refresh, save_warmup,
266 num_chains, init_chain_id, &samplers, &model, &rngs,
267 &interrupt, &logger, &sample_writer, &cont_vectors,
268 &diagnostic_writer](const tbb::blocked_range<size_t>& r) {
269 for (size_t i = r.begin(); i != r.end(); ++i) {
270 util::run_adaptive_sampler(
271 samplers[i], model, cont_vectors[i], num_warmup,
272 num_samples, num_thin, refresh, save_warmup,
273 rngs[i], interrupt, logger, sample_writer[i],
274 diagnostic_writer[i], init_chain_id + i,
275 num_chains);
276 }
277 },
278 tbb::simple_partitioner());
279 return error_codes::OK;
280 }
281
282 /**
283 * Runs multiple chains of HMC with NUTS with adaptation using diagonal
284 * Euclidean metric.
285 *
286 * @tparam Model Model class
287 * @tparam InitContextPtr A pointer with underlying type derived from
288 * `stan::io::var_context`
289 * @tparam SamplerWriter A type derived from `stan::callbacks::writer`
290 * @tparam DiagnosticWriter A type derived from `stan::callbacks::writer`
291 * @tparam InitWriter A type derived from `stan::callbacks::writer`
292 * @param[in] model Input model to test (with data already instantiated)
293 * @param[in] num_chains The number of chains to run in parallel. `init`,
294 * `init_writer`, `sample_writer`, and `diagnostic_writer` must be the same
295 * length as this value.
296 * @param[in] init An std vector of init var contexts for initialization of each
297 * chain.
298 * @param[in] random_seed random seed for the random number generator
299 * @param[in] init_chain_id first chain id. The pseudo random number generator
300 * will advance by for each chain by an integer sequence from `init_chain_id` to
301 * `init_chain_id+num_chains-1`
302 * @param[in] init_radius radius to initialize
303 * @param[in] num_warmup Number of warmup samples
304 * @param[in] num_samples Number of samples
305 * @param[in] num_thin Number to thin the samples
306 * @param[in] save_warmup Indicates whether to save the warmup iterations
307 * @param[in] refresh Controls the output
308 * @param[in] stepsize initial stepsize for discrete evolution
309 * @param[in] stepsize_jitter uniform random jitter of stepsize
310 * @param[in] max_depth Maximum tree depth
311 * @param[in] delta adaptation target acceptance statistic
312 * @param[in] gamma adaptation regularization scale
313 * @param[in] kappa adaptation relaxation exponent
314 * @param[in] t0 adaptation iteration offset
315 * @param[in] init_buffer width of initial fast adaptation interval
316 * @param[in] term_buffer width of final fast adaptation interval
317 * @param[in] window initial width of slow adaptation interval
318 * @param[in,out] interrupt Callback for interrupts
319 * @param[in,out] logger Logger for messages
320 * @param[in,out] init_writer std vector of Writer callbacks for unconstrained
321 * inits of each chain.
322 * @param[in,out] sample_writer std vector of Writers for draws of each chain.
323 * @param[in,out] diagnostic_writer std vector of Writers for diagnostic
324 * information of each chain.
325 * @return error_codes::OK if successful
326 */
327 template <class Model, typename InitContextPtr, typename InitWriter,
328 typename SampleWriter, typename DiagnosticWriter>
hmc_nuts_diag_e_adapt(Model & model,size_t num_chains,const std::vector<InitContextPtr> & init,unsigned int random_seed,unsigned int init_chain_id,double init_radius,int num_warmup,int num_samples,int num_thin,bool save_warmup,int refresh,double stepsize,double stepsize_jitter,int max_depth,double delta,double gamma,double kappa,double t0,unsigned int init_buffer,unsigned int term_buffer,unsigned int window,callbacks::interrupt & interrupt,callbacks::logger & logger,std::vector<InitWriter> & init_writer,std::vector<SampleWriter> & sample_writer,std::vector<DiagnosticWriter> & diagnostic_writer)329 int hmc_nuts_diag_e_adapt(
330 Model& model, size_t num_chains, const std::vector<InitContextPtr>& init,
331 unsigned int random_seed, unsigned int init_chain_id, double init_radius,
332 int num_warmup, int num_samples, int num_thin, bool save_warmup,
333 int refresh, double stepsize, double stepsize_jitter, int max_depth,
334 double delta, double gamma, double kappa, double t0,
335 unsigned int init_buffer, unsigned int term_buffer, unsigned int window,
336 callbacks::interrupt& interrupt, callbacks::logger& logger,
337 std::vector<InitWriter>& init_writer,
338 std::vector<SampleWriter>& sample_writer,
339 std::vector<DiagnosticWriter>& diagnostic_writer) {
340 if (num_chains == 1) {
341 return hmc_nuts_diag_e_adapt(
342 model, *init[0], random_seed, init_chain_id, init_radius, num_warmup,
343 num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter,
344 max_depth, delta, gamma, kappa, t0, init_buffer, term_buffer, window,
345 interrupt, logger, init_writer[0], sample_writer[0],
346 diagnostic_writer[0]);
347 }
348 std::vector<std::unique_ptr<stan::io::dump>> unit_e_metrics;
349 unit_e_metrics.reserve(num_chains);
350 for (size_t i = 0; i < num_chains; ++i) {
351 unit_e_metrics.emplace_back(std::make_unique<stan::io::dump>(
352 util::create_unit_e_diag_inv_metric(model.num_params_r())));
353 }
354 return hmc_nuts_diag_e_adapt(
355 model, num_chains, init, unit_e_metrics, random_seed, init_chain_id,
356 init_radius, num_warmup, num_samples, num_thin, save_warmup, refresh,
357 stepsize, stepsize_jitter, max_depth, delta, gamma, kappa, t0,
358 init_buffer, term_buffer, window, interrupt, logger, init_writer,
359 sample_writer, diagnostic_writer);
360 }
361
362 } // namespace sample
363 } // namespace services
364 } // namespace stan
365 #endif
366