1 #ifndef STAN_MCMC_HMC_NUTS_CLASSIC_BASE_NUTS_CLASSIC_HPP 2 #define STAN_MCMC_HMC_NUTS_CLASSIC_BASE_NUTS_CLASSIC_HPP 3 4 #include <stan/callbacks/logger.hpp> 5 #include <boost/math/special_functions/fpclassify.hpp> 6 #include <stan/mcmc/hmc/base_hmc.hpp> 7 #include <stan/mcmc/hmc/hamiltonians/ps_point.hpp> 8 #include <algorithm> 9 #include <cmath> 10 #include <limits> 11 #include <string> 12 #include <vector> 13 14 namespace stan { 15 namespace mcmc { 16 17 struct nuts_util { 18 // Constants through each recursion 19 double log_u; 20 double H0; 21 int sign; 22 23 // Aggregators through each recursion 24 int n_tree; 25 double sum_prob; 26 bool criterion; 27 28 // just to guarantee bool initializes to valid value nuts_utilstan::mcmc::nuts_util29 nuts_util() : criterion(false) { } 30 }; 31 32 // The No-U-Turn Sampler (NUTS) with the 33 // original slice sampler implementation 34 template <class Model, template<class, class> class Hamiltonian, 35 template<class> class Integrator, class BaseRNG> 36 class base_nuts_classic: 37 public base_hmc<Model, Hamiltonian, Integrator, BaseRNG> { 38 public: base_nuts_classic(const Model & model,BaseRNG & rng)39 base_nuts_classic(const Model& model, BaseRNG& rng): 40 base_hmc<Model, Hamiltonian, Integrator, BaseRNG>(model, rng), 41 depth_(0), max_depth_(5), max_delta_(1000), 42 n_leapfrog_(0), divergent_(0), energy_(0) { 43 } 44 ~base_nuts_classic()45 ~base_nuts_classic() {} 46 set_max_depth(int d)47 void set_max_depth(int d) { 48 if (d > 0) 49 max_depth_ = d; 50 } 51 set_max_delta(double d)52 void set_max_delta(double d) { 53 max_delta_ = d; 54 } 55 get_max_depth()56 int get_max_depth() { return this->max_depth_; } get_max_delta()57 double get_max_delta() { return this->max_delta_; } 58 59 sample transition(sample & init_sample,callbacks::logger & logger)60 transition(sample& init_sample, callbacks::logger& logger) { 61 // Initialize the algorithm 62 this->sample_stepsize(); 63 64 nuts_util util; 65 66 this->seed(init_sample.cont_params()); 67 68 this->hamiltonian_.sample_p(this->z_, this->rand_int_); 69 this->hamiltonian_.init(this->z_, logger); 70 71 ps_point z_plus(this->z_); 72 ps_point z_minus(z_plus); 73 74 ps_point z_sample(z_plus); 75 ps_point z_propose(z_plus); 76 77 int n_cont = init_sample.cont_params().size(); 78 79 Eigen::VectorXd rho_init = this->z_.p; 80 Eigen::VectorXd rho_plus(n_cont); rho_plus.setZero(); 81 Eigen::VectorXd rho_minus(n_cont); rho_minus.setZero(); 82 83 util.H0 = this->hamiltonian_.H(this->z_); 84 85 // Sample the slice variable 86 util.log_u = std::log(this->rand_uniform_()); 87 88 // Build a balanced binary tree until the NUTS criterion fails 89 util.criterion = true; 90 int n_valid = 0; 91 92 this->depth_ = 0; 93 this->divergent_ = 0; 94 95 util.n_tree = 0; 96 util.sum_prob = 0; 97 98 while (util.criterion && (this->depth_ <= this->max_depth_)) { 99 // Randomly sample a direction in time 100 ps_point* z = 0; 101 Eigen::VectorXd* rho = 0; 102 103 if (this->rand_uniform_() > 0.5) { 104 z = &z_plus; 105 rho = &rho_plus; 106 util.sign = 1; 107 } else { 108 z = &z_minus; 109 rho = &rho_minus; 110 util.sign = -1; 111 } 112 113 // And build a new subtree in that direction 114 this->z_.ps_point::operator=(*z); 115 116 int n_valid_subtree = build_tree(depth_, *rho, 0, z_propose, util, 117 logger); 118 ++(this->depth_); 119 120 *z = this->z_; 121 122 // Metropolis-Hastings sample the fresh subtree 123 if (!util.criterion) 124 break; 125 126 double subtree_prob = 0; 127 128 if (n_valid) { 129 subtree_prob = static_cast<double>(n_valid_subtree) / 130 static_cast<double>(n_valid); 131 } else { 132 subtree_prob = n_valid_subtree ? 1 : 0; 133 } 134 135 if (this->rand_uniform_() < subtree_prob) 136 z_sample = z_propose; 137 138 n_valid += n_valid_subtree; 139 140 // Check validity of completed tree 141 this->z_.ps_point::operator=(z_plus); 142 Eigen::VectorXd delta_rho = rho_minus + rho_init + rho_plus; 143 144 util.criterion = compute_criterion(z_minus, this->z_, delta_rho); 145 } 146 147 this->n_leapfrog_ = util.n_tree; 148 149 double accept_prob = util.sum_prob / static_cast<double>(util.n_tree); 150 151 this->z_.ps_point::operator=(z_sample); 152 this->energy_ = this->hamiltonian_.H(this->z_); 153 return sample(this->z_.q, - this->z_.V, accept_prob); 154 } 155 get_sampler_param_names(std::vector<std::string> & names)156 void get_sampler_param_names(std::vector<std::string>& names) { 157 names.push_back("stepsize__"); 158 names.push_back("treedepth__"); 159 names.push_back("n_leapfrog__"); 160 names.push_back("divergent__"); 161 names.push_back("energy__"); 162 } 163 get_sampler_params(std::vector<double> & values)164 void get_sampler_params(std::vector<double>& values) { 165 values.push_back(this->epsilon_); 166 values.push_back(this->depth_); 167 values.push_back(this->n_leapfrog_); 168 values.push_back(this->divergent_); 169 values.push_back(this->energy_); 170 } 171 172 virtual bool compute_criterion(ps_point& start, 173 typename Hamiltonian<Model, BaseRNG> 174 ::PointType& finish, 175 Eigen::VectorXd& rho) = 0; 176 177 // Returns number of valid points in the completed subtree build_tree(int depth,Eigen::VectorXd & rho,ps_point * z_init_parent,ps_point & z_propose,nuts_util & util,callbacks::logger & logger)178 int build_tree(int depth, Eigen::VectorXd& rho, 179 ps_point* z_init_parent, ps_point& z_propose, 180 nuts_util& util, 181 callbacks::logger& logger) { 182 // Base case 183 if (depth == 0) { 184 this->integrator_.evolve(this->z_, this->hamiltonian_, 185 util.sign * this->epsilon_, 186 logger); 187 rho += this->z_.p; 188 189 if (z_init_parent) *z_init_parent = this->z_; 190 z_propose = this->z_; 191 192 double h = this->hamiltonian_.H(this->z_); 193 if (boost::math::isnan(h)) 194 h = std::numeric_limits<double>::infinity(); 195 196 util.criterion = util.log_u + (h - util.H0) < this->max_delta_; 197 if (!util.criterion) ++(this->divergent_); 198 199 util.sum_prob += std::min(1.0, std::exp(util.H0 - h)); 200 util.n_tree += 1; 201 202 return (util.log_u + (h - util.H0) < 0); 203 204 } else { 205 // General recursion 206 Eigen::VectorXd left_subtree_rho(rho.size()); 207 left_subtree_rho.setZero(); 208 ps_point z_init(this->z_); 209 210 int n1 = build_tree(depth - 1, left_subtree_rho, &z_init, 211 z_propose, util, 212 logger); 213 214 if (z_init_parent) *z_init_parent = z_init; 215 216 if (!util.criterion) return 0; 217 218 Eigen::VectorXd right_subtree_rho(rho.size()); 219 right_subtree_rho.setZero(); 220 ps_point z_propose_right(z_init); 221 222 int n2 = build_tree(depth - 1, right_subtree_rho, 0, 223 z_propose_right, util, 224 logger); 225 226 double accept_prob = static_cast<double>(n2) / 227 static_cast<double>(n1 + n2); 228 229 if ( util.criterion && (this->rand_uniform_() < accept_prob) ) 230 z_propose = z_propose_right; 231 232 Eigen::VectorXd& subtree_rho = left_subtree_rho; 233 subtree_rho += right_subtree_rho; 234 235 rho += subtree_rho; 236 237 util.criterion &= compute_criterion(z_init, this->z_, subtree_rho); 238 239 return n1 + n2; 240 } 241 } 242 243 int depth_; 244 int max_depth_; 245 double max_delta_; 246 247 int n_leapfrog_; 248 int divergent_; 249 double energy_; 250 }; 251 252 } // mcmc 253 } // stan 254 #endif 255