1 use alga::general::ComplexField;
2
3 use crate::base::allocator::Allocator;
4 use crate::base::dimension::Dim;
5 use crate::base::storage::{Storage, StorageMut};
6 use crate::base::{DefaultAllocator, MatrixN, SquareMatrix};
7
8 use crate::linalg::lu;
9
10 impl<N: ComplexField, D: Dim, S: Storage<N, D, D>> SquareMatrix<N, D, S> {
11 /// Attempts to invert this matrix.
12 #[inline]
try_inverse(self) -> Option<MatrixN<N, D>> where DefaultAllocator: Allocator<N, D, D>13 pub fn try_inverse(self) -> Option<MatrixN<N, D>>
14 where DefaultAllocator: Allocator<N, D, D> {
15 let mut me = self.into_owned();
16 if me.try_inverse_mut() {
17 Some(me)
18 } else {
19 None
20 }
21 }
22 }
23
24 impl<N: ComplexField, D: Dim, S: StorageMut<N, D, D>> SquareMatrix<N, D, S> {
25 /// Attempts to invert this matrix in-place. Returns `false` and leaves `self` untouched if
26 /// inversion fails.
27 #[inline]
try_inverse_mut(&mut self) -> bool where DefaultAllocator: Allocator<N, D, D>28 pub fn try_inverse_mut(&mut self) -> bool
29 where DefaultAllocator: Allocator<N, D, D> {
30 assert!(self.is_square(), "Unable to invert a non-square matrix.");
31
32 let dim = self.shape().0;
33
34 unsafe {
35 match dim {
36 0 => true,
37 1 => {
38 let determinant = self.get_unchecked((0, 0)).clone();
39 if determinant.is_zero() {
40 false
41 } else {
42 *self.get_unchecked_mut((0, 0)) = N::one() / determinant;
43 true
44 }
45 }
46 2 => {
47 let m11 = *self.get_unchecked((0, 0));
48 let m12 = *self.get_unchecked((0, 1));
49 let m21 = *self.get_unchecked((1, 0));
50 let m22 = *self.get_unchecked((1, 1));
51
52 let determinant = m11 * m22 - m21 * m12;
53
54 if determinant.is_zero() {
55 false
56 } else {
57 *self.get_unchecked_mut((0, 0)) = m22 / determinant;
58 *self.get_unchecked_mut((0, 1)) = -m12 / determinant;
59
60 *self.get_unchecked_mut((1, 0)) = -m21 / determinant;
61 *self.get_unchecked_mut((1, 1)) = m11 / determinant;
62
63 true
64 }
65 }
66 3 => {
67 let m11 = *self.get_unchecked((0, 0));
68 let m12 = *self.get_unchecked((0, 1));
69 let m13 = *self.get_unchecked((0, 2));
70
71 let m21 = *self.get_unchecked((1, 0));
72 let m22 = *self.get_unchecked((1, 1));
73 let m23 = *self.get_unchecked((1, 2));
74
75 let m31 = *self.get_unchecked((2, 0));
76 let m32 = *self.get_unchecked((2, 1));
77 let m33 = *self.get_unchecked((2, 2));
78
79 let minor_m12_m23 = m22 * m33 - m32 * m23;
80 let minor_m11_m23 = m21 * m33 - m31 * m23;
81 let minor_m11_m22 = m21 * m32 - m31 * m22;
82
83 let determinant =
84 m11 * minor_m12_m23 - m12 * minor_m11_m23 + m13 * minor_m11_m22;
85
86 if determinant.is_zero() {
87 false
88 } else {
89 *self.get_unchecked_mut((0, 0)) = minor_m12_m23 / determinant;
90 *self.get_unchecked_mut((0, 1)) = (m13 * m32 - m33 * m12) / determinant;
91 *self.get_unchecked_mut((0, 2)) = (m12 * m23 - m22 * m13) / determinant;
92
93 *self.get_unchecked_mut((1, 0)) = -minor_m11_m23 / determinant;
94 *self.get_unchecked_mut((1, 1)) = (m11 * m33 - m31 * m13) / determinant;
95 *self.get_unchecked_mut((1, 2)) = (m13 * m21 - m23 * m11) / determinant;
96
97 *self.get_unchecked_mut((2, 0)) = minor_m11_m22 / determinant;
98 *self.get_unchecked_mut((2, 1)) = (m12 * m31 - m32 * m11) / determinant;
99 *self.get_unchecked_mut((2, 2)) = (m11 * m22 - m21 * m12) / determinant;
100
101 true
102 }
103 }
104 4 => {
105 let oself = self.clone_owned();
106 do_inverse4(&oself, self)
107 }
108 _ => {
109 let oself = self.clone_owned();
110 lu::try_invert_to(oself, self)
111 }
112 }
113 }
114 }
115 }
116
117 // NOTE: this is an extremely efficient, loop-unrolled matrix inverse from MESA (MIT licensed).
do_inverse4<N: ComplexField, D: Dim, S: StorageMut<N, D, D>>( m: &MatrixN<N, D>, out: &mut SquareMatrix<N, D, S>, ) -> bool where DefaultAllocator: Allocator<N, D, D>,118 fn do_inverse4<N: ComplexField, D: Dim, S: StorageMut<N, D, D>>(
119 m: &MatrixN<N, D>,
120 out: &mut SquareMatrix<N, D, S>,
121 ) -> bool
122 where
123 DefaultAllocator: Allocator<N, D, D>,
124 {
125 let m = m.data.as_slice();
126
127 out[(0, 0)] = m[5] * m[10] * m[15] - m[5] * m[11] * m[14] - m[9] * m[6] * m[15]
128 + m[9] * m[7] * m[14]
129 + m[13] * m[6] * m[11]
130 - m[13] * m[7] * m[10];
131
132 out[(1, 0)] = -m[1] * m[10] * m[15] + m[1] * m[11] * m[14] + m[9] * m[2] * m[15]
133 - m[9] * m[3] * m[14]
134 - m[13] * m[2] * m[11]
135 + m[13] * m[3] * m[10];
136
137 out[(2, 0)] = m[1] * m[6] * m[15] - m[1] * m[7] * m[14] - m[5] * m[2] * m[15]
138 + m[5] * m[3] * m[14]
139 + m[13] * m[2] * m[7]
140 - m[13] * m[3] * m[6];
141
142 out[(3, 0)] = -m[1] * m[6] * m[11] + m[1] * m[7] * m[10] + m[5] * m[2] * m[11]
143 - m[5] * m[3] * m[10]
144 - m[9] * m[2] * m[7]
145 + m[9] * m[3] * m[6];
146
147 out[(0, 1)] = -m[4] * m[10] * m[15] + m[4] * m[11] * m[14] + m[8] * m[6] * m[15]
148 - m[8] * m[7] * m[14]
149 - m[12] * m[6] * m[11]
150 + m[12] * m[7] * m[10];
151
152 out[(1, 1)] = m[0] * m[10] * m[15] - m[0] * m[11] * m[14] - m[8] * m[2] * m[15]
153 + m[8] * m[3] * m[14]
154 + m[12] * m[2] * m[11]
155 - m[12] * m[3] * m[10];
156
157 out[(2, 1)] = -m[0] * m[6] * m[15] + m[0] * m[7] * m[14] + m[4] * m[2] * m[15]
158 - m[4] * m[3] * m[14]
159 - m[12] * m[2] * m[7]
160 + m[12] * m[3] * m[6];
161
162 out[(3, 1)] = m[0] * m[6] * m[11] - m[0] * m[7] * m[10] - m[4] * m[2] * m[11]
163 + m[4] * m[3] * m[10]
164 + m[8] * m[2] * m[7]
165 - m[8] * m[3] * m[6];
166
167 out[(0, 2)] = m[4] * m[9] * m[15] - m[4] * m[11] * m[13] - m[8] * m[5] * m[15]
168 + m[8] * m[7] * m[13]
169 + m[12] * m[5] * m[11]
170 - m[12] * m[7] * m[9];
171
172 out[(1, 2)] = -m[0] * m[9] * m[15] + m[0] * m[11] * m[13] + m[8] * m[1] * m[15]
173 - m[8] * m[3] * m[13]
174 - m[12] * m[1] * m[11]
175 + m[12] * m[3] * m[9];
176
177 out[(2, 2)] = m[0] * m[5] * m[15] - m[0] * m[7] * m[13] - m[4] * m[1] * m[15]
178 + m[4] * m[3] * m[13]
179 + m[12] * m[1] * m[7]
180 - m[12] * m[3] * m[5];
181
182 out[(0, 3)] = -m[4] * m[9] * m[14] + m[4] * m[10] * m[13] + m[8] * m[5] * m[14]
183 - m[8] * m[6] * m[13]
184 - m[12] * m[5] * m[10]
185 + m[12] * m[6] * m[9];
186
187 out[(3, 2)] = -m[0] * m[5] * m[11] + m[0] * m[7] * m[9] + m[4] * m[1] * m[11]
188 - m[4] * m[3] * m[9]
189 - m[8] * m[1] * m[7]
190 + m[8] * m[3] * m[5];
191
192 out[(1, 3)] = m[0] * m[9] * m[14] - m[0] * m[10] * m[13] - m[8] * m[1] * m[14]
193 + m[8] * m[2] * m[13]
194 + m[12] * m[1] * m[10]
195 - m[12] * m[2] * m[9];
196
197 out[(2, 3)] = -m[0] * m[5] * m[14] + m[0] * m[6] * m[13] + m[4] * m[1] * m[14]
198 - m[4] * m[2] * m[13]
199 - m[12] * m[1] * m[6]
200 + m[12] * m[2] * m[5];
201
202 out[(3, 3)] = m[0] * m[5] * m[10] - m[0] * m[6] * m[9] - m[4] * m[1] * m[10]
203 + m[4] * m[2] * m[9]
204 + m[8] * m[1] * m[6]
205 - m[8] * m[2] * m[5];
206
207 let det = m[0] * out[(0, 0)] + m[1] * out[(0, 1)] + m[2] * out[(0, 2)] + m[3] * out[(0, 3)];
208
209 if !det.is_zero() {
210 let inv_det = N::one() / det;
211
212 for j in 0..4 {
213 for i in 0..4 {
214 out[(i, j)] *= inv_det;
215 }
216 }
217 true
218 } else {
219 false
220 }
221 }
222