1 //! This module provides the matrix exponent (exp) function to square matrices.
2 //!
3 use crate::{
4     base::{
5         allocator::Allocator,
6         dimension::{Const, Dim, DimMin, DimMinimum},
7         DefaultAllocator,
8     },
9     convert, try_convert, ComplexField, OMatrix, RealField,
10 };
11 
12 use crate::num::Zero;
13 
14 // https://github.com/scipy/scipy/blob/c1372d8aa90a73d8a52f135529293ff4edb98fc8/scipy/sparse/linalg/matfuncs.py
15 struct ExpmPadeHelper<T, D>
16 where
17     T: ComplexField,
18     D: DimMin<D>,
19     DefaultAllocator: Allocator<T, D, D> + Allocator<(usize, usize), DimMinimum<D, D>>,
20 {
21     use_exact_norm: bool,
22     ident: OMatrix<T, D, D>,
23 
24     a: OMatrix<T, D, D>,
25     a2: Option<OMatrix<T, D, D>>,
26     a4: Option<OMatrix<T, D, D>>,
27     a6: Option<OMatrix<T, D, D>>,
28     a8: Option<OMatrix<T, D, D>>,
29     a10: Option<OMatrix<T, D, D>>,
30 
31     d4_exact: Option<T::RealField>,
32     d6_exact: Option<T::RealField>,
33     d8_exact: Option<T::RealField>,
34     d10_exact: Option<T::RealField>,
35 
36     d4_approx: Option<T::RealField>,
37     d6_approx: Option<T::RealField>,
38     d8_approx: Option<T::RealField>,
39     d10_approx: Option<T::RealField>,
40 }
41 
42 impl<T, D> ExpmPadeHelper<T, D>
43 where
44     T: ComplexField,
45     D: DimMin<D>,
46     DefaultAllocator: Allocator<T, D, D> + Allocator<(usize, usize), DimMinimum<D, D>>,
47 {
new(a: OMatrix<T, D, D>, use_exact_norm: bool) -> Self48     fn new(a: OMatrix<T, D, D>, use_exact_norm: bool) -> Self {
49         let (nrows, ncols) = a.shape_generic();
50         ExpmPadeHelper {
51             use_exact_norm,
52             ident: OMatrix::<T, D, D>::identity_generic(nrows, ncols),
53             a,
54             a2: None,
55             a4: None,
56             a6: None,
57             a8: None,
58             a10: None,
59             d4_exact: None,
60             d6_exact: None,
61             d8_exact: None,
62             d10_exact: None,
63             d4_approx: None,
64             d6_approx: None,
65             d8_approx: None,
66             d10_approx: None,
67         }
68     }
69 
calc_a2(&mut self)70     fn calc_a2(&mut self) {
71         if self.a2.is_none() {
72             self.a2 = Some(&self.a * &self.a);
73         }
74     }
75 
calc_a4(&mut self)76     fn calc_a4(&mut self) {
77         if self.a4.is_none() {
78             self.calc_a2();
79             let a2 = self.a2.as_ref().unwrap();
80             self.a4 = Some(a2 * a2);
81         }
82     }
83 
calc_a6(&mut self)84     fn calc_a6(&mut self) {
85         if self.a6.is_none() {
86             self.calc_a2();
87             self.calc_a4();
88             let a2 = self.a2.as_ref().unwrap();
89             let a4 = self.a4.as_ref().unwrap();
90             self.a6 = Some(a4 * a2);
91         }
92     }
93 
calc_a8(&mut self)94     fn calc_a8(&mut self) {
95         if self.a8.is_none() {
96             self.calc_a2();
97             self.calc_a6();
98             let a2 = self.a2.as_ref().unwrap();
99             let a6 = self.a6.as_ref().unwrap();
100             self.a8 = Some(a6 * a2);
101         }
102     }
103 
calc_a10(&mut self)104     fn calc_a10(&mut self) {
105         if self.a10.is_none() {
106             self.calc_a4();
107             self.calc_a6();
108             let a4 = self.a4.as_ref().unwrap();
109             let a6 = self.a6.as_ref().unwrap();
110             self.a10 = Some(a6 * a4);
111         }
112     }
113 
d4_tight(&mut self) -> T::RealField114     fn d4_tight(&mut self) -> T::RealField {
115         if self.d4_exact.is_none() {
116             self.calc_a4();
117             self.d4_exact = Some(one_norm(self.a4.as_ref().unwrap()).powf(convert(0.25)));
118         }
119         self.d4_exact.clone().unwrap()
120     }
121 
d6_tight(&mut self) -> T::RealField122     fn d6_tight(&mut self) -> T::RealField {
123         if self.d6_exact.is_none() {
124             self.calc_a6();
125             self.d6_exact = Some(one_norm(self.a6.as_ref().unwrap()).powf(convert(1.0 / 6.0)));
126         }
127         self.d6_exact.clone().unwrap()
128     }
129 
d8_tight(&mut self) -> T::RealField130     fn d8_tight(&mut self) -> T::RealField {
131         if self.d8_exact.is_none() {
132             self.calc_a8();
133             self.d8_exact = Some(one_norm(self.a8.as_ref().unwrap()).powf(convert(1.0 / 8.0)));
134         }
135         self.d8_exact.clone().unwrap()
136     }
137 
d10_tight(&mut self) -> T::RealField138     fn d10_tight(&mut self) -> T::RealField {
139         if self.d10_exact.is_none() {
140             self.calc_a10();
141             self.d10_exact = Some(one_norm(self.a10.as_ref().unwrap()).powf(convert(1.0 / 10.0)));
142         }
143         self.d10_exact.clone().unwrap()
144     }
145 
d4_loose(&mut self) -> T::RealField146     fn d4_loose(&mut self) -> T::RealField {
147         if self.use_exact_norm {
148             return self.d4_tight();
149         }
150 
151         if self.d4_exact.is_some() {
152             return self.d4_exact.clone().unwrap();
153         }
154 
155         if self.d4_approx.is_none() {
156             self.calc_a4();
157             self.d4_approx = Some(one_norm(self.a4.as_ref().unwrap()).powf(convert(0.25)));
158         }
159 
160         self.d4_approx.clone().unwrap()
161     }
162 
d6_loose(&mut self) -> T::RealField163     fn d6_loose(&mut self) -> T::RealField {
164         if self.use_exact_norm {
165             return self.d6_tight();
166         }
167 
168         if self.d6_exact.is_some() {
169             return self.d6_exact.clone().unwrap();
170         }
171 
172         if self.d6_approx.is_none() {
173             self.calc_a6();
174             self.d6_approx = Some(one_norm(self.a6.as_ref().unwrap()).powf(convert(1.0 / 6.0)));
175         }
176 
177         self.d6_approx.clone().unwrap()
178     }
179 
d8_loose(&mut self) -> T::RealField180     fn d8_loose(&mut self) -> T::RealField {
181         if self.use_exact_norm {
182             return self.d8_tight();
183         }
184 
185         if self.d8_exact.is_some() {
186             return self.d8_exact.clone().unwrap();
187         }
188 
189         if self.d8_approx.is_none() {
190             self.calc_a8();
191             self.d8_approx = Some(one_norm(self.a8.as_ref().unwrap()).powf(convert(1.0 / 8.0)));
192         }
193 
194         self.d8_approx.clone().unwrap()
195     }
196 
d10_loose(&mut self) -> T::RealField197     fn d10_loose(&mut self) -> T::RealField {
198         if self.use_exact_norm {
199             return self.d10_tight();
200         }
201 
202         if self.d10_exact.is_some() {
203             return self.d10_exact.clone().unwrap();
204         }
205 
206         if self.d10_approx.is_none() {
207             self.calc_a10();
208             self.d10_approx = Some(one_norm(self.a10.as_ref().unwrap()).powf(convert(1.0 / 10.0)));
209         }
210 
211         self.d10_approx.clone().unwrap()
212     }
213 
pade3(&mut self) -> (OMatrix<T, D, D>, OMatrix<T, D, D>)214     fn pade3(&mut self) -> (OMatrix<T, D, D>, OMatrix<T, D, D>) {
215         let b: [T; 4] = [convert(120.0), convert(60.0), convert(12.0), convert(1.0)];
216         self.calc_a2();
217         let a2 = self.a2.as_ref().unwrap();
218         let u = &self.a * (a2 * b[3].clone() + &self.ident * b[1].clone());
219         let v = a2 * b[2].clone() + &self.ident * b[0].clone();
220         (u, v)
221     }
222 
pade5(&mut self) -> (OMatrix<T, D, D>, OMatrix<T, D, D>)223     fn pade5(&mut self) -> (OMatrix<T, D, D>, OMatrix<T, D, D>) {
224         let b: [T; 6] = [
225             convert(30240.0),
226             convert(15120.0),
227             convert(3360.0),
228             convert(420.0),
229             convert(30.0),
230             convert(1.0),
231         ];
232         self.calc_a2();
233         self.calc_a6();
234         let u = &self.a
235             * (self.a4.as_ref().unwrap() * b[5].clone()
236                 + self.a2.as_ref().unwrap() * b[3].clone()
237                 + &self.ident * b[1].clone());
238         let v = self.a4.as_ref().unwrap() * b[4].clone()
239             + self.a2.as_ref().unwrap() * b[2].clone()
240             + &self.ident * b[0].clone();
241         (u, v)
242     }
243 
pade7(&mut self) -> (OMatrix<T, D, D>, OMatrix<T, D, D>)244     fn pade7(&mut self) -> (OMatrix<T, D, D>, OMatrix<T, D, D>) {
245         let b: [T; 8] = [
246             convert(17_297_280.0),
247             convert(8_648_640.0),
248             convert(1_995_840.0),
249             convert(277_200.0),
250             convert(25_200.0),
251             convert(1_512.0),
252             convert(56.0),
253             convert(1.0),
254         ];
255         self.calc_a2();
256         self.calc_a4();
257         self.calc_a6();
258         let u = &self.a
259             * (self.a6.as_ref().unwrap() * b[7].clone()
260                 + self.a4.as_ref().unwrap() * b[5].clone()
261                 + self.a2.as_ref().unwrap() * b[3].clone()
262                 + &self.ident * b[1].clone());
263         let v = self.a6.as_ref().unwrap() * b[6].clone()
264             + self.a4.as_ref().unwrap() * b[4].clone()
265             + self.a2.as_ref().unwrap() * b[2].clone()
266             + &self.ident * b[0].clone();
267         (u, v)
268     }
269 
pade9(&mut self) -> (OMatrix<T, D, D>, OMatrix<T, D, D>)270     fn pade9(&mut self) -> (OMatrix<T, D, D>, OMatrix<T, D, D>) {
271         let b: [T; 10] = [
272             convert(17_643_225_600.0),
273             convert(8_821_612_800.0),
274             convert(2_075_673_600.0),
275             convert(302_702_400.0),
276             convert(30_270_240.0),
277             convert(2_162_160.0),
278             convert(110_880.0),
279             convert(3_960.0),
280             convert(90.0),
281             convert(1.0),
282         ];
283         self.calc_a2();
284         self.calc_a4();
285         self.calc_a6();
286         self.calc_a8();
287         let u = &self.a
288             * (self.a8.as_ref().unwrap() * b[9].clone()
289                 + self.a6.as_ref().unwrap() * b[7].clone()
290                 + self.a4.as_ref().unwrap() * b[5].clone()
291                 + self.a2.as_ref().unwrap() * b[3].clone()
292                 + &self.ident * b[1].clone());
293         let v = self.a8.as_ref().unwrap() * b[8].clone()
294             + self.a6.as_ref().unwrap() * b[6].clone()
295             + self.a4.as_ref().unwrap() * b[4].clone()
296             + self.a2.as_ref().unwrap() * b[2].clone()
297             + &self.ident * b[0].clone();
298         (u, v)
299     }
300 
pade13_scaled(&mut self, s: u64) -> (OMatrix<T, D, D>, OMatrix<T, D, D>)301     fn pade13_scaled(&mut self, s: u64) -> (OMatrix<T, D, D>, OMatrix<T, D, D>) {
302         let b: [T; 14] = [
303             convert(64_764_752_532_480_000.0),
304             convert(32_382_376_266_240_000.0),
305             convert(7_771_770_303_897_600.0),
306             convert(1_187_353_796_428_800.0),
307             convert(129_060_195_264_000.0),
308             convert(10_559_470_521_600.0),
309             convert(670_442_572_800.0),
310             convert(33_522_128_640.0),
311             convert(1_323_241_920.0),
312             convert(40_840_800.0),
313             convert(960_960.0),
314             convert(16_380.0),
315             convert(182.0),
316             convert(1.0),
317         ];
318         let s = s as f64;
319 
320         let mb = &self.a * convert::<f64, T>(2.0_f64.powf(-s));
321         self.calc_a2();
322         self.calc_a4();
323         self.calc_a6();
324         let mb2 = self.a2.as_ref().unwrap() * convert::<f64, T>(2.0_f64.powf(-2.0 * s.clone()));
325         let mb4 = self.a4.as_ref().unwrap() * convert::<f64, T>(2.0.powf(-4.0 * s.clone()));
326         let mb6 = self.a6.as_ref().unwrap() * convert::<f64, T>(2.0.powf(-6.0 * s));
327 
328         let u2 = &mb6 * (&mb6 * b[13].clone() + &mb4 * b[11].clone() + &mb2 * b[9].clone());
329         let u = &mb
330             * (&u2
331                 + &mb6 * b[7].clone()
332                 + &mb4 * b[5].clone()
333                 + &mb2 * b[3].clone()
334                 + &self.ident * b[1].clone());
335         let v2 = &mb6 * (&mb6 * b[12].clone() + &mb4 * b[10].clone() + &mb2 * b[8].clone());
336         let v = v2
337             + &mb6 * b[6].clone()
338             + &mb4 * b[4].clone()
339             + &mb2 * b[2].clone()
340             + &self.ident * b[0].clone();
341         (u, v)
342     }
343 }
344 
factorial(n: u128) -> u128345 fn factorial(n: u128) -> u128 {
346     if n == 1 {
347         return 1;
348     }
349     n * factorial(n - 1)
350 }
351 
352 /// Compute the 1-norm of a non-negative integer power of a non-negative matrix.
onenorm_matrix_power_nonm<T, D>(a: &OMatrix<T, D, D>, p: u64) -> T where T: RealField, D: Dim, DefaultAllocator: Allocator<T, D, D> + Allocator<T, D>,353 fn onenorm_matrix_power_nonm<T, D>(a: &OMatrix<T, D, D>, p: u64) -> T
354 where
355     T: RealField,
356     D: Dim,
357     DefaultAllocator: Allocator<T, D, D> + Allocator<T, D>,
358 {
359     let nrows = a.shape_generic().0;
360     let mut v = crate::OVector::<T, D>::repeat_generic(nrows, Const::<1>, convert(1.0));
361     let m = a.transpose();
362 
363     for _ in 0..p {
364         v = &m * v;
365     }
366 
367     v.max()
368 }
369 
ell<T, D>(a: &OMatrix<T, D, D>, m: u64) -> u64 where T: ComplexField, D: Dim, DefaultAllocator: Allocator<T, D, D> + Allocator<T, D> + Allocator<T::RealField, D> + Allocator<T::RealField, D, D>,370 fn ell<T, D>(a: &OMatrix<T, D, D>, m: u64) -> u64
371 where
372     T: ComplexField,
373     D: Dim,
374     DefaultAllocator: Allocator<T, D, D>
375         + Allocator<T, D>
376         + Allocator<T::RealField, D>
377         + Allocator<T::RealField, D, D>,
378 {
379     // 2m choose m = (2m)!/(m! * (2m-m)!)
380 
381     let a_abs = a.map(|x| x.abs());
382 
383     let a_abs_onenorm = onenorm_matrix_power_nonm(&a_abs, 2 * m + 1);
384 
385     if a_abs_onenorm == <T as ComplexField>::RealField::zero() {
386         return 0;
387     }
388 
389     let choose_2m_m =
390         factorial(2 * m as u128) / (factorial(m as u128) * factorial(2 * m as u128 - m as u128));
391     let abs_c_recip = choose_2m_m * factorial(2 * m as u128 + 1);
392     let alpha = a_abs_onenorm / one_norm(a);
393     let alpha: f64 = try_convert(alpha).unwrap() / abs_c_recip as f64;
394 
395     let u = 2_f64.powf(-53.0);
396     let log2_alpha_div_u = (alpha / u).log2();
397     let value = (log2_alpha_div_u / (2.0 * m as f64)).ceil();
398     if value > 0.0 {
399         value as u64
400     } else {
401         0
402     }
403 }
404 
solve_p_q<T, D>(u: OMatrix<T, D, D>, v: OMatrix<T, D, D>) -> OMatrix<T, D, D> where T: ComplexField, D: DimMin<D, Output = D>, DefaultAllocator: Allocator<T, D, D> + Allocator<(usize, usize), DimMinimum<D, D>>,405 fn solve_p_q<T, D>(u: OMatrix<T, D, D>, v: OMatrix<T, D, D>) -> OMatrix<T, D, D>
406 where
407     T: ComplexField,
408     D: DimMin<D, Output = D>,
409     DefaultAllocator: Allocator<T, D, D> + Allocator<(usize, usize), DimMinimum<D, D>>,
410 {
411     let p = &u + &v;
412     let q = &v - &u;
413 
414     q.lu().solve(&p).unwrap()
415 }
416 
one_norm<T, D>(m: &OMatrix<T, D, D>) -> T::RealField where T: ComplexField, D: Dim, DefaultAllocator: Allocator<T, D, D>,417 fn one_norm<T, D>(m: &OMatrix<T, D, D>) -> T::RealField
418 where
419     T: ComplexField,
420     D: Dim,
421     DefaultAllocator: Allocator<T, D, D>,
422 {
423     let mut max = <T as ComplexField>::RealField::zero();
424 
425     for i in 0..m.ncols() {
426         let col = m.column(i);
427         max = max.max(
428             col.iter()
429                 .fold(<T as ComplexField>::RealField::zero(), |a, b| {
430                     a + b.clone().abs()
431                 }),
432         );
433     }
434 
435     max
436 }
437 
438 impl<T: ComplexField, D> OMatrix<T, D, D>
439 where
440     D: DimMin<D, Output = D>,
441     DefaultAllocator: Allocator<T, D, D>
442         + Allocator<(usize, usize), DimMinimum<D, D>>
443         + Allocator<T, D>
444         + Allocator<T::RealField, D>
445         + Allocator<T::RealField, D, D>,
446 {
447     /// Computes exponential of this matrix
448     #[must_use]
exp(&self) -> Self449     pub fn exp(&self) -> Self {
450         // Simple case
451         if self.nrows() == 1 {
452             return self.map(|v| v.exp());
453         }
454 
455         let mut helper = ExpmPadeHelper::new(self.clone(), true);
456 
457         let eta_1 = T::RealField::max(helper.d4_loose(), helper.d6_loose());
458         if eta_1 < convert(1.495_585_217_958_292e-2) && ell(&helper.a, 3) == 0 {
459             let (u, v) = helper.pade3();
460             return solve_p_q(u, v);
461         }
462 
463         let eta_2 = T::RealField::max(helper.d4_tight(), helper.d6_loose());
464         if eta_2 < convert(2.539_398_330_063_23e-1) && ell(&helper.a, 5) == 0 {
465             let (u, v) = helper.pade5();
466             return solve_p_q(u, v);
467         }
468 
469         let eta_3 = T::RealField::max(helper.d6_tight(), helper.d8_loose());
470         if eta_3 < convert(9.504_178_996_162_932e-1) && ell(&helper.a, 7) == 0 {
471             let (u, v) = helper.pade7();
472             return solve_p_q(u, v);
473         }
474         if eta_3 < convert(2.097_847_961_257_068e0) && ell(&helper.a, 9) == 0 {
475             let (u, v) = helper.pade9();
476             return solve_p_q(u, v);
477         }
478 
479         let eta_4 = T::RealField::max(helper.d8_loose(), helper.d10_loose());
480         let eta_5 = T::RealField::min(eta_3, eta_4);
481         let theta_13 = convert(4.25);
482 
483         let mut s = if eta_5 == T::RealField::zero() {
484             0
485         } else {
486             let l2 = try_convert((eta_5 / theta_13).log2().ceil()).unwrap();
487 
488             if l2 < 0.0 {
489                 0
490             } else {
491                 l2 as u64
492             }
493         };
494 
495         s += ell(
496             &(&helper.a * convert::<f64, T>(2.0_f64.powf(-(s as f64)))),
497             13,
498         );
499 
500         let (u, v) = helper.pade13_scaled(s);
501         let mut x = solve_p_q(u, v);
502 
503         for _ in 0..s {
504             x = &x * &x;
505         }
506         x
507     }
508 }
509 
510 #[cfg(test)]
511 mod tests {
512     #[test]
one_norm()513     fn one_norm() {
514         use crate::Matrix3;
515         let m = Matrix3::new(-3.0, 5.0, 7.0, 2.0, 6.0, 4.0, 0.0, 2.0, 8.0);
516 
517         assert_eq!(super::one_norm(&m), 19.0);
518     }
519 }
520