1 //! 4x4 matrix inverse
2 // Code ported from the `packed_simd` crate
3 // Run this code with `cargo test --example matrix_inversion`
4 #![feature(array_chunks, portable_simd)]
5 use core_simd::simd::*;
6 use Which::*;
7 
8 // Gotta define our own 4x4 matrix since Rust doesn't ship multidim arrays yet :^)
9 #[derive(Copy, Clone, Debug, PartialEq, PartialOrd)]
10 pub struct Matrix4x4([[f32; 4]; 4]);
11 
12 #[allow(clippy::too_many_lines)]
scalar_inv4x4(m: Matrix4x4) -> Option<Matrix4x4>13 pub fn scalar_inv4x4(m: Matrix4x4) -> Option<Matrix4x4> {
14     let m = m.0;
15 
16     #[rustfmt::skip]
17     let mut inv = [
18         // row 0:
19         [
20             // 0,0:
21             m[1][1] * m[2][2] * m[3][3] -
22             m[1][1] * m[2][3] * m[3][2] -
23             m[2][1] * m[1][2] * m[3][3] +
24             m[2][1] * m[1][3] * m[3][2] +
25             m[3][1] * m[1][2] * m[2][3] -
26             m[3][1] * m[1][3] * m[2][2],
27             // 0,1:
28            -m[0][1] * m[2][2] * m[3][3] +
29             m[0][1] * m[2][3] * m[3][2] +
30             m[2][1] * m[0][2] * m[3][3] -
31             m[2][1] * m[0][3] * m[3][2] -
32             m[3][1] * m[0][2] * m[2][3] +
33             m[3][1] * m[0][3] * m[2][2],
34             // 0,2:
35             m[0][1] * m[1][2] * m[3][3] -
36             m[0][1] * m[1][3] * m[3][2] -
37             m[1][1] * m[0][2] * m[3][3] +
38             m[1][1] * m[0][3] * m[3][2] +
39             m[3][1] * m[0][2] * m[1][3] -
40             m[3][1] * m[0][3] * m[1][2],
41             // 0,3:
42            -m[0][1] * m[1][2] * m[2][3] +
43             m[0][1] * m[1][3] * m[2][2] +
44             m[1][1] * m[0][2] * m[2][3] -
45             m[1][1] * m[0][3] * m[2][2] -
46             m[2][1] * m[0][2] * m[1][3] +
47             m[2][1] * m[0][3] * m[1][2],
48         ],
49         // row 1
50         [
51             // 1,0:
52            -m[1][0] * m[2][2] * m[3][3] +
53             m[1][0] * m[2][3] * m[3][2] +
54             m[2][0] * m[1][2] * m[3][3] -
55             m[2][0] * m[1][3] * m[3][2] -
56             m[3][0] * m[1][2] * m[2][3] +
57             m[3][0] * m[1][3] * m[2][2],
58             // 1,1:
59             m[0][0] * m[2][2] * m[3][3] -
60             m[0][0] * m[2][3] * m[3][2] -
61             m[2][0] * m[0][2] * m[3][3] +
62             m[2][0] * m[0][3] * m[3][2] +
63             m[3][0] * m[0][2] * m[2][3] -
64             m[3][0] * m[0][3] * m[2][2],
65             // 1,2:
66            -m[0][0] * m[1][2] * m[3][3] +
67             m[0][0] * m[1][3] * m[3][2] +
68             m[1][0] * m[0][2] * m[3][3] -
69             m[1][0] * m[0][3] * m[3][2] -
70             m[3][0] * m[0][2] * m[1][3] +
71             m[3][0] * m[0][3] * m[1][2],
72             // 1,3:
73             m[0][0] * m[1][2] * m[2][3] -
74             m[0][0] * m[1][3] * m[2][2] -
75             m[1][0] * m[0][2] * m[2][3] +
76             m[1][0] * m[0][3] * m[2][2] +
77             m[2][0] * m[0][2] * m[1][3] -
78             m[2][0] * m[0][3] * m[1][2],
79         ],
80         // row 2
81         [
82             // 2,0:
83             m[1][0] * m[2][1] * m[3][3] -
84             m[1][0] * m[2][3] * m[3][1] -
85             m[2][0] * m[1][1] * m[3][3] +
86             m[2][0] * m[1][3] * m[3][1] +
87             m[3][0] * m[1][1] * m[2][3] -
88             m[3][0] * m[1][3] * m[2][1],
89             // 2,1:
90            -m[0][0] * m[2][1] * m[3][3] +
91             m[0][0] * m[2][3] * m[3][1] +
92             m[2][0] * m[0][1] * m[3][3] -
93             m[2][0] * m[0][3] * m[3][1] -
94             m[3][0] * m[0][1] * m[2][3] +
95             m[3][0] * m[0][3] * m[2][1],
96             // 2,2:
97             m[0][0] * m[1][1] * m[3][3] -
98             m[0][0] * m[1][3] * m[3][1] -
99             m[1][0] * m[0][1] * m[3][3] +
100             m[1][0] * m[0][3] * m[3][1] +
101             m[3][0] * m[0][1] * m[1][3] -
102             m[3][0] * m[0][3] * m[1][1],
103             // 2,3:
104            -m[0][0] * m[1][1] * m[2][3] +
105             m[0][0] * m[1][3] * m[2][1] +
106             m[1][0] * m[0][1] * m[2][3] -
107             m[1][0] * m[0][3] * m[2][1] -
108             m[2][0] * m[0][1] * m[1][3] +
109             m[2][0] * m[0][3] * m[1][1],
110         ],
111         // row 3
112         [
113             // 3,0:
114            -m[1][0] * m[2][1] * m[3][2] +
115             m[1][0] * m[2][2] * m[3][1] +
116             m[2][0] * m[1][1] * m[3][2] -
117             m[2][0] * m[1][2] * m[3][1] -
118             m[3][0] * m[1][1] * m[2][2] +
119             m[3][0] * m[1][2] * m[2][1],
120             // 3,1:
121             m[0][0] * m[2][1] * m[3][2] -
122             m[0][0] * m[2][2] * m[3][1] -
123             m[2][0] * m[0][1] * m[3][2] +
124             m[2][0] * m[0][2] * m[3][1] +
125             m[3][0] * m[0][1] * m[2][2] -
126             m[3][0] * m[0][2] * m[2][1],
127             // 3,2:
128            -m[0][0] * m[1][1] * m[3][2] +
129             m[0][0] * m[1][2] * m[3][1] +
130             m[1][0] * m[0][1] * m[3][2] -
131             m[1][0] * m[0][2] * m[3][1] -
132             m[3][0] * m[0][1] * m[1][2] +
133             m[3][0] * m[0][2] * m[1][1],
134             // 3,3:
135             m[0][0] * m[1][1] * m[2][2] -
136             m[0][0] * m[1][2] * m[2][1] -
137             m[1][0] * m[0][1] * m[2][2] +
138             m[1][0] * m[0][2] * m[2][1] +
139             m[2][0] * m[0][1] * m[1][2] -
140             m[2][0] * m[0][2] * m[1][1],
141         ],
142     ];
143 
144     let det = m[0][0] * inv[0][0] + m[0][1] * inv[1][0] + m[0][2] * inv[2][0] + m[0][3] * inv[3][0];
145     if det == 0. {
146         return None;
147     }
148 
149     let det_inv = 1. / det;
150 
151     for row in &mut inv {
152         for elem in row.iter_mut() {
153             *elem *= det_inv;
154         }
155     }
156 
157     Some(Matrix4x4(inv))
158 }
159 
simd_inv4x4(m: Matrix4x4) -> Option<Matrix4x4>160 pub fn simd_inv4x4(m: Matrix4x4) -> Option<Matrix4x4> {
161     let m = m.0;
162     let m_0 = f32x4::from_array(m[0]);
163     let m_1 = f32x4::from_array(m[1]);
164     let m_2 = f32x4::from_array(m[2]);
165     let m_3 = f32x4::from_array(m[3]);
166 
167     const SHUFFLE01: [Which; 4] = [First(0), First(1), Second(0), Second(1)];
168     const SHUFFLE02: [Which; 4] = [First(0), First(2), Second(0), Second(2)];
169     const SHUFFLE13: [Which; 4] = [First(1), First(3), Second(1), Second(3)];
170     const SHUFFLE23: [Which; 4] = [First(2), First(3), Second(2), Second(3)];
171 
172     let tmp = simd_swizzle!(m_0, m_1, SHUFFLE01);
173     let row1 = simd_swizzle!(m_2, m_3, SHUFFLE01);
174 
175     let row0 = simd_swizzle!(tmp, row1, SHUFFLE02);
176     let row1 = simd_swizzle!(row1, tmp, SHUFFLE13);
177 
178     let tmp = simd_swizzle!(m_0, m_1, SHUFFLE23);
179     let row3 = simd_swizzle!(m_2, m_3, SHUFFLE23);
180     let row2 = simd_swizzle!(tmp, row3, SHUFFLE02);
181     let row3 = simd_swizzle!(row3, tmp, SHUFFLE13);
182 
183     let tmp = (row2 * row3).reverse().rotate_lanes_right::<2>();
184     let minor0 = row1 * tmp;
185     let minor1 = row0 * tmp;
186     let tmp = tmp.rotate_lanes_right::<2>();
187     let minor0 = (row1 * tmp) - minor0;
188     let minor1 = (row0 * tmp) - minor1;
189     let minor1 = minor1.rotate_lanes_right::<2>();
190 
191     let tmp = (row1 * row2).reverse().rotate_lanes_right::<2>();
192     let minor0 = (row3 * tmp) + minor0;
193     let minor3 = row0 * tmp;
194     let tmp = tmp.rotate_lanes_right::<2>();
195 
196     let minor0 = minor0 - row3 * tmp;
197     let minor3 = row0 * tmp - minor3;
198     let minor3 = minor3.rotate_lanes_right::<2>();
199 
200     let tmp = (row3 * row1.rotate_lanes_right::<2>())
201         .reverse()
202         .rotate_lanes_right::<2>();
203     let row2 = row2.rotate_lanes_right::<2>();
204     let minor0 = row2 * tmp + minor0;
205     let minor2 = row0 * tmp;
206     let tmp = tmp.rotate_lanes_right::<2>();
207     let minor0 = minor0 - row2 * tmp;
208     let minor2 = row0 * tmp - minor2;
209     let minor2 = minor2.rotate_lanes_right::<2>();
210 
211     let tmp = (row0 * row1).reverse().rotate_lanes_right::<2>();
212     let minor2 = minor2 + row3 * tmp;
213     let minor3 = row2 * tmp - minor3;
214     let tmp = tmp.rotate_lanes_right::<2>();
215     let minor2 = row3 * tmp - minor2;
216     let minor3 = minor3 - row2 * tmp;
217 
218     let tmp = (row0 * row3).reverse().rotate_lanes_right::<2>();
219     let minor1 = minor1 - row2 * tmp;
220     let minor2 = row1 * tmp + minor2;
221     let tmp = tmp.rotate_lanes_right::<2>();
222     let minor1 = row2 * tmp + minor1;
223     let minor2 = minor2 - row1 * tmp;
224 
225     let tmp = (row0 * row2).reverse().rotate_lanes_right::<2>();
226     let minor1 = row3 * tmp + minor1;
227     let minor3 = minor3 - row1 * tmp;
228     let tmp = tmp.rotate_lanes_right::<2>();
229     let minor1 = minor1 - row3 * tmp;
230     let minor3 = row1 * tmp + minor3;
231 
232     let det = row0 * minor0;
233     let det = det.rotate_lanes_right::<2>() + det;
234     let det = det.reverse().rotate_lanes_right::<2>() + det;
235 
236     if det.horizontal_sum() == 0. {
237         return None;
238     }
239     // calculate the reciprocal
240     let tmp = f32x4::splat(1.0) / det;
241     let det = tmp + tmp - det * tmp * tmp;
242 
243     let res0 = minor0 * det;
244     let res1 = minor1 * det;
245     let res2 = minor2 * det;
246     let res3 = minor3 * det;
247 
248     let mut m = m;
249 
250     m[0] = res0.to_array();
251     m[1] = res1.to_array();
252     m[2] = res2.to_array();
253     m[3] = res3.to_array();
254 
255     Some(Matrix4x4(m))
256 }
257 
258 #[cfg(test)]
259 #[rustfmt::skip]
260 mod tests {
261     use super::*;
262 
263     #[test]
test()264     fn test() {
265     let tests: &[(Matrix4x4, Option<Matrix4x4>)] = &[
266         // Identity:
267         (Matrix4x4([
268             [1., 0., 0., 0.],
269             [0., 1., 0., 0.],
270             [0., 0., 1., 0.],
271             [0., 0., 0., 1.],
272          ]),
273          Some(Matrix4x4([
274              [1., 0., 0., 0.],
275              [0., 1., 0., 0.],
276              [0., 0., 1., 0.],
277              [0., 0., 0., 1.],
278          ]))
279         ),
280         // None:
281         (Matrix4x4([
282             [1., 2., 3., 4.],
283             [12., 11., 10., 9.],
284             [5., 6., 7., 8.],
285             [16., 15., 14., 13.],
286         ]),
287          None
288         ),
289         // Other:
290         (Matrix4x4([
291             [1., 1., 1., 0.],
292             [0., 3., 1., 2.],
293             [2., 3., 1., 0.],
294             [1., 0., 2., 1.],
295         ]),
296          Some(Matrix4x4([
297              [-3., -0.5,   1.5,  1.0],
298              [ 1., 0.25, -0.25, -0.5],
299              [ 3., 0.25, -1.25, -0.5],
300              [-3., 0.0,    1.0,  1.0],
301          ]))
302         ),
303 
304 
305     ];
306 
307         for &(input, output) in tests {
308             assert_eq!(scalar_inv4x4(input), output);
309             assert_eq!(simd_inv4x4(input), output);
310         }
311     }
312 }
313 
main()314 fn main() {
315     // Empty main to make cargo happy
316 }
317