1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5 
6 #ifndef LIB_JXL_LINALG_H_
7 #define LIB_JXL_LINALG_H_
8 
9 // Linear algebra.
10 
11 #include <stddef.h>
12 
13 #include <algorithm>
14 #include <cmath>
15 #include <vector>
16 
17 #include "lib/jxl/base/compiler_specific.h"
18 #include "lib/jxl/base/status.h"
19 #include "lib/jxl/image.h"
20 #include "lib/jxl/image_ops.h"
21 
22 namespace jxl {
23 
24 using ImageD = Plane<double>;
25 
26 template <typename T>
DotProduct(const size_t N,const T * const JXL_RESTRICT a,const T * const JXL_RESTRICT b)27 inline T DotProduct(const size_t N, const T* const JXL_RESTRICT a,
28                     const T* const JXL_RESTRICT b) {
29   T sum = 0.0;
30   for (size_t k = 0; k < N; ++k) {
31     sum += a[k] * b[k];
32   }
33   return sum;
34 }
35 
36 template <typename T>
L2NormSquared(const size_t N,const T * const JXL_RESTRICT a)37 inline T L2NormSquared(const size_t N, const T* const JXL_RESTRICT a) {
38   return DotProduct(N, a, a);
39 }
40 
41 template <typename T>
L1Norm(const size_t N,const T * const JXL_RESTRICT a)42 inline T L1Norm(const size_t N, const T* const JXL_RESTRICT a) {
43   T sum = 0;
44   for (size_t k = 0; k < N; ++k) {
45     sum += a[k] >= 0 ? a[k] : -a[k];
46   }
47   return sum;
48 }
49 
DotProduct(const ImageD & a,const ImageD & b)50 inline double DotProduct(const ImageD& a, const ImageD& b) {
51   JXL_ASSERT(a.ysize() == 1);
52   JXL_ASSERT(b.ysize() == 1);
53   JXL_ASSERT(a.xsize() == b.xsize());
54   const double* const JXL_RESTRICT row_a = a.Row(0);
55   const double* const JXL_RESTRICT row_b = b.Row(0);
56   return DotProduct(a.xsize(), row_a, row_b);
57 }
58 
Transpose(const ImageD & A)59 inline ImageD Transpose(const ImageD& A) {
60   ImageD out(A.ysize(), A.xsize());
61   for (size_t x = 0; x < A.xsize(); ++x) {
62     double* const JXL_RESTRICT row_out = out.Row(x);
63     for (size_t y = 0; y < A.ysize(); ++y) {
64       row_out[y] = A.Row(y)[x];
65     }
66   }
67   return out;
68 }
69 
70 template <typename Tout, typename Tin1, typename Tin2>
MatMul(const Plane<Tin1> & A,const Plane<Tin2> & B)71 Plane<Tout> MatMul(const Plane<Tin1>& A, const Plane<Tin2>& B) {
72   JXL_ASSERT(A.ysize() == B.xsize());
73   Plane<Tout> out(A.xsize(), B.ysize());
74   for (size_t y = 0; y < B.ysize(); ++y) {
75     const Tin2* const JXL_RESTRICT row_b = B.Row(y);
76     Tout* const JXL_RESTRICT row_out = out.Row(y);
77     for (size_t x = 0; x < A.xsize(); ++x) {
78       row_out[x] = 0.0;
79       for (size_t k = 0; k < B.xsize(); ++k) {
80         row_out[x] += A.Row(k)[x] * row_b[k];
81       }
82     }
83   }
84   return out;
85 }
86 
87 template <typename T1, typename T2>
MatMul(const Plane<T1> & A,const Plane<T2> & B)88 ImageD MatMul(const Plane<T1>& A, const Plane<T2>& B) {
89   return MatMul<double, T1, T2>(A, B);
90 }
91 
92 template <typename T1, typename T2>
MatMulI(const Plane<T1> & A,const Plane<T2> & B)93 ImageI MatMulI(const Plane<T1>& A, const Plane<T2>& B) {
94   return MatMul<int, T1, T2>(A, B);
95 }
96 
97 // Computes A = B * C, with sizes rows*cols: A=ha*wa, B=wa*wb, C=ha*wb
98 template <typename T>
MatMul(const T * a,const T * b,int ha,int wa,int wb,T * c)99 void MatMul(const T* a, const T* b, int ha, int wa, int wb, T* c) {
100   std::vector<T> temp(wa);  // Make better use of cache lines
101   for (int x = 0; x < wb; x++) {
102     for (int z = 0; z < wa; z++) {
103       temp[z] = b[z * wb + x];
104     }
105     for (int y = 0; y < ha; y++) {
106       double e = 0;
107       for (int z = 0; z < wa; z++) {
108         e += a[y * wa + z] * temp[z];
109       }
110       c[y * wb + x] = e;
111     }
112   }
113 }
114 
115 // Computes C = A + factor * B
116 template <typename T, typename F>
MatAdd(const T * a,const T * b,F factor,int h,int w,T * c)117 void MatAdd(const T* a, const T* b, F factor, int h, int w, T* c) {
118   for (int i = 0; i < w * h; i++) {
119     c[i] = a[i] + b[i] * factor;
120   }
121 }
122 
123 template <typename T>
Identity(const size_t N)124 inline Plane<T> Identity(const size_t N) {
125   Plane<T> out(N, N);
126   for (size_t i = 0; i < N; ++i) {
127     T* JXL_RESTRICT row = out.Row(i);
128     std::fill(row, row + N, 0);
129     row[i] = static_cast<T>(1.0);
130   }
131   return out;
132 }
133 
Diagonal(const ImageD & d)134 inline ImageD Diagonal(const ImageD& d) {
135   JXL_ASSERT(d.ysize() == 1);
136   ImageD out(d.xsize(), d.xsize());
137   const double* JXL_RESTRICT row_diag = d.Row(0);
138   for (size_t k = 0; k < d.xsize(); ++k) {
139     double* JXL_RESTRICT row_out = out.Row(k);
140     std::fill(row_out, row_out + d.xsize(), 0.0);
141     row_out[k] = row_diag[k];
142   }
143   return out;
144 }
145 
146 // Computes c, s such that c^2 + s^2 = 1 and
147 //   [c -s] [x] = [ * ]
148 //   [s  c] [y]   [ 0 ]
149 void GivensRotation(double x, double y, double* c, double* s);
150 
151 // U = U * Givens(i, j, c, s)
152 void RotateMatrixCols(ImageD* JXL_RESTRICT U, int i, int j, double c, double s);
153 
154 // A is symmetric, U is orthogonal, T is tri-diagonal and
155 // A = U * T * Transpose(U).
156 void ConvertToTridiagonal(const ImageD& A, ImageD* JXL_RESTRICT T,
157                           ImageD* JXL_RESTRICT U);
158 
159 // A is symmetric, U is orthogonal, and A = U * Diagonal(diag) * Transpose(U).
160 void ConvertToDiagonal(const ImageD& A, ImageD* JXL_RESTRICT diag,
161                        ImageD* JXL_RESTRICT U);
162 
163 // A is square matrix, Q is orthogonal, R is upper triangular and A = Q * R;
164 void ComputeQRFactorization(const ImageD& A, ImageD* JXL_RESTRICT Q,
165                             ImageD* JXL_RESTRICT R);
166 
167 // Inverts a 3x3 matrix in place
168 template <typename T>
Inv3x3Matrix(T * matrix)169 Status Inv3x3Matrix(T* matrix) {
170   // Intermediate computation is done in double precision.
171   double temp[9];
172   temp[0] = static_cast<double>(matrix[4]) * matrix[8] -
173             static_cast<double>(matrix[5]) * matrix[7];
174   temp[1] = static_cast<double>(matrix[2]) * matrix[7] -
175             static_cast<double>(matrix[1]) * matrix[8];
176   temp[2] = static_cast<double>(matrix[1]) * matrix[5] -
177             static_cast<double>(matrix[2]) * matrix[4];
178   temp[3] = static_cast<double>(matrix[5]) * matrix[6] -
179             static_cast<double>(matrix[3]) * matrix[8];
180   temp[4] = static_cast<double>(matrix[0]) * matrix[8] -
181             static_cast<double>(matrix[2]) * matrix[6];
182   temp[5] = static_cast<double>(matrix[2]) * matrix[3] -
183             static_cast<double>(matrix[0]) * matrix[5];
184   temp[6] = static_cast<double>(matrix[3]) * matrix[7] -
185             static_cast<double>(matrix[4]) * matrix[6];
186   temp[7] = static_cast<double>(matrix[1]) * matrix[6] -
187             static_cast<double>(matrix[0]) * matrix[7];
188   temp[8] = static_cast<double>(matrix[0]) * matrix[4] -
189             static_cast<double>(matrix[1]) * matrix[3];
190   double det = matrix[0] * temp[0] + matrix[1] * temp[3] + matrix[2] * temp[6];
191   if (std::abs(det) < 1e-10) {
192     return JXL_FAILURE("Matrix determinant is too close to 0");
193   }
194   double idet = 1.0 / det;
195   for (int i = 0; i < 9; i++) {
196     matrix[i] = temp[i] * idet;
197   }
198   return true;
199 }
200 
201 // Solves system of linear equations A * X = B using the conjugate gradient
202 // method. Matrix a must be a n*n, symmetric and positive definite.
203 // Vectors b and x must have n elements
204 template <typename T>
ConjugateGradient(const T * a,int n,const T * b,T * x)205 void ConjugateGradient(const T* a, int n, const T* b, T* x) {
206   std::vector<T> r(n);
207   MatMul(a, x, n, n, 1, r.data());
208   MatAdd(b, r.data(), -1, n, 1, r.data());
209   std::vector<T> p = r;
210   T rr;
211   MatMul(r.data(), r.data(), 1, n, 1, &rr);  // inner product
212 
213   if (rr == 0) return;  // The initial values were already optimal
214 
215   for (int i = 0; i < n; i++) {
216     std::vector<T> ap(n);
217     MatMul(a, p.data(), n, n, 1, ap.data());
218     T alpha;
219     MatMul(r.data(), ap.data(), 1, n, 1, &alpha);
220     // Normally alpha couldn't be zero here but if numerical issues caused it,
221     // return assuming the solution is close.
222     if (alpha == 0) return;
223     alpha = rr / alpha;
224     MatAdd(x, p.data(), alpha, n, 1, x);
225     MatAdd(r.data(), ap.data(), -alpha, n, 1, r.data());
226 
227     T rr2;
228     MatMul(r.data(), r.data(), 1, n, 1, &rr2);  // inner product
229     if (rr2 < 1e-20) break;
230 
231     T beta = rr2 / rr;
232     MatAdd(r.data(), p.data(), beta, 1, n, p.data());
233     rr = rr2;
234   }
235 }
236 
237 // Computes optimal coefficients r to approximate points p with linear
238 // combination of functions f. The matrix f has h rows and w columns, r has h
239 // values, p has w values. h is the amount of functions, w the amount of points.
240 // Uses the finite element method and minimizes mean square error.
241 template <typename T>
FEM(const T * f,int h,int w,const T * p,T * r)242 void FEM(const T* f, int h, int w, const T* p, T* r) {
243   // Compute "Gramian" matrix G = F * F^T
244   // Speed up multiplication by using non-zero intervals in sparse F.
245   std::vector<int> start(h);
246   std::vector<int> end(h);
247   for (int y = 0; y < h; y++) {
248     start[y] = end[y] = 0;
249     for (int x = 0; x < w; x++) {
250       if (f[y * w + x] != 0) {
251         start[y] = x;
252         break;
253       }
254     }
255     for (int x = w - 1; x >= 0; x--) {
256       if (f[y * w + x] != 0) {
257         end[y] = x + 1;
258         break;
259       }
260     }
261   }
262 
263   std::vector<T> g(h * h);
264   for (int y = 0; y < h; y++) {
265     for (int x = 0; x <= y; x++) {
266       T v = 0;
267       // Intersection of the two sparse intervals.
268       int s = std::max(start[x], start[y]);
269       int e = std::min(end[x], end[y]);
270       for (int z = s; z < e; z++) {
271         v += f[x * w + z] * f[y * w + z];
272       }
273       // Symmetric, so two values output at once
274       g[y * h + x] = v;
275       g[x * h + y] = v;
276     }
277   }
278 
279   // B vector: sum of each column of F multiplied by corresponding p
280   std::vector<T> b(h, 0);
281   for (int y = 0; y < h; y++) {
282     T v = 0;
283     for (int x = 0; x < w; x++) {
284       v += f[y * w + x] * p[x];
285     }
286     b[y] = v;
287   }
288 
289   ConjugateGradient(g.data(), h, b.data(), r);
290 }
291 
292 }  // namespace jxl
293 
294 #endif  // LIB_JXL_LINALG_H_
295