1 #include <config.h>
2 #include <model/SymTab.h>
3 #include <model/Model.h>
4 #include <graph/Node.h>
5 #include <util/nainf.h>
6 #include <util/dim.h>
7
8 #include <string>
9 #include <stdexcept>
10 #include <utility>
11 #include <set>
12
13 using std::vector;
14 using std::map;
15 using std::pair;
16 using std::string;
17 using std::runtime_error;
18 using std::logic_error;
19 using std::set;
20
21 namespace jags {
22
SymTab(Model * model)23 SymTab::SymTab(Model *model)
24 : _model(model)
25 {
26 }
27
~SymTab()28 SymTab::~SymTab() {
29
30 map<string, NodeArray*>::iterator p;
31 for (p = _varTable.begin(); p != _varTable.end(); ++p) {
32 delete p->second;
33 }
34 }
35
addVariable(string const & name,vector<unsigned int> const & dim)36 void SymTab::addVariable(string const &name, vector<unsigned int> const &dim)
37 {
38 if (_varTable.find(name) != _varTable.end()) {
39 string msg("Name ");
40 msg.append(name);
41 msg.append(" already in use in symbol table");
42 throw runtime_error(msg);
43 }
44
45 if (isFlat(dim)) {
46 string msg = string("Cannot create variable ") + name +
47 " with zero dimension";
48 throw runtime_error(msg);
49 }
50 NodeArray *array = new NodeArray(name, dim, _model->nchain());
51 _varTable[name] = array;
52 }
53
getVariable(string const & name) const54 NodeArray* SymTab::getVariable(string const &name) const
55 {
56 map<string, NodeArray*>::const_iterator p = _varTable.find(name);
57
58 if (p == _varTable.end()) {
59 return 0;
60 }
61 else {
62 return p->second;
63 }
64 }
65
writeData(map<string,SArray> const & data_table)66 void SymTab::writeData(map<string, SArray> const &data_table)
67 {
68 for(map<string, SArray>::const_iterator p(data_table.begin());
69 p != data_table.end(); ++p) {
70 NodeArray *array = getVariable(p->first);
71 if (array) {
72 if (array->range().dim(false) != p->second.dim(false)) {
73 string msg("Dimension mismatch in values supplied for ");
74 msg.append(p->first);
75 throw runtime_error(msg);
76 }
77 array->setData(p->second, _model);
78 }
79 }
80 }
81
82
writeValues(map<string,SArray> const & data_table,unsigned int chain)83 void SymTab::writeValues(map<string, SArray> const &data_table,
84 unsigned int chain)
85 {
86 for(map<string, SArray>::const_iterator p(data_table.begin());
87 p != data_table.end(); ++p) {
88 //set<Node*> psetnodes;
89 NodeArray *array = getVariable(p->first);
90 if (array) {
91 if (array->range().dim(false) != p->second.dim(false)) {
92 string msg("Dimension mismatch in values supplied for ");
93 msg.append(p->first);
94 throw runtime_error(msg);
95 }
96 //array->setValue(p->second, chain, psetnodes);
97 array->setValue(p->second, chain);
98 //setnodes.insert(psetnodes.begin(), psetnodes.end());
99 }
100 }
101 }
102
allMissing(SArray const & sarray)103 static bool allMissing(SArray const &sarray)
104 {
105 unsigned int N=sarray.length();
106 vector<double> const &v = sarray.value();
107 for (unsigned int i = 0; i < N; ++i) {
108 if (v[i] != JAGS_NA)
109 return false;
110 }
111 return true;
112 }
113
readValues(map<string,SArray> & data_table,unsigned int chain,bool (* condition)(Node const *)) const114 void SymTab::readValues(map<string, SArray> &data_table,
115 unsigned int chain,
116 bool (*condition)(Node const *)) const
117 {
118 if (chain > _model->nchain())
119 throw logic_error("Invalid chain in SymTab::readValues");
120 if (!condition)
121 throw logic_error("NULL condition in Symtab::readValues");
122
123 map<string, NodeArray*>::const_iterator p;
124 for (p = _varTable.begin(); p != _varTable.end(); ++p) {
125 /* Create a new SArray to hold the values from the symbol table */
126 SArray read_values(p->second->range().dim(false));
127 p->second->getValue(read_values, chain, condition);
128 /* Only write to the data table if we can find at least one
129 non-missing value */
130 if (!allMissing(read_values)) {
131 string const &name = p->first;
132 if (data_table.find(name) != data_table.end()) {
133 //Replace any existing entry
134 data_table.erase(name);
135 }
136 data_table.insert(pair<string,SArray>(name, read_values));
137 }
138 }
139 }
140
size() const141 unsigned int SymTab::size() const
142 {
143 return _varTable.size();
144 }
145
clear()146 void SymTab::clear()
147 {
148 _varTable.clear();
149 }
150
getName(Node const * node) const151 string SymTab::getName(Node const *node) const
152 {
153 map<string, NodeArray*>::const_iterator p;
154 for (p = _varTable.begin(); p != _varTable.end(); ++p) {
155 NodeArray *array = p->second;
156 Range node_range = array->getRange(node);
157 if (!isNULL(node_range)) {
158 if (node_range == array->range()) {
159 return p->first;
160 }
161 else {
162 return p->first + print(array->getRange(node));
163 }
164 }
165 }
166
167 //Name not in symbol table: calculate name from parents
168 vector<Node const *> const &parents = node->parents();
169 vector<string> parnames(parents.size());
170 for (unsigned int i = 0; i < parents.size(); ++i) {
171 parnames[i] = getName(parents[i]);
172 }
173 return node->deparse(parnames);
174 }
175
176 } //namespace jags
177