1 #include <config.h>
2 #include <model/NodeArray.h>
3 #include <model/Model.h>
4 #include <graph/ConstantNode.h>
5 #include <graph/StochasticNode.h>
6 #include <graph/AggNode.h>
7 #include <sarray/RangeIterator.h>
8 #include <graph/NodeError.h>
9 #include <sarray/SArray.h>
10 #include <util/nainf.h>
11 #include <util/dim.h>
12
13 #include <string>
14 #include <stdexcept>
15 #include <limits>
16
17 using std::pair;
18 using std::vector;
19 using std::map;
20 using std::string;
21 using std::runtime_error;
22 using std::logic_error;
23 using std::set;
24 using std::numeric_limits;
25
26
hasRepeats(jags::Range const & target_range)27 static bool hasRepeats(jags::Range const &target_range)
28 {
29 /* Returns true if the target range has any repeated indices
30
31 We choose the vectorized version of set::insert as it is
32 amortized linear time in the length of the index vector
33 scope[i] if the indices are in increasing order, which should
34 be true most of the time.
35 */
36
37 vector<vector<int> > const &scope = target_range.scope();
38 for (unsigned int i = 0; i < scope.size(); ++i) {
39 set<int> seen;
40 seen.insert(scope[i].begin(), scope[i].end());
41 if (seen.size() != scope[i].size()) return true;
42 }
43 return false;
44 }
45
46 namespace jags {
47
NodeArray(string const & name,vector<unsigned int> const & dim,unsigned int nchain)48 NodeArray::NodeArray(string const &name, vector<unsigned int> const &dim,
49 unsigned int nchain)
50 : _name(name), _range(dim), _nchain(nchain),
51 _node_pointers(product(dim), 0),
52 _offsets(product(dim), numeric_limits<unsigned int>::max())
53
54 {
55 }
56
insert(Node * node,Range const & target_range)57 void NodeArray::insert(Node *node, Range const &target_range)
58 {
59 if (!node) {
60 throw logic_error(string("Attempt to insert NULL node at ") +
61 name() + print(target_range));
62 }
63 if (node->dim() != target_range.dim(true)) {
64 throw runtime_error(string("Cannot insert node into ") + name() +
65 print(target_range) + ". Dimension mismatch");
66 }
67 if (!_range.contains(target_range)) {
68 throw runtime_error(string("Cannot insert node into ") + name() +
69 print(target_range) + ". Range out of bounds");
70 }
71 if (hasRepeats(target_range)) {
72 throw runtime_error(string("Cannot insert node into ") + name() +
73 print(target_range) +
74 ". Range has repeat indices");
75 }
76
77 /* Check that the range is not already occupied, even partially */
78 for (RangeIterator p(target_range); !p.atEnd(); p.nextLeft()) {
79 if (_node_pointers[_range.leftOffset(p)] != 0) {
80 throw runtime_error(string("Node ") + name()
81 + print(target_range)
82 + " overlaps previously defined nodes");
83 }
84 }
85
86 /* Set the _node_pointers array and the offset array */
87 unsigned int k = 0;
88 for (RangeIterator p(target_range); !p.atEnd(); p.nextLeft())
89 {
90 unsigned int i = _range.leftOffset(p);
91 _node_pointers[i] = node;
92 _offsets[i] = k++;
93 }
94
95 /* Add multivariate nodes to range map */
96 if (node->length() > 1) {
97 _mv_nodes[target_range] = node;
98 }
99
100 /* Add node to the graph */
101 _member_graph.insert(node);
102 }
103
getSubset(Range const & target_range,Model & model)104 Node *NodeArray::getSubset(Range const &target_range, Model &model)
105 {
106 //Check validity of target range
107 if (!_range.contains(target_range)) {
108 throw runtime_error(string("Cannot get subset ") + name() +
109 print(target_range) + ". Range out of bounds");
110 }
111
112 if (target_range.length() == 1) {
113 unsigned int start = _range.leftOffset(target_range.first());
114 Node *node = _node_pointers[start];
115 if (node && node->length() == 1) {
116 if (_offsets[start] != 0) {
117 throw logic_error("Invalid scalar node in NodeArray");
118 }
119 return node;
120 }
121 }
122 else {
123 map<Range, Node *>::const_iterator p = _mv_nodes.find(target_range);
124 if (p != _mv_nodes.end()) {
125 return p->second;
126 }
127 }
128
129 /* If range corresponds to a previously created subset, then
130 * return this */
131 map<Range, AggNode *>::iterator p = _generated_nodes.find(target_range);
132 if (p != _generated_nodes.end()) {
133 return p->second;
134 }
135
136 /* Otherwise create an aggregate node */
137
138 vector<Node const *> nodes;
139 vector<unsigned int> offsets;
140 for (RangeIterator p(target_range); !p.atEnd(); p.nextLeft()) {
141 unsigned int i = _range.leftOffset(p);
142 if (_node_pointers[i] == 0) {
143 return 0;
144 }
145 nodes.push_back(_node_pointers[i]);
146 offsets.push_back(_offsets[i]);
147 }
148 AggNode *anode = new AggNode(target_range.dim(true), _nchain,
149 nodes, offsets);
150 _generated_nodes[target_range] = anode;
151 model.addNode(anode);
152 _member_graph.insert(anode);
153 return anode;
154 }
155
setValue(SArray const & value,unsigned int chain)156 void NodeArray::setValue(SArray const &value, unsigned int chain)
157 {
158 if (!(_range == value.range())) {
159 throw runtime_error(string("Dimension mismatch in ") + name());
160 }
161
162 vector<double> const &x = value.value();
163 unsigned int N = value.length();
164
165 //Gather all the nodes for which a data value is supplied
166 set<Node*> setnodes;
167 for (unsigned int i = 0; i < _range.length(); ++i) {
168 if (x[i] != JAGS_NA) {
169 Node *node = _node_pointers[i];
170 if (node == 0) {
171 string msg = "Attempt to set value of undefined node ";
172 throw runtime_error(msg + name() +
173 print(value.range().leftIndex(i)));
174 }
175 switch(node->randomVariableStatus()) {
176 case RV_FALSE:
177 throw NodeError(node,
178 "Cannot set value of non-variable node");
179 break;
180 case RV_TRUE_OBSERVED:
181 throw NodeError(node,
182 "Cannot overwrite value of observed node");
183 break;
184 case RV_TRUE_UNOBSERVED:
185 setnodes.insert(node);
186 break;
187 }
188 }
189 }
190
191
192 for (set<Node*>::const_iterator p = setnodes.begin();
193 p != setnodes.end(); ++p)
194 {
195 //Step through each node
196 Node *node = *p;
197
198 vector<double> node_value(node->length());
199
200 //Get vector of values for this node
201 for (unsigned int i = 0; i < N; ++i) {
202 if (_node_pointers[i] == node) {
203 if (_offsets[i] > node->length()) {
204 throw logic_error("Invalid offset in NodeArray::setValue");
205 }
206 else {
207 node_value[_offsets[i]] = x[i];
208 }
209 }
210 }
211 // If there are any missing values, they must all be missing
212 bool missing = node_value[0] == JAGS_NA;
213 for (unsigned int j = 1; j < node->length(); ++j) {
214 if ((node_value[j] == JAGS_NA) != missing) {
215 throw NodeError(node,"Values supplied for node are partially missing");
216 }
217 }
218 if (!missing) {
219 node->setValue(&node_value[0], node->length(), chain);
220 }
221 }
222 }
223
getValue(SArray & value,unsigned int chain,bool (* condition)(Node const *)) const224 void NodeArray::getValue(SArray &value, unsigned int chain,
225 bool (*condition)(Node const *)) const
226 {
227 if (!(_range == value.range())) {
228 string msg("Dimension mismatch when getting value of node array ");
229 msg.append(name());
230 throw runtime_error(msg);
231 }
232
233 unsigned int array_length = _range.length();
234 vector<double> array_value(array_length);
235 for (unsigned int j = 0; j < array_length; ++j) {
236 Node const *node = _node_pointers[j];
237 if (node && condition(node)) {
238 array_value[j] = node->value(chain)[_offsets[j]];
239 }
240 else {
241 array_value[j] = JAGS_NA;
242 }
243 }
244
245 value.setValue(array_value);
246 }
247
setData(SArray const & value,Model * model)248 void NodeArray::setData(SArray const &value, Model *model)
249 {
250 if (!(_range == value.range())) {
251 throw runtime_error(string("Dimension mismatch when setting value of node array ") + name());
252 }
253
254 vector<double> const &x = value.value();
255
256 //Gather all the nodes for which a data value is supplied
257 for (unsigned int i = 0; i < _range.length(); ++i) {
258 if (x[i] != JAGS_NA) {
259 if (_node_pointers[i] == 0) {
260 //Insert a new constant data node
261 ConstantNode *cnode = new ConstantNode(x[i], _nchain, true);
262 model->addNode(cnode);
263 SimpleRange target_range(_range.leftIndex(i));
264 insert(cnode, target_range);
265 }
266 else {
267 throw logic_error("Error in NodeArray::setData");
268 }
269 }
270 }
271 }
272
273
name() const274 string const &NodeArray::name() const
275 {
276 return _name;
277 }
278
range() const279 SimpleRange const &NodeArray::range() const
280 {
281 return _range;
282 }
283
getRange(Node const * node) const284 Range NodeArray::getRange(Node const *node) const
285 {
286 if (!_member_graph.contains(node)) {
287 return Range();
288 }
289
290 //Look among inserted nodes first
291 if (node->length() == 1) {
292 for (unsigned int i = 0; i < _range.length(); ++i) {
293 if (_node_pointers[i] == node) {
294 return SimpleRange(_range.leftIndex(i));
295 }
296 }
297 }
298 else {
299 for (map<Range, Node *>::const_iterator p = _mv_nodes.begin();
300 p != _mv_nodes.end(); ++p)
301 {
302 if (node == p->second) {
303 return p->first;
304 }
305 }
306 }
307
308 //Then among generated nodes
309 for (map<Range, AggNode *>::const_iterator p = _generated_nodes.begin();
310 p != _generated_nodes.end(); ++p)
311 {
312 if (node == p->second) {
313 return p->first;
314 }
315 }
316
317 throw logic_error("Failed to find Node range");
318 return Range(); //Wall
319 }
320
nchain() const321 unsigned int NodeArray::nchain() const
322 {
323 return _nchain;
324 }
325
326 } //namespace jags
327