1 #include <config.h>
2 #include <graph/Node.h>
3 #include <graph/NodeError.h>
4 #include <util/nainf.h>
5 #include <util/dim.h>
6
7 #include <stdexcept>
8 #include <algorithm>
9
10 using std::string;
11 using std::vector;
12 using std::logic_error;
13 using std::copy;
14 using std::find;
15 using std::list;
16
17 namespace jags {
18
19 class DeterminsticNode;
20 class StochasticNode;
21
Node(vector<unsigned int> const & dim,unsigned int nchain)22 Node::Node(vector<unsigned int> const &dim, unsigned int nchain)
23 : _parents(0), _stoch_children(0), _dtrm_children(0),
24 _dim(getUnique(dim)), _length(product(dim)), _nchain(nchain), _data(0)
25 {
26 if (nchain==0)
27 throw logic_error("Node must have at least one chain");
28
29 unsigned int N = _length * _nchain;
30 _data = new double[N];
31 for (unsigned int i = 0; i < N; ++i) {
32 _data[i] = JAGS_NA;
33 }
34
35 _dtrm_children = new list<DeterministicNode*>;
36 _stoch_children = new list<StochasticNode*>;
37 }
38
Node(vector<unsigned int> const & dim,unsigned int nchain,vector<Node const * > const & parents)39 Node::Node(vector<unsigned int> const &dim, unsigned int nchain,
40 vector<Node const *> const &parents)
41 : _parents(parents), _stoch_children(0), _dtrm_children(0),
42 _dim(getUnique(dim)), _length(product(dim)),
43 _nchain(nchain), _data(0)
44 {
45 if (nchain==0)
46 throw logic_error("Node must have at least one chain");
47
48 unsigned int N = _length * _nchain;
49 _data = new double[N];
50 for (unsigned int i = 0; i < N; ++i) {
51 _data[i] = JAGS_NA;
52 }
53
54 _stoch_children = new list<StochasticNode*>;
55 _dtrm_children = new list<DeterministicNode*>;
56 }
57
~Node()58 Node::~Node()
59 {
60 delete [] _data;
61 delete _stoch_children;
62 delete _dtrm_children;
63 }
64
parents() const65 vector <Node const *> const &Node::parents() const
66 {
67 return _parents;
68 }
69
stochasticChildren()70 list<StochasticNode*> const *Node::stochasticChildren()
71 {
72 return _stoch_children;
73 }
74
deterministicChildren()75 list<DeterministicNode*> const *Node::deterministicChildren()
76 {
77 return _dtrm_children;
78 }
79
isInitialized(Node const * node,unsigned int n)80 static bool isInitialized(Node const *node, unsigned int n)
81 {
82 double const *value = node->value(n);
83 for (unsigned int i = 0; i < node->length(); ++i) {
84 if (value[i] == JAGS_NA)
85 return false;
86 }
87 return true;
88 }
89
initialize(unsigned int n)90 bool Node::initialize(unsigned int n)
91 {
92 // Test whether node is already initialized and, if so, skip it
93 if (isInitialized(this, n))
94 return true;
95
96 // Check that parents are initialized
97 for (unsigned int i = 0; i < _parents.size(); ++i) {
98 if (!isInitialized(_parents[i], n)) {
99 return false; // Uninitialized parent
100 }
101 }
102
103 deterministicSample(n);
104
105 return true;
106 }
107
nchain() const108 unsigned int Node::nchain() const
109 {
110 return _nchain;
111 }
112
countChains(vector<Node const * > const & parameters)113 unsigned int countChains(vector<Node const *> const ¶meters)
114 {
115 unsigned int nchain = parameters.empty() ? 0 : parameters[0]->nchain();
116
117 for (unsigned int i = 1; i < parameters.size(); ++i) {
118 if (parameters[i]->nchain() != nchain) {
119 nchain = 0;
120 break;
121 }
122 }
123
124 return nchain;
125 }
126
setValue(double const * value,unsigned int length,unsigned int chain)127 void Node::setValue(double const *value, unsigned int length, unsigned int chain)
128 {
129 if (length != _length)
130 throw NodeError(this, "Length mismatch in Node::setValue");
131 if (chain >= _nchain)
132 throw NodeError(this, "Invalid chain in Node::setValue");
133
134 copy(value, value + _length, _data + chain * _length);
135 }
136
swapValue(unsigned int chain1,unsigned int chain2)137 void Node::swapValue(unsigned int chain1, unsigned int chain2)
138 {
139 double *value1 = _data + chain1 * _length;
140 double *value2 = _data + chain2 * _length;
141 for (unsigned int i = 0; i < _length; ++i) {
142 double v = value1[i];
143 value1[i] = value2[i];
144 value2[i] = v;
145 }
146 }
147
value(unsigned int chain) const148 double const *Node::value(unsigned int chain) const
149 {
150 return _data + chain * _length;
151 }
152
dim() const153 vector<unsigned int> const &Node::dim() const
154 {
155 return _dim;
156 }
157
length() const158 unsigned int Node::length() const
159 {
160 return _length;
161 }
162
addChild(DeterministicNode * node) const163 void Node::addChild(DeterministicNode *node) const
164 {
165 _dtrm_children->push_back(node);
166 }
167
addChild(StochasticNode * node) const168 void Node::addChild(StochasticNode *node) const
169 {
170 _stoch_children->push_back(node);
171 }
172
removeChild(DeterministicNode * node) const173 void Node::removeChild(DeterministicNode *node) const
174 {
175 /*
176 Removes the given node from the list of deterministic children.
177
178 When Model::~Model is called, all nodes are deleted in reverse
179 order of construction. In this case, the element of _dtrm_node
180 to remove is the last one. For efficiency, we therefore search
181 for the element of _dtrm_node to remove starting at the
182 end. (NB Searching from the beginning results in quadratic
183 complexity in the size of _dtrm_node, which can cause real
184 efficiency problems when deleting a model)
185 */
186
187 list<DeterministicNode*>::reverse_iterator p =
188 find(_dtrm_children->rbegin(), _dtrm_children->rend(), node);
189 if (p != _dtrm_children->rend()) {
190 //The erase function only accepts iterators. Some shennanigans
191 //are required to convert the reverse iterator correctly.
192 _dtrm_children->erase((++p).base());
193 }
194 }
195
removeChild(StochasticNode * node) const196 void Node::removeChild(StochasticNode *node) const
197 {
198 /* See comments in removeChild for DeterministicNodes */
199
200 list<StochasticNode*>::reverse_iterator p =
201 find(_stoch_children->rbegin(), _stoch_children->rend(), node);
202 if (p != _stoch_children->rend()) {
203 _stoch_children->erase((++p).base());
204 }
205 }
206
207 } //namespace jags
208