1 #include <config.h> 2 3 #include "REFactory2.h" 4 #include "REMethod2.h" 5 #include "GLMSampler.h" 6 7 #include <graph/StochasticNode.h> 8 #include <distribution/Distribution.h> 9 #include <sampler/SingletonGraphView.h> 10 #include <sampler/MutableSampler.h> 11 12 #include <algorithm> 13 #include <utility> 14 15 using std::vector; 16 using std::string; 17 using std::list; 18 using std::set; 19 20 namespace jags { 21 namespace glm { 22 REFactory2(std::string const & name)23 REFactory2::REFactory2(std::string const &name) 24 : _name(name) 25 {} 26 checkTau(SingletonGraphView const * tau,GraphView const * glmview) const27 bool REFactory2::checkTau(SingletonGraphView const *tau, 28 GraphView const *glmview) const 29 { 30 if (!tau->deterministicChildren().empty()) { 31 return false; 32 } 33 34 vector<StochasticNode *> const &eps = tau->stochasticChildren(); 35 for (unsigned int i = 0; i < eps.size(); ++i) { 36 if (isObserved(eps[i])) { 37 return false; 38 } 39 if (isBounded(eps[i])) { 40 return false; 41 } 42 if (eps[i]->distribution()->name() != "dnorm" && 43 eps[i]->distribution()->name() != "dmnorm") { 44 return false; 45 } 46 47 Node const *mu_tau = eps[i]->parents()[1]; 48 if (mu_tau != tau->node()) { 49 return false; 50 } 51 if (tau->isDependent(eps[i]->parents()[0])) { 52 return false; //mean parameter depends on snode 53 } 54 } 55 56 //Check that all stochastic children of tau are included in 57 //the linear predictor of the glm 58 if (glmview->nodes().size() < eps.size()) { 59 return false; 60 } 61 62 set<StochasticNode*> lpset; 63 lpset.insert(glmview->nodes().begin(), glmview->nodes().end()); 64 for (unsigned int i = 0; i < eps.size(); ++i) { 65 if (lpset.count(eps[i]) == 0) { 66 return false; 67 } 68 } 69 70 return true; //We made it! 71 } 72 ~REFactory2()73 REFactory2::~REFactory2() 74 {} 75 makeSampler(list<StochasticNode * > const & free_nodes,set<StochasticNode * > & used_nodes,GLMSampler const * glmsampler,Graph const & graph) const76 Sampler * REFactory2::makeSampler( 77 list<StochasticNode*> const &free_nodes, 78 set<StochasticNode*> &used_nodes, 79 GLMSampler const *glmsampler, Graph const &graph) const 80 { 81 SingletonGraphView *tau = 0; 82 for (list<StochasticNode*>::const_iterator p = free_nodes.begin(); 83 p != free_nodes.end(); ++p) 84 { 85 if (used_nodes.count(*p)) continue; 86 87 if (canSample(*p)) { 88 tau = new SingletonGraphView(*p, graph); 89 if (checkTau(tau, glmsampler->_view)) { 90 break; 91 } 92 else { 93 delete tau; tau = 0; 94 } 95 } 96 } 97 98 /* Create a single GraphView containing all sampled nodes 99 (from tau and from eps). This is required by the Sampler 100 class. Note that this is a multilevel GraphView. 101 102 vector<StochasticNode*> snodes = eps->nodes(); 103 snodes.push_back(tau->node()); 104 GraphView *view = new GraphView(snodes, graph, true); 105 */ 106 107 if (tau) { 108 unsigned int nchain = glmsampler->_methods.size(); 109 vector<MutableSampleMethod*> methods(nchain); 110 for (unsigned int i = 0; i < nchain; ++i) { 111 methods[i] = newMethod(tau, glmsampler->_methods[i]); 112 } 113 used_nodes.insert(tau->node()); 114 return new MutableSampler(tau, methods, _name); 115 } 116 return 0; 117 } 118 119 } // namespace glm 120 } //namespace jags 121 122