1 #include <config.h>
2 #include <graph/GraphMarks.h>
3 #include <graph/Graph.h>
4 #include <graph/Node.h>
5
6 #include <vector>
7 #include <set>
8 #include <stdexcept>
9 #include <utility>
10 #include <list>
11
12 using std::map;
13 using std::vector;
14 using std::set;
15 using std::logic_error;
16 using std::pair;
17 using std::list;
18
19 namespace jags {
20
21
22 /*
23 This is a helper class for the markAncestors member function. A
24 GMIterator moves across a vector of Nodes. A dereferencing
25 operator and a prefix increment operator are provided so that
26 it behaves like an STL iterator.
27
28 Unlike an STL iterator, however, a GMIterator stores its end
29 position. The atEnd member function can be used to test
30 whether the iterator has moved beyond the last element of the
31 vector.
32 */
33 class GMIterator {
34 vector<Node const*>::const_iterator _begin, _end;
35 public:
GMIterator(vector<Node const * > const & v)36 GMIterator(vector<Node const*> const &v)
37 : _begin(v.begin()), _end(v.end()) {}
operator *() const38 inline Node const * operator*() const { return *_begin; }
atEnd() const39 inline bool atEnd() const { return _begin == _end; }
operator ++()40 inline void operator++() { ++_begin; }
41 };
42
GraphMarks(Graph const & graph)43 GraphMarks::GraphMarks(Graph const &graph)
44 : _graph(graph)
45 {
46 }
47
~GraphMarks()48 GraphMarks::~GraphMarks()
49 {}
50
graph() const51 Graph const &GraphMarks::graph() const
52 {
53 return _graph;
54 }
55
mark(Node const * node,int m)56 void GraphMarks::mark(Node const *node, int m)
57 {
58 if (!_graph.contains(node)) {
59 throw logic_error("Attempt to set mark of node not in graph");
60 }
61 if (m == 0) {
62 _marks.erase(node);
63 }
64 else {
65 _marks[node] = m;
66 }
67 }
68
mark(Node const * node) const69 int GraphMarks::mark(Node const *node) const
70 {
71 if (!_graph.contains(node)) {
72 throw logic_error("Attempt to get mark of node not in Graph");
73 }
74
75 map<Node const*, int>::const_iterator i = _marks.find(node);
76 if (i == _marks.end()) {
77 return 0;
78 }
79 else {
80 return i->second;
81 }
82 }
83
clear()84 void GraphMarks::clear()
85 {
86 _marks.clear();
87 }
88
markParents(Node const * node,int m)89 void GraphMarks::markParents(Node const *node, int m)
90 {
91 if (!_graph.contains(node)) {
92 throw logic_error("Can't mark parents of node: not in Graph");
93 }
94 else {
95 vector<Node const *> const &parents = node->parents();
96 for (vector<Node const *>::const_iterator p = parents.begin();
97 p != parents.end(); ++p)
98 {
99 if (_graph.contains(*p)) {
100 _marks[*p] = m;
101 }
102 }
103 }
104 }
105
106 //FIXME
107 //Used by MixtureSampler factory
108 void
markParents(Node const * node,bool (* test)(Node const *),int m)109 GraphMarks::markParents(Node const *node, bool (*test)(Node const *), int m)
110 {
111
112 if (!_graph.contains(node)) {
113 throw logic_error("Can't mark parents of node: not in Graph");
114 }
115
116 vector<Node const *> const &parents = node->parents();
117 for (vector<Node const*>::const_iterator p = parents.begin();
118 p != parents.end(); ++p)
119 {
120 Node const *parent = *p;
121 if (_graph.contains(parent)) {
122 if (test(parent)) {
123 _marks[parent] = m;
124 }
125 else {
126 markParents(parent, test, m);
127 }
128 }
129 }
130 }
131
markAncestors(vector<Node const * > const & nodes,int m)132 void GraphMarks::markAncestors(vector<Node const *> const &nodes, int m)
133 {
134 set<Node const*> visited; //visited nodes
135 vector<Node const*> ancestors; //ancestor nodes
136
137 /*
138 Do a depth-first search of the graph to find all the ancestors
139 of the given Nodes in the graph. The set "visited" keeps track
140 of previously visited nodes for efficiency. Ancestors are
141 pushed back on to the vector "ancestors" in the order they are
142 found.
143
144 We could do this with a recursive helper function, but it is
145 safer to iterate. So we keep our own stack of GMIterators to
146 record the current position on the graph.
147
148 The GMIterator class is defined above.
149 */
150
151 vector<GMIterator> stack;
152 stack.push_back(GMIterator(nodes));
153
154 while (!stack.empty()) {
155
156 for (GMIterator &p = stack.back(); !p.atEnd(); ++p) {
157 if (visited.count(*p) == 0 && _graph.contains(*p)) {
158 visited.insert(*p);
159 ancestors.push_back(*p);
160 stack.push_back(GMIterator((*p)->parents()));
161 break;
162 }
163 }
164
165 if (stack.back().atEnd()) {
166 stack.pop_back();
167 }
168 }
169
170 /* Now set the marks of all ancestors */
171 for(vector<Node const*>::const_iterator p = ancestors.begin();
172 p != ancestors.end(); ++p)
173 {
174 if (m == 0) {
175 _marks.erase(*p);
176 }
177 else {
178 _marks[*p] = m;
179 }
180 }
181
182 }
183
184 } //namespace jags
185