1 /** 2 * 3 * Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) 4 * {prenom.nom}_at_lip6.fr 5 * 6 * This library is free software: you can redistribute it and/or modify 7 * it under the terms of the GNU Lesser General Public License as published by 8 * the Free Software Foundation, either version 3 of the License, or 9 * (at your option) any later version. 10 * 11 * This library is distributed in the hope that it will be useful, 12 * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 * GNU Lesser General Public License for more details. 15 * 16 * You should have received a copy of the GNU Lesser General Public License 17 * along with this library. If not, see <http://www.gnu.org/licenses/>. 18 * 19 */ 20 21 %define IMPROVE_INFERENCE_API(classname...) 22 %feature("shadow") gum::classname::setEvidence %{ 23 def setEvidence(self, evidces): 24 """ 25 Erase all the evidences and apply addEvidence(key,value) for every pairs in evidces. 26 27 Parameters 28 ---------- 29 evidces : dict 30 a dict of evidences 31 32 Raises 33 ------ 34 gum.InvalidArgument 35 If one value is not a value for the node 36 gum.InvalidArgument 37 If the size of a value is different from the domain side of the node 38 gum.FatalError 39 If one value is a vector of 0s 40 gum.UndefinedElement 41 If one node does not belong to the Bayesian network 42 """ 43 if not isinstance(evidces, dict): 44 raise TypeError("setEvidence parameter must be a dict, not %s"%(type(evidces))) 45 self.eraseAllEvidence() 46 for k,v in evidces.items(): 47 self.addEvidence(k,v) 48 %} 49 50 %feature("shadow") gum::classname::updateEvidence %{ 51 def updateEvidence(self, evidces): 52 """ 53 Apply chgEvidence(key,value) for every pairs in evidces (or addEvidence). 54 55 Parameters 56 ---------- 57 evidces : dict 58 a dict of evidences 59 60 Raises 61 ------ 62 gum.InvalidArgument 63 If one value is not a value for the node 64 gum.InvalidArgument 65 If the size of a value is different from the domain side of the node 66 gum.FatalError 67 If one value is a vector of 0s 68 gum.UndefinedElement 69 If one node does not belong to the Bayesian network 70 """ 71 if not isinstance(evidces, dict): 72 raise TypeError("setEvidence parameter must be a dict, not %s"%(type(evidces))) 73 74 for k,v in evidces.items(): 75 if self.hasEvidence(k): 76 self.chgEvidence(k,v) 77 else: 78 self.addEvidence(k,v) 79 %} 80 81 %feature("shadow") gum::classname::setTargets %{ 82 def setTargets(self, targets): 83 """ 84 Remove all the targets and add the ones in parameter. 85 86 Parameters 87 ---------- 88 targets : set 89 a set of targets 90 91 Raises 92 ------ 93 gum.UndefinedElement 94 If one target is not in the Bayes net 95 """ 96 if not isinstance(targets, set): 97 raise TypeError("setTargets parameter must be a set, not %s"%(type(targets))) 98 99 self.eraseAllTargets() 100 for k in targets: 101 self.addTarget(k) 102 %} 103 104 %ignore gum::classname::evidenceImpact(NodeId target, const NodeSet& evs); 105 %ignore gum::classname::evidenceImpact(const std::string& target, const std::vector<std::string>& evs); 106 107 // these void class extensions are rewritten by "shadow" declarations 108 %extend gum::classname { setEvidence(PyObject * evidces)109 void setEvidence(PyObject *evidces) {} updateEvidence(PyObject * evidces)110 void updateEvidence(PyObject *evidces) {} setTargets(PyObject * targets)111 void setTargets(PyObject* targets) {} 112 hardEvidenceNodes()113 PyObject* hardEvidenceNodes() { 114 return PyAgrumHelper::PySetFromNodeSet(self->hardEvidenceNodes() ) ; 115 } softEvidenceNodes()116 PyObject* softEvidenceNodes() { 117 return PyAgrumHelper::PySetFromNodeSet(self->softEvidenceNodes() ) ; 118 } targets()119 PyObject* targets() { 120 return PyAgrumHelper::PySetFromNodeSet(self->targets() ); 121 } evidenceImpact(PyObject * target,PyObject * evs)122 Potential<double> evidenceImpact(PyObject* target,PyObject *evs) { 123 gum::NodeId itarget=PyAgrumHelper::nodeIdFromNameOrIndex(target,self->BN().variableNodeMap()); 124 gum::NodeSet soe; 125 PyAgrumHelper::populateNodeSetFromPySequenceOfIntOrString(soe,evs,self->BN().variableNodeMap()); 126 return self->evidenceImpact(itarget,soe); 127 } 128 } 129 %enddef 130 131 IMPROVE_INFERENCE_API(LazyPropagation<double>) 132 IMPROVE_INFERENCE_API(ShaferShenoyInference<double>) 133 IMPROVE_INFERENCE_API(VariableElimination<double>) 134 IMPROVE_INFERENCE_API(GibbsSampling<double>) 135 IMPROVE_INFERENCE_API(ImportanceSampling<double>) 136 IMPROVE_INFERENCE_API(WeightedSampling<double>) 137 IMPROVE_INFERENCE_API(MonteCarloSampling<double>) 138 IMPROVE_INFERENCE_API(LoopyBeliefPropagation<double>) 139 IMPROVE_INFERENCE_API(LoopySamplingInference<double,gum::GibbsSampling>) 140 IMPROVE_INFERENCE_API(LoopySamplingInference<double,gum::ImportanceSampling>) 141 IMPROVE_INFERENCE_API(LoopySamplingInference<double,gum::WeightedSampling>) 142 IMPROVE_INFERENCE_API(LoopySamplingInference<double,gum::MonteCarloSampling>) 143 144 %define IMPROVE_JOINT_INFERENCE_API(classname) 145 %ignore classname::evidenceJointImpact(const NodeSet& target, const NodeSet& evs); 146 %ignore classname::jointMutualInformation(const NodeSet &targets); 147 %extend classname { jointMutualInformation(PyObject * targets)148 double jointMutualInformation(PyObject* targets) { 149 gum::NodeSet sot; 150 PyAgrumHelper::populateNodeSetFromPySequenceOfIntOrString(sot,targets,self->BN().variableNodeMap()); 151 return self->jointMutualInformation(sot); 152 } evidenceJointImpact(PyObject * targets,PyObject * evs)153 Potential<double> evidenceJointImpact(PyObject* targets,PyObject *evs) { 154 gum::NodeSet sot; 155 gum::NodeSet soe; 156 PyAgrumHelper::populateNodeSetFromPySequenceOfIntOrString(sot,targets,self->BN().variableNodeMap()); 157 PyAgrumHelper::populateNodeSetFromPySequenceOfIntOrString(soe,evs,self->BN().variableNodeMap()); 158 return self->evidenceJointImpact(sot,soe); 159 } jointPosterior(PyObject * targets)160 Potential<double> jointPosterior(PyObject *targets) { 161 if (! PyAnySet_Check(targets)) { 162 GUM_ERROR(gum::InvalidArgument,"The argument must be a set") 163 } 164 gum::NodeSet nodeset; 165 PyAgrumHelper::populateNodeSetFromPySequenceOfIntOrString(nodeset,targets,self->BN().variableNodeMap()); 166 return self->jointPosterior(nodeset); 167 }; 168 addJointTarget(PyObject * targets)169 void addJointTarget( PyObject* targets ) { 170 if (! PyAnySet_Check(targets)) { 171 GUM_ERROR(gum::InvalidArgument,"The argument must be a set") 172 } 173 gum::NodeSet nodeset; 174 PyAgrumHelper::populateNodeSetFromPySequenceOfIntOrString(nodeset,targets,self->BN().variableNodeMap()); 175 176 self->gum::JointTargetedInference<double>::addJointTarget(nodeset); 177 } eraseJointTarget(PyObject * targets)178 void eraseJointTarget( PyObject* targets ) { 179 if (! PyAnySet_Check(targets)) { 180 GUM_ERROR(gum::InvalidArgument,"The argument must be a set") 181 } 182 gum::NodeSet nodeset; 183 PyAgrumHelper::populateNodeSetFromPySequenceOfIntOrString(nodeset,targets,self->BN().variableNodeMap()); 184 self->gum::JointTargetedInference<double>::eraseJointTarget(nodeset); 185 } 186 isJointTarget(PyObject * targets)187 bool isJointTarget( PyObject* targets ) { 188 if (! PyAnySet_Check(targets)) { 189 GUM_ERROR(gum::InvalidArgument,"The argument must be a set") 190 } 191 gum::NodeSet nodeset; 192 PyAgrumHelper::populateNodeSetFromPySequenceOfIntOrString(nodeset,targets,self->BN().variableNodeMap()); 193 return self->gum::JointTargetedInference<double>::isJointTarget(nodeset); 194 } 195 jointTargets()196 PyObject* jointTargets() const { 197 PyObject* q = PyList_New( 0 ); 198 199 for ( auto ns : self->JointTargetedInference<double>::jointTargets()) { 200 PyList_Append( q,PyAgrumHelper::PySetFromNodeSet(ns)); 201 } 202 return q; 203 } 204 } 205 %enddef 206 IMPROVE_JOINT_INFERENCE_API(gum::LazyPropagation<double>) 207 IMPROVE_JOINT_INFERENCE_API(gum::ShaferShenoyInference<double>) 208 IMPROVE_JOINT_INFERENCE_API(gum::VariableElimination<double>) 209 210 211 // create a reference to python BN into python inference 212 %define IMPROVE_EXACT_INFERENCE_API(classname) 213 %pythonappend gum::classname<double>::classname %{ 214 self._model=args[0] 215 %} 216 %pythonappend gum::classname<double>::junctionTree %{ 217 val._engine=self 218 %} 219 %enddef 220 IMPROVE_EXACT_INFERENCE_API(LazyPropagation) 221 IMPROVE_EXACT_INFERENCE_API(ShaferShenoyInference) 222 IMPROVE_EXACT_INFERENCE_API(VariableElimination) 223 224 225 %define IMPROVE_APPROX_INFERENCE_API(constructor,classname...) 226 %pythonappend gum::classname::constructor %{ 227 self._model=bn #BN 228 %} 229 %enddef 230 IMPROVE_APPROX_INFERENCE_API(GibbsSampling,GibbsSampling<double>) 231 IMPROVE_APPROX_INFERENCE_API(ImportanceSampling,ImportanceSampling<double>) 232 IMPROVE_APPROX_INFERENCE_API(WeightedSampling,WeightedSampling<double>) 233 IMPROVE_APPROX_INFERENCE_API(MonteCarloSampling,MonteCarloSampling<double>) 234 IMPROVE_APPROX_INFERENCE_API(LoopyBeliefPropagation,LoopyBeliefPropagation<double>) 235 IMPROVE_APPROX_INFERENCE_API(LoopySamplingInference,LoopySamplingInference<double,gum::GibbsSampling>) 236 IMPROVE_APPROX_INFERENCE_API(LoopySamplingInference,LoopySamplingInference<double,gum::ImportanceSampling>) 237 IMPROVE_APPROX_INFERENCE_API(LoopySamplingInference,LoopySamplingInference<double,gum::WeightedSampling>) 238 IMPROVE_APPROX_INFERENCE_API(LoopySamplingInference,LoopySamplingInference<double,gum::MonteCarloSampling>) 239