1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5 
6 #include "lib/jxl/linalg.h"
7 
8 #include <random>
9 
10 #include "gtest/gtest.h"
11 #include "lib/jxl/image_test_utils.h"
12 
13 namespace jxl {
14 namespace {
15 
16 template <typename T, typename Random>
RandomMatrix(const size_t xsize,const size_t ysize,Random & rng,const T vmin,const T vmax)17 Plane<T> RandomMatrix(const size_t xsize, const size_t ysize, Random& rng,
18                       const T vmin, const T vmax) {
19   Plane<T> A(xsize, ysize);
20   GeneratorRandom<T, Random> gen(&rng, vmin, vmax);
21   GenerateImage(gen, &A);
22   return A;
23 }
24 
25 template <typename T, typename Random>
RandomSymmetricMatrix(const size_t N,Random & rng,const T vmin,const T vmax)26 Plane<T> RandomSymmetricMatrix(const size_t N, Random& rng, const T vmin,
27                                const T vmax) {
28   Plane<T> A = RandomMatrix<T>(N, N, rng, vmin, vmax);
29   for (size_t i = 0; i < N; ++i) {
30     for (size_t j = 0; j < i; ++j) {
31       A.Row(j)[i] = A.Row(i)[j];
32     }
33   }
34   return A;
35 }
VerifyMatrixEqual(const ImageD & A,const ImageD & B,const double eps)36 void VerifyMatrixEqual(const ImageD& A, const ImageD& B, const double eps) {
37   ASSERT_EQ(A.xsize(), B.xsize());
38   ASSERT_EQ(A.ysize(), B.ysize());
39   for (size_t y = 0; y < A.ysize(); ++y) {
40     for (size_t x = 0; x < A.xsize(); ++x) {
41       ASSERT_NEAR(A.Row(y)[x], B.Row(y)[x], eps);
42     }
43   }
44 }
45 
VerifyOrthogonal(const ImageD & A,const double eps)46 void VerifyOrthogonal(const ImageD& A, const double eps) {
47   VerifyMatrixEqual(Identity<double>(A.xsize()), MatMul(Transpose(A), A), eps);
48 }
49 
VerifyTridiagonal(const ImageD & T,const double eps)50 void VerifyTridiagonal(const ImageD& T, const double eps) {
51   ASSERT_EQ(T.xsize(), T.ysize());
52   for (size_t i = 0; i < T.xsize(); ++i) {
53     for (size_t j = i + 2; j < T.xsize(); ++j) {
54       ASSERT_NEAR(T.Row(i)[j], 0.0, eps);
55       ASSERT_NEAR(T.Row(j)[i], 0.0, eps);
56     }
57   }
58 }
59 
VerifyUpperTriangular(const ImageD & R,const double eps)60 void VerifyUpperTriangular(const ImageD& R, const double eps) {
61   ASSERT_EQ(R.xsize(), R.ysize());
62   for (size_t i = 0; i < R.xsize(); ++i) {
63     for (size_t j = i + 1; j < R.xsize(); ++j) {
64       ASSERT_NEAR(R.Row(i)[j], 0.0, eps);
65     }
66   }
67 }
68 
TEST(LinAlgTest,ConvertToTridiagonal)69 TEST(LinAlgTest, ConvertToTridiagonal) {
70   {
71     ImageD I = Identity<double>(5);
72     ImageD T, U;
73     ConvertToTridiagonal(I, &T, &U);
74     VerifyMatrixEqual(I, T, 1e-15);
75     VerifyMatrixEqual(I, U, 1e-15);
76   }
77   {
78     ImageD A = Identity<double>(5);
79     A.Row(0)[1] = A.Row(1)[0] = 2.0;
80     A.Row(0)[4] = A.Row(4)[0] = 3.0;
81     A.Row(2)[3] = A.Row(3)[2] = 2.0;
82     A.Row(3)[4] = A.Row(4)[3] = 2.0;
83     ImageD U, d;
84     ConvertToDiagonal(A, &d, &U);
85     VerifyOrthogonal(U, 1e-12);
86     VerifyMatrixEqual(A, MatMul(U, MatMul(Diagonal(d), Transpose(U))), 1e-12);
87   }
88   std::mt19937_64 rng;
89   for (int N = 2; N < 100; ++N) {
90     ImageD A = RandomSymmetricMatrix(N, rng, -1.0, 1.0);
91     ImageD T, U;
92     ConvertToTridiagonal(A, &T, &U);
93     VerifyOrthogonal(U, 1e-12);
94     VerifyTridiagonal(T, 1e-12);
95     VerifyMatrixEqual(A, MatMul(U, MatMul(T, Transpose(U))), 1e-12);
96   }
97 }
98 
TEST(LinAlgTest,ConvertToDiagonal)99 TEST(LinAlgTest, ConvertToDiagonal) {
100   {
101     ImageD I = Identity<double>(5);
102     ImageD U, d;
103     ConvertToDiagonal(I, &d, &U);
104     VerifyMatrixEqual(I, U, 1e-15);
105     for (int k = 0; k < 5; ++k) {
106       ASSERT_NEAR(d.Row(0)[k], 1.0, 1e-15);
107     }
108   }
109   {
110     ImageD A = Identity<double>(5);
111     A.Row(0)[1] = A.Row(1)[0] = 2.0;
112     A.Row(2)[3] = A.Row(3)[2] = 2.0;
113     A.Row(3)[4] = A.Row(4)[3] = 2.0;
114     ImageD U, d;
115     ConvertToDiagonal(A, &d, &U);
116     VerifyOrthogonal(U, 1e-12);
117     VerifyMatrixEqual(A, MatMul(U, MatMul(Diagonal(d), Transpose(U))), 1e-12);
118   }
119   std::mt19937_64 rng;
120   for (int N = 2; N < 100; ++N) {
121     ImageD A = RandomSymmetricMatrix(N, rng, -1.0, 1.0);
122     ImageD U, d;
123     ConvertToDiagonal(A, &d, &U);
124     VerifyOrthogonal(U, 1e-12);
125     VerifyMatrixEqual(A, MatMul(U, MatMul(Diagonal(d), Transpose(U))), 1e-12);
126   }
127 }
128 
TEST(LinAlgTest,ComputeQRFactorization)129 TEST(LinAlgTest, ComputeQRFactorization) {
130   {
131     ImageD I = Identity<double>(5);
132     ImageD Q, R;
133     ComputeQRFactorization(I, &Q, &R);
134     VerifyMatrixEqual(I, Q, 1e-15);
135     VerifyMatrixEqual(I, R, 1e-15);
136   }
137   std::mt19937_64 rng;
138   for (int N = 2; N < 100; ++N) {
139     ImageD A = RandomMatrix(N, N, rng, -1.0, 1.0);
140     ImageD Q, R;
141     ComputeQRFactorization(A, &Q, &R);
142     VerifyOrthogonal(Q, 1e-12);
143     VerifyUpperTriangular(R, 1e-12);
144     VerifyMatrixEqual(A, MatMul(Q, R), 1e-12);
145   }
146 }
147 
148 }  // namespace
149 }  // namespace jxl
150