1 /** 2 * 3 * Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(_at_LIP6) & Christophe GONZALES(_at_AMU) 4 * info_at_agrum_dot_org 5 * 6 * This library is free software: you can redistribute it and/or modify 7 * it under the terms of the GNU Lesser General Public License as published by 8 * the Free Software Foundation, either version 3 of the License, or 9 * (at your option) any later version. 10 * 11 * This library is distributed in the hope that it will be useful, 12 * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 * GNU Lesser General Public License for more details. 15 * 16 * You should have received a copy of the GNU Lesser General Public License 17 * along with this library. If not, see <http://www.gnu.org/licenses/>. 18 * 19 */ 20 21 22 #include <sstream> 23 24 #include <gumtest/AgrumTestSuite.h> 25 #include <gumtest/testsuite_utils.h> 26 27 #include <agrum/tools/multidim/ICIModels/multiDimNoisyORNet.h> 28 #include <agrum/tools/multidim/potential.h> 29 #include <agrum/tools/variables/labelizedVariable.h> 30 #include <agrum/tools/variables/rangeVariable.h> 31 32 namespace gum_tests { 33 34 class MultiDimNoisyORNetTestSuite: public CxxTest::TestSuite { 35 public: testCreationNoisyOr()36 void testCreationNoisyOr() { 37 gum::LabelizedVariable a("a", "", 2), b("b", "", 2), c("c", "", 2), d("d", "", 2); 38 gum::MultiDimNoisyORNet< double > p(0.2f); 39 40 // trying to change weight for a non cause 41 TS_ASSERT_THROWS(p.causalWeight(b, 0.4f), gum::InvalidArgument) 42 TS_ASSERT_THROWS(p.causalWeight(d, 0.0f), gum::InvalidArgument) 43 44 // adding causes 45 TS_GUM_ASSERT_THROWS_NOTHING(p << a << b << c << d); 46 47 // trying to set 0 for causal weight 48 TS_ASSERT_THROWS(p.causalWeight(d, 0.0f), gum::OutOfBounds) 49 50 // doing the right stuff :) 51 p.causalWeight(b, 0.4f); 52 p.causalWeight(d, 0.7f); 53 54 TS_ASSERT_EQUALS(p.toString(), 55 "a:Labelized(<0,1>)=noisyORNet([0.2],b:Labelized(<0,1>)[0.4]c:Labelized(<0," 56 "1>)[1]d:Labelized(<0,1>)[0.7])"); 57 TS_ASSERT_EQUALS(p.realSize(), (gum::Size)4) 58 59 gum::MultiDimNoisyORNet< double > q(p); 60 TS_ASSERT_EQUALS(q.toString(), 61 "a:Labelized(<0,1>)=noisyORNet([0.2],b:Labelized(<0,1>)[0.4]c:Labelized(<0," 62 "1>)[1]d:Labelized(<0,1>)[0.7])"); 63 TS_ASSERT_EQUALS(p.realSize(), (gum::Size)4) 64 } 65 testCompatibleWithHardOR()66 void testCompatibleWithHardOR() { 67 gum::LabelizedVariable cold("Cold", "", 2); 68 gum::LabelizedVariable flu("Flu", "", 2); 69 gum::LabelizedVariable malaria("Malaria", "", 2); 70 gum::LabelizedVariable fever("Fever", "", 2); 71 72 gum::MultiDimNoisyORNet< double > p(0.0f); 73 p << fever << malaria << flu << cold; 74 p.causalWeight(cold, 1.0f); 75 p.causalWeight(flu, 1.0f); 76 p.causalWeight(malaria, 1.0f); 77 78 gum::Instantiation i(p); 79 float witness[] = {1.0f, 80 0.0f, 81 0.0f, 82 1.0f, 83 0.0f, 84 1.0f, 85 0.0f, 86 1.0f, 87 0.0f, 88 1.0f, 89 0.0f, 90 1.0f, 91 0.0f, 92 1.0f, 93 0.0f, 94 1.0f}; 95 96 int j = 0; 97 98 for (i.setFirst(); !i.end(); ++i, j++) { 99 TS_ASSERT_DELTA(p[i], witness[j], 1e-6) 100 } 101 } 102 testComputationInNoisyORNet()103 void testComputationInNoisyORNet() { 104 gum::LabelizedVariable cold("Cold", "", 2); 105 gum::LabelizedVariable flu("Flu", "", 2); 106 gum::LabelizedVariable malaria("Malaria", "", 2); 107 gum::LabelizedVariable fever("Fever", "", 2); 108 109 gum::MultiDimNoisyORNet< double > p(0.0f); 110 p << fever << malaria << flu << cold; 111 p.causalWeight(cold, 0.4f); 112 p.causalWeight(flu, 0.8f); 113 p.causalWeight(malaria, 0.9f); 114 115 gum::Instantiation i(p); 116 float witness[] = {1, 117 0, 118 0.1f, 119 0.9f, 120 0.2f, 121 0.8f, 122 0.02f, 123 0.98f, 124 0.6f, 125 0.4f, 126 0.06f, 127 0.94f, 128 0.12f, 129 0.88f, 130 0.012f, 131 0.988f}; 132 133 int j = 0; 134 135 for (i.setFirst(); !i.end(); ++i, j++) { 136 TS_ASSERT_DELTA(p[i], witness[j], 1e-6) 137 } 138 139 gum::MultiDimNoisyORNet< double > q(p); 140 141 j = 0; 142 143 for (i.setFirst(); !i.end(); ++i, j++) { 144 TS_ASSERT_DELTA(q[i], witness[j], 1e-6) 145 } 146 } 147 testComputationInNoisyORNet2()148 void testComputationInNoisyORNet2() { 149 gum::LabelizedVariable lazy("lazy", "", 2); 150 gum::LabelizedVariable degree("degree", "", 2); 151 gum::LabelizedVariable motivation("motivation", "", 2); 152 gum::LabelizedVariable requirement("requirement", "", 2); 153 gum::LabelizedVariable competition("competition", "", 2); 154 gum::LabelizedVariable unemployment("unemployment", "", 2); 155 156 gum::MultiDimNoisyORNet< double > p(0.0001f); 157 p << unemployment << competition << requirement << motivation << degree << lazy; 158 p.causalWeight(lazy, 0.1f); 159 p.causalWeight(degree, 0.3f); 160 p.causalWeight(motivation, 0.5f); 161 p.causalWeight(requirement, 0.7f); 162 p.causalWeight(competition, 0.9f); 163 164 gum::Instantiation i(p); 165 float witness[] 166 = {0.9999f, 0.0001f, 0.09999f, 0.90001f, 0.29997f, 0.70003f, 0.029997f, 0.970003f, 167 0.49995f, 0.50005f, 0.049995f, 0.950005f, 0.149985f, 0.850015f, 0.014999f, 0.985002f, 168 0.69993f, 0.30007f, 0.069993f, 0.930007f, 0.209979f, 0.790021f, 0.020998f, 0.979002f, 169 0.349965f, 0.650035f, 0.034997f, 0.965004f, 0.104990f, 0.895011f, 0.010499f, 0.989501f, 170 0.89991f, 0.10009f, 0.089991f, 0.910009f, 0.269973f, 0.730027f, 0.026997f, 0.973003f, 171 0.449955f, 0.550045f, 0.044996f, 0.955005f, 0.134987f, 0.865014f, 0.013499f, 0.986501f, 172 0.629937f, 0.370063f, 0.062994f, 0.937006f, 0.188981f, 0.811019f, 0.018898f, 0.981101f, 173 0.314969f, 0.685032f, 0.031497f, 0.968503f, 0.094491f, 0.905509f, 0.009449f, 0.990551f}; 174 175 int j = 0; 176 177 for (i.setFirst(); !i.end(); ++i, j++) { 178 TS_ASSERT_DELTA(p[i], witness[j], 1e-6) 179 } 180 181 gum::MultiDimNoisyORNet< double > q(p); 182 183 j = 0; 184 185 for (i.setFirst(); !i.end(); ++i, j++) { 186 TS_ASSERT_DELTA(q[i], witness[j], 1e-6) 187 } 188 } 189 }; 190 } // namespace gum_tests 191