1 /**
2  *
3  *  Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6)
4  *   {prenom.nom}_at_lip6.fr
5  *
6  *  This library is free software: you can redistribute it and/or modify
7  *  it under the terms of the GNU Lesser General Public License as published by
8  *  the Free Software Foundation, either version 3 of the License, or
9  *  (at your option) any later version.
10  *
11  *  This library is distributed in the hope that it will be useful,
12  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  *  GNU Lesser General Public License for more details.
15  *
16  *  You should have received a copy of the GNU Lesser General Public License
17  *  along with this library.  If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 %define IMPROVE_INFERENCE_API(classname...)
22 %feature("shadow") gum::classname::setEvidence %{
23 def setEvidence(self, evidces):
24     """
25     Erase all the evidences and apply addEvidence(key,value) for every pairs in evidces.
26 
27     Parameters
28     ----------
29     evidces : dict
30       a dict of evidences
31 
32     Raises
33     ------
34     gum.InvalidArgument
35         If one value is not a value for the node
36     gum.InvalidArgument
37         If the size of a value is different from the domain side of the node
38     gum.FatalError
39         If one value is a vector of 0s
40     gum.UndefinedElement
41         If one node does not belong to the Bayesian network
42     """
43     if not isinstance(evidces, dict):
44         raise TypeError("setEvidence parameter must be a dict, not %s"%(type(evidces)))
45     self.eraseAllEvidence()
46     for k,v in evidces.items():
47         self.addEvidence(k,v)
48 %}
49 
50 %feature("shadow") gum::classname::updateEvidence %{
51 def updateEvidence(self, evidces):
52     """
53     Apply chgEvidence(key,value) for every pairs in evidces (or addEvidence).
54 
55     Parameters
56     ----------
57     evidces : dict
58       a dict of evidences
59 
60     Raises
61     ------
62     gum.InvalidArgument
63         If one value is not a value for the node
64     gum.InvalidArgument
65         If the size of a value is different from the domain side of the node
66     gum.FatalError
67         If one value is a vector of 0s
68     gum.UndefinedElement
69         If one node does not belong to the Bayesian network
70     """
71     if not isinstance(evidces, dict):
72         raise TypeError("setEvidence parameter must be a dict, not %s"%(type(evidces)))
73 
74     for k,v in evidces.items():
75         if self.hasEvidence(k):
76             self.chgEvidence(k,v)
77         else:
78             self.addEvidence(k,v)
79 %}
80 
81 %feature("shadow") gum::classname::setTargets %{
82 def setTargets(self, targets):
83     """
84     Remove all the targets and add the ones in parameter.
85 
86     Parameters
87     ----------
88     targets : set
89       a set of targets
90 
91     Raises
92     ------
93     gum.UndefinedElement
94         If one target is not in the Bayes net
95     """
96     if not isinstance(targets, set):
97         raise TypeError("setTargets parameter must be a set, not %s"%(type(targets)))
98 
99     self.eraseAllTargets()
100     for k in targets:
101         self.addTarget(k)
102 %}
103 
104 %ignore gum::classname::evidenceImpact(NodeId target, const NodeSet& evs);
105 %ignore gum::classname::evidenceImpact(const std::string& target, const std::vector<std::string>& evs);
106 
107 // these void class extensions are rewritten by "shadow" declarations
108 %extend gum::classname {
setEvidence(PyObject * evidces)109     void setEvidence(PyObject *evidces) {}
updateEvidence(PyObject * evidces)110     void updateEvidence(PyObject *evidces) {}
setTargets(PyObject * targets)111     void setTargets(PyObject* targets) {}
112 
hardEvidenceNodes()113     PyObject* hardEvidenceNodes() {
114       return PyAgrumHelper::PySetFromNodeSet(self->hardEvidenceNodes() ) ;
115     }
softEvidenceNodes()116     PyObject* softEvidenceNodes() {
117       return PyAgrumHelper::PySetFromNodeSet(self->softEvidenceNodes() ) ;
118     }
targets()119     PyObject* targets() {
120       return PyAgrumHelper::PySetFromNodeSet(self->targets() );
121     }
evidenceImpact(PyObject * target,PyObject * evs)122     Potential<double> evidenceImpact(PyObject* target,PyObject *evs) {
123       gum::NodeId itarget=PyAgrumHelper::nodeIdFromNameOrIndex(target,self->BN().variableNodeMap());
124       gum::NodeSet soe;
125       PyAgrumHelper::populateNodeSetFromPySequenceOfIntOrString(soe,evs,self->BN().variableNodeMap());
126       return self->evidenceImpact(itarget,soe);
127     }
128 }
129 %enddef
130 
131 IMPROVE_INFERENCE_API(LazyPropagation<double>)
132 IMPROVE_INFERENCE_API(ShaferShenoyInference<double>)
133 IMPROVE_INFERENCE_API(VariableElimination<double>)
134 IMPROVE_INFERENCE_API(GibbsSampling<double>)
135 IMPROVE_INFERENCE_API(ImportanceSampling<double>)
136 IMPROVE_INFERENCE_API(WeightedSampling<double>)
137 IMPROVE_INFERENCE_API(MonteCarloSampling<double>)
138 IMPROVE_INFERENCE_API(LoopyBeliefPropagation<double>)
139 IMPROVE_INFERENCE_API(LoopySamplingInference<double,gum::GibbsSampling>)
140 IMPROVE_INFERENCE_API(LoopySamplingInference<double,gum::ImportanceSampling>)
141 IMPROVE_INFERENCE_API(LoopySamplingInference<double,gum::WeightedSampling>)
142 IMPROVE_INFERENCE_API(LoopySamplingInference<double,gum::MonteCarloSampling>)
143 
144 %define IMPROVE_JOINT_INFERENCE_API(classname)
145 %ignore classname::evidenceJointImpact(const NodeSet& target, const NodeSet& evs);
146 %ignore classname::jointMutualInformation(const NodeSet &targets);
147 %extend classname {
jointMutualInformation(PyObject * targets)148     double jointMutualInformation(PyObject* targets) {
149       gum::NodeSet sot;
150       PyAgrumHelper::populateNodeSetFromPySequenceOfIntOrString(sot,targets,self->BN().variableNodeMap());
151       return self->jointMutualInformation(sot);
152     }
evidenceJointImpact(PyObject * targets,PyObject * evs)153     Potential<double> evidenceJointImpact(PyObject* targets,PyObject *evs) {
154       gum::NodeSet sot;
155       gum::NodeSet soe;
156       PyAgrumHelper::populateNodeSetFromPySequenceOfIntOrString(sot,targets,self->BN().variableNodeMap());
157       PyAgrumHelper::populateNodeSetFromPySequenceOfIntOrString(soe,evs,self->BN().variableNodeMap());
158       return self->evidenceJointImpact(sot,soe);
159     }
jointPosterior(PyObject * targets)160     Potential<double> jointPosterior(PyObject *targets) {
161       if (! PyAnySet_Check(targets)) {
162         GUM_ERROR(gum::InvalidArgument,"The argument must be a set")
163       }
164       gum::NodeSet nodeset;
165       PyAgrumHelper::populateNodeSetFromPySequenceOfIntOrString(nodeset,targets,self->BN().variableNodeMap());
166       return self->jointPosterior(nodeset);
167     };
168 
addJointTarget(PyObject * targets)169     void addJointTarget( PyObject* targets ) {
170       if (! PyAnySet_Check(targets)) {
171         GUM_ERROR(gum::InvalidArgument,"The argument must be a set")
172       }
173       gum::NodeSet nodeset;
174       PyAgrumHelper::populateNodeSetFromPySequenceOfIntOrString(nodeset,targets,self->BN().variableNodeMap());
175 
176       self->gum::JointTargetedInference<double>::addJointTarget(nodeset);
177     }
eraseJointTarget(PyObject * targets)178     void eraseJointTarget( PyObject* targets ) {
179       if (! PyAnySet_Check(targets)) {
180         GUM_ERROR(gum::InvalidArgument,"The argument must be a set")
181       }
182       gum::NodeSet nodeset;
183       PyAgrumHelper::populateNodeSetFromPySequenceOfIntOrString(nodeset,targets,self->BN().variableNodeMap());
184       self->gum::JointTargetedInference<double>::eraseJointTarget(nodeset);
185     }
186 
isJointTarget(PyObject * targets)187     bool isJointTarget( PyObject* targets ) {
188       if (! PyAnySet_Check(targets)) {
189         GUM_ERROR(gum::InvalidArgument,"The argument must be a set")
190       }
191       gum::NodeSet nodeset;
192       PyAgrumHelper::populateNodeSetFromPySequenceOfIntOrString(nodeset,targets,self->BN().variableNodeMap());
193       return self->gum::JointTargetedInference<double>::isJointTarget(nodeset);
194     }
195 
jointTargets()196     PyObject* jointTargets() const {
197       PyObject* q = PyList_New( 0 );
198 
199       for ( auto ns : self->JointTargetedInference<double>::jointTargets()) {
200         PyList_Append( q,PyAgrumHelper::PySetFromNodeSet(ns));
201       }
202       return q;
203     }
204 }
205 %enddef
206 IMPROVE_JOINT_INFERENCE_API(gum::LazyPropagation<double>)
207 IMPROVE_JOINT_INFERENCE_API(gum::ShaferShenoyInference<double>)
208 IMPROVE_JOINT_INFERENCE_API(gum::VariableElimination<double>)
209 
210 
211 // create a reference to python BN into python inference
212 %define IMPROVE_EXACT_INFERENCE_API(classname)
213 %pythonappend gum::classname<double>::classname %{
214   self._model=args[0]
215 %}
216 %pythonappend gum::classname<double>::junctionTree %{
217    val._engine=self
218 %}
219 %enddef
220 IMPROVE_EXACT_INFERENCE_API(LazyPropagation)
221 IMPROVE_EXACT_INFERENCE_API(ShaferShenoyInference)
222 IMPROVE_EXACT_INFERENCE_API(VariableElimination)
223 
224 
225 %define IMPROVE_APPROX_INFERENCE_API(constructor,classname...)
226 %pythonappend gum::classname::constructor %{
227   self._model=bn  #BN
228 %}
229 %enddef
230 IMPROVE_APPROX_INFERENCE_API(GibbsSampling,GibbsSampling<double>)
231 IMPROVE_APPROX_INFERENCE_API(ImportanceSampling,ImportanceSampling<double>)
232 IMPROVE_APPROX_INFERENCE_API(WeightedSampling,WeightedSampling<double>)
233 IMPROVE_APPROX_INFERENCE_API(MonteCarloSampling,MonteCarloSampling<double>)
234 IMPROVE_APPROX_INFERENCE_API(LoopyBeliefPropagation,LoopyBeliefPropagation<double>)
235 IMPROVE_APPROX_INFERENCE_API(LoopySamplingInference,LoopySamplingInference<double,gum::GibbsSampling>)
236 IMPROVE_APPROX_INFERENCE_API(LoopySamplingInference,LoopySamplingInference<double,gum::ImportanceSampling>)
237 IMPROVE_APPROX_INFERENCE_API(LoopySamplingInference,LoopySamplingInference<double,gum::WeightedSampling>)
238 IMPROVE_APPROX_INFERENCE_API(LoopySamplingInference,LoopySamplingInference<double,gum::MonteCarloSampling>)
239