1# -*- encoding: UTF-8 -*-
2import unittest
3
4import numpy
5import math
6import pyAgrum as gum
7from pyAgrumTestSuite import pyAgrumTestCase, addTests
8
9
10class BayesNetFragmentTestCase(pyAgrumTestCase):
11  def fill(self):
12    return gum.fastBN("v1;v2;v3;v4;v5[3];v6[3]<-v3<-v1->v4<-v2->v5<-v3;v4->v5")
13
14  def fill2(self, bn1):
15    bn2 = gum.fastBN("v1;v2;v3;v4;v5[3];v6[3]<-v3<-v1->v4<-v2->v5<-v3")
16    for n in bn1.names():
17      if n != "v5":
18        bn2.cpt(n).fillWith(bn1.cpt(n))
19    return bn2
20
21  def testCreation(self):
22    bn = self.fill()
23    frag = gum.BayesNetFragment(bn)
24    frag2 = gum.BayesNetFragment(frag)
25
26  def testInstallNodes(self):
27    bn = self.fill()
28    frag = gum.BayesNetFragment(bn)
29
30    self.assertEqual(frag.size(), 0)
31    frag.installNode("v1")
32    self.assertEqual(frag.size(), 1)
33    self.assertTrue(not frag.empty())
34
35    frag.installNode("v1")  # once again : no effect
36    self.assertEqual(frag.size(), 1)
37    self.assertEqual(frag.sizeArcs(), 0)
38
39    frag.installNode("v6")  # second node, without arc v2-v6
40    self.assertEqual(frag.size(), 2)
41    self.assertEqual(frag.sizeArcs(), 0)
42
43    frag.installNode("v3")  # third node, without arc v2-v3-v6
44    self.assertEqual(frag.size(), 3)
45    self.assertEqual(frag.sizeArcs(), 2)
46
47    frag.installAscendants("v6")  # nothing should happen
48    self.assertEqual(frag.size(), 3)
49    self.assertEqual(frag.sizeArcs(), 2)
50
51    frag.installAscendants("v5")
52    self.assertEqual(frag.size(), 6)
53    self.assertEqual(frag.sizeArcs(), 7)
54
55    frag2 = gum.BayesNetFragment(bn)
56    frag2.installAscendants("v5")
57    self.assertEqual(frag2.size(), 5)
58    self.assertEqual(frag2.sizeArcs(), 6)
59
60  def testUninstallNode(self):
61    bn = self.fill()
62    frag = gum.BayesNetFragment(bn)
63    frag.installAscendants("v6")
64    self.assertEqual(frag.size(), 3)
65    self.assertEqual(frag.sizeArcs(), 2)
66
67    frag.uninstallNode("v3")
68    self.assertEqual(frag.size(), 2)
69    self.assertEqual(frag.sizeArcs(), 0)
70
71  def testBayesNetMethods(self):
72    bn = self.fill()
73    frag = gum.BayesNetFragment(bn)
74
75    self.assertTrue(frag.empty())
76    frag.installNode("v1")
77    self.assertTrue(not frag.empty())
78    frag.installNode("v6")
79
80    self.assertEqual(frag.dag().sizeArcs(), 0)
81    self.assertEqual(frag.size(), 2)
82    self.assertEqual(frag.dim(), (3 - 1) + (2 - 1))
83    self.assertEqual(pow(10, frag.log10DomainSize()), 2 * 3)
84
85    frag.installAscendants("v6")
86
87    self.assertEqual(frag.dag().sizeArcs(), 2)
88    self.assertEqual(frag.size(), 3)
89    self.assertEqual(frag.dim(), (2 * (3 - 1)) + (2 * (2 - 1)) + (2 - 1))
90    self.assertAlmostEqual(
91      pow(10, frag.log10DomainSize()), 2 * 2 * 3, delta=1e-5)
92
93    I = frag.completeInstantiation()
94    I.setFirst()
95    self.assertEqual(I.__str__(), "<v1:0|v3:0|v6:0>")
96
97    while not I.end():
98      p = bn.cpt("v1").get(I) * bn.cpt("v3").get(I) * bn.cpt('v6').get(I)
99      self.assertAlmostEqual(frag.jointProbability(I), p, 1e-5)
100      self.assertAlmostEqual(frag.log2JointProbability(I),
101                             math.log(p, 2), delta=1e-5)
102      I.inc()
103
104  def testRelevantReasonning(self):
105    # an inference for all the bn with an hard evidence and an inference for
106    # the right fragment with a local CPT should be the same
107    bn = self.fill()
108    inf_complete = gum.LazyPropagation(bn)
109    inf_complete.setEvidence({"v3": 1})
110    inf_complete.makeInference()
111    p = inf_complete.posterior("v6")
112
113    frag = gum.BayesNetFragment(bn)
114    frag.installAscendants("v6")
115    marg = gum.Potential().add(frag.variable("v3"))
116    marg.fillWith([0, 1])
117    frag.installMarginal("v3", marg)
118    self.assertEqual(frag.size(), 3)
119    self.assertEqual(frag.sizeArcs(), 1)
120    inf_frag = gum.LazyPropagation(frag)
121    inf_frag.makeInference()
122
123    for x1, x2 in zip(inf_complete.posterior("v6").tolist(),
124                      inf_frag.posterior("v6").tolist()):
125      self.assertAlmostEqual(x1, x2, delta=1e-5)
126
127  def testInstallCPTs(self):
128    bn = self.fill()
129    frag = gum.BayesNetFragment(bn)
130    frag.installAscendants("v6")
131    self.assertEqual(frag.size(), 3)
132    self.assertEqual(frag.sizeArcs(), 2)
133    for nod in frag.nodes():
134      self.assertTrue(frag.checkConsistency(nod))
135    self.assertTrue(frag.checkConsistency())
136
137    frag.installNode("v5")
138    # 1->3->6 et 3->5 but 5 does not have all this parents (2,3 et 4)
139    with self.assertRaises(gum.NotFound):
140      v = frag.variable("v4").name()
141    with self.assertRaises(gum.NotFound):
142      v = frag.variable(bn.idFromName("v2")).name()
143    self.assertEqual(frag.size(), 4)
144    self.assertEqual(frag.sizeArcs(), 3)
145    self.assertTrue(not frag.checkConsistency())
146    self.assertTrue(not frag.checkConsistency("v5"))
147    for nod in frag.nodes():
148      if frag.variable(nod).name() != "v5":
149        self.assertTrue(frag.checkConsistency(nod))
150
151    newV5 = gum.Potential().add(frag.variable("v5"))
152    newV5.fillWith([0, 0, 1])
153    frag.installMarginal("v5", newV5)
154    for nod in frag.nodes():
155      self.assertTrue(frag.checkConsistency(nod))
156    self.assertTrue(frag.checkConsistency())
157    self.assertEqual(frag.size(), 4)
158    self.assertEqual(frag.sizeArcs(), 2)
159
160    frag.installAscendants("v4")
161    self.assertTrue(not frag.checkConsistency())
162    self.assertEqual(frag.size(), 6)
163    self.assertEqual(frag.sizeArcs(), 6)
164
165    frag.uninstallCPT("v5")
166    for nod in frag.nodes():
167      self.assertTrue(frag.checkConsistency(nod))
168    self.assertTrue(frag.checkConsistency())
169    self.assertEqual(frag.size(), 6)
170    self.assertEqual(frag.sizeArcs(), 7)
171
172    frag.uninstallNode("v4")
173    self.assertTrue(not frag.checkConsistency())
174    self.assertEqual(frag.size(), 5)
175    self.assertEqual(frag.sizeArcs(), 4)
176
177    newV5bis = gum.Potential().add(frag.variable("v5")).add(
178      frag.variable("v2")).add(frag.variable("v3"))
179    frag.installCPT("v5", newV5bis)
180    self.assertTrue(frag.checkConsistency())
181    self.assertEqual(frag.size(), 5)
182    self.assertEqual(frag.sizeArcs(), 4)
183
184  def testInferenceWithLocalsCPT(self):
185    bn = self.fill()
186    bn2 = self.fill2(bn)
187    frag = gum.BayesNetFragment(bn)
188    for i in bn.nodes():
189      frag.installNode(i)
190    self.assertTrue(frag.checkConsistency())
191    self.assertEqual(frag.size(), 6)
192    self.assertEqual(frag.sizeArcs(), 7)
193
194    newV5 = gum.Potential().add(frag.variable("v5")).add(
195      frag.variable("v2")).add(frag.variable("v3"))
196    newV5.fillWith(bn2.cpt("v5"))
197    frag.installCPT("v5", newV5)
198    self.assertTrue(frag.checkConsistency())
199    self.assertEqual(frag.size(), 6)
200    self.assertEqual(frag.sizeArcs(), 6)
201
202    ie2 = gum.LazyPropagation(bn2)
203    ie2.makeInference()
204    ie = gum.LazyPropagation(frag)
205    ie.makeInference()
206
207    for n in frag.names():
208      for x1, x2 in zip(ie2.posterior(n).tolist(), ie.posterior(n).tolist()):
209        self.assertAlmostEqual(
210          x1, x2, delta=1e-5, msg="For variable '{}'".format(n))
211
212  def testCopyToBN(self):
213    bn = gum.fastBN("A->B->C->D;E<-C<-F")
214    self.assertEqual(repr(bn.cpt("B").variable(1)), repr(bn.variable("A")))
215
216    frag = gum.BayesNetFragment(bn)
217
218    frag.installNode("B")
219    self.assertFalse(frag.checkConsistency())
220    with self.assertRaises(gum.OperationNotAllowed):
221      minibn = frag.toBN()
222
223    # checking if the nodes are well copied and referenced in frag and then in
224    # minibn checking if the potential are well copied
225    frag.installNode("A")
226    self.assertTrue(frag.checkConsistency())
227    self.assertEqual(repr(bn.variable("A")), repr(frag.variable("A")))
228    self.assertEqual(repr(bn.variable("B")), repr(frag.variable("B")))
229    self.assertEqual(str(bn.cpt("A")), str(frag.cpt("A")))
230    self.assertEqual(str(bn.cpt("B")), str(frag.cpt("B")))
231    self.assertEqual(repr(frag.cpt("B").variable(1)), repr(bn.variable("A")))
232    self.assertEqual(repr(frag.cpt("B").variable(1)), repr(frag.variable("A")))
233
234    minibn = frag.toBN()
235    self.assertEqual(minibn.size(), 2)
236    self.assertEqual(minibn.sizeArcs(), 1)
237    self.assertNotEqual(repr(bn.variable("A")), repr(minibn.variable("A")))
238    self.assertNotEqual(repr(bn.variable("B")), repr(minibn.variable("B")))
239    self.assertEqual(str(bn.cpt("A")), str(minibn.cpt("A")))
240    self.assertEqual(str(bn.cpt("B")), str(minibn.cpt("B")))
241    self.assertEqual(repr(minibn.cpt("B").variable(1)), repr(minibn.variable("A")))
242    self.assertNotEqual(repr(minibn.cpt("B").variable(1)), repr(frag.variable("A")))
243
244
245ts = unittest.TestSuite()
246addTests(ts, BayesNetFragmentTestCase)
247