1 #include <config.h>
2 #include <model/SymTab.h>
3 #include <model/Model.h>
4 #include <graph/Node.h>
5 #include <util/nainf.h>
6 #include <util/dim.h>
7 
8 #include <string>
9 #include <stdexcept>
10 #include <utility>
11 #include <set>
12 
13 using std::vector;
14 using std::map;
15 using std::pair;
16 using std::string;
17 using std::runtime_error;
18 using std::logic_error;
19 using std::set;
20 
21 namespace jags {
22 
SymTab(Model * model)23 SymTab::SymTab(Model *model)
24     : _model(model)
25 {
26 }
27 
~SymTab()28 SymTab::~SymTab() {
29 
30     map<string, NodeArray*>::iterator p;
31     for (p = _varTable.begin(); p != _varTable.end(); ++p) {
32 	delete p->second;
33     }
34 }
35 
addVariable(string const & name,vector<unsigned int> const & dim)36 void SymTab::addVariable(string const &name, vector<unsigned int> const &dim)
37 {
38     if (_varTable.find(name) != _varTable.end()) {
39 	string msg("Name ");
40 	msg.append(name);
41 	msg.append(" already in use in symbol table");
42 	throw runtime_error(msg);
43     }
44 
45     if (isFlat(dim)) {
46 	string msg = string("Cannot create variable ") + name +
47 	    " with zero dimension";
48 	throw runtime_error(msg);
49     }
50     NodeArray *array = new NodeArray(name, dim, _model->nchain());
51     _varTable[name] = array;
52 }
53 
getVariable(string const & name) const54 NodeArray* SymTab::getVariable(string const &name) const
55 {
56   map<string, NodeArray*>::const_iterator p =  _varTable.find(name);
57 
58   if (p == _varTable.end()) {
59     return 0;
60   }
61   else {
62     return p->second;
63   }
64 }
65 
writeData(map<string,SArray> const & data_table)66 void SymTab::writeData(map<string, SArray> const &data_table)
67 {
68   for(map<string, SArray>::const_iterator p(data_table.begin());
69       p != data_table.end(); ++p) {
70     NodeArray *array = getVariable(p->first);
71     if (array) {
72       if (array->range().dim(false) != p->second.dim(false)) {
73 	string msg("Dimension mismatch in values supplied for ");
74 	msg.append(p->first);
75 	throw runtime_error(msg);
76       }
77       array->setData(p->second, _model);
78     }
79   }
80 }
81 
82 
writeValues(map<string,SArray> const & data_table,unsigned int chain)83 void SymTab::writeValues(map<string, SArray> const &data_table,
84 		         unsigned int chain)
85 {
86     for(map<string, SArray>::const_iterator p(data_table.begin());
87 	p != data_table.end(); ++p) {
88         //set<Node*> psetnodes;
89 	NodeArray *array = getVariable(p->first);
90 	if (array) {
91 	    if (array->range().dim(false) != p->second.dim(false)) {
92 		string msg("Dimension mismatch in values supplied for ");
93 		msg.append(p->first);
94 		throw runtime_error(msg);
95 	    }
96 	    //array->setValue(p->second, chain, psetnodes);
97 	    array->setValue(p->second, chain);
98             //setnodes.insert(psetnodes.begin(), psetnodes.end());
99 	}
100     }
101 }
102 
allMissing(SArray const & sarray)103 static bool allMissing(SArray const &sarray)
104 {
105     unsigned int N=sarray.length();
106     vector<double> const &v = sarray.value();
107     for (unsigned int i = 0; i < N; ++i) {
108 	if (v[i] != JAGS_NA)
109 	    return false;
110     }
111     return true;
112 }
113 
readValues(map<string,SArray> & data_table,unsigned int chain,bool (* condition)(Node const *)) const114 void SymTab::readValues(map<string, SArray> &data_table,
115 		        unsigned int chain,
116                         bool (*condition)(Node const *)) const
117 {
118     if (chain > _model->nchain())
119 	throw logic_error("Invalid chain in SymTab::readValues");
120     if (!condition)
121 	throw logic_error("NULL condition in Symtab::readValues");
122 
123     map<string, NodeArray*>::const_iterator p;
124     for (p = _varTable.begin(); p != _varTable.end(); ++p) {
125 	/* Create a new SArray to hold the values from the symbol table */
126 	SArray read_values(p->second->range().dim(false));
127 	p->second->getValue(read_values, chain, condition);
128 	/* Only write to the data table if we can find at least one
129 	   non-missing value */
130 	if (!allMissing(read_values)) {
131 	    string const &name = p->first;
132 	    if (data_table.find(name) != data_table.end()) {
133 		//Replace any existing entry
134 		data_table.erase(name);
135 	    }
136 	    data_table.insert(pair<string,SArray>(name, read_values));
137 	}
138     }
139 }
140 
size() const141 unsigned int SymTab::size() const
142 {
143   return _varTable.size();
144 }
145 
clear()146 void SymTab::clear()
147 {
148   _varTable.clear();
149 }
150 
getName(Node const * node) const151 string SymTab::getName(Node const *node) const
152 {
153     map<string, NodeArray*>::const_iterator p;
154     for (p = _varTable.begin(); p != _varTable.end(); ++p) {
155 	NodeArray *array = p->second;
156 	Range node_range = array->getRange(node);
157 	if (!isNULL(node_range)) {
158 	    if (node_range == array->range()) {
159 		return p->first;
160 	    }
161 	    else {
162 		return p->first + print(array->getRange(node));
163 	    }
164 	}
165     }
166 
167     //Name not in symbol table: calculate name from parents
168     vector<Node const *> const &parents = node->parents();
169     vector<string> parnames(parents.size());
170     for (unsigned int i = 0; i < parents.size(); ++i) {
171 	parnames[i] = getName(parents[i]);
172     }
173     return node->deparse(parnames);
174 }
175 
176 } //namespace jags
177