1 #include <config.h>
2 #include <graph/LogicalNode.h>
3 #include <function/FuncError.h>
4 #include <function/Function.h>
5 #include <graph/GraphMarks.h>
6 #include <graph/Graph.h>
7 #include <util/dim.h>
8 
9 #include <stdexcept>
10 #include <vector>
11 #include <string>
12 #include <math.h>
13 
14 using std::vector;
15 using std::string;
16 using std::set;
17 using std::logic_error;
18 
19 namespace jags {
20 
21 static vector<vector<double const *> >
mkParams(vector<Node const * > const & parents,unsigned int nchain)22 mkParams(vector<Node const*> const &parents, unsigned int nchain)
23 {
24     vector<vector<double const *> > ans(nchain);
25     for (unsigned int n = 0; n < nchain; ++n) {
26 	ans[n].reserve(parents.size());
27 	for (unsigned long j = 0; j < parents.size(); ++j) {
28 	    ans[n].push_back(parents[j]->value(n));
29 	}
30 
31     }
32     return ans;
33 }
34 
35 
LogicalNode(vector<unsigned int> const & dim,unsigned int nchain,vector<Node const * > const & parameters,Function const * function)36 LogicalNode::LogicalNode(vector<unsigned int> const &dim,
37 			 unsigned int nchain,
38 			 vector<Node const *> const &parameters,
39 			 Function const *function)
40     : DeterministicNode(dim, nchain, parameters),
41       _func(function), _discrete(false),
42       _parameters(mkParams(parameters, nchain))
43 {
44     if (!checkNPar(function, parameters.size())) {
45 	throw FuncError(function, "Incorrect number of arguments");
46     }
47     vector<bool> mask(parents().size());
48     for (unsigned long j = 0; j < parents().size(); ++j) {
49         mask[j] = parents()[j]->isDiscreteValued();
50     }
51     if (!_func->checkParameterDiscrete(mask)) {
52 	throw FuncError(function, "Failed check for discrete-valued arguments");
53     }
54     _discrete = _func->isDiscreteValued(mask);
55 }
56 
deparse(vector<string> const & parents) const57 string LogicalNode::deparse(vector<string> const &parents) const
58 {
59     string name = "(";
60     name.append(_func->deparse(parents));
61     name.append(")");
62 
63     return name;
64 }
65 
isClosed(set<Node const * > const & ancestors,ClosedFuncClass fc,bool fixed) const66 bool LogicalNode::isClosed(set<Node const *> const &ancestors,
67 			   ClosedFuncClass fc, bool fixed) const
68 {
69     vector<Node const *> const &par = parents();
70 
71     vector<bool> mask(par.size());
72     vector<bool> fixed_mask;
73     unsigned int nmask = 0;
74     for (unsigned int i = 0; i < par.size(); ++i) {
75 	mask[i] = ancestors.count(par[i]);
76 	if (mask[i]) {
77 	    ++nmask;
78 	}
79 	if (fixed) {
80 	    fixed_mask.push_back(par[i]->isFixed());
81 	}
82     }
83 
84     if (nmask == 0) {
85         throw logic_error("Invalid mask in LogicalNode::isClosed");
86     }
87 
88     switch(fc) {
89     case DNODE_ADDITIVE:
90 	return _func->isAdditive(mask, fixed_mask);
91 	break;
92     case DNODE_LINEAR:
93 	return _func->isLinear(mask, fixed_mask);
94         break;
95     case DNODE_SCALE:
96 	return _func->isScale(mask, fixed_mask);
97 	break;
98     case DNODE_SCALE_MIX:
99         return (nmask == 1) && _func->isScale(mask, fixed_mask);
100 	break;
101     case DNODE_POWER:
102 	return _func->isPower(mask, fixed_mask);
103         break;
104     }
105 
106     return false; //Wall
107 }
108 
isDiscreteValued() const109 bool LogicalNode::isDiscreteValued() const
110 {
111     return _discrete;
112 }
113 
114 } //namespace jags
115