1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5
6 #ifndef LIB_JXL_LINALG_H_
7 #define LIB_JXL_LINALG_H_
8
9 // Linear algebra.
10
11 #include <stddef.h>
12
13 #include <algorithm>
14 #include <cmath>
15 #include <vector>
16
17 #include "lib/jxl/base/compiler_specific.h"
18 #include "lib/jxl/base/status.h"
19 #include "lib/jxl/image.h"
20 #include "lib/jxl/image_ops.h"
21
22 namespace jxl {
23
24 using ImageD = Plane<double>;
25
26 template <typename T>
DotProduct(const size_t N,const T * const JXL_RESTRICT a,const T * const JXL_RESTRICT b)27 inline T DotProduct(const size_t N, const T* const JXL_RESTRICT a,
28 const T* const JXL_RESTRICT b) {
29 T sum = 0.0;
30 for (size_t k = 0; k < N; ++k) {
31 sum += a[k] * b[k];
32 }
33 return sum;
34 }
35
36 template <typename T>
L2NormSquared(const size_t N,const T * const JXL_RESTRICT a)37 inline T L2NormSquared(const size_t N, const T* const JXL_RESTRICT a) {
38 return DotProduct(N, a, a);
39 }
40
41 template <typename T>
L1Norm(const size_t N,const T * const JXL_RESTRICT a)42 inline T L1Norm(const size_t N, const T* const JXL_RESTRICT a) {
43 T sum = 0;
44 for (size_t k = 0; k < N; ++k) {
45 sum += a[k] >= 0 ? a[k] : -a[k];
46 }
47 return sum;
48 }
49
DotProduct(const ImageD & a,const ImageD & b)50 inline double DotProduct(const ImageD& a, const ImageD& b) {
51 JXL_ASSERT(a.ysize() == 1);
52 JXL_ASSERT(b.ysize() == 1);
53 JXL_ASSERT(a.xsize() == b.xsize());
54 const double* const JXL_RESTRICT row_a = a.Row(0);
55 const double* const JXL_RESTRICT row_b = b.Row(0);
56 return DotProduct(a.xsize(), row_a, row_b);
57 }
58
Transpose(const ImageD & A)59 inline ImageD Transpose(const ImageD& A) {
60 ImageD out(A.ysize(), A.xsize());
61 for (size_t x = 0; x < A.xsize(); ++x) {
62 double* const JXL_RESTRICT row_out = out.Row(x);
63 for (size_t y = 0; y < A.ysize(); ++y) {
64 row_out[y] = A.Row(y)[x];
65 }
66 }
67 return out;
68 }
69
70 template <typename Tout, typename Tin1, typename Tin2>
MatMul(const Plane<Tin1> & A,const Plane<Tin2> & B)71 Plane<Tout> MatMul(const Plane<Tin1>& A, const Plane<Tin2>& B) {
72 JXL_ASSERT(A.ysize() == B.xsize());
73 Plane<Tout> out(A.xsize(), B.ysize());
74 for (size_t y = 0; y < B.ysize(); ++y) {
75 const Tin2* const JXL_RESTRICT row_b = B.Row(y);
76 Tout* const JXL_RESTRICT row_out = out.Row(y);
77 for (size_t x = 0; x < A.xsize(); ++x) {
78 row_out[x] = 0.0;
79 for (size_t k = 0; k < B.xsize(); ++k) {
80 row_out[x] += A.Row(k)[x] * row_b[k];
81 }
82 }
83 }
84 return out;
85 }
86
87 template <typename T1, typename T2>
MatMul(const Plane<T1> & A,const Plane<T2> & B)88 ImageD MatMul(const Plane<T1>& A, const Plane<T2>& B) {
89 return MatMul<double, T1, T2>(A, B);
90 }
91
92 template <typename T1, typename T2>
MatMulI(const Plane<T1> & A,const Plane<T2> & B)93 ImageI MatMulI(const Plane<T1>& A, const Plane<T2>& B) {
94 return MatMul<int, T1, T2>(A, B);
95 }
96
97 // Computes A = B * C, with sizes rows*cols: A=ha*wa, B=wa*wb, C=ha*wb
98 template <typename T>
MatMul(const T * a,const T * b,int ha,int wa,int wb,T * c)99 void MatMul(const T* a, const T* b, int ha, int wa, int wb, T* c) {
100 std::vector<T> temp(wa); // Make better use of cache lines
101 for (int x = 0; x < wb; x++) {
102 for (int z = 0; z < wa; z++) {
103 temp[z] = b[z * wb + x];
104 }
105 for (int y = 0; y < ha; y++) {
106 double e = 0;
107 for (int z = 0; z < wa; z++) {
108 e += a[y * wa + z] * temp[z];
109 }
110 c[y * wb + x] = e;
111 }
112 }
113 }
114
115 // Computes C = A + factor * B
116 template <typename T, typename F>
MatAdd(const T * a,const T * b,F factor,int h,int w,T * c)117 void MatAdd(const T* a, const T* b, F factor, int h, int w, T* c) {
118 for (int i = 0; i < w * h; i++) {
119 c[i] = a[i] + b[i] * factor;
120 }
121 }
122
123 template <typename T>
Identity(const size_t N)124 inline Plane<T> Identity(const size_t N) {
125 Plane<T> out(N, N);
126 for (size_t i = 0; i < N; ++i) {
127 T* JXL_RESTRICT row = out.Row(i);
128 std::fill(row, row + N, 0);
129 row[i] = static_cast<T>(1.0);
130 }
131 return out;
132 }
133
Diagonal(const ImageD & d)134 inline ImageD Diagonal(const ImageD& d) {
135 JXL_ASSERT(d.ysize() == 1);
136 ImageD out(d.xsize(), d.xsize());
137 const double* JXL_RESTRICT row_diag = d.Row(0);
138 for (size_t k = 0; k < d.xsize(); ++k) {
139 double* JXL_RESTRICT row_out = out.Row(k);
140 std::fill(row_out, row_out + d.xsize(), 0.0);
141 row_out[k] = row_diag[k];
142 }
143 return out;
144 }
145
146 // Computes c, s such that c^2 + s^2 = 1 and
147 // [c -s] [x] = [ * ]
148 // [s c] [y] [ 0 ]
149 void GivensRotation(double x, double y, double* c, double* s);
150
151 // U = U * Givens(i, j, c, s)
152 void RotateMatrixCols(ImageD* JXL_RESTRICT U, int i, int j, double c, double s);
153
154 // A is symmetric, U is orthogonal, T is tri-diagonal and
155 // A = U * T * Transpose(U).
156 void ConvertToTridiagonal(const ImageD& A, ImageD* JXL_RESTRICT T,
157 ImageD* JXL_RESTRICT U);
158
159 // A is symmetric, U is orthogonal, and A = U * Diagonal(diag) * Transpose(U).
160 void ConvertToDiagonal(const ImageD& A, ImageD* JXL_RESTRICT diag,
161 ImageD* JXL_RESTRICT U);
162
163 // A is square matrix, Q is orthogonal, R is upper triangular and A = Q * R;
164 void ComputeQRFactorization(const ImageD& A, ImageD* JXL_RESTRICT Q,
165 ImageD* JXL_RESTRICT R);
166
167 // Inverts a 3x3 matrix in place
168 template <typename T>
Inv3x3Matrix(T * matrix)169 Status Inv3x3Matrix(T* matrix) {
170 // Intermediate computation is done in double precision.
171 double temp[9];
172 temp[0] = static_cast<double>(matrix[4]) * matrix[8] -
173 static_cast<double>(matrix[5]) * matrix[7];
174 temp[1] = static_cast<double>(matrix[2]) * matrix[7] -
175 static_cast<double>(matrix[1]) * matrix[8];
176 temp[2] = static_cast<double>(matrix[1]) * matrix[5] -
177 static_cast<double>(matrix[2]) * matrix[4];
178 temp[3] = static_cast<double>(matrix[5]) * matrix[6] -
179 static_cast<double>(matrix[3]) * matrix[8];
180 temp[4] = static_cast<double>(matrix[0]) * matrix[8] -
181 static_cast<double>(matrix[2]) * matrix[6];
182 temp[5] = static_cast<double>(matrix[2]) * matrix[3] -
183 static_cast<double>(matrix[0]) * matrix[5];
184 temp[6] = static_cast<double>(matrix[3]) * matrix[7] -
185 static_cast<double>(matrix[4]) * matrix[6];
186 temp[7] = static_cast<double>(matrix[1]) * matrix[6] -
187 static_cast<double>(matrix[0]) * matrix[7];
188 temp[8] = static_cast<double>(matrix[0]) * matrix[4] -
189 static_cast<double>(matrix[1]) * matrix[3];
190 double det = matrix[0] * temp[0] + matrix[1] * temp[3] + matrix[2] * temp[6];
191 if (std::abs(det) < 1e-10) {
192 return JXL_FAILURE("Matrix determinant is too close to 0");
193 }
194 double idet = 1.0 / det;
195 for (int i = 0; i < 9; i++) {
196 matrix[i] = temp[i] * idet;
197 }
198 return true;
199 }
200
201 // Solves system of linear equations A * X = B using the conjugate gradient
202 // method. Matrix a must be a n*n, symmetric and positive definite.
203 // Vectors b and x must have n elements
204 template <typename T>
ConjugateGradient(const T * a,int n,const T * b,T * x)205 void ConjugateGradient(const T* a, int n, const T* b, T* x) {
206 std::vector<T> r(n);
207 MatMul(a, x, n, n, 1, r.data());
208 MatAdd(b, r.data(), -1, n, 1, r.data());
209 std::vector<T> p = r;
210 T rr;
211 MatMul(r.data(), r.data(), 1, n, 1, &rr); // inner product
212
213 if (rr == 0) return; // The initial values were already optimal
214
215 for (int i = 0; i < n; i++) {
216 std::vector<T> ap(n);
217 MatMul(a, p.data(), n, n, 1, ap.data());
218 T alpha;
219 MatMul(r.data(), ap.data(), 1, n, 1, &alpha);
220 // Normally alpha couldn't be zero here but if numerical issues caused it,
221 // return assuming the solution is close.
222 if (alpha == 0) return;
223 alpha = rr / alpha;
224 MatAdd(x, p.data(), alpha, n, 1, x);
225 MatAdd(r.data(), ap.data(), -alpha, n, 1, r.data());
226
227 T rr2;
228 MatMul(r.data(), r.data(), 1, n, 1, &rr2); // inner product
229 if (rr2 < 1e-20) break;
230
231 T beta = rr2 / rr;
232 MatAdd(r.data(), p.data(), beta, 1, n, p.data());
233 rr = rr2;
234 }
235 }
236
237 // Computes optimal coefficients r to approximate points p with linear
238 // combination of functions f. The matrix f has h rows and w columns, r has h
239 // values, p has w values. h is the amount of functions, w the amount of points.
240 // Uses the finite element method and minimizes mean square error.
241 template <typename T>
FEM(const T * f,int h,int w,const T * p,T * r)242 void FEM(const T* f, int h, int w, const T* p, T* r) {
243 // Compute "Gramian" matrix G = F * F^T
244 // Speed up multiplication by using non-zero intervals in sparse F.
245 std::vector<int> start(h);
246 std::vector<int> end(h);
247 for (int y = 0; y < h; y++) {
248 start[y] = end[y] = 0;
249 for (int x = 0; x < w; x++) {
250 if (f[y * w + x] != 0) {
251 start[y] = x;
252 break;
253 }
254 }
255 for (int x = w - 1; x >= 0; x--) {
256 if (f[y * w + x] != 0) {
257 end[y] = x + 1;
258 break;
259 }
260 }
261 }
262
263 std::vector<T> g(h * h);
264 for (int y = 0; y < h; y++) {
265 for (int x = 0; x <= y; x++) {
266 T v = 0;
267 // Intersection of the two sparse intervals.
268 int s = std::max(start[x], start[y]);
269 int e = std::min(end[x], end[y]);
270 for (int z = s; z < e; z++) {
271 v += f[x * w + z] * f[y * w + z];
272 }
273 // Symmetric, so two values output at once
274 g[y * h + x] = v;
275 g[x * h + y] = v;
276 }
277 }
278
279 // B vector: sum of each column of F multiplied by corresponding p
280 std::vector<T> b(h, 0);
281 for (int y = 0; y < h; y++) {
282 T v = 0;
283 for (int x = 0; x < w; x++) {
284 v += f[y * w + x] * p[x];
285 }
286 b[y] = v;
287 }
288
289 ConjugateGradient(g.data(), h, b.data(), r);
290 }
291
292 } // namespace jxl
293
294 #endif // LIB_JXL_LINALG_H_
295