1 #include <config.h>
2 #include "DirichletCatFactory.h"
3 #include "DirichletCat.h"
4 
5 #include <sampler/MutableSampler.h>
6 #include <sampler/MutableSampleMethod.h>
7 #include <sampler/SingletonGraphView.h>
8 #include <graph/Graph.h>
9 #include <graph/StochasticNode.h>
10 #include <graph/MixtureNode.h>
11 #include <distribution/Distribution.h>
12 
13 #include <set>
14 #include <map>
15 
16 using std::list;
17 using std::set;
18 using std::vector;
19 using std::string;
20 using std::map;
21 
22 #define NLEVEL 200
23 #define MAX_TEMP 100
24 #define NREP 5
25 
26 namespace jags {
27 
28     //STL syntax becomes unreadable if we don't use a typedef
29     typedef map<vector<StochasticNode *>, vector<StochasticNode*> > DCMap;
30 
31     static
isCandidate(SingletonGraphView const & gv)32     bool isCandidate(SingletonGraphView const &gv)
33     {
34 	vector<StochasticNode *> const &schild = gv.stochasticChildren();
35 	vector<DeterministicNode *> const &dchild = gv.deterministicChildren();
36 
37 	//A necessary but not sufficient condition: we do this first
38 	//because it is fast
39 	if (schild.size() != dchild.size()) return false;
40 
41 	//Stochastic children must all have dcat distribution
42 	Distribution const *dist0 = schild[0]->distribution();
43 	if (dist0->name() != "dcat") return false;
44 	for (unsigned int i = 1; i < schild.size(); ++i) {
45 	    if (schild[i]->distribution() != dist0) return false;
46 	}
47 
48 	 //Deterministic descendants must all be mixture nodes
49 	 for (unsigned int j = 0; j < dchild.size(); ++j) {
50 	     if (!isMixture(dchild[j])) return false;
51 	 }
52 
53 	 return true;
54      }
55 
56      namespace mix {
57 
58 	 Sampler *
makeSampler(vector<StochasticNode * > const & snodes,Graph const & graph) const59 	 DirichletCatFactory::makeSampler(vector<StochasticNode*> const &snodes,
60 					  Graph const &graph) const
61 	 {
62 	     GraphView *gv = new GraphView(snodes, graph);
63 	     Sampler * sampler = 0;
64 	     unsigned int nchain = snodes[0]->nchain();
65 
66 	     if (DirichletCat::canSample(gv)) {
67 		 vector<MutableSampleMethod*> methods(nchain);
68 		 for (unsigned int ch = 0; ch < nchain; ++ch) {
69 		     methods[ch] = new DirichletCat(gv, ch);
70 		 }
71 		 sampler = new MutableSampler(gv, methods, "mix::DirichletCat");
72 	     }
73 	     else {
74 		 delete gv;
75 	     }
76 	     return sampler;
77 	 }
78 
name() const79 	 string DirichletCatFactory::name() const
80 	 {
81 	     return "mix::DirichletCat";
82 	 }
83 
84 
85 	vector<Sampler*>
makeSamplers(list<StochasticNode * > const & nodes,Graph const & graph) const86 	DirichletCatFactory::makeSamplers(list<StochasticNode*> const &nodes,
87 					  Graph const &graph) const
88 	{
89 	    //Assemble candidates from available nodes and classify
90 	    //them by their stochastic children
91 	    DCMap cmap;
92 
93 	    for (list<StochasticNode*>::const_iterator p = nodes.begin();
94 		 p != nodes.end(); ++p)
95 	    {
96 		if ((*p)->distribution()->name() != "ddirch") continue;
97 		SingletonGraphView gv(*p, graph);
98 		if (isCandidate(gv)) {
99 		    cmap[gv.stochasticChildren()].push_back(*p);
100 		}
101 	    }
102 
103 	    //Now traverse the candidate map and generate samplers
104 	    vector<Sampler*> samplers;
105 	    for (DCMap::const_iterator q = cmap.begin(); q != cmap.end(); ++q)
106 	    {
107 		Sampler *s = makeSampler(q->second, graph);
108 		if (s) samplers.push_back(s);
109 	    }
110 
111 	    return samplers;
112 	}
113 
114     }
115 }
116