1 #include <config.h> 2 #include "LDAFactory.h" 3 #include "LDA.h" 4 5 #include <sampler/MutableSampler.h> 6 #include <sampler/MutableSampleMethod.h> 7 #include <sampler/SingletonGraphView.h> 8 #include <graph/Graph.h> 9 #include <graph/StochasticNode.h> 10 #include <graph/MixtureNode.h> 11 #include <graph/MixTab.h> 12 #include <distribution/Distribution.h> 13 14 #include <graph/VectorStochasticNode.h> 15 16 #include <set> 17 #include <map> 18 #include <algorithm> 19 20 using std::set; 21 using std::vector; 22 using std::string; 23 using std::map; 24 using std::list; 25 using std::find; 26 27 namespace jags { 28 29 /* Struct to hold Dirichlet nodes that will be marginalized out 30 by the LDA sampler */ 31 struct DirichletPriors { 32 vector<StochasticNode*> words; 33 vector<StochasticNode*> topics; 34 }; 35 36 typedef map<MixTab const *, DirichletPriors> LDAMap; 37 isCat(StochasticNode const * snode)38 static inline bool isCat(StochasticNode const *snode) { 39 return snode->distribution()->name() == "dcat"; 40 } 41 checkTopicPrior(GraphView const & gv,Graph const & graph)42 MixTab const *checkTopicPrior(GraphView const &gv, Graph const &graph) 43 { 44 /* 45 Dirichlet node TopicPrior has categorical stochastic 46 children. There are no intermediate deterministic children. 47 */ 48 if (!gv.deterministicChildren().empty()) return 0; 49 vector<StochasticNode *> const &schild = gv.stochasticChildren(); 50 for (unsigned int i = 0; i < schild.size(); ++i) { 51 if (!isCat(schild[i])) return 0; 52 } 53 54 /* 55 Each stochastic child acts as the index of a single mixture 56 node. This mixture node has a single stochastic child with 57 a categorical distribution. 58 59 All the mixture nodes have a common MixTab. 60 */ 61 62 MixTab const *mtab = 0; 63 for (unsigned int i = 0; i < schild.size(); ++i) { 64 65 SingletonGraphView gvi(schild[i], graph); 66 67 vector<StochasticNode *> const &si = gvi.stochasticChildren(); 68 if (si.size() != 1) return 0; 69 if (!isCat(si[0])) return 0; 70 71 vector<DeterministicNode *> const &di = gvi.deterministicChildren(); 72 if (di.size() != 1) return 0; 73 MixtureNode const *m = asMixture(di[0]); 74 if (m == 0) return 0; 75 76 //Check that schild[i] is the index of the mixture node 77 if (m->index_size() != 1) return 0; 78 if (m->parents()[0] != schild[i]) return 0; 79 for (unsigned int j = 1; j < m->parents().size(); ++j) { 80 if (m->parents()[j] == schild[i]) return 0; 81 } 82 83 if (i == 0) { 84 mtab = m->mixTab(); 85 } 86 else { 87 if (m->mixTab() != mtab) return 0; 88 } 89 } 90 return mtab; 91 } 92 checkWordPrior(GraphView const & gv,Graph const & graph)93 MixTab const *checkWordPrior(GraphView const &gv, Graph const &graph) 94 { 95 /* 96 Dirichlet node WordPrior is related to multiple categorical 97 outcomes via a set of mixture nodes that all share the same 98 mixTable. 99 */ 100 vector<StochasticNode *> const &schild = gv.stochasticChildren(); 101 for (unsigned int i = 0; i < schild.size(); ++i) { 102 if (!isCat(schild[i])) return 0; 103 } 104 105 MixTab const *mtab = 0; 106 vector<DeterministicNode *> const &dchild = gv.deterministicChildren(); 107 for (unsigned int j = 0; j < dchild.size(); ++j) { 108 109 MixtureNode const *m = asMixture(dchild[j]); 110 if (m == 0) return 0; 111 112 if (j == 0) { 113 mtab = m->mixTab(); 114 } 115 else if (mtab != m->mixTab()) { 116 return 0; 117 } 118 } 119 120 return mtab; 121 } 122 123 namespace mix { 124 125 Sampler * makeSampler(vector<StochasticNode * > const & topicPriors,vector<StochasticNode * > const & wordPriors,list<StochasticNode * > const & free_nodes,Graph const & graph) const126 LDAFactory::makeSampler(vector<StochasticNode*> const &topicPriors, 127 vector<StochasticNode*> const &wordPriors, 128 list<StochasticNode*> const &free_nodes, 129 Graph const &graph) const 130 { 131 if (topicPriors.empty() || wordPriors.empty()) return 0; 132 133 unsigned int nDoc = topicPriors.size(); 134 vector<vector<StochasticNode*> > topics(nDoc), words(nDoc); 135 vector<StochasticNode*> snodes; 136 for (unsigned int d = 0; d < nDoc; ++d) { 137 SingletonGraphView gvd(topicPriors[d], graph); 138 topics[d] = gvd.stochasticChildren(); 139 for (unsigned int i = 0; i < topics[d].size(); ++i) { 140 if (find(free_nodes.begin(), free_nodes.end(), topics[d][i]) 141 == free_nodes.end()) 142 { 143 return 0; 144 } 145 SingletonGraphView gvi(topics[d][i], graph); 146 words[d].push_back(gvi.stochasticChildren()[0]); 147 snodes.push_back(topics[d][i]); 148 } 149 } 150 151 if (LDA::canSample(topics, words, topicPriors, wordPriors, graph)) { 152 153 GraphView *view = new GraphView(snodes, graph); 154 unsigned int N = nchain(view); 155 vector<MutableSampleMethod*> methods(N); 156 for (unsigned int ch = 0; ch < N; ++ch) { 157 methods[ch] = new LDA(topics, words, topicPriors, 158 wordPriors, view, ch); 159 } 160 return new MutableSampler(view, methods, "mix::LDA"); 161 } 162 else return 0; 163 } 164 name() const165 string LDAFactory::name() const 166 { 167 return "mix::LDA"; 168 } 169 170 vector<Sampler*> makeSamplers(list<StochasticNode * > const & free_nodes,Graph const & graph) const171 LDAFactory::makeSamplers(list<StochasticNode*> const &free_nodes, 172 Graph const &graph) const 173 { 174 //First we need to traverse the graph looking for 175 //Dirichlet nodes. We are not interested in sampling 176 //them, but they are the basis for finding the categorical 177 //nodes that we do want to sample 178 179 set<StochasticNode*> dirichlet_nodes; 180 181 for (Graph::const_iterator p = graph.begin(); p != graph.end(); ++p) 182 { 183 VectorStochasticNode *vsnode = 184 dynamic_cast<VectorStochasticNode *>(*p); 185 if (vsnode && vsnode->distribution()->name() == "ddirch") { 186 dirichlet_nodes.insert(vsnode); 187 } 188 } 189 190 // Now classify them according to their MixTab 191 192 LDAMap dirichlet_map; 193 for (set<StochasticNode*>::iterator p = dirichlet_nodes.begin(); 194 p != dirichlet_nodes.end(); ++p) 195 { 196 SingletonGraphView gv(*p, graph); 197 if (MixTab const *mtab = checkWordPrior(gv, graph)) { 198 dirichlet_map[mtab].words.push_back(*p); 199 } 200 else if (MixTab const *mtab = checkTopicPrior(gv, graph)) { 201 dirichlet_map[mtab].topics.push_back(*p); 202 } 203 } 204 205 // Traverse the elements of the mixtable 206 vector<Sampler*> samplers; 207 for (LDAMap::const_iterator p = dirichlet_map.begin(); 208 p != dirichlet_map.end(); ++p) 209 { 210 Sampler *s = makeSampler(p->second.topics, p->second.words, 211 free_nodes, graph); 212 213 if (s) samplers.push_back(s); 214 } 215 return samplers; 216 } 217 } 218 } 219