1 //! This module provides the matrix exponent (exp) function to square matrices.
2 //!
3 use crate::{
4 base::{
5 allocator::Allocator,
6 dimension::{Const, Dim, DimMin, DimMinimum},
7 DefaultAllocator,
8 },
9 convert, try_convert, ComplexField, OMatrix, RealField,
10 };
11
12 use crate::num::Zero;
13
14 // https://github.com/scipy/scipy/blob/c1372d8aa90a73d8a52f135529293ff4edb98fc8/scipy/sparse/linalg/matfuncs.py
15 struct ExpmPadeHelper<T, D>
16 where
17 T: ComplexField,
18 D: DimMin<D>,
19 DefaultAllocator: Allocator<T, D, D> + Allocator<(usize, usize), DimMinimum<D, D>>,
20 {
21 use_exact_norm: bool,
22 ident: OMatrix<T, D, D>,
23
24 a: OMatrix<T, D, D>,
25 a2: Option<OMatrix<T, D, D>>,
26 a4: Option<OMatrix<T, D, D>>,
27 a6: Option<OMatrix<T, D, D>>,
28 a8: Option<OMatrix<T, D, D>>,
29 a10: Option<OMatrix<T, D, D>>,
30
31 d4_exact: Option<T::RealField>,
32 d6_exact: Option<T::RealField>,
33 d8_exact: Option<T::RealField>,
34 d10_exact: Option<T::RealField>,
35
36 d4_approx: Option<T::RealField>,
37 d6_approx: Option<T::RealField>,
38 d8_approx: Option<T::RealField>,
39 d10_approx: Option<T::RealField>,
40 }
41
42 impl<T, D> ExpmPadeHelper<T, D>
43 where
44 T: ComplexField,
45 D: DimMin<D>,
46 DefaultAllocator: Allocator<T, D, D> + Allocator<(usize, usize), DimMinimum<D, D>>,
47 {
new(a: OMatrix<T, D, D>, use_exact_norm: bool) -> Self48 fn new(a: OMatrix<T, D, D>, use_exact_norm: bool) -> Self {
49 let (nrows, ncols) = a.shape_generic();
50 ExpmPadeHelper {
51 use_exact_norm,
52 ident: OMatrix::<T, D, D>::identity_generic(nrows, ncols),
53 a,
54 a2: None,
55 a4: None,
56 a6: None,
57 a8: None,
58 a10: None,
59 d4_exact: None,
60 d6_exact: None,
61 d8_exact: None,
62 d10_exact: None,
63 d4_approx: None,
64 d6_approx: None,
65 d8_approx: None,
66 d10_approx: None,
67 }
68 }
69
calc_a2(&mut self)70 fn calc_a2(&mut self) {
71 if self.a2.is_none() {
72 self.a2 = Some(&self.a * &self.a);
73 }
74 }
75
calc_a4(&mut self)76 fn calc_a4(&mut self) {
77 if self.a4.is_none() {
78 self.calc_a2();
79 let a2 = self.a2.as_ref().unwrap();
80 self.a4 = Some(a2 * a2);
81 }
82 }
83
calc_a6(&mut self)84 fn calc_a6(&mut self) {
85 if self.a6.is_none() {
86 self.calc_a2();
87 self.calc_a4();
88 let a2 = self.a2.as_ref().unwrap();
89 let a4 = self.a4.as_ref().unwrap();
90 self.a6 = Some(a4 * a2);
91 }
92 }
93
calc_a8(&mut self)94 fn calc_a8(&mut self) {
95 if self.a8.is_none() {
96 self.calc_a2();
97 self.calc_a6();
98 let a2 = self.a2.as_ref().unwrap();
99 let a6 = self.a6.as_ref().unwrap();
100 self.a8 = Some(a6 * a2);
101 }
102 }
103
calc_a10(&mut self)104 fn calc_a10(&mut self) {
105 if self.a10.is_none() {
106 self.calc_a4();
107 self.calc_a6();
108 let a4 = self.a4.as_ref().unwrap();
109 let a6 = self.a6.as_ref().unwrap();
110 self.a10 = Some(a6 * a4);
111 }
112 }
113
d4_tight(&mut self) -> T::RealField114 fn d4_tight(&mut self) -> T::RealField {
115 if self.d4_exact.is_none() {
116 self.calc_a4();
117 self.d4_exact = Some(one_norm(self.a4.as_ref().unwrap()).powf(convert(0.25)));
118 }
119 self.d4_exact.clone().unwrap()
120 }
121
d6_tight(&mut self) -> T::RealField122 fn d6_tight(&mut self) -> T::RealField {
123 if self.d6_exact.is_none() {
124 self.calc_a6();
125 self.d6_exact = Some(one_norm(self.a6.as_ref().unwrap()).powf(convert(1.0 / 6.0)));
126 }
127 self.d6_exact.clone().unwrap()
128 }
129
d8_tight(&mut self) -> T::RealField130 fn d8_tight(&mut self) -> T::RealField {
131 if self.d8_exact.is_none() {
132 self.calc_a8();
133 self.d8_exact = Some(one_norm(self.a8.as_ref().unwrap()).powf(convert(1.0 / 8.0)));
134 }
135 self.d8_exact.clone().unwrap()
136 }
137
d10_tight(&mut self) -> T::RealField138 fn d10_tight(&mut self) -> T::RealField {
139 if self.d10_exact.is_none() {
140 self.calc_a10();
141 self.d10_exact = Some(one_norm(self.a10.as_ref().unwrap()).powf(convert(1.0 / 10.0)));
142 }
143 self.d10_exact.clone().unwrap()
144 }
145
d4_loose(&mut self) -> T::RealField146 fn d4_loose(&mut self) -> T::RealField {
147 if self.use_exact_norm {
148 return self.d4_tight();
149 }
150
151 if self.d4_exact.is_some() {
152 return self.d4_exact.clone().unwrap();
153 }
154
155 if self.d4_approx.is_none() {
156 self.calc_a4();
157 self.d4_approx = Some(one_norm(self.a4.as_ref().unwrap()).powf(convert(0.25)));
158 }
159
160 self.d4_approx.clone().unwrap()
161 }
162
d6_loose(&mut self) -> T::RealField163 fn d6_loose(&mut self) -> T::RealField {
164 if self.use_exact_norm {
165 return self.d6_tight();
166 }
167
168 if self.d6_exact.is_some() {
169 return self.d6_exact.clone().unwrap();
170 }
171
172 if self.d6_approx.is_none() {
173 self.calc_a6();
174 self.d6_approx = Some(one_norm(self.a6.as_ref().unwrap()).powf(convert(1.0 / 6.0)));
175 }
176
177 self.d6_approx.clone().unwrap()
178 }
179
d8_loose(&mut self) -> T::RealField180 fn d8_loose(&mut self) -> T::RealField {
181 if self.use_exact_norm {
182 return self.d8_tight();
183 }
184
185 if self.d8_exact.is_some() {
186 return self.d8_exact.clone().unwrap();
187 }
188
189 if self.d8_approx.is_none() {
190 self.calc_a8();
191 self.d8_approx = Some(one_norm(self.a8.as_ref().unwrap()).powf(convert(1.0 / 8.0)));
192 }
193
194 self.d8_approx.clone().unwrap()
195 }
196
d10_loose(&mut self) -> T::RealField197 fn d10_loose(&mut self) -> T::RealField {
198 if self.use_exact_norm {
199 return self.d10_tight();
200 }
201
202 if self.d10_exact.is_some() {
203 return self.d10_exact.clone().unwrap();
204 }
205
206 if self.d10_approx.is_none() {
207 self.calc_a10();
208 self.d10_approx = Some(one_norm(self.a10.as_ref().unwrap()).powf(convert(1.0 / 10.0)));
209 }
210
211 self.d10_approx.clone().unwrap()
212 }
213
pade3(&mut self) -> (OMatrix<T, D, D>, OMatrix<T, D, D>)214 fn pade3(&mut self) -> (OMatrix<T, D, D>, OMatrix<T, D, D>) {
215 let b: [T; 4] = [convert(120.0), convert(60.0), convert(12.0), convert(1.0)];
216 self.calc_a2();
217 let a2 = self.a2.as_ref().unwrap();
218 let u = &self.a * (a2 * b[3].clone() + &self.ident * b[1].clone());
219 let v = a2 * b[2].clone() + &self.ident * b[0].clone();
220 (u, v)
221 }
222
pade5(&mut self) -> (OMatrix<T, D, D>, OMatrix<T, D, D>)223 fn pade5(&mut self) -> (OMatrix<T, D, D>, OMatrix<T, D, D>) {
224 let b: [T; 6] = [
225 convert(30240.0),
226 convert(15120.0),
227 convert(3360.0),
228 convert(420.0),
229 convert(30.0),
230 convert(1.0),
231 ];
232 self.calc_a2();
233 self.calc_a6();
234 let u = &self.a
235 * (self.a4.as_ref().unwrap() * b[5].clone()
236 + self.a2.as_ref().unwrap() * b[3].clone()
237 + &self.ident * b[1].clone());
238 let v = self.a4.as_ref().unwrap() * b[4].clone()
239 + self.a2.as_ref().unwrap() * b[2].clone()
240 + &self.ident * b[0].clone();
241 (u, v)
242 }
243
pade7(&mut self) -> (OMatrix<T, D, D>, OMatrix<T, D, D>)244 fn pade7(&mut self) -> (OMatrix<T, D, D>, OMatrix<T, D, D>) {
245 let b: [T; 8] = [
246 convert(17_297_280.0),
247 convert(8_648_640.0),
248 convert(1_995_840.0),
249 convert(277_200.0),
250 convert(25_200.0),
251 convert(1_512.0),
252 convert(56.0),
253 convert(1.0),
254 ];
255 self.calc_a2();
256 self.calc_a4();
257 self.calc_a6();
258 let u = &self.a
259 * (self.a6.as_ref().unwrap() * b[7].clone()
260 + self.a4.as_ref().unwrap() * b[5].clone()
261 + self.a2.as_ref().unwrap() * b[3].clone()
262 + &self.ident * b[1].clone());
263 let v = self.a6.as_ref().unwrap() * b[6].clone()
264 + self.a4.as_ref().unwrap() * b[4].clone()
265 + self.a2.as_ref().unwrap() * b[2].clone()
266 + &self.ident * b[0].clone();
267 (u, v)
268 }
269
pade9(&mut self) -> (OMatrix<T, D, D>, OMatrix<T, D, D>)270 fn pade9(&mut self) -> (OMatrix<T, D, D>, OMatrix<T, D, D>) {
271 let b: [T; 10] = [
272 convert(17_643_225_600.0),
273 convert(8_821_612_800.0),
274 convert(2_075_673_600.0),
275 convert(302_702_400.0),
276 convert(30_270_240.0),
277 convert(2_162_160.0),
278 convert(110_880.0),
279 convert(3_960.0),
280 convert(90.0),
281 convert(1.0),
282 ];
283 self.calc_a2();
284 self.calc_a4();
285 self.calc_a6();
286 self.calc_a8();
287 let u = &self.a
288 * (self.a8.as_ref().unwrap() * b[9].clone()
289 + self.a6.as_ref().unwrap() * b[7].clone()
290 + self.a4.as_ref().unwrap() * b[5].clone()
291 + self.a2.as_ref().unwrap() * b[3].clone()
292 + &self.ident * b[1].clone());
293 let v = self.a8.as_ref().unwrap() * b[8].clone()
294 + self.a6.as_ref().unwrap() * b[6].clone()
295 + self.a4.as_ref().unwrap() * b[4].clone()
296 + self.a2.as_ref().unwrap() * b[2].clone()
297 + &self.ident * b[0].clone();
298 (u, v)
299 }
300
pade13_scaled(&mut self, s: u64) -> (OMatrix<T, D, D>, OMatrix<T, D, D>)301 fn pade13_scaled(&mut self, s: u64) -> (OMatrix<T, D, D>, OMatrix<T, D, D>) {
302 let b: [T; 14] = [
303 convert(64_764_752_532_480_000.0),
304 convert(32_382_376_266_240_000.0),
305 convert(7_771_770_303_897_600.0),
306 convert(1_187_353_796_428_800.0),
307 convert(129_060_195_264_000.0),
308 convert(10_559_470_521_600.0),
309 convert(670_442_572_800.0),
310 convert(33_522_128_640.0),
311 convert(1_323_241_920.0),
312 convert(40_840_800.0),
313 convert(960_960.0),
314 convert(16_380.0),
315 convert(182.0),
316 convert(1.0),
317 ];
318 let s = s as f64;
319
320 let mb = &self.a * convert::<f64, T>(2.0_f64.powf(-s));
321 self.calc_a2();
322 self.calc_a4();
323 self.calc_a6();
324 let mb2 = self.a2.as_ref().unwrap() * convert::<f64, T>(2.0_f64.powf(-2.0 * s.clone()));
325 let mb4 = self.a4.as_ref().unwrap() * convert::<f64, T>(2.0.powf(-4.0 * s.clone()));
326 let mb6 = self.a6.as_ref().unwrap() * convert::<f64, T>(2.0.powf(-6.0 * s));
327
328 let u2 = &mb6 * (&mb6 * b[13].clone() + &mb4 * b[11].clone() + &mb2 * b[9].clone());
329 let u = &mb
330 * (&u2
331 + &mb6 * b[7].clone()
332 + &mb4 * b[5].clone()
333 + &mb2 * b[3].clone()
334 + &self.ident * b[1].clone());
335 let v2 = &mb6 * (&mb6 * b[12].clone() + &mb4 * b[10].clone() + &mb2 * b[8].clone());
336 let v = v2
337 + &mb6 * b[6].clone()
338 + &mb4 * b[4].clone()
339 + &mb2 * b[2].clone()
340 + &self.ident * b[0].clone();
341 (u, v)
342 }
343 }
344
factorial(n: u128) -> u128345 fn factorial(n: u128) -> u128 {
346 if n == 1 {
347 return 1;
348 }
349 n * factorial(n - 1)
350 }
351
352 /// Compute the 1-norm of a non-negative integer power of a non-negative matrix.
onenorm_matrix_power_nonm<T, D>(a: &OMatrix<T, D, D>, p: u64) -> T where T: RealField, D: Dim, DefaultAllocator: Allocator<T, D, D> + Allocator<T, D>,353 fn onenorm_matrix_power_nonm<T, D>(a: &OMatrix<T, D, D>, p: u64) -> T
354 where
355 T: RealField,
356 D: Dim,
357 DefaultAllocator: Allocator<T, D, D> + Allocator<T, D>,
358 {
359 let nrows = a.shape_generic().0;
360 let mut v = crate::OVector::<T, D>::repeat_generic(nrows, Const::<1>, convert(1.0));
361 let m = a.transpose();
362
363 for _ in 0..p {
364 v = &m * v;
365 }
366
367 v.max()
368 }
369
ell<T, D>(a: &OMatrix<T, D, D>, m: u64) -> u64 where T: ComplexField, D: Dim, DefaultAllocator: Allocator<T, D, D> + Allocator<T, D> + Allocator<T::RealField, D> + Allocator<T::RealField, D, D>,370 fn ell<T, D>(a: &OMatrix<T, D, D>, m: u64) -> u64
371 where
372 T: ComplexField,
373 D: Dim,
374 DefaultAllocator: Allocator<T, D, D>
375 + Allocator<T, D>
376 + Allocator<T::RealField, D>
377 + Allocator<T::RealField, D, D>,
378 {
379 // 2m choose m = (2m)!/(m! * (2m-m)!)
380
381 let a_abs = a.map(|x| x.abs());
382
383 let a_abs_onenorm = onenorm_matrix_power_nonm(&a_abs, 2 * m + 1);
384
385 if a_abs_onenorm == <T as ComplexField>::RealField::zero() {
386 return 0;
387 }
388
389 let choose_2m_m =
390 factorial(2 * m as u128) / (factorial(m as u128) * factorial(2 * m as u128 - m as u128));
391 let abs_c_recip = choose_2m_m * factorial(2 * m as u128 + 1);
392 let alpha = a_abs_onenorm / one_norm(a);
393 let alpha: f64 = try_convert(alpha).unwrap() / abs_c_recip as f64;
394
395 let u = 2_f64.powf(-53.0);
396 let log2_alpha_div_u = (alpha / u).log2();
397 let value = (log2_alpha_div_u / (2.0 * m as f64)).ceil();
398 if value > 0.0 {
399 value as u64
400 } else {
401 0
402 }
403 }
404
solve_p_q<T, D>(u: OMatrix<T, D, D>, v: OMatrix<T, D, D>) -> OMatrix<T, D, D> where T: ComplexField, D: DimMin<D, Output = D>, DefaultAllocator: Allocator<T, D, D> + Allocator<(usize, usize), DimMinimum<D, D>>,405 fn solve_p_q<T, D>(u: OMatrix<T, D, D>, v: OMatrix<T, D, D>) -> OMatrix<T, D, D>
406 where
407 T: ComplexField,
408 D: DimMin<D, Output = D>,
409 DefaultAllocator: Allocator<T, D, D> + Allocator<(usize, usize), DimMinimum<D, D>>,
410 {
411 let p = &u + &v;
412 let q = &v - &u;
413
414 q.lu().solve(&p).unwrap()
415 }
416
one_norm<T, D>(m: &OMatrix<T, D, D>) -> T::RealField where T: ComplexField, D: Dim, DefaultAllocator: Allocator<T, D, D>,417 fn one_norm<T, D>(m: &OMatrix<T, D, D>) -> T::RealField
418 where
419 T: ComplexField,
420 D: Dim,
421 DefaultAllocator: Allocator<T, D, D>,
422 {
423 let mut max = <T as ComplexField>::RealField::zero();
424
425 for i in 0..m.ncols() {
426 let col = m.column(i);
427 max = max.max(
428 col.iter()
429 .fold(<T as ComplexField>::RealField::zero(), |a, b| {
430 a + b.clone().abs()
431 }),
432 );
433 }
434
435 max
436 }
437
438 impl<T: ComplexField, D> OMatrix<T, D, D>
439 where
440 D: DimMin<D, Output = D>,
441 DefaultAllocator: Allocator<T, D, D>
442 + Allocator<(usize, usize), DimMinimum<D, D>>
443 + Allocator<T, D>
444 + Allocator<T::RealField, D>
445 + Allocator<T::RealField, D, D>,
446 {
447 /// Computes exponential of this matrix
448 #[must_use]
exp(&self) -> Self449 pub fn exp(&self) -> Self {
450 // Simple case
451 if self.nrows() == 1 {
452 return self.map(|v| v.exp());
453 }
454
455 let mut helper = ExpmPadeHelper::new(self.clone(), true);
456
457 let eta_1 = T::RealField::max(helper.d4_loose(), helper.d6_loose());
458 if eta_1 < convert(1.495_585_217_958_292e-2) && ell(&helper.a, 3) == 0 {
459 let (u, v) = helper.pade3();
460 return solve_p_q(u, v);
461 }
462
463 let eta_2 = T::RealField::max(helper.d4_tight(), helper.d6_loose());
464 if eta_2 < convert(2.539_398_330_063_23e-1) && ell(&helper.a, 5) == 0 {
465 let (u, v) = helper.pade5();
466 return solve_p_q(u, v);
467 }
468
469 let eta_3 = T::RealField::max(helper.d6_tight(), helper.d8_loose());
470 if eta_3 < convert(9.504_178_996_162_932e-1) && ell(&helper.a, 7) == 0 {
471 let (u, v) = helper.pade7();
472 return solve_p_q(u, v);
473 }
474 if eta_3 < convert(2.097_847_961_257_068e0) && ell(&helper.a, 9) == 0 {
475 let (u, v) = helper.pade9();
476 return solve_p_q(u, v);
477 }
478
479 let eta_4 = T::RealField::max(helper.d8_loose(), helper.d10_loose());
480 let eta_5 = T::RealField::min(eta_3, eta_4);
481 let theta_13 = convert(4.25);
482
483 let mut s = if eta_5 == T::RealField::zero() {
484 0
485 } else {
486 let l2 = try_convert((eta_5 / theta_13).log2().ceil()).unwrap();
487
488 if l2 < 0.0 {
489 0
490 } else {
491 l2 as u64
492 }
493 };
494
495 s += ell(
496 &(&helper.a * convert::<f64, T>(2.0_f64.powf(-(s as f64)))),
497 13,
498 );
499
500 let (u, v) = helper.pade13_scaled(s);
501 let mut x = solve_p_q(u, v);
502
503 for _ in 0..s {
504 x = &x * &x;
505 }
506 x
507 }
508 }
509
510 #[cfg(test)]
511 mod tests {
512 #[test]
one_norm()513 fn one_norm() {
514 use crate::Matrix3;
515 let m = Matrix3::new(-3.0, 5.0, 7.0, 2.0, 6.0, 4.0, 0.0, 2.0, 8.0);
516
517 assert_eq!(super::one_norm(&m), 19.0);
518 }
519 }
520