1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5
6 #include "lib/jxl/linalg.h"
7
8 #include <random>
9
10 #include "gtest/gtest.h"
11 #include "lib/jxl/image_test_utils.h"
12
13 namespace jxl {
14 namespace {
15
16 template <typename T, typename Random>
RandomMatrix(const size_t xsize,const size_t ysize,Random & rng,const T vmin,const T vmax)17 Plane<T> RandomMatrix(const size_t xsize, const size_t ysize, Random& rng,
18 const T vmin, const T vmax) {
19 Plane<T> A(xsize, ysize);
20 GeneratorRandom<T, Random> gen(&rng, vmin, vmax);
21 GenerateImage(gen, &A);
22 return A;
23 }
24
25 template <typename T, typename Random>
RandomSymmetricMatrix(const size_t N,Random & rng,const T vmin,const T vmax)26 Plane<T> RandomSymmetricMatrix(const size_t N, Random& rng, const T vmin,
27 const T vmax) {
28 Plane<T> A = RandomMatrix<T>(N, N, rng, vmin, vmax);
29 for (size_t i = 0; i < N; ++i) {
30 for (size_t j = 0; j < i; ++j) {
31 A.Row(j)[i] = A.Row(i)[j];
32 }
33 }
34 return A;
35 }
VerifyMatrixEqual(const ImageD & A,const ImageD & B,const double eps)36 void VerifyMatrixEqual(const ImageD& A, const ImageD& B, const double eps) {
37 ASSERT_EQ(A.xsize(), B.xsize());
38 ASSERT_EQ(A.ysize(), B.ysize());
39 for (size_t y = 0; y < A.ysize(); ++y) {
40 for (size_t x = 0; x < A.xsize(); ++x) {
41 ASSERT_NEAR(A.Row(y)[x], B.Row(y)[x], eps);
42 }
43 }
44 }
45
VerifyOrthogonal(const ImageD & A,const double eps)46 void VerifyOrthogonal(const ImageD& A, const double eps) {
47 VerifyMatrixEqual(Identity<double>(A.xsize()), MatMul(Transpose(A), A), eps);
48 }
49
VerifyTridiagonal(const ImageD & T,const double eps)50 void VerifyTridiagonal(const ImageD& T, const double eps) {
51 ASSERT_EQ(T.xsize(), T.ysize());
52 for (size_t i = 0; i < T.xsize(); ++i) {
53 for (size_t j = i + 2; j < T.xsize(); ++j) {
54 ASSERT_NEAR(T.Row(i)[j], 0.0, eps);
55 ASSERT_NEAR(T.Row(j)[i], 0.0, eps);
56 }
57 }
58 }
59
VerifyUpperTriangular(const ImageD & R,const double eps)60 void VerifyUpperTriangular(const ImageD& R, const double eps) {
61 ASSERT_EQ(R.xsize(), R.ysize());
62 for (size_t i = 0; i < R.xsize(); ++i) {
63 for (size_t j = i + 1; j < R.xsize(); ++j) {
64 ASSERT_NEAR(R.Row(i)[j], 0.0, eps);
65 }
66 }
67 }
68
TEST(LinAlgTest,ConvertToTridiagonal)69 TEST(LinAlgTest, ConvertToTridiagonal) {
70 {
71 ImageD I = Identity<double>(5);
72 ImageD T, U;
73 ConvertToTridiagonal(I, &T, &U);
74 VerifyMatrixEqual(I, T, 1e-15);
75 VerifyMatrixEqual(I, U, 1e-15);
76 }
77 {
78 ImageD A = Identity<double>(5);
79 A.Row(0)[1] = A.Row(1)[0] = 2.0;
80 A.Row(0)[4] = A.Row(4)[0] = 3.0;
81 A.Row(2)[3] = A.Row(3)[2] = 2.0;
82 A.Row(3)[4] = A.Row(4)[3] = 2.0;
83 ImageD U, d;
84 ConvertToDiagonal(A, &d, &U);
85 VerifyOrthogonal(U, 1e-12);
86 VerifyMatrixEqual(A, MatMul(U, MatMul(Diagonal(d), Transpose(U))), 1e-12);
87 }
88 std::mt19937_64 rng;
89 for (int N = 2; N < 100; ++N) {
90 ImageD A = RandomSymmetricMatrix(N, rng, -1.0, 1.0);
91 ImageD T, U;
92 ConvertToTridiagonal(A, &T, &U);
93 VerifyOrthogonal(U, 1e-12);
94 VerifyTridiagonal(T, 1e-12);
95 VerifyMatrixEqual(A, MatMul(U, MatMul(T, Transpose(U))), 1e-12);
96 }
97 }
98
TEST(LinAlgTest,ConvertToDiagonal)99 TEST(LinAlgTest, ConvertToDiagonal) {
100 {
101 ImageD I = Identity<double>(5);
102 ImageD U, d;
103 ConvertToDiagonal(I, &d, &U);
104 VerifyMatrixEqual(I, U, 1e-15);
105 for (int k = 0; k < 5; ++k) {
106 ASSERT_NEAR(d.Row(0)[k], 1.0, 1e-15);
107 }
108 }
109 {
110 ImageD A = Identity<double>(5);
111 A.Row(0)[1] = A.Row(1)[0] = 2.0;
112 A.Row(2)[3] = A.Row(3)[2] = 2.0;
113 A.Row(3)[4] = A.Row(4)[3] = 2.0;
114 ImageD U, d;
115 ConvertToDiagonal(A, &d, &U);
116 VerifyOrthogonal(U, 1e-12);
117 VerifyMatrixEqual(A, MatMul(U, MatMul(Diagonal(d), Transpose(U))), 1e-12);
118 }
119 std::mt19937_64 rng;
120 for (int N = 2; N < 100; ++N) {
121 ImageD A = RandomSymmetricMatrix(N, rng, -1.0, 1.0);
122 ImageD U, d;
123 ConvertToDiagonal(A, &d, &U);
124 VerifyOrthogonal(U, 1e-12);
125 VerifyMatrixEqual(A, MatMul(U, MatMul(Diagonal(d), Transpose(U))), 1e-12);
126 }
127 }
128
TEST(LinAlgTest,ComputeQRFactorization)129 TEST(LinAlgTest, ComputeQRFactorization) {
130 {
131 ImageD I = Identity<double>(5);
132 ImageD Q, R;
133 ComputeQRFactorization(I, &Q, &R);
134 VerifyMatrixEqual(I, Q, 1e-15);
135 VerifyMatrixEqual(I, R, 1e-15);
136 }
137 std::mt19937_64 rng;
138 for (int N = 2; N < 100; ++N) {
139 ImageD A = RandomMatrix(N, N, rng, -1.0, 1.0);
140 ImageD Q, R;
141 ComputeQRFactorization(A, &Q, &R);
142 VerifyOrthogonal(Q, 1e-12);
143 VerifyUpperTriangular(R, 1e-12);
144 VerifyMatrixEqual(A, MatMul(Q, R), 1e-12);
145 }
146 }
147
148 } // namespace
149 } // namespace jxl
150