1 #include <config.h>
2 #include <graph/Node.h>
3 #include <graph/NodeError.h>
4 #include <util/nainf.h>
5 #include <util/dim.h>
6 
7 #include <stdexcept>
8 #include <algorithm>
9 
10 using std::string;
11 using std::vector;
12 using std::logic_error;
13 using std::copy;
14 using std::find;
15 using std::list;
16 
17 namespace jags {
18 
19 class DeterminsticNode;
20 class StochasticNode;
21 
Node(vector<unsigned int> const & dim,unsigned int nchain)22 Node::Node(vector<unsigned int> const &dim, unsigned int nchain)
23     : _parents(0), _stoch_children(0), _dtrm_children(0),
24       _dim(getUnique(dim)), _length(product(dim)), _nchain(nchain), _data(0)
25 {
26     if (nchain==0)
27 	throw logic_error("Node must have at least one chain");
28 
29     unsigned int N = _length * _nchain;
30     _data = new double[N];
31     for (unsigned int i = 0; i < N; ++i) {
32 	_data[i] = JAGS_NA;
33     }
34 
35     _dtrm_children = new list<DeterministicNode*>;
36     _stoch_children = new list<StochasticNode*>;
37 }
38 
Node(vector<unsigned int> const & dim,unsigned int nchain,vector<Node const * > const & parents)39 Node::Node(vector<unsigned int> const &dim, unsigned int nchain,
40 	   vector<Node const *> const &parents)
41     : _parents(parents), _stoch_children(0), _dtrm_children(0),
42       _dim(getUnique(dim)), _length(product(dim)),
43       _nchain(nchain), _data(0)
44 {
45     if (nchain==0)
46 	throw logic_error("Node must have at least one chain");
47 
48     unsigned int N = _length * _nchain;
49     _data = new double[N];
50     for (unsigned int i = 0; i < N; ++i) {
51 	_data[i] = JAGS_NA;
52     }
53 
54     _stoch_children = new list<StochasticNode*>;
55     _dtrm_children = new list<DeterministicNode*>;
56 }
57 
~Node()58 Node::~Node()
59 {
60     delete [] _data;
61     delete _stoch_children;
62     delete _dtrm_children;
63 }
64 
parents() const65 vector <Node const *> const &Node::parents() const
66 {
67     return _parents;
68 }
69 
stochasticChildren()70 list<StochasticNode*> const *Node::stochasticChildren()
71 {
72     return _stoch_children;
73 }
74 
deterministicChildren()75 list<DeterministicNode*> const *Node::deterministicChildren()
76 {
77     return _dtrm_children;
78 }
79 
isInitialized(Node const * node,unsigned int n)80 static bool isInitialized(Node const *node, unsigned int n)
81 {
82     double const *value = node->value(n);
83     for (unsigned int i = 0; i < node->length(); ++i) {
84 	if (value[i] == JAGS_NA)
85 	    return false;
86     }
87     return true;
88 }
89 
initialize(unsigned int n)90 bool Node::initialize(unsigned int n)
91 {
92     // Test whether node is already initialized and, if so, skip it
93     if (isInitialized(this, n))
94         return true;
95 
96     // Check that parents are initialized
97     for (unsigned int i = 0; i < _parents.size(); ++i) {
98         if (!isInitialized(_parents[i], n)) {
99 	    return false; // Uninitialized parent
100         }
101     }
102 
103     deterministicSample(n);
104 
105     return true;
106 }
107 
nchain() const108 unsigned int Node::nchain() const
109 {
110   return _nchain;
111 }
112 
countChains(vector<Node const * > const & parameters)113 unsigned int countChains(vector<Node const *> const &parameters)
114 {
115     unsigned int nchain = parameters.empty() ? 0 : parameters[0]->nchain();
116 
117     for (unsigned int i = 1; i < parameters.size(); ++i) {
118 	if (parameters[i]->nchain() != nchain) {
119 	    nchain = 0;
120 	    break;
121 	}
122     }
123 
124     return nchain;
125 }
126 
setValue(double const * value,unsigned int length,unsigned int chain)127 void Node::setValue(double const *value, unsigned int length, unsigned int chain)
128 {
129    if (length != _length)
130       throw NodeError(this, "Length mismatch in Node::setValue");
131    if (chain >= _nchain)
132       throw NodeError(this, "Invalid chain in Node::setValue");
133 
134    copy(value, value + _length, _data + chain * _length);
135 }
136 
swapValue(unsigned int chain1,unsigned int chain2)137 void Node::swapValue(unsigned int chain1, unsigned int chain2)
138 {
139     double *value1 = _data + chain1 * _length;
140     double *value2 = _data + chain2 * _length;
141     for (unsigned int i = 0; i < _length; ++i) {
142 	double v = value1[i];
143 	value1[i] = value2[i];
144 	value2[i] = v;
145     }
146 }
147 
value(unsigned int chain) const148 double const *Node::value(unsigned int chain) const
149 {
150     return _data + chain * _length;
151 }
152 
dim() const153 vector<unsigned int> const &Node::dim() const
154 {
155     return _dim;
156 }
157 
length() const158 unsigned int Node::length() const
159 {
160     return _length;
161 }
162 
addChild(DeterministicNode * node) const163 void Node::addChild(DeterministicNode *node) const
164 {
165     _dtrm_children->push_back(node);
166 }
167 
addChild(StochasticNode * node) const168 void Node::addChild(StochasticNode *node) const
169 {
170     _stoch_children->push_back(node);
171 }
172 
removeChild(DeterministicNode * node) const173 void Node::removeChild(DeterministicNode *node) const
174 {
175     /*
176        Removes the given node from the list of deterministic children.
177 
178        When Model::~Model is called, all nodes are deleted in reverse
179        order of construction. In this case, the element of _dtrm_node
180        to remove is the last one. For efficiency, we therefore search
181        for the element of _dtrm_node to remove starting at the
182        end. (NB Searching from the beginning results in quadratic
183        complexity in the size of _dtrm_node, which can cause real
184        efficiency problems when deleting a model)
185     */
186 
187     list<DeterministicNode*>::reverse_iterator p =
188 	find(_dtrm_children->rbegin(), _dtrm_children->rend(), node);
189     if (p != _dtrm_children->rend()) {
190 	//The erase function only accepts iterators. Some shennanigans
191 	//are required to convert the reverse iterator correctly.
192 	_dtrm_children->erase((++p).base());
193     }
194 }
195 
removeChild(StochasticNode * node) const196 void Node::removeChild(StochasticNode *node) const
197 {
198     /* See comments in removeChild for DeterministicNodes */
199 
200     list<StochasticNode*>::reverse_iterator p =
201 	find(_stoch_children->rbegin(), _stoch_children->rend(), node);
202     if (p != _stoch_children->rend()) {
203 	_stoch_children->erase((++p).base());
204     }
205 }
206 
207 } //namespace jags
208