1 #include <config.h>
2 #include <graph/LogicalNode.h>
3 #include <function/FuncError.h>
4 #include <function/Function.h>
5 #include <graph/GraphMarks.h>
6 #include <graph/Graph.h>
7 #include <util/dim.h>
8
9 #include <stdexcept>
10 #include <vector>
11 #include <string>
12 #include <math.h>
13
14 using std::vector;
15 using std::string;
16 using std::set;
17 using std::logic_error;
18
19 namespace jags {
20
21 static vector<vector<double const *> >
mkParams(vector<Node const * > const & parents,unsigned int nchain)22 mkParams(vector<Node const*> const &parents, unsigned int nchain)
23 {
24 vector<vector<double const *> > ans(nchain);
25 for (unsigned int n = 0; n < nchain; ++n) {
26 ans[n].reserve(parents.size());
27 for (unsigned long j = 0; j < parents.size(); ++j) {
28 ans[n].push_back(parents[j]->value(n));
29 }
30
31 }
32 return ans;
33 }
34
35
LogicalNode(vector<unsigned int> const & dim,unsigned int nchain,vector<Node const * > const & parameters,Function const * function)36 LogicalNode::LogicalNode(vector<unsigned int> const &dim,
37 unsigned int nchain,
38 vector<Node const *> const ¶meters,
39 Function const *function)
40 : DeterministicNode(dim, nchain, parameters),
41 _func(function), _discrete(false),
42 _parameters(mkParams(parameters, nchain))
43 {
44 if (!checkNPar(function, parameters.size())) {
45 throw FuncError(function, "Incorrect number of arguments");
46 }
47 vector<bool> mask(parents().size());
48 for (unsigned long j = 0; j < parents().size(); ++j) {
49 mask[j] = parents()[j]->isDiscreteValued();
50 }
51 if (!_func->checkParameterDiscrete(mask)) {
52 throw FuncError(function, "Failed check for discrete-valued arguments");
53 }
54 _discrete = _func->isDiscreteValued(mask);
55 }
56
deparse(vector<string> const & parents) const57 string LogicalNode::deparse(vector<string> const &parents) const
58 {
59 string name = "(";
60 name.append(_func->deparse(parents));
61 name.append(")");
62
63 return name;
64 }
65
isClosed(set<Node const * > const & ancestors,ClosedFuncClass fc,bool fixed) const66 bool LogicalNode::isClosed(set<Node const *> const &ancestors,
67 ClosedFuncClass fc, bool fixed) const
68 {
69 vector<Node const *> const &par = parents();
70
71 vector<bool> mask(par.size());
72 vector<bool> fixed_mask;
73 unsigned int nmask = 0;
74 for (unsigned int i = 0; i < par.size(); ++i) {
75 mask[i] = ancestors.count(par[i]);
76 if (mask[i]) {
77 ++nmask;
78 }
79 if (fixed) {
80 fixed_mask.push_back(par[i]->isFixed());
81 }
82 }
83
84 if (nmask == 0) {
85 throw logic_error("Invalid mask in LogicalNode::isClosed");
86 }
87
88 switch(fc) {
89 case DNODE_ADDITIVE:
90 return _func->isAdditive(mask, fixed_mask);
91 break;
92 case DNODE_LINEAR:
93 return _func->isLinear(mask, fixed_mask);
94 break;
95 case DNODE_SCALE:
96 return _func->isScale(mask, fixed_mask);
97 break;
98 case DNODE_SCALE_MIX:
99 return (nmask == 1) && _func->isScale(mask, fixed_mask);
100 break;
101 case DNODE_POWER:
102 return _func->isPower(mask, fixed_mask);
103 break;
104 }
105
106 return false; //Wall
107 }
108
isDiscreteValued() const109 bool LogicalNode::isDiscreteValued() const
110 {
111 return _discrete;
112 }
113
114 } //namespace jags
115