1 #ifndef STAN_MCMC_HMC_NUTS_CLASSIC_BASE_NUTS_CLASSIC_HPP
2 #define STAN_MCMC_HMC_NUTS_CLASSIC_BASE_NUTS_CLASSIC_HPP
3 
4 #include <stan/callbacks/logger.hpp>
5 #include <boost/math/special_functions/fpclassify.hpp>
6 #include <stan/mcmc/hmc/base_hmc.hpp>
7 #include <stan/mcmc/hmc/hamiltonians/ps_point.hpp>
8 #include <algorithm>
9 #include <cmath>
10 #include <limits>
11 #include <string>
12 #include <vector>
13 
14 namespace stan {
15   namespace mcmc {
16 
17     struct nuts_util {
18       // Constants through each recursion
19       double log_u;
20       double H0;
21       int sign;
22 
23       // Aggregators through each recursion
24       int n_tree;
25       double sum_prob;
26       bool criterion;
27 
28       // just to guarantee bool initializes to valid value
nuts_utilstan::mcmc::nuts_util29       nuts_util() : criterion(false) { }
30     };
31 
32     // The No-U-Turn Sampler (NUTS) with the
33     // original slice sampler implementation
34     template <class Model, template<class, class> class Hamiltonian,
35               template<class> class Integrator, class BaseRNG>
36     class base_nuts_classic:
37       public base_hmc<Model, Hamiltonian, Integrator, BaseRNG> {
38     public:
base_nuts_classic(const Model & model,BaseRNG & rng)39       base_nuts_classic(const Model& model, BaseRNG& rng):
40         base_hmc<Model, Hamiltonian, Integrator, BaseRNG>(model, rng),
41         depth_(0), max_depth_(5), max_delta_(1000),
42         n_leapfrog_(0), divergent_(0), energy_(0) {
43       }
44 
~base_nuts_classic()45       ~base_nuts_classic() {}
46 
set_max_depth(int d)47       void set_max_depth(int d) {
48         if (d > 0)
49           max_depth_ = d;
50       }
51 
set_max_delta(double d)52       void set_max_delta(double d) {
53         max_delta_ = d;
54       }
55 
get_max_depth()56       int get_max_depth() { return this->max_depth_; }
get_max_delta()57       double get_max_delta() { return this->max_delta_; }
58 
59       sample
transition(sample & init_sample,callbacks::logger & logger)60       transition(sample& init_sample, callbacks::logger& logger) {
61         // Initialize the algorithm
62         this->sample_stepsize();
63 
64         nuts_util util;
65 
66         this->seed(init_sample.cont_params());
67 
68         this->hamiltonian_.sample_p(this->z_, this->rand_int_);
69         this->hamiltonian_.init(this->z_, logger);
70 
71         ps_point z_plus(this->z_);
72         ps_point z_minus(z_plus);
73 
74         ps_point z_sample(z_plus);
75         ps_point z_propose(z_plus);
76 
77         int n_cont = init_sample.cont_params().size();
78 
79         Eigen::VectorXd rho_init = this->z_.p;
80         Eigen::VectorXd rho_plus(n_cont); rho_plus.setZero();
81         Eigen::VectorXd rho_minus(n_cont); rho_minus.setZero();
82 
83         util.H0 = this->hamiltonian_.H(this->z_);
84 
85         // Sample the slice variable
86         util.log_u = std::log(this->rand_uniform_());
87 
88         // Build a balanced binary tree until the NUTS criterion fails
89         util.criterion = true;
90         int n_valid = 0;
91 
92         this->depth_ = 0;
93         this->divergent_ = 0;
94 
95         util.n_tree = 0;
96         util.sum_prob = 0;
97 
98         while (util.criterion && (this->depth_ <= this->max_depth_)) {
99           // Randomly sample a direction in time
100           ps_point* z = 0;
101           Eigen::VectorXd* rho = 0;
102 
103           if (this->rand_uniform_() > 0.5) {
104             z = &z_plus;
105             rho = &rho_plus;
106             util.sign = 1;
107           } else {
108             z = &z_minus;
109             rho = &rho_minus;
110             util.sign = -1;
111           }
112 
113           // And build a new subtree in that direction
114           this->z_.ps_point::operator=(*z);
115 
116           int n_valid_subtree = build_tree(depth_, *rho, 0, z_propose, util,
117                                            logger);
118           ++(this->depth_);
119 
120           *z = this->z_;
121 
122           // Metropolis-Hastings sample the fresh subtree
123           if (!util.criterion)
124             break;
125 
126           double subtree_prob = 0;
127 
128           if (n_valid) {
129             subtree_prob = static_cast<double>(n_valid_subtree) /
130               static_cast<double>(n_valid);
131           } else {
132             subtree_prob = n_valid_subtree ? 1 : 0;
133           }
134 
135           if (this->rand_uniform_() < subtree_prob)
136             z_sample = z_propose;
137 
138           n_valid += n_valid_subtree;
139 
140           // Check validity of completed tree
141           this->z_.ps_point::operator=(z_plus);
142           Eigen::VectorXd delta_rho = rho_minus + rho_init + rho_plus;
143 
144           util.criterion = compute_criterion(z_minus, this->z_, delta_rho);
145         }
146 
147         this->n_leapfrog_ = util.n_tree;
148 
149         double accept_prob = util.sum_prob / static_cast<double>(util.n_tree);
150 
151         this->z_.ps_point::operator=(z_sample);
152         this->energy_ = this->hamiltonian_.H(this->z_);
153         return sample(this->z_.q, - this->z_.V, accept_prob);
154       }
155 
get_sampler_param_names(std::vector<std::string> & names)156       void get_sampler_param_names(std::vector<std::string>& names) {
157         names.push_back("stepsize__");
158         names.push_back("treedepth__");
159         names.push_back("n_leapfrog__");
160         names.push_back("divergent__");
161         names.push_back("energy__");
162       }
163 
get_sampler_params(std::vector<double> & values)164       void get_sampler_params(std::vector<double>& values) {
165         values.push_back(this->epsilon_);
166         values.push_back(this->depth_);
167         values.push_back(this->n_leapfrog_);
168         values.push_back(this->divergent_);
169         values.push_back(this->energy_);
170       }
171 
172       virtual bool compute_criterion(ps_point& start,
173                                      typename Hamiltonian<Model, BaseRNG>
174                                      ::PointType& finish,
175                                      Eigen::VectorXd& rho) = 0;
176 
177       // Returns number of valid points in the completed subtree
build_tree(int depth,Eigen::VectorXd & rho,ps_point * z_init_parent,ps_point & z_propose,nuts_util & util,callbacks::logger & logger)178       int build_tree(int depth, Eigen::VectorXd& rho,
179                      ps_point* z_init_parent, ps_point& z_propose,
180                      nuts_util& util,
181                      callbacks::logger& logger) {
182         // Base case
183         if (depth == 0) {
184             this->integrator_.evolve(this->z_, this->hamiltonian_,
185                                      util.sign * this->epsilon_,
186                                      logger);
187             rho += this->z_.p;
188 
189             if (z_init_parent) *z_init_parent = this->z_;
190             z_propose = this->z_;
191 
192             double h = this->hamiltonian_.H(this->z_);
193             if (boost::math::isnan(h))
194               h = std::numeric_limits<double>::infinity();
195 
196             util.criterion = util.log_u + (h - util.H0) < this->max_delta_;
197             if (!util.criterion) ++(this->divergent_);
198 
199             util.sum_prob += std::min(1.0, std::exp(util.H0 - h));
200             util.n_tree += 1;
201 
202             return (util.log_u + (h - util.H0) < 0);
203 
204           } else {
205           // General recursion
206           Eigen::VectorXd left_subtree_rho(rho.size());
207           left_subtree_rho.setZero();
208           ps_point z_init(this->z_);
209 
210           int n1 = build_tree(depth - 1, left_subtree_rho, &z_init,
211                               z_propose, util,
212                               logger);
213 
214           if (z_init_parent) *z_init_parent = z_init;
215 
216           if (!util.criterion) return 0;
217 
218           Eigen::VectorXd right_subtree_rho(rho.size());
219           right_subtree_rho.setZero();
220           ps_point z_propose_right(z_init);
221 
222           int n2 = build_tree(depth - 1, right_subtree_rho, 0,
223                               z_propose_right, util,
224                               logger);
225 
226           double accept_prob = static_cast<double>(n2) /
227             static_cast<double>(n1 + n2);
228 
229           if ( util.criterion && (this->rand_uniform_() < accept_prob) )
230             z_propose = z_propose_right;
231 
232           Eigen::VectorXd& subtree_rho = left_subtree_rho;
233           subtree_rho += right_subtree_rho;
234 
235           rho += subtree_rho;
236 
237           util.criterion &= compute_criterion(z_init, this->z_, subtree_rho);
238 
239           return n1 + n2;
240         }
241       }
242 
243       int depth_;
244       int max_depth_;
245       double max_delta_;
246 
247       int n_leapfrog_;
248       int divergent_;
249       double energy_;
250     };
251 
252   }  // mcmc
253 }  // stan
254 #endif
255