1 #include <config.h>
2 #include <module/Module.h>
3 #include <compiler/Compiler.h>
4 #include <model/Model.h>
5 #include <function/DFunction.h>
6 #include <function/PFunction.h>
7 #include <function/QFunction.h>
8 #include <function/ScalarLogDensity.h>
9 #include <function/VectorLogDensity.h>
10 #include <function/ArrayLogDensity.h>
11 #include <distribution/RScalarDist.h>
12 
13 #include <algorithm>
14 
15 using std::vector;
16 using std::list;
17 using std::string;
18 using std::find;
19 using std::pair;
20 
21 namespace jags {
22 
Module(string const & name)23 Module::Module(string const &name)
24     : _name(name), _loaded(false)
25 {
26     modules().push_back(this);
27 }
28 
~Module()29 Module::~Module()
30 {
31     //FIXME: Could be causing windows segfault??
32     unload();
33     list<Module*>::iterator p = find(modules().begin(), modules().end(), this);
34     if (p != modules().end()) {
35 	modules().erase(p);
36     }
37 }
38 
39 
insert(ScalarFunction * func)40 void Module::insert(ScalarFunction *func)
41 {
42     _functions.push_back(func);
43     _fp_list.push_back(FunctionPtr(func));
44 }
45 
insert(LinkFunction * func)46 void Module::insert(LinkFunction *func)
47 {
48     _functions.push_back(func);
49     _fp_list.push_back(FunctionPtr(func));
50 
51 }
52 
insert(VectorFunction * func)53 void Module::insert(VectorFunction *func)
54 {
55     _functions.push_back(func);
56     _fp_list.push_back(FunctionPtr(func));
57 
58 }
59 
insert(ArrayFunction * func)60 void Module::insert(ArrayFunction *func)
61 {
62     _functions.push_back(func);
63     _fp_list.push_back(FunctionPtr(func));
64 
65 }
66 
insert(RScalarDist * dist)67 void Module::insert(RScalarDist *dist)
68 {
69     _distributions.push_back(dist);
70     _dp_list.push_back(DistPtr(dist));
71 
72     insert(new ScalarLogDensity(dist));
73 
74     insert(new DFunction(dist));
75     insert(new PFunction(dist));
76     insert(new QFunction(dist));
77 }
78 
insert(ScalarDist * dist)79 void Module::insert(ScalarDist *dist)
80 {
81     _distributions.push_back(dist);
82     _dp_list.push_back(DistPtr(dist));
83 
84     insert(new ScalarLogDensity(dist));
85 }
86 
insert(VectorDist * dist)87 void Module::insert(VectorDist *dist)
88 {
89     _distributions.push_back(dist);
90     _dp_list.push_back(DistPtr(dist));
91 
92     insert(new VectorLogDensity(dist));
93 }
94 
insert(ArrayDist * dist)95 void Module::insert(ArrayDist *dist)
96 {
97     _distributions.push_back(dist);
98     _dp_list.push_back(DistPtr(dist));
99 
100     insert(new ArrayLogDensity(dist));
101 }
102 
insert(ScalarDist * dist,ScalarFunction * func)103 void Module::insert(ScalarDist *dist, ScalarFunction *func)
104 {
105     _obs_functions.push_back(pair<DistPtr,FunctionPtr>(dist,func));
106     insert(dist);
107     insert(func);
108 }
109 
insert(ScalarDist * dist,LinkFunction * func)110 void Module::insert(ScalarDist *dist, LinkFunction *func)
111 {
112     _obs_functions.push_back(pair<DistPtr,FunctionPtr>(dist,func));
113     insert(dist);
114     insert(func);
115 }
116 
insert(ScalarDist * dist,VectorFunction * func)117 void Module::insert(ScalarDist *dist, VectorFunction *func)
118 {
119     _obs_functions.push_back(pair<DistPtr,FunctionPtr>(dist,func));
120     insert(dist);
121     insert(func);
122 }
123 
insert(ScalarDist * dist,ArrayFunction * func)124 void Module::insert(ScalarDist *dist, ArrayFunction *func)
125 {
126     _obs_functions.push_back(pair<DistPtr,FunctionPtr>(dist,func));
127     insert(dist);
128     insert(func);
129 }
130 
131 //
132 
insert(VectorDist * dist,ScalarFunction * func)133 void Module::insert(VectorDist *dist, ScalarFunction *func)
134 {
135     _obs_functions.push_back(pair<DistPtr,FunctionPtr>(dist,func));
136     insert(dist);
137     insert(func);
138 }
139 
insert(VectorDist * dist,LinkFunction * func)140 void Module::insert(VectorDist *dist, LinkFunction *func)
141 {
142     _obs_functions.push_back(pair<DistPtr,FunctionPtr>(dist,func));
143     insert(dist);
144     insert(func);
145 }
146 
insert(VectorDist * dist,VectorFunction * func)147 void Module::insert(VectorDist *dist, VectorFunction *func)
148 {
149     _obs_functions.push_back(pair<DistPtr,FunctionPtr>(dist,func));
150     insert(dist);
151     insert(func);
152 }
153 
insert(VectorDist * dist,ArrayFunction * func)154 void Module::insert(VectorDist *dist, ArrayFunction *func)
155 {
156     _obs_functions.push_back(pair<DistPtr,FunctionPtr>(dist,func));
157     insert(dist);
158     insert(func);
159 }
160 
161 //
162 
insert(ArrayDist * dist,ScalarFunction * func)163 void Module::insert(ArrayDist *dist, ScalarFunction *func)
164 {
165     _obs_functions.push_back(pair<DistPtr,FunctionPtr>(dist,func));
166     insert(dist);
167     insert(func);
168 }
169 
insert(ArrayDist * dist,LinkFunction * func)170 void Module::insert(ArrayDist *dist, LinkFunction *func)
171 {
172     _obs_functions.push_back(pair<DistPtr,FunctionPtr>(dist,func));
173     insert(dist);
174     insert(func);
175 }
176 
insert(ArrayDist * dist,VectorFunction * func)177 void Module::insert(ArrayDist *dist, VectorFunction *func)
178 {
179     _obs_functions.push_back(pair<DistPtr,FunctionPtr>(dist,func));
180     insert(dist);
181     insert(func);
182 }
183 
insert(ArrayDist * dist,ArrayFunction * func)184 void Module::insert(ArrayDist *dist, ArrayFunction *func)
185 {
186     _obs_functions.push_back(pair<DistPtr,FunctionPtr>(dist,func));
187     insert(dist);
188     insert(func);
189 }
190 
insert(SamplerFactory * fac)191 void Module::insert(SamplerFactory *fac)
192 {
193     _sampler_factories.push_back(fac);
194 }
195 
insert(RNGFactory * fac)196 void Module::insert(RNGFactory *fac)
197 {
198     _rng_factories.push_back(fac);
199 }
200 
insert(MonitorFactory * fac)201 void Module::insert(MonitorFactory *fac)
202 {
203     _monitor_factories.push_back(fac);
204 }
205 
load()206 void Module::load()
207 {
208     if (_loaded)
209 	return;
210 
211     for (unsigned int i = 0; i < _monitor_factories.size(); ++i) {
212 	pair<MonitorFactory*,bool> p(_monitor_factories[i], true);
213 	Model::monitorFactories().push_front(p);
214     }
215     for (unsigned int i = 0; i < _rng_factories.size(); ++i) {
216 	pair<RNGFactory*, bool> p(_rng_factories[i], true);
217 	Model::rngFactories().push_front(p);
218     }
219     for (unsigned int i = 0; i < _sampler_factories.size(); ++i) {
220 	pair<SamplerFactory*, bool> p(_sampler_factories[i], true);
221 	Model::samplerFactories().push_front(p);
222     }
223     for (unsigned int i = 0; i < _dp_list.size(); ++i) {
224 	Compiler::distTab().insert(_dp_list[i]);
225     }
226     for (unsigned int i = 0; i < _fp_list.size(); ++i) {
227 	Compiler::funcTab().insert(_fp_list[i]);
228     }
229     for (unsigned int i = 0; i < _obs_functions.size(); ++i) {
230 	Compiler::obsFuncTab().insert(_obs_functions[i].first,
231 				      _obs_functions[i].second);
232     }
233 
234     _loaded = true;
235     loadedModules().push_back(this);
236 }
237 
unload()238 void Module::unload()
239 {
240     if (!_loaded)
241 	return;
242 
243     loadedModules().remove(this);
244     _loaded = false;
245 
246     for (unsigned int i = 0; i < _fp_list.size(); ++i) {
247 	Compiler::funcTab().erase(_fp_list[i]);
248     }
249     for (unsigned int i = 0; i < _obs_functions.size(); ++i) {
250 	Compiler::obsFuncTab().erase(_obs_functions[i].first,
251 				     _obs_functions[i].second);
252     }
253     for (unsigned int i = 0; i < _distributions.size(); ++i) {
254 	Compiler::distTab().erase(_dp_list[i]);
255     }
256 
257     list<pair<RNGFactory *, bool> > &rngf = Model::rngFactories();
258     for (unsigned int i = 0; i < _rng_factories.size(); ++i) {
259 	RNGFactory *f = _rng_factories[i];
260 	rngf.remove(pair<RNGFactory *, bool>(f, true));
261 	rngf.remove(pair<RNGFactory *, bool>(f, false));
262     }
263 
264     list<pair<SamplerFactory *, bool> > &sf = Model::samplerFactories();
265     for (unsigned int i = 0; i < _sampler_factories.size(); ++i) {
266 	SamplerFactory *f = _sampler_factories[i];
267 	sf.remove(pair<SamplerFactory *, bool>(f, true));
268 	sf.remove(pair<SamplerFactory *, bool>(f, false));
269     }
270 
271     list<pair<MonitorFactory *, bool> > &mf = Model::monitorFactories();
272     for (unsigned int i = 0; i < _monitor_factories.size(); ++i) {
273 	MonitorFactory *f = _monitor_factories[i];
274 	mf.remove(pair<MonitorFactory *, bool>(f, true));
275 	mf.remove(pair<MonitorFactory *, bool>(f, false));
276     }
277 
278 }
279 
functions() const280 vector<Function*> const &Module::functions() const
281 {
282     return _functions;
283 }
284 
distributions() const285 vector<Distribution*> const &Module::distributions() const
286 {
287     return _distributions;
288 }
289 
samplerFactories() const290 vector<SamplerFactory*> const &Module::samplerFactories() const
291 {
292     return _sampler_factories;
293 }
294 
rngFactories() const295 vector<RNGFactory*> const &Module::rngFactories() const
296 {
297     return _rng_factories;
298 }
299 
monitorFactories() const300 vector<MonitorFactory*> const &Module::monitorFactories() const
301 {
302     return _monitor_factories;
303 }
304 
name() const305 string const &Module::name() const
306 {
307     return _name;
308 }
309 
modules()310 list<Module *> &Module::modules()
311 {
312     static list<Module*> *_modules = new list<Module*>;
313     return *_modules;
314 }
315 
loadedModules()316 list<Module *> &Module::loadedModules()
317 {
318     static list<Module*> *_modules = new list<Module*>;
319     return *_modules;
320 }
321 
322 } //namespace jags
323