1 #include <config.h> 2 #include "DirichletCatFactory.h" 3 #include "DirichletCat.h" 4 5 #include <sampler/MutableSampler.h> 6 #include <sampler/MutableSampleMethod.h> 7 #include <sampler/SingletonGraphView.h> 8 #include <graph/Graph.h> 9 #include <graph/StochasticNode.h> 10 #include <graph/MixtureNode.h> 11 #include <distribution/Distribution.h> 12 13 #include <set> 14 #include <map> 15 16 using std::list; 17 using std::set; 18 using std::vector; 19 using std::string; 20 using std::map; 21 22 #define NLEVEL 200 23 #define MAX_TEMP 100 24 #define NREP 5 25 26 namespace jags { 27 28 //STL syntax becomes unreadable if we don't use a typedef 29 typedef map<vector<StochasticNode *>, vector<StochasticNode*> > DCMap; 30 31 static isCandidate(SingletonGraphView const & gv)32 bool isCandidate(SingletonGraphView const &gv) 33 { 34 vector<StochasticNode *> const &schild = gv.stochasticChildren(); 35 vector<DeterministicNode *> const &dchild = gv.deterministicChildren(); 36 37 //A necessary but not sufficient condition: we do this first 38 //because it is fast 39 if (schild.size() != dchild.size()) return false; 40 41 //Stochastic children must all have dcat distribution 42 Distribution const *dist0 = schild[0]->distribution(); 43 if (dist0->name() != "dcat") return false; 44 for (unsigned int i = 1; i < schild.size(); ++i) { 45 if (schild[i]->distribution() != dist0) return false; 46 } 47 48 //Deterministic descendants must all be mixture nodes 49 for (unsigned int j = 0; j < dchild.size(); ++j) { 50 if (!isMixture(dchild[j])) return false; 51 } 52 53 return true; 54 } 55 56 namespace mix { 57 58 Sampler * makeSampler(vector<StochasticNode * > const & snodes,Graph const & graph) const59 DirichletCatFactory::makeSampler(vector<StochasticNode*> const &snodes, 60 Graph const &graph) const 61 { 62 GraphView *gv = new GraphView(snodes, graph); 63 Sampler * sampler = 0; 64 unsigned int nchain = snodes[0]->nchain(); 65 66 if (DirichletCat::canSample(gv)) { 67 vector<MutableSampleMethod*> methods(nchain); 68 for (unsigned int ch = 0; ch < nchain; ++ch) { 69 methods[ch] = new DirichletCat(gv, ch); 70 } 71 sampler = new MutableSampler(gv, methods, "mix::DirichletCat"); 72 } 73 else { 74 delete gv; 75 } 76 return sampler; 77 } 78 name() const79 string DirichletCatFactory::name() const 80 { 81 return "mix::DirichletCat"; 82 } 83 84 85 vector<Sampler*> makeSamplers(list<StochasticNode * > const & nodes,Graph const & graph) const86 DirichletCatFactory::makeSamplers(list<StochasticNode*> const &nodes, 87 Graph const &graph) const 88 { 89 //Assemble candidates from available nodes and classify 90 //them by their stochastic children 91 DCMap cmap; 92 93 for (list<StochasticNode*>::const_iterator p = nodes.begin(); 94 p != nodes.end(); ++p) 95 { 96 if ((*p)->distribution()->name() != "ddirch") continue; 97 SingletonGraphView gv(*p, graph); 98 if (isCandidate(gv)) { 99 cmap[gv.stochasticChildren()].push_back(*p); 100 } 101 } 102 103 //Now traverse the candidate map and generate samplers 104 vector<Sampler*> samplers; 105 for (DCMap::const_iterator q = cmap.begin(); q != cmap.end(); ++q) 106 { 107 Sampler *s = makeSampler(q->second, graph); 108 if (s) samplers.push_back(s); 109 } 110 111 return samplers; 112 } 113 114 } 115 } 116