1 // Copyright (c) 2010-2021, Lawrence Livermore National Security, LLC. Produced 2 // at the Lawrence Livermore National Laboratory. All Rights reserved. See files 3 // LICENSE and NOTICE for details. LLNL-CODE-806117. 4 // 5 // This file is part of the MFEM library. For more information and source code 6 // availability visit https://mfem.org. 7 // 8 // MFEM is free software; you can redistribute it and/or modify it under the 9 // terms of the BSD-3 license. We welcome feedback and contributions, see file 10 // CONTRIBUTING.md for details. 11 12 #include "mfem.hpp" 13 #include "unit_tests.hpp" 14 #include "linalg/dtensor.hpp" 15 16 using namespace mfem; 17 18 TEST_CASE("DenseMatrix init-list construction", "[DenseMatrix]") 19 { 20 double ContigData[6] = {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}; 21 DenseMatrix Contiguous(ContigData, 2, 3); 22 23 DenseMatrix Nested( 24 { 25 {6.0, 4.0, 2.0}, 26 {5.0, 3.0, 1.0} 27 }); 28 29 for (int i = 0; i < Contiguous.Height(); i++) 30 { 31 for (int j = 0; j < Contiguous.Width(); j++) 32 { 33 REQUIRE(Nested(i,j) == Contiguous(i,j)); 34 } 35 } 36 } 37 38 TEST_CASE("DenseMatrix LinearSolve methods", 39 "[DenseMatrix]") 40 { 41 SECTION("singular_system") 42 { 43 constexpr int N = 3; 44 45 DenseMatrix A(N); 46 A.SetRow(0, 0.0); 47 A.SetRow(1, 0.0); 48 A.SetRow(2, 0.0); 49 50 double X[3]; 51 52 REQUIRE_FALSE(LinearSolve(A,X)); 53 } 54 55 SECTION("1x1_system") 56 { 57 constexpr int N = 1; 58 DenseMatrix A(N); 59 A(0,0) = 2; 60 61 double X[1] = { 12 }; 62 63 REQUIRE(LinearSolve(A,X)); 64 REQUIRE(X[0] == MFEM_Approx(6)); 65 } 66 67 SECTION("2x2_system") 68 { 69 constexpr int N = 2; 70 71 DenseMatrix A(N); 72 A(0,0) = 2.0; A(0,1) = 1.0; 73 A(1,0) = 3.0; A(1,1) = 4.0; 74 75 double X[2] = { 1, 14 }; 76 77 REQUIRE(LinearSolve(A,X)); 78 REQUIRE(X[0] == MFEM_Approx(-2)); 79 REQUIRE(X[1] == MFEM_Approx(5)); 80 } 81 82 SECTION("3x3_system") 83 { 84 constexpr int N = 3; 85 86 DenseMatrix A(N); 87 A(0,0) = 4; A(0,1) = 5; A(0,2) = -2; 88 A(1,0) = 7; A(1,1) = -1; A(1,2) = 2; 89 A(2,0) = 3; A(2,1) = 1; A(2,2) = 4; 90 91 double X[3] = { -14, 42, 28 }; 92 93 REQUIRE(LinearSolve(A,X)); 94 REQUIRE(X[0] == MFEM_Approx(4)); 95 REQUIRE(X[1] == MFEM_Approx(-4)); 96 REQUIRE(X[2] == MFEM_Approx(5)); 97 } 98 99 } 100 101 TEST_CASE("DenseMatrix A*B^T methods", 102 "[DenseMatrix]") 103 { 104 double tol = 1e-12; 105 106 double AtData[6] = {6.0, 5.0, 107 4.0, 3.0, 108 2.0, 1.0 109 }; 110 double BtData[12] = {1.0, 3.0, 5.0, 7.0, 111 2.0, 4.0, 6.0, 8.0, 112 1.0, 2.0, 3.0, 5.0 113 }; 114 115 DenseMatrix A(AtData, 2, 3); 116 DenseMatrix B(BtData, 4, 3); 117 DenseMatrix C(2,4); 118 119 SECTION("MultABt") 120 { 121 double BData[12] = {1.0, 2.0, 1.0, 122 3.0, 4.0, 2.0, 123 5.0, 6.0, 3.0, 124 7.0, 8.0, 5.0 125 }; 126 DenseMatrix Bt(BData, 3, 4); 127 128 double CtData[8] = {16.0, 12.0, 129 38.0, 29.0, 130 60.0, 46.0, 131 84.0, 64.0 132 }; 133 DenseMatrix Cexact(CtData, 2, 4); 134 135 MultABt(A, B, C); 136 C.Add(-1.0, Cexact); 137 138 REQUIRE(C.MaxMaxNorm() < tol); 139 140 Mult(A, Bt, Cexact); 141 MultABt(A, B, C); 142 C.Add(-1.0, Cexact); 143 144 REQUIRE(C.MaxMaxNorm() < tol); 145 } 146 SECTION("MultADBt") 147 { 148 double DData[3] = {11.0, 7.0, 5.0}; 149 Vector D(DData, 3); 150 151 double CtData[8] = {132.0, 102.0, 152 330.0, 259.0, 153 528.0, 416.0, 154 736.0, 578.0 155 }; 156 DenseMatrix Cexact(CtData, 2, 4); 157 158 MultADBt(A, D, B, C); 159 C.Add(-1.0, Cexact); 160 161 REQUIRE(C.MaxMaxNorm() < tol); 162 } 163 SECTION("AddMultABt") 164 { 165 double CtData[8] = {17.0, 17.0, 166 40.0, 35.0, 167 63.0, 53.0, 168 88.0, 72.0 169 }; 170 DenseMatrix Cexact(CtData, 2, 4); 171 172 C(0, 0) = 1.0; C(0, 1) = 2.0; C(0, 2) = 3.0; C(0, 3) = 4.0; 173 C(1, 0) = 5.0; C(1, 1) = 6.0; C(1, 2) = 7.0; C(1, 3) = 8.0; 174 175 AddMultABt(A, B, C); 176 C.Add(-1.0, Cexact); 177 178 REQUIRE(C.MaxMaxNorm() < tol); 179 180 MultABt(A, B, C); 181 C *= -1.0; 182 AddMultABt(A, B, C); 183 REQUIRE(C.MaxMaxNorm() < tol); 184 } 185 SECTION("AddMultADBt") 186 { 187 double DData[3] = {11.0, 7.0, 5.0}; 188 Vector D(DData, 3); 189 190 double CtData[8] = {133.0, 107.0, 191 332.0, 265.0, 192 531.0, 423.0, 193 740.0, 586.0 194 }; 195 DenseMatrix Cexact(CtData, 2, 4); 196 197 C(0, 0) = 1.0; C(0, 1) = 2.0; C(0, 2) = 3.0; C(0, 3) = 4.0; 198 C(1, 0) = 5.0; C(1, 1) = 6.0; C(1, 2) = 7.0; C(1, 3) = 8.0; 199 200 AddMultADBt(A, D, B, C); 201 C.Add(-1.0, Cexact); 202 203 REQUIRE(C.MaxMaxNorm() < tol); 204 205 MultADBt(A, D, B, C); 206 C *= -1.0; 207 AddMultADBt(A, D, B, C); 208 REQUIRE(C.MaxMaxNorm() < tol); 209 210 DData[0] = 1.0; DData[1] = 1.0; DData[2] = 1.0; 211 MultABt(A, B, C); 212 C *= -1.0; 213 AddMultADBt(A, D, B, C); 214 REQUIRE(C.MaxMaxNorm() < tol); 215 } 216 SECTION("AddMult_a_ABt") 217 { 218 double a = 3.0; 219 220 double CtData[8] = { 49.0, 41.0, 221 116.0, 93.0, 222 183.0, 145.0, 223 256.0, 200.0 224 }; 225 DenseMatrix Cexact(CtData, 2, 4); 226 227 C(0, 0) = 1.0; C(0, 1) = 2.0; C(0, 2) = 3.0; C(0, 3) = 4.0; 228 C(1, 0) = 5.0; C(1, 1) = 6.0; C(1, 2) = 7.0; C(1, 3) = 8.0; 229 230 AddMult_a_ABt(a, A, B, C); 231 C.Add(-1.0, Cexact); 232 233 REQUIRE(C.MaxMaxNorm() < tol); 234 235 MultABt(A, B, C); 236 AddMult_a_ABt(-1.0, A, B, C); 237 238 REQUIRE(C.MaxMaxNorm() < tol); 239 } 240 } 241 242 243 TEST_CASE("LUFactors RightSolve", "[DenseMatrix]") 244 { 245 double tol = 1e-12; 246 247 // Zero on diagonal forces non-trivial pivot 248 double AData[9] = { 0.0, 0.0, 3.0, 2.0, 2.0, 2.0, 2.0, 0.0, 4.0 }; 249 double BData[6] = { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 }; 250 int ipiv[3]; 251 252 DenseMatrix A(AData, 3, 3); 253 DenseMatrix B(BData, 2, 3); 254 255 DenseMatrixInverse Af1(A); 256 DenseMatrix Ainv; 257 Af1.GetInverseMatrix(Ainv); 258 259 LUFactors Af2(AData, ipiv); 260 Af2.Factor(3); 261 262 DenseMatrix C(2,3); 263 Mult(B, Ainv, C); 264 Af2.RightSolve(3, 2, B.GetData()); 265 C -= B; 266 267 REQUIRE(C.MaxMaxNorm() < tol); 268 } 269 270 TEST_CASE("DenseTensor LinearSolve methods", 271 "[DenseMatrix]") 272 { 273 274 int N = 3; 275 DenseMatrix A(N); 276 A(0,0) = 4; A(0,1) = 5; A(0,2) = -2; 277 A(1,0) = 7; A(1,1) = -1; A(1,2) = 2; 278 A(2,0) = 3; A(2,1) = 1; A(2,2) = 4; 279 280 double X[3] = { -14, 42, 28 }; 281 282 int NE = 10; 283 Vector X_batch(N*NE); 284 DenseTensor A_batch(N,N,NE); 285 286 auto a_batch = mfem::Reshape(A_batch.HostWrite(),N,N,NE); 287 auto x_batch = mfem::Reshape(X_batch.HostWrite(),N,NE); 288 // Column major 289 for (int e=0; e<NE; ++e) 290 { 291 292 for (int r=0; r<N; ++r) 293 { 294 for (int c=0; c<N; ++c) 295 { 296 a_batch(c, r, e) = A.GetData()[c+r*N]; 297 } 298 x_batch(r,e) = X[r]; 299 } 300 } 301 302 Array<int> P; 303 BatchLUFactor(A_batch, P); 304 BatchLUSolve(A_batch, P, X_batch); 305 306 auto xans_batch = mfem::Reshape(X_batch.HostRead(),N,NE); 307 REQUIRE(LinearSolve(A,X)); 308 for (int e=0; e<NE; ++e) 309 { 310 for (int r=0; r<N; ++r) 311 { 312 REQUIRE(xans_batch(r,e) == MFEM_Approx(X[r])); 313 } 314 } 315 } 316 317 TEST_CASE("DenseTensor copy", "[DenseMatrix][DenseTensor]") 318 { 319 DenseTensor t1(2,3,4); 320 for (int i=0; i<t1.TotalSize(); ++i) 321 { 322 t1.Data()[i] = i; 323 } 324 DenseTensor t2(t1); 325 DenseTensor t3; 326 t3 = t1; 327 REQUIRE(t2.SizeI() == t1.SizeI()); 328 REQUIRE(t2.SizeJ() == t1.SizeJ()); 329 REQUIRE(t2.SizeK() == t1.SizeK()); 330 331 REQUIRE(t3.SizeI() == t1.SizeI()); 332 REQUIRE(t3.SizeJ() == t1.SizeJ()); 333 REQUIRE(t3.SizeK() == t1.SizeK()); 334 335 REQUIRE(t2.Data() != t1.Data()); 336 REQUIRE(t3.Data() != t1.Data()); 337 338 for (int i=0; i<t1.TotalSize(); ++i) 339 { 340 REQUIRE(t2.Data()[i] == t1.Data()[i]); 341 REQUIRE(t3.Data()[i] == t1.Data()[i]); 342 } 343 } 344