1 // Copyright (c) 2010-2021, Lawrence Livermore National Security, LLC. Produced
2 // at the Lawrence Livermore National Laboratory. All Rights reserved. See files
3 // LICENSE and NOTICE for details. LLNL-CODE-806117.
4 //
5 // This file is part of the MFEM library. For more information and source code
6 // availability visit https://mfem.org.
7 //
8 // MFEM is free software; you can redistribute it and/or modify it under the
9 // terms of the BSD-3 license. We welcome feedback and contributions, see file
10 // CONTRIBUTING.md for details.
11 
12 #include "mfem.hpp"
13 #include "unit_tests.hpp"
14 #include "linalg/dtensor.hpp"
15 
16 using namespace mfem;
17 
18 TEST_CASE("DenseMatrix init-list construction", "[DenseMatrix]")
19 {
20    double ContigData[6] = {6.0, 5.0, 4.0, 3.0, 2.0, 1.0};
21    DenseMatrix Contiguous(ContigData, 2, 3);
22 
23    DenseMatrix Nested(
24    {
25       {6.0, 4.0, 2.0},
26       {5.0, 3.0, 1.0}
27    });
28 
29    for (int i = 0; i < Contiguous.Height(); i++)
30    {
31       for (int j = 0; j < Contiguous.Width(); j++)
32       {
33          REQUIRE(Nested(i,j) == Contiguous(i,j));
34       }
35    }
36 }
37 
38 TEST_CASE("DenseMatrix LinearSolve methods",
39           "[DenseMatrix]")
40 {
41    SECTION("singular_system")
42    {
43       constexpr int N = 3;
44 
45       DenseMatrix A(N);
46       A.SetRow(0, 0.0);
47       A.SetRow(1, 0.0);
48       A.SetRow(2, 0.0);
49 
50       double X[3];
51 
52       REQUIRE_FALSE(LinearSolve(A,X));
53    }
54 
55    SECTION("1x1_system")
56    {
57       constexpr int N = 1;
58       DenseMatrix A(N);
59       A(0,0) = 2;
60 
61       double X[1] = { 12 };
62 
63       REQUIRE(LinearSolve(A,X));
64       REQUIRE(X[0] == MFEM_Approx(6));
65    }
66 
67    SECTION("2x2_system")
68    {
69       constexpr int N = 2;
70 
71       DenseMatrix A(N);
72       A(0,0) = 2.0; A(0,1) = 1.0;
73       A(1,0) = 3.0; A(1,1) = 4.0;
74 
75       double X[2] = { 1, 14 };
76 
77       REQUIRE(LinearSolve(A,X));
78       REQUIRE(X[0] == MFEM_Approx(-2));
79       REQUIRE(X[1] == MFEM_Approx(5));
80    }
81 
82    SECTION("3x3_system")
83    {
84       constexpr int N = 3;
85 
86       DenseMatrix A(N);
87       A(0,0) = 4; A(0,1) =  5; A(0,2) = -2;
88       A(1,0) = 7; A(1,1) = -1; A(1,2) =  2;
89       A(2,0) = 3; A(2,1) =  1; A(2,2) =  4;
90 
91       double X[3] = { -14, 42, 28 };
92 
93       REQUIRE(LinearSolve(A,X));
94       REQUIRE(X[0] == MFEM_Approx(4));
95       REQUIRE(X[1] == MFEM_Approx(-4));
96       REQUIRE(X[2] == MFEM_Approx(5));
97    }
98 
99 }
100 
101 TEST_CASE("DenseMatrix A*B^T methods",
102           "[DenseMatrix]")
103 {
104    double tol = 1e-12;
105 
106    double AtData[6] = {6.0, 5.0,
107                        4.0, 3.0,
108                        2.0, 1.0
109                       };
110    double BtData[12] = {1.0, 3.0, 5.0, 7.0,
111                         2.0, 4.0, 6.0, 8.0,
112                         1.0, 2.0, 3.0, 5.0
113                        };
114 
115    DenseMatrix A(AtData, 2, 3);
116    DenseMatrix B(BtData, 4, 3);
117    DenseMatrix C(2,4);
118 
119    SECTION("MultABt")
120    {
121       double BData[12] = {1.0, 2.0, 1.0,
122                           3.0, 4.0, 2.0,
123                           5.0, 6.0, 3.0,
124                           7.0, 8.0, 5.0
125                          };
126       DenseMatrix Bt(BData, 3, 4);
127 
128       double CtData[8] = {16.0, 12.0,
129                           38.0, 29.0,
130                           60.0, 46.0,
131                           84.0, 64.0
132                          };
133       DenseMatrix Cexact(CtData, 2, 4);
134 
135       MultABt(A, B, C);
136       C.Add(-1.0, Cexact);
137 
138       REQUIRE(C.MaxMaxNorm() < tol);
139 
140       Mult(A, Bt, Cexact);
141       MultABt(A, B, C);
142       C.Add(-1.0, Cexact);
143 
144       REQUIRE(C.MaxMaxNorm() < tol);
145    }
146    SECTION("MultADBt")
147    {
148       double DData[3] = {11.0, 7.0, 5.0};
149       Vector D(DData, 3);
150 
151       double CtData[8] = {132.0, 102.0,
152                           330.0, 259.0,
153                           528.0, 416.0,
154                           736.0, 578.0
155                          };
156       DenseMatrix Cexact(CtData, 2, 4);
157 
158       MultADBt(A, D, B, C);
159       C.Add(-1.0, Cexact);
160 
161       REQUIRE(C.MaxMaxNorm() < tol);
162    }
163    SECTION("AddMultABt")
164    {
165       double CtData[8] = {17.0, 17.0,
166                           40.0, 35.0,
167                           63.0, 53.0,
168                           88.0, 72.0
169                          };
170       DenseMatrix Cexact(CtData, 2, 4);
171 
172       C(0, 0) = 1.0; C(0, 1) = 2.0; C(0, 2) = 3.0; C(0, 3) = 4.0;
173       C(1, 0) = 5.0; C(1, 1) = 6.0; C(1, 2) = 7.0; C(1, 3) = 8.0;
174 
175       AddMultABt(A, B, C);
176       C.Add(-1.0, Cexact);
177 
178       REQUIRE(C.MaxMaxNorm() < tol);
179 
180       MultABt(A, B, C);
181       C *= -1.0;
182       AddMultABt(A, B, C);
183       REQUIRE(C.MaxMaxNorm() < tol);
184    }
185    SECTION("AddMultADBt")
186    {
187       double DData[3] = {11.0, 7.0, 5.0};
188       Vector D(DData, 3);
189 
190       double CtData[8] = {133.0, 107.0,
191                           332.0, 265.0,
192                           531.0, 423.0,
193                           740.0, 586.0
194                          };
195       DenseMatrix Cexact(CtData, 2, 4);
196 
197       C(0, 0) = 1.0; C(0, 1) = 2.0; C(0, 2) = 3.0; C(0, 3) = 4.0;
198       C(1, 0) = 5.0; C(1, 1) = 6.0; C(1, 2) = 7.0; C(1, 3) = 8.0;
199 
200       AddMultADBt(A, D, B, C);
201       C.Add(-1.0, Cexact);
202 
203       REQUIRE(C.MaxMaxNorm() < tol);
204 
205       MultADBt(A, D, B, C);
206       C *= -1.0;
207       AddMultADBt(A, D, B, C);
208       REQUIRE(C.MaxMaxNorm() < tol);
209 
210       DData[0] = 1.0; DData[1] = 1.0; DData[2] = 1.0;
211       MultABt(A, B, C);
212       C *= -1.0;
213       AddMultADBt(A, D, B, C);
214       REQUIRE(C.MaxMaxNorm() < tol);
215    }
216    SECTION("AddMult_a_ABt")
217    {
218       double a = 3.0;
219 
220       double CtData[8] = { 49.0,  41.0,
221                            116.0,  93.0,
222                            183.0, 145.0,
223                            256.0, 200.0
224                          };
225       DenseMatrix Cexact(CtData, 2, 4);
226 
227       C(0, 0) = 1.0; C(0, 1) = 2.0; C(0, 2) = 3.0; C(0, 3) = 4.0;
228       C(1, 0) = 5.0; C(1, 1) = 6.0; C(1, 2) = 7.0; C(1, 3) = 8.0;
229 
230       AddMult_a_ABt(a, A, B, C);
231       C.Add(-1.0, Cexact);
232 
233       REQUIRE(C.MaxMaxNorm() < tol);
234 
235       MultABt(A, B, C);
236       AddMult_a_ABt(-1.0, A, B, C);
237 
238       REQUIRE(C.MaxMaxNorm() < tol);
239    }
240 }
241 
242 
243 TEST_CASE("LUFactors RightSolve", "[DenseMatrix]")
244 {
245    double tol = 1e-12;
246 
247    // Zero on diagonal forces non-trivial pivot
248    double AData[9] = { 0.0, 0.0, 3.0, 2.0, 2.0, 2.0, 2.0, 0.0, 4.0 };
249    double BData[6] = { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 };
250    int ipiv[3];
251 
252    DenseMatrix A(AData, 3, 3);
253    DenseMatrix B(BData, 2, 3);
254 
255    DenseMatrixInverse Af1(A);
256    DenseMatrix Ainv;
257    Af1.GetInverseMatrix(Ainv);
258 
259    LUFactors Af2(AData, ipiv);
260    Af2.Factor(3);
261 
262    DenseMatrix C(2,3);
263    Mult(B, Ainv, C);
264    Af2.RightSolve(3, 2, B.GetData());
265    C -= B;
266 
267    REQUIRE(C.MaxMaxNorm() < tol);
268 }
269 
270 TEST_CASE("DenseTensor LinearSolve methods",
271           "[DenseMatrix]")
272 {
273 
274    int N = 3;
275    DenseMatrix A(N);
276    A(0,0) = 4; A(0,1) =  5; A(0,2) = -2;
277    A(1,0) = 7; A(1,1) = -1; A(1,2) =  2;
278    A(2,0) = 3; A(2,1) =  1; A(2,2) =  4;
279 
280    double X[3] = { -14, 42, 28 };
281 
282    int NE = 10;
283    Vector X_batch(N*NE);
284    DenseTensor A_batch(N,N,NE);
285 
286    auto a_batch = mfem::Reshape(A_batch.HostWrite(),N,N,NE);
287    auto x_batch = mfem::Reshape(X_batch.HostWrite(),N,NE);
288    // Column major
289    for (int e=0; e<NE; ++e)
290    {
291 
292       for (int r=0; r<N; ++r)
293       {
294          for (int c=0; c<N; ++c)
295          {
296             a_batch(c, r, e) = A.GetData()[c+r*N];
297          }
298          x_batch(r,e) = X[r];
299       }
300    }
301 
302    Array<int> P;
303    BatchLUFactor(A_batch, P);
304    BatchLUSolve(A_batch, P, X_batch);
305 
306    auto xans_batch = mfem::Reshape(X_batch.HostRead(),N,NE);
307    REQUIRE(LinearSolve(A,X));
308    for (int e=0; e<NE; ++e)
309    {
310       for (int r=0; r<N; ++r)
311       {
312          REQUIRE(xans_batch(r,e) == MFEM_Approx(X[r]));
313       }
314    }
315 }
316 
317 TEST_CASE("DenseTensor copy", "[DenseMatrix][DenseTensor]")
318 {
319    DenseTensor t1(2,3,4);
320    for (int i=0; i<t1.TotalSize(); ++i)
321    {
322       t1.Data()[i] = i;
323    }
324    DenseTensor t2(t1);
325    DenseTensor t3;
326    t3 = t1;
327    REQUIRE(t2.SizeI() == t1.SizeI());
328    REQUIRE(t2.SizeJ() == t1.SizeJ());
329    REQUIRE(t2.SizeK() == t1.SizeK());
330 
331    REQUIRE(t3.SizeI() == t1.SizeI());
332    REQUIRE(t3.SizeJ() == t1.SizeJ());
333    REQUIRE(t3.SizeK() == t1.SizeK());
334 
335    REQUIRE(t2.Data() != t1.Data());
336    REQUIRE(t3.Data() != t1.Data());
337 
338    for (int i=0; i<t1.TotalSize(); ++i)
339    {
340       REQUIRE(t2.Data()[i] == t1.Data()[i]);
341       REQUIRE(t3.Data()[i] == t1.Data()[i]);
342    }
343 }
344