1 #include <config.h>
2 #include "LDAFactory.h"
3 #include "LDA.h"
4 
5 #include <sampler/MutableSampler.h>
6 #include <sampler/MutableSampleMethod.h>
7 #include <sampler/SingletonGraphView.h>
8 #include <graph/Graph.h>
9 #include <graph/StochasticNode.h>
10 #include <graph/MixtureNode.h>
11 #include <graph/MixTab.h>
12 #include <distribution/Distribution.h>
13 
14 #include <graph/VectorStochasticNode.h>
15 
16 #include <set>
17 #include <map>
18 #include <algorithm>
19 
20 using std::set;
21 using std::vector;
22 using std::string;
23 using std::map;
24 using std::list;
25 using std::find;
26 
27 namespace jags {
28 
29     /* Struct to hold Dirichlet nodes that will be marginalized out
30        by the LDA sampler */
31     struct DirichletPriors {
32 	vector<StochasticNode*> words;
33 	vector<StochasticNode*> topics;
34     };
35 
36     typedef map<MixTab const *, DirichletPriors> LDAMap;
37 
isCat(StochasticNode const * snode)38     static inline bool isCat(StochasticNode const *snode) {
39 	return snode->distribution()->name() == "dcat";
40     }
41 
checkTopicPrior(GraphView const & gv,Graph const & graph)42     MixTab const *checkTopicPrior(GraphView const &gv, Graph const &graph)
43     {
44 	/*
45 	  Dirichlet node TopicPrior has categorical stochastic
46 	  children.  There are no intermediate deterministic children.
47 	*/
48 	if (!gv.deterministicChildren().empty()) return 0;
49 	vector<StochasticNode *> const &schild = gv.stochasticChildren();
50 	for (unsigned int i = 0; i < schild.size(); ++i) {
51 	    if (!isCat(schild[i])) return 0;
52 	}
53 
54 	/*
55 	  Each stochastic child acts as the index of a single mixture
56 	  node.  This mixture node has a single stochastic child with
57 	  a categorical distribution.
58 
59 	  All the mixture nodes have a common MixTab.
60 	*/
61 
62 	MixTab const *mtab = 0;
63 	for (unsigned int i = 0; i < schild.size(); ++i) {
64 
65 	    SingletonGraphView gvi(schild[i], graph);
66 
67 	    vector<StochasticNode *> const &si = gvi.stochasticChildren();
68 	    if (si.size() != 1) return 0;
69 	    if (!isCat(si[0])) return 0;
70 
71 	    vector<DeterministicNode *> const &di = gvi.deterministicChildren();
72 	    if (di.size() != 1) return 0;
73 	    MixtureNode const *m = asMixture(di[0]);
74 	    if (m == 0) return 0;
75 
76 	    //Check that schild[i] is the index of the mixture node
77 	    if (m->index_size() != 1) return 0;
78 	    if (m->parents()[0] != schild[i]) return 0;
79 	    for (unsigned int j = 1; j < m->parents().size(); ++j) {
80 		if (m->parents()[j] == schild[i]) return 0;
81 	    }
82 
83 	    if (i == 0) {
84 		mtab = m->mixTab();
85 	    }
86 	    else {
87 		if (m->mixTab() != mtab) return 0;
88 	    }
89 	}
90 	return mtab;
91     }
92 
checkWordPrior(GraphView const & gv,Graph const & graph)93     MixTab const *checkWordPrior(GraphView const &gv, Graph const &graph)
94     {
95 	/*
96 	   Dirichlet node WordPrior is related to multiple categorical
97 	   outcomes via a set of mixture nodes that all share the same
98 	   mixTable.
99 	*/
100 	vector<StochasticNode *> const &schild = gv.stochasticChildren();
101 	for (unsigned int i = 0; i < schild.size(); ++i) {
102 	    if (!isCat(schild[i])) return 0;
103 	}
104 
105 	MixTab const *mtab = 0;
106 	vector<DeterministicNode *> const &dchild = gv.deterministicChildren();
107 	for (unsigned int j = 0; j < dchild.size(); ++j) {
108 
109 	    MixtureNode const *m = asMixture(dchild[j]);
110 	    if (m == 0) return 0;
111 
112 	    if (j == 0) {
113 		mtab = m->mixTab();
114 	    }
115 	    else if (mtab != m->mixTab()) {
116 		return 0;
117 	    }
118 	}
119 
120 	return mtab;
121     }
122 
123     namespace mix {
124 
125 	Sampler *
makeSampler(vector<StochasticNode * > const & topicPriors,vector<StochasticNode * > const & wordPriors,list<StochasticNode * > const & free_nodes,Graph const & graph) const126 	LDAFactory::makeSampler(vector<StochasticNode*> const &topicPriors,
127 				vector<StochasticNode*> const &wordPriors,
128 				list<StochasticNode*> const &free_nodes,
129 				Graph const &graph) const
130 	{
131 	    if (topicPriors.empty() || wordPriors.empty()) return 0;
132 
133 	    unsigned int nDoc = topicPriors.size();
134 	    vector<vector<StochasticNode*> > topics(nDoc), words(nDoc);
135 	    vector<StochasticNode*> snodes;
136 	    for (unsigned int d = 0; d < nDoc; ++d) {
137 		SingletonGraphView gvd(topicPriors[d], graph);
138 		topics[d] = gvd.stochasticChildren();
139 		for (unsigned int i = 0; i < topics[d].size(); ++i) {
140 		    if (find(free_nodes.begin(), free_nodes.end(), topics[d][i])
141 			== free_nodes.end())
142 		    {
143 			return 0;
144 		    }
145 		    SingletonGraphView gvi(topics[d][i], graph);
146 		    words[d].push_back(gvi.stochasticChildren()[0]);
147 		    snodes.push_back(topics[d][i]);
148 		}
149 	    }
150 
151 	    if (LDA::canSample(topics, words, topicPriors, wordPriors, graph)) {
152 
153 		GraphView *view = new GraphView(snodes, graph);
154 		unsigned int N = nchain(view);
155 		vector<MutableSampleMethod*> methods(N);
156 		for (unsigned int ch = 0; ch < N; ++ch) {
157 		    methods[ch] = new LDA(topics, words, topicPriors,
158 					  wordPriors, view, ch);
159 		}
160 		return new MutableSampler(view, methods, "mix::LDA");
161 	    }
162 	    else return 0;
163 	}
164 
name() const165 	string LDAFactory::name() const
166 	{
167 	    return "mix::LDA";
168 	}
169 
170 	vector<Sampler*>
makeSamplers(list<StochasticNode * > const & free_nodes,Graph const & graph) const171 	LDAFactory::makeSamplers(list<StochasticNode*> const &free_nodes,
172 				 Graph const &graph) const
173 	{
174 	    //First we need to traverse the graph looking for
175 	    //Dirichlet nodes.  We are not interested in sampling
176 	    //them, but they are the basis for finding the categorical
177 	    //nodes that we do want to sample
178 
179 	    set<StochasticNode*> dirichlet_nodes;
180 
181 	    for (Graph::const_iterator p = graph.begin(); p != graph.end(); ++p)
182 	    {
183 		VectorStochasticNode *vsnode =
184 		    dynamic_cast<VectorStochasticNode *>(*p);
185 		if (vsnode && vsnode->distribution()->name() == "ddirch") {
186 	    dirichlet_nodes.insert(vsnode);
187 		}
188 	    }
189 
190 	    // Now classify them according to their MixTab
191 
192 	    LDAMap dirichlet_map;
193 	    for (set<StochasticNode*>::iterator p = dirichlet_nodes.begin();
194 		 p != dirichlet_nodes.end(); ++p)
195 	    {
196 		SingletonGraphView gv(*p, graph);
197 		if (MixTab const *mtab = checkWordPrior(gv, graph)) {
198 		    dirichlet_map[mtab].words.push_back(*p);
199 		}
200 		else if (MixTab const *mtab = checkTopicPrior(gv, graph)) {
201 		    dirichlet_map[mtab].topics.push_back(*p);
202 		}
203 	    }
204 
205 	    // Traverse the elements of the mixtable
206 	    vector<Sampler*> samplers;
207 	    for (LDAMap::const_iterator p = dirichlet_map.begin();
208 		 p != dirichlet_map.end(); ++p)
209 	    {
210 		Sampler *s = makeSampler(p->second.topics, p->second.words,
211 					 free_nodes, graph);
212 
213 		if (s) samplers.push_back(s);
214 	    }
215 	    return samplers;
216 	}
217     }
218 }
219