1 #include <config.h>
2 #include <model/NodeArray.h>
3 #include <model/Model.h>
4 #include <graph/ConstantNode.h>
5 #include <graph/StochasticNode.h>
6 #include <graph/AggNode.h>
7 #include <sarray/RangeIterator.h>
8 #include <graph/NodeError.h>
9 #include <sarray/SArray.h>
10 #include <util/nainf.h>
11 #include <util/dim.h>
12 
13 #include <string>
14 #include <stdexcept>
15 #include <limits>
16 
17 using std::pair;
18 using std::vector;
19 using std::map;
20 using std::string;
21 using std::runtime_error;
22 using std::logic_error;
23 using std::set;
24 using std::numeric_limits;
25 
26 
hasRepeats(jags::Range const & target_range)27 static bool hasRepeats(jags::Range const &target_range)
28 {
29     /* Returns true if the target range has any repeated indices
30 
31        We choose the vectorized version of set::insert as it is
32        amortized linear time in the length of the index vector
33        scope[i] if the indices are in increasing order, which should
34        be true most of the time.
35     */
36 
37     vector<vector<int> > const &scope = target_range.scope();
38     for (unsigned int i = 0; i < scope.size(); ++i) {
39 	set<int> seen;
40 	seen.insert(scope[i].begin(), scope[i].end());
41 	if (seen.size() != scope[i].size()) return true;
42     }
43     return false;
44 }
45 
46 namespace jags {
47 
NodeArray(string const & name,vector<unsigned int> const & dim,unsigned int nchain)48     NodeArray::NodeArray(string const &name, vector<unsigned int> const &dim,
49 			 unsigned int nchain)
50 	: _name(name), _range(dim), _nchain(nchain),
51 	  _node_pointers(product(dim), 0),
52 	  _offsets(product(dim), numeric_limits<unsigned int>::max())
53 
54     {
55     }
56 
insert(Node * node,Range const & target_range)57     void NodeArray::insert(Node *node, Range const &target_range)
58     {
59 	if (!node) {
60 	    throw logic_error(string("Attempt to insert NULL node at ") +
61 			      name() + print(target_range));
62 	}
63 	if (node->dim() != target_range.dim(true)) {
64 	    throw runtime_error(string("Cannot insert node into ") + name() +
65 				print(target_range) + ". Dimension mismatch");
66 	}
67 	if (!_range.contains(target_range)) {
68 	    throw runtime_error(string("Cannot insert node into ") + name() +
69 				print(target_range) + ". Range out of bounds");
70 	}
71 	if (hasRepeats(target_range)) {
72 	    throw runtime_error(string("Cannot insert node into ") + name() +
73 				print(target_range) +
74 				". Range has repeat indices");
75 	}
76 
77 	/* Check that the range is not already occupied, even partially */
78 	for (RangeIterator p(target_range); !p.atEnd(); p.nextLeft()) {
79 	    if (_node_pointers[_range.leftOffset(p)] != 0) {
80 		throw runtime_error(string("Node ") + name()
81 				    + print(target_range)
82 				    + " overlaps previously defined nodes");
83 	    }
84 	}
85 
86 	/* Set the _node_pointers array and the offset array */
87 	unsigned int k = 0;
88 	for (RangeIterator p(target_range); !p.atEnd(); p.nextLeft())
89 	{
90 	    unsigned int i = _range.leftOffset(p);
91 	    _node_pointers[i] = node;
92 	    _offsets[i] = k++;
93 	}
94 
95 	/* Add multivariate nodes to range map */
96 	if (node->length() > 1) {
97 	    _mv_nodes[target_range] = node;
98 	}
99 
100 	/* Add node to the graph */
101 	_member_graph.insert(node);
102     }
103 
getSubset(Range const & target_range,Model & model)104     Node *NodeArray::getSubset(Range const &target_range, Model &model)
105     {
106 	//Check validity of target range
107 	if (!_range.contains(target_range)) {
108 	    throw runtime_error(string("Cannot get subset ") + name() +
109 				print(target_range) + ". Range out of bounds");
110 	}
111 
112 	if (target_range.length() == 1) {
113 	    unsigned int start = _range.leftOffset(target_range.first());
114 	    Node *node = _node_pointers[start];
115 	    if (node && node->length() == 1) {
116 		if (_offsets[start] != 0) {
117 		    throw logic_error("Invalid scalar node in NodeArray");
118 		}
119 		return node;
120 	    }
121 	}
122 	else {
123 	    map<Range, Node *>::const_iterator p = _mv_nodes.find(target_range);
124 	    if (p != _mv_nodes.end()) {
125 		return p->second;
126 	    }
127 	}
128 
129 	/* If range corresponds to a previously created subset, then
130 	 * return this */
131 	map<Range, AggNode *>::iterator p = _generated_nodes.find(target_range);
132 	if (p != _generated_nodes.end()) {
133 	    return p->second;
134 	}
135 
136 	/* Otherwise create an aggregate node */
137 
138 	vector<Node const *> nodes;
139 	vector<unsigned int> offsets;
140 	for (RangeIterator p(target_range); !p.atEnd(); p.nextLeft()) {
141 	    unsigned int i = _range.leftOffset(p);
142 	    if (_node_pointers[i] == 0) {
143 		return 0;
144 	    }
145 	    nodes.push_back(_node_pointers[i]);
146 	    offsets.push_back(_offsets[i]);
147 	}
148 	AggNode *anode = new AggNode(target_range.dim(true), _nchain,
149 				     nodes, offsets);
150 	_generated_nodes[target_range] = anode;
151 	model.addNode(anode);
152 	_member_graph.insert(anode);
153 	return anode;
154     }
155 
setValue(SArray const & value,unsigned int chain)156     void NodeArray::setValue(SArray const &value, unsigned int chain)
157     {
158 	if (!(_range == value.range())) {
159 	    throw runtime_error(string("Dimension mismatch in ") + name());
160 	}
161 
162 	vector<double> const &x = value.value();
163 	unsigned int N = value.length();
164 
165 	//Gather all the nodes for which a data value is supplied
166 	set<Node*> setnodes;
167 	for (unsigned int i = 0; i < _range.length(); ++i) {
168 	    if (x[i] != JAGS_NA) {
169 		Node *node = _node_pointers[i];
170 		if (node == 0) {
171 		    string msg = "Attempt to set value of undefined node ";
172 		    throw runtime_error(msg + name() +
173 					print(value.range().leftIndex(i)));
174 		}
175 		switch(node->randomVariableStatus()) {
176 		case RV_FALSE:
177 		    throw NodeError(node,
178 				    "Cannot set value of non-variable node");
179 		    break;
180 		case RV_TRUE_OBSERVED:
181 		    throw NodeError(node,
182 				    "Cannot overwrite value of observed node");
183 		    break;
184 		case RV_TRUE_UNOBSERVED:
185 		    setnodes.insert(node);
186 		    break;
187 		}
188 	    }
189 	}
190 
191 
192 	for (set<Node*>::const_iterator p = setnodes.begin();
193 	     p != setnodes.end(); ++p)
194 	{
195 	    //Step through each node
196 	    Node *node = *p;
197 
198 	    vector<double> node_value(node->length());
199 
200 	    //Get vector of values for this node
201 	    for (unsigned int i = 0; i < N; ++i) {
202 		if (_node_pointers[i] == node) {
203 		    if (_offsets[i] > node->length()) {
204 			throw logic_error("Invalid offset in NodeArray::setValue");
205 		    }
206 		    else {
207 			node_value[_offsets[i]] = x[i];
208 		    }
209 		}
210 	    }
211 	    // If there are any missing values, they must all be missing
212 	    bool missing = node_value[0] == JAGS_NA;
213 	    for (unsigned int j = 1; j < node->length(); ++j) {
214 		if ((node_value[j] == JAGS_NA) != missing) {
215 		    throw NodeError(node,"Values supplied for node are partially missing");
216 		}
217 	    }
218 	    if (!missing) {
219 		node->setValue(&node_value[0], node->length(), chain);
220 	    }
221 	}
222     }
223 
getValue(SArray & value,unsigned int chain,bool (* condition)(Node const *)) const224 void NodeArray::getValue(SArray &value, unsigned int chain,
225 			 bool (*condition)(Node const *)) const
226 {
227     if (!(_range == value.range())) {
228 	string msg("Dimension mismatch when getting value of node array ");
229 	msg.append(name());
230 	throw runtime_error(msg);
231     }
232 
233     unsigned int array_length = _range.length();
234     vector<double> array_value(array_length);
235     for (unsigned int j = 0; j < array_length; ++j) {
236 	Node const *node = _node_pointers[j];
237 	if (node && condition(node)) {
238 	    array_value[j] = node->value(chain)[_offsets[j]];
239 	}
240 	else {
241 	    array_value[j] = JAGS_NA;
242 	}
243     }
244 
245     value.setValue(array_value);
246 }
247 
setData(SArray const & value,Model * model)248 void NodeArray::setData(SArray const &value, Model *model)
249 {
250     if (!(_range == value.range())) {
251 	throw runtime_error(string("Dimension mismatch when setting value of node array ") + name());
252     }
253 
254     vector<double> const &x = value.value();
255 
256     //Gather all the nodes for which a data value is supplied
257     for (unsigned int i = 0; i < _range.length(); ++i) {
258 	if (x[i] != JAGS_NA) {
259 	    if (_node_pointers[i] == 0) {
260 		//Insert a new constant data node
261 		ConstantNode *cnode = new ConstantNode(x[i], _nchain, true);
262 		model->addNode(cnode);
263 		SimpleRange target_range(_range.leftIndex(i));
264 		insert(cnode, target_range);
265 	    }
266 	    else {
267 		throw logic_error("Error in NodeArray::setData");
268 	    }
269 	}
270     }
271 }
272 
273 
name() const274     string const &NodeArray::name() const
275     {
276 	return _name;
277     }
278 
range() const279     SimpleRange const &NodeArray::range() const
280     {
281 	return _range;
282     }
283 
getRange(Node const * node) const284     Range NodeArray::getRange(Node const *node) const
285     {
286 	if (!_member_graph.contains(node)) {
287 	    return Range();
288 	}
289 
290 	//Look among inserted nodes first
291 	if (node->length() == 1) {
292 	    for (unsigned int i = 0; i < _range.length(); ++i) {
293 		if (_node_pointers[i] == node) {
294 		    return SimpleRange(_range.leftIndex(i));
295 		}
296 	    }
297 	}
298 	else {
299 	    for (map<Range, Node *>::const_iterator p = _mv_nodes.begin();
300 		 p != _mv_nodes.end(); ++p)
301 	    {
302 		if (node == p->second) {
303 		    return p->first;
304 		}
305 	    }
306 	}
307 
308 	//Then among generated nodes
309 	for (map<Range, AggNode *>::const_iterator p = _generated_nodes.begin();
310 	     p != _generated_nodes.end(); ++p)
311 	{
312 	    if (node == p->second) {
313 		return p->first;
314 	    }
315 	}
316 
317 	throw logic_error("Failed to find Node range");
318 	return Range(); //Wall
319     }
320 
nchain() const321     unsigned int NodeArray::nchain() const
322     {
323 	return _nchain;
324     }
325 
326 } //namespace jags
327