1 /**
2  *
3  *   Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(_at_LIP6) & Christophe GONZALES(_at_AMU)
4  *   info_at_agrum_dot_org
5  *
6  *  This library is free software: you can redistribute it and/or modify
7  *  it under the terms of the GNU Lesser General Public License as published by
8  *  the Free Software Foundation, either version 3 of the License, or
9  *  (at your option) any later version.
10  *
11  *  This library is distributed in the hope that it will be useful,
12  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  *  GNU Lesser General Public License for more details.
15  *
16  *  You should have received a copy of the GNU Lesser General Public License
17  *  along with this library.  If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 #include <sstream>
23 
24 #include <gumtest/AgrumTestSuite.h>
25 #include <gumtest/testsuite_utils.h>
26 
27 #include <agrum/tools/multidim/ICIModels/multiDimNoisyORNet.h>
28 #include <agrum/tools/multidim/potential.h>
29 #include <agrum/tools/variables/labelizedVariable.h>
30 #include <agrum/tools/variables/rangeVariable.h>
31 
32 namespace gum_tests {
33 
34   class MultiDimNoisyORNetTestSuite: public CxxTest::TestSuite {
35     public:
testCreationNoisyOr()36     void testCreationNoisyOr() {
37       gum::LabelizedVariable            a("a", "", 2), b("b", "", 2), c("c", "", 2), d("d", "", 2);
38       gum::MultiDimNoisyORNet< double > p(0.2f);
39 
40       // trying to change weight for a non cause
41       TS_ASSERT_THROWS(p.causalWeight(b, 0.4f), gum::InvalidArgument)
42       TS_ASSERT_THROWS(p.causalWeight(d, 0.0f), gum::InvalidArgument)
43 
44       // adding causes
45       TS_GUM_ASSERT_THROWS_NOTHING(p << a << b << c << d);
46 
47       // trying to set 0 for causal weight
48       TS_ASSERT_THROWS(p.causalWeight(d, 0.0f), gum::OutOfBounds)
49 
50       // doing the right stuff :)
51       p.causalWeight(b, 0.4f);
52       p.causalWeight(d, 0.7f);
53 
54       TS_ASSERT_EQUALS(p.toString(),
55                        "a:Labelized(<0,1>)=noisyORNet([0.2],b:Labelized(<0,1>)[0.4]c:Labelized(<0,"
56                        "1>)[1]d:Labelized(<0,1>)[0.7])");
57       TS_ASSERT_EQUALS(p.realSize(), (gum::Size)4)
58 
59       gum::MultiDimNoisyORNet< double > q(p);
60       TS_ASSERT_EQUALS(q.toString(),
61                        "a:Labelized(<0,1>)=noisyORNet([0.2],b:Labelized(<0,1>)[0.4]c:Labelized(<0,"
62                        "1>)[1]d:Labelized(<0,1>)[0.7])");
63       TS_ASSERT_EQUALS(p.realSize(), (gum::Size)4)
64     }
65 
testCompatibleWithHardOR()66     void testCompatibleWithHardOR() {
67       gum::LabelizedVariable cold("Cold", "", 2);
68       gum::LabelizedVariable flu("Flu", "", 2);
69       gum::LabelizedVariable malaria("Malaria", "", 2);
70       gum::LabelizedVariable fever("Fever", "", 2);
71 
72       gum::MultiDimNoisyORNet< double > p(0.0f);
73       p << fever << malaria << flu << cold;
74       p.causalWeight(cold, 1.0f);
75       p.causalWeight(flu, 1.0f);
76       p.causalWeight(malaria, 1.0f);
77 
78       gum::Instantiation i(p);
79       float              witness[] = {1.0f,
80                          0.0f,
81                          0.0f,
82                          1.0f,
83                          0.0f,
84                          1.0f,
85                          0.0f,
86                          1.0f,
87                          0.0f,
88                          1.0f,
89                          0.0f,
90                          1.0f,
91                          0.0f,
92                          1.0f,
93                          0.0f,
94                          1.0f};
95 
96       int j = 0;
97 
98       for (i.setFirst(); !i.end(); ++i, j++) {
99         TS_ASSERT_DELTA(p[i], witness[j], 1e-6)
100       }
101     }
102 
testComputationInNoisyORNet()103     void testComputationInNoisyORNet() {
104       gum::LabelizedVariable cold("Cold", "", 2);
105       gum::LabelizedVariable flu("Flu", "", 2);
106       gum::LabelizedVariable malaria("Malaria", "", 2);
107       gum::LabelizedVariable fever("Fever", "", 2);
108 
109       gum::MultiDimNoisyORNet< double > p(0.0f);
110       p << fever << malaria << flu << cold;
111       p.causalWeight(cold, 0.4f);
112       p.causalWeight(flu, 0.8f);
113       p.causalWeight(malaria, 0.9f);
114 
115       gum::Instantiation i(p);
116       float              witness[] = {1,
117                          0,
118                          0.1f,
119                          0.9f,
120                          0.2f,
121                          0.8f,
122                          0.02f,
123                          0.98f,
124                          0.6f,
125                          0.4f,
126                          0.06f,
127                          0.94f,
128                          0.12f,
129                          0.88f,
130                          0.012f,
131                          0.988f};
132 
133       int j = 0;
134 
135       for (i.setFirst(); !i.end(); ++i, j++) {
136         TS_ASSERT_DELTA(p[i], witness[j], 1e-6)
137       }
138 
139       gum::MultiDimNoisyORNet< double > q(p);
140 
141       j = 0;
142 
143       for (i.setFirst(); !i.end(); ++i, j++) {
144         TS_ASSERT_DELTA(q[i], witness[j], 1e-6)
145       }
146     }
147 
testComputationInNoisyORNet2()148     void testComputationInNoisyORNet2() {
149       gum::LabelizedVariable lazy("lazy", "", 2);
150       gum::LabelizedVariable degree("degree", "", 2);
151       gum::LabelizedVariable motivation("motivation", "", 2);
152       gum::LabelizedVariable requirement("requirement", "", 2);
153       gum::LabelizedVariable competition("competition", "", 2);
154       gum::LabelizedVariable unemployment("unemployment", "", 2);
155 
156       gum::MultiDimNoisyORNet< double > p(0.0001f);
157       p << unemployment << competition << requirement << motivation << degree << lazy;
158       p.causalWeight(lazy, 0.1f);
159       p.causalWeight(degree, 0.3f);
160       p.causalWeight(motivation, 0.5f);
161       p.causalWeight(requirement, 0.7f);
162       p.causalWeight(competition, 0.9f);
163 
164       gum::Instantiation i(p);
165       float              witness[]
166          = {0.9999f,   0.0001f,   0.09999f,  0.90001f,  0.29997f,  0.70003f,  0.029997f, 0.970003f,
167             0.49995f,  0.50005f,  0.049995f, 0.950005f, 0.149985f, 0.850015f, 0.014999f, 0.985002f,
168             0.69993f,  0.30007f,  0.069993f, 0.930007f, 0.209979f, 0.790021f, 0.020998f, 0.979002f,
169             0.349965f, 0.650035f, 0.034997f, 0.965004f, 0.104990f, 0.895011f, 0.010499f, 0.989501f,
170             0.89991f,  0.10009f,  0.089991f, 0.910009f, 0.269973f, 0.730027f, 0.026997f, 0.973003f,
171             0.449955f, 0.550045f, 0.044996f, 0.955005f, 0.134987f, 0.865014f, 0.013499f, 0.986501f,
172             0.629937f, 0.370063f, 0.062994f, 0.937006f, 0.188981f, 0.811019f, 0.018898f, 0.981101f,
173             0.314969f, 0.685032f, 0.031497f, 0.968503f, 0.094491f, 0.905509f, 0.009449f, 0.990551f};
174 
175       int j = 0;
176 
177       for (i.setFirst(); !i.end(); ++i, j++) {
178         TS_ASSERT_DELTA(p[i], witness[j], 1e-6)
179       }
180 
181       gum::MultiDimNoisyORNet< double > q(p);
182 
183       j = 0;
184 
185       for (i.setFirst(); !i.end(); ++i, j++) {
186         TS_ASSERT_DELTA(q[i], witness[j], 1e-6)
187       }
188     }
189   };
190 }   // namespace gum_tests
191