1# -*- encoding: UTF-8 -*- 2import unittest 3 4import numpy 5import math 6import pyAgrum as gum 7from pyAgrumTestSuite import pyAgrumTestCase, addTests 8 9 10class BayesNetFragmentTestCase(pyAgrumTestCase): 11 def fill(self): 12 return gum.fastBN("v1;v2;v3;v4;v5[3];v6[3]<-v3<-v1->v4<-v2->v5<-v3;v4->v5") 13 14 def fill2(self, bn1): 15 bn2 = gum.fastBN("v1;v2;v3;v4;v5[3];v6[3]<-v3<-v1->v4<-v2->v5<-v3") 16 for n in bn1.names(): 17 if n != "v5": 18 bn2.cpt(n).fillWith(bn1.cpt(n)) 19 return bn2 20 21 def testCreation(self): 22 bn = self.fill() 23 frag = gum.BayesNetFragment(bn) 24 frag2 = gum.BayesNetFragment(frag) 25 26 def testInstallNodes(self): 27 bn = self.fill() 28 frag = gum.BayesNetFragment(bn) 29 30 self.assertEqual(frag.size(), 0) 31 frag.installNode("v1") 32 self.assertEqual(frag.size(), 1) 33 self.assertTrue(not frag.empty()) 34 35 frag.installNode("v1") # once again : no effect 36 self.assertEqual(frag.size(), 1) 37 self.assertEqual(frag.sizeArcs(), 0) 38 39 frag.installNode("v6") # second node, without arc v2-v6 40 self.assertEqual(frag.size(), 2) 41 self.assertEqual(frag.sizeArcs(), 0) 42 43 frag.installNode("v3") # third node, without arc v2-v3-v6 44 self.assertEqual(frag.size(), 3) 45 self.assertEqual(frag.sizeArcs(), 2) 46 47 frag.installAscendants("v6") # nothing should happen 48 self.assertEqual(frag.size(), 3) 49 self.assertEqual(frag.sizeArcs(), 2) 50 51 frag.installAscendants("v5") 52 self.assertEqual(frag.size(), 6) 53 self.assertEqual(frag.sizeArcs(), 7) 54 55 frag2 = gum.BayesNetFragment(bn) 56 frag2.installAscendants("v5") 57 self.assertEqual(frag2.size(), 5) 58 self.assertEqual(frag2.sizeArcs(), 6) 59 60 def testUninstallNode(self): 61 bn = self.fill() 62 frag = gum.BayesNetFragment(bn) 63 frag.installAscendants("v6") 64 self.assertEqual(frag.size(), 3) 65 self.assertEqual(frag.sizeArcs(), 2) 66 67 frag.uninstallNode("v3") 68 self.assertEqual(frag.size(), 2) 69 self.assertEqual(frag.sizeArcs(), 0) 70 71 def testBayesNetMethods(self): 72 bn = self.fill() 73 frag = gum.BayesNetFragment(bn) 74 75 self.assertTrue(frag.empty()) 76 frag.installNode("v1") 77 self.assertTrue(not frag.empty()) 78 frag.installNode("v6") 79 80 self.assertEqual(frag.dag().sizeArcs(), 0) 81 self.assertEqual(frag.size(), 2) 82 self.assertEqual(frag.dim(), (3 - 1) + (2 - 1)) 83 self.assertEqual(pow(10, frag.log10DomainSize()), 2 * 3) 84 85 frag.installAscendants("v6") 86 87 self.assertEqual(frag.dag().sizeArcs(), 2) 88 self.assertEqual(frag.size(), 3) 89 self.assertEqual(frag.dim(), (2 * (3 - 1)) + (2 * (2 - 1)) + (2 - 1)) 90 self.assertAlmostEqual( 91 pow(10, frag.log10DomainSize()), 2 * 2 * 3, delta=1e-5) 92 93 I = frag.completeInstantiation() 94 I.setFirst() 95 self.assertEqual(I.__str__(), "<v1:0|v3:0|v6:0>") 96 97 while not I.end(): 98 p = bn.cpt("v1").get(I) * bn.cpt("v3").get(I) * bn.cpt('v6').get(I) 99 self.assertAlmostEqual(frag.jointProbability(I), p, 1e-5) 100 self.assertAlmostEqual(frag.log2JointProbability(I), 101 math.log(p, 2), delta=1e-5) 102 I.inc() 103 104 def testRelevantReasonning(self): 105 # an inference for all the bn with an hard evidence and an inference for 106 # the right fragment with a local CPT should be the same 107 bn = self.fill() 108 inf_complete = gum.LazyPropagation(bn) 109 inf_complete.setEvidence({"v3": 1}) 110 inf_complete.makeInference() 111 p = inf_complete.posterior("v6") 112 113 frag = gum.BayesNetFragment(bn) 114 frag.installAscendants("v6") 115 marg = gum.Potential().add(frag.variable("v3")) 116 marg.fillWith([0, 1]) 117 frag.installMarginal("v3", marg) 118 self.assertEqual(frag.size(), 3) 119 self.assertEqual(frag.sizeArcs(), 1) 120 inf_frag = gum.LazyPropagation(frag) 121 inf_frag.makeInference() 122 123 for x1, x2 in zip(inf_complete.posterior("v6").tolist(), 124 inf_frag.posterior("v6").tolist()): 125 self.assertAlmostEqual(x1, x2, delta=1e-5) 126 127 def testInstallCPTs(self): 128 bn = self.fill() 129 frag = gum.BayesNetFragment(bn) 130 frag.installAscendants("v6") 131 self.assertEqual(frag.size(), 3) 132 self.assertEqual(frag.sizeArcs(), 2) 133 for nod in frag.nodes(): 134 self.assertTrue(frag.checkConsistency(nod)) 135 self.assertTrue(frag.checkConsistency()) 136 137 frag.installNode("v5") 138 # 1->3->6 et 3->5 but 5 does not have all this parents (2,3 et 4) 139 with self.assertRaises(gum.NotFound): 140 v = frag.variable("v4").name() 141 with self.assertRaises(gum.NotFound): 142 v = frag.variable(bn.idFromName("v2")).name() 143 self.assertEqual(frag.size(), 4) 144 self.assertEqual(frag.sizeArcs(), 3) 145 self.assertTrue(not frag.checkConsistency()) 146 self.assertTrue(not frag.checkConsistency("v5")) 147 for nod in frag.nodes(): 148 if frag.variable(nod).name() != "v5": 149 self.assertTrue(frag.checkConsistency(nod)) 150 151 newV5 = gum.Potential().add(frag.variable("v5")) 152 newV5.fillWith([0, 0, 1]) 153 frag.installMarginal("v5", newV5) 154 for nod in frag.nodes(): 155 self.assertTrue(frag.checkConsistency(nod)) 156 self.assertTrue(frag.checkConsistency()) 157 self.assertEqual(frag.size(), 4) 158 self.assertEqual(frag.sizeArcs(), 2) 159 160 frag.installAscendants("v4") 161 self.assertTrue(not frag.checkConsistency()) 162 self.assertEqual(frag.size(), 6) 163 self.assertEqual(frag.sizeArcs(), 6) 164 165 frag.uninstallCPT("v5") 166 for nod in frag.nodes(): 167 self.assertTrue(frag.checkConsistency(nod)) 168 self.assertTrue(frag.checkConsistency()) 169 self.assertEqual(frag.size(), 6) 170 self.assertEqual(frag.sizeArcs(), 7) 171 172 frag.uninstallNode("v4") 173 self.assertTrue(not frag.checkConsistency()) 174 self.assertEqual(frag.size(), 5) 175 self.assertEqual(frag.sizeArcs(), 4) 176 177 newV5bis = gum.Potential().add(frag.variable("v5")).add( 178 frag.variable("v2")).add(frag.variable("v3")) 179 frag.installCPT("v5", newV5bis) 180 self.assertTrue(frag.checkConsistency()) 181 self.assertEqual(frag.size(), 5) 182 self.assertEqual(frag.sizeArcs(), 4) 183 184 def testInferenceWithLocalsCPT(self): 185 bn = self.fill() 186 bn2 = self.fill2(bn) 187 frag = gum.BayesNetFragment(bn) 188 for i in bn.nodes(): 189 frag.installNode(i) 190 self.assertTrue(frag.checkConsistency()) 191 self.assertEqual(frag.size(), 6) 192 self.assertEqual(frag.sizeArcs(), 7) 193 194 newV5 = gum.Potential().add(frag.variable("v5")).add( 195 frag.variable("v2")).add(frag.variable("v3")) 196 newV5.fillWith(bn2.cpt("v5")) 197 frag.installCPT("v5", newV5) 198 self.assertTrue(frag.checkConsistency()) 199 self.assertEqual(frag.size(), 6) 200 self.assertEqual(frag.sizeArcs(), 6) 201 202 ie2 = gum.LazyPropagation(bn2) 203 ie2.makeInference() 204 ie = gum.LazyPropagation(frag) 205 ie.makeInference() 206 207 for n in frag.names(): 208 for x1, x2 in zip(ie2.posterior(n).tolist(), ie.posterior(n).tolist()): 209 self.assertAlmostEqual( 210 x1, x2, delta=1e-5, msg="For variable '{}'".format(n)) 211 212 def testCopyToBN(self): 213 bn = gum.fastBN("A->B->C->D;E<-C<-F") 214 self.assertEqual(repr(bn.cpt("B").variable(1)), repr(bn.variable("A"))) 215 216 frag = gum.BayesNetFragment(bn) 217 218 frag.installNode("B") 219 self.assertFalse(frag.checkConsistency()) 220 with self.assertRaises(gum.OperationNotAllowed): 221 minibn = frag.toBN() 222 223 # checking if the nodes are well copied and referenced in frag and then in 224 # minibn checking if the potential are well copied 225 frag.installNode("A") 226 self.assertTrue(frag.checkConsistency()) 227 self.assertEqual(repr(bn.variable("A")), repr(frag.variable("A"))) 228 self.assertEqual(repr(bn.variable("B")), repr(frag.variable("B"))) 229 self.assertEqual(str(bn.cpt("A")), str(frag.cpt("A"))) 230 self.assertEqual(str(bn.cpt("B")), str(frag.cpt("B"))) 231 self.assertEqual(repr(frag.cpt("B").variable(1)), repr(bn.variable("A"))) 232 self.assertEqual(repr(frag.cpt("B").variable(1)), repr(frag.variable("A"))) 233 234 minibn = frag.toBN() 235 self.assertEqual(minibn.size(), 2) 236 self.assertEqual(minibn.sizeArcs(), 1) 237 self.assertNotEqual(repr(bn.variable("A")), repr(minibn.variable("A"))) 238 self.assertNotEqual(repr(bn.variable("B")), repr(minibn.variable("B"))) 239 self.assertEqual(str(bn.cpt("A")), str(minibn.cpt("A"))) 240 self.assertEqual(str(bn.cpt("B")), str(minibn.cpt("B"))) 241 self.assertEqual(repr(minibn.cpt("B").variable(1)), repr(minibn.variable("A"))) 242 self.assertNotEqual(repr(minibn.cpt("B").variable(1)), repr(frag.variable("A"))) 243 244 245ts = unittest.TestSuite() 246addTests(ts, BayesNetFragmentTestCase) 247