1 #include <config.h>
2 
3 #include "REFactory2.h"
4 #include "REMethod2.h"
5 #include "GLMSampler.h"
6 
7 #include <graph/StochasticNode.h>
8 #include <distribution/Distribution.h>
9 #include <sampler/SingletonGraphView.h>
10 #include <sampler/MutableSampler.h>
11 
12 #include <algorithm>
13 #include <utility>
14 
15 using std::vector;
16 using std::string;
17 using std::list;
18 using std::set;
19 
20 namespace jags {
21     namespace glm {
22 
REFactory2(std::string const & name)23 	REFactory2::REFactory2(std::string const &name)
24 	    : _name(name)
25 	{}
26 
checkTau(SingletonGraphView const * tau,GraphView const * glmview) const27 	bool REFactory2::checkTau(SingletonGraphView const *tau,
28 				  GraphView const *glmview) const
29 	{
30 	    if (!tau->deterministicChildren().empty()) {
31 		return false;
32 	    }
33 
34 	    vector<StochasticNode *> const &eps = tau->stochasticChildren();
35 	    for (unsigned int i = 0; i < eps.size(); ++i) {
36 		if (isObserved(eps[i])) {
37 		    return false;
38 		}
39 		if (isBounded(eps[i])) {
40 		    return false;
41 		}
42 		if (eps[i]->distribution()->name() != "dnorm" &&
43 		    eps[i]->distribution()->name() != "dmnorm") {
44 		    return false;
45 		}
46 
47 		Node const *mu_tau = eps[i]->parents()[1];
48 		if (mu_tau != tau->node()) {
49 		    return false;
50 		}
51 		if (tau->isDependent(eps[i]->parents()[0])) {
52 		    return false; //mean parameter depends on snode
53 		}
54 	    }
55 
56 	    //Check that all stochastic children of tau are included in
57 	    //the linear predictor of the glm
58 	    if (glmview->nodes().size() < eps.size()) {
59 		return false;
60 	    }
61 
62 	    set<StochasticNode*> lpset;
63 	    lpset.insert(glmview->nodes().begin(), glmview->nodes().end());
64 	    for (unsigned int i = 0; i < eps.size(); ++i) {
65 		if (lpset.count(eps[i]) == 0) {
66 		    return false;
67 		}
68 	    }
69 
70 	    return true; //We made it!
71 	}
72 
~REFactory2()73 	REFactory2::~REFactory2()
74 	{}
75 
makeSampler(list<StochasticNode * > const & free_nodes,set<StochasticNode * > & used_nodes,GLMSampler const * glmsampler,Graph const & graph) const76 	Sampler * REFactory2::makeSampler(
77 	    list<StochasticNode*> const &free_nodes,
78 	    set<StochasticNode*> &used_nodes,
79 	    GLMSampler const *glmsampler, Graph const &graph) const
80 	{
81 	    SingletonGraphView *tau = 0;
82 	    for (list<StochasticNode*>::const_iterator p = free_nodes.begin();
83 		 p != free_nodes.end(); ++p)
84 	    {
85 		if (used_nodes.count(*p)) continue;
86 
87 		if (canSample(*p)) {
88 		    tau = new SingletonGraphView(*p, graph);
89 		    if (checkTau(tau, glmsampler->_view)) {
90 			break;
91 		    }
92 		    else {
93 			delete tau; tau = 0;
94 		    }
95 		}
96 	    }
97 
98 	    /* Create a single GraphView containing all sampled nodes
99 	       (from tau and from eps). This is required by the Sampler
100 	       class. Note that this is a multilevel GraphView.
101 
102 	    vector<StochasticNode*> snodes = eps->nodes();
103 	    snodes.push_back(tau->node());
104 	    GraphView *view = new GraphView(snodes, graph, true);
105 	    */
106 
107 	    if (tau) {
108 		unsigned int nchain = glmsampler->_methods.size();
109 		vector<MutableSampleMethod*> methods(nchain);
110 		for (unsigned int i = 0; i < nchain; ++i) {
111 		    methods[i] = newMethod(tau, glmsampler->_methods[i]);
112 		}
113 		used_nodes.insert(tau->node());
114 		return new MutableSampler(tau, methods, _name);
115 	    }
116 	    return 0;
117 	}
118 
119     } // namespace glm
120 } //namespace jags
121 
122