1 //! 4x4 matrix inverse
2 // Code ported from the `packed_simd` crate
3 // Run this code with `cargo test --example matrix_inversion`
4 #![feature(array_chunks, portable_simd)]
5 use core_simd::simd::*;
6 use Which::*;
7
8 // Gotta define our own 4x4 matrix since Rust doesn't ship multidim arrays yet :^)
9 #[derive(Copy, Clone, Debug, PartialEq, PartialOrd)]
10 pub struct Matrix4x4([[f32; 4]; 4]);
11
12 #[allow(clippy::too_many_lines)]
scalar_inv4x4(m: Matrix4x4) -> Option<Matrix4x4>13 pub fn scalar_inv4x4(m: Matrix4x4) -> Option<Matrix4x4> {
14 let m = m.0;
15
16 #[rustfmt::skip]
17 let mut inv = [
18 // row 0:
19 [
20 // 0,0:
21 m[1][1] * m[2][2] * m[3][3] -
22 m[1][1] * m[2][3] * m[3][2] -
23 m[2][1] * m[1][2] * m[3][3] +
24 m[2][1] * m[1][3] * m[3][2] +
25 m[3][1] * m[1][2] * m[2][3] -
26 m[3][1] * m[1][3] * m[2][2],
27 // 0,1:
28 -m[0][1] * m[2][2] * m[3][3] +
29 m[0][1] * m[2][3] * m[3][2] +
30 m[2][1] * m[0][2] * m[3][3] -
31 m[2][1] * m[0][3] * m[3][2] -
32 m[3][1] * m[0][2] * m[2][3] +
33 m[3][1] * m[0][3] * m[2][2],
34 // 0,2:
35 m[0][1] * m[1][2] * m[3][3] -
36 m[0][1] * m[1][3] * m[3][2] -
37 m[1][1] * m[0][2] * m[3][3] +
38 m[1][1] * m[0][3] * m[3][2] +
39 m[3][1] * m[0][2] * m[1][3] -
40 m[3][1] * m[0][3] * m[1][2],
41 // 0,3:
42 -m[0][1] * m[1][2] * m[2][3] +
43 m[0][1] * m[1][3] * m[2][2] +
44 m[1][1] * m[0][2] * m[2][3] -
45 m[1][1] * m[0][3] * m[2][2] -
46 m[2][1] * m[0][2] * m[1][3] +
47 m[2][1] * m[0][3] * m[1][2],
48 ],
49 // row 1
50 [
51 // 1,0:
52 -m[1][0] * m[2][2] * m[3][3] +
53 m[1][0] * m[2][3] * m[3][2] +
54 m[2][0] * m[1][2] * m[3][3] -
55 m[2][0] * m[1][3] * m[3][2] -
56 m[3][0] * m[1][2] * m[2][3] +
57 m[3][0] * m[1][3] * m[2][2],
58 // 1,1:
59 m[0][0] * m[2][2] * m[3][3] -
60 m[0][0] * m[2][3] * m[3][2] -
61 m[2][0] * m[0][2] * m[3][3] +
62 m[2][0] * m[0][3] * m[3][2] +
63 m[3][0] * m[0][2] * m[2][3] -
64 m[3][0] * m[0][3] * m[2][2],
65 // 1,2:
66 -m[0][0] * m[1][2] * m[3][3] +
67 m[0][0] * m[1][3] * m[3][2] +
68 m[1][0] * m[0][2] * m[3][3] -
69 m[1][0] * m[0][3] * m[3][2] -
70 m[3][0] * m[0][2] * m[1][3] +
71 m[3][0] * m[0][3] * m[1][2],
72 // 1,3:
73 m[0][0] * m[1][2] * m[2][3] -
74 m[0][0] * m[1][3] * m[2][2] -
75 m[1][0] * m[0][2] * m[2][3] +
76 m[1][0] * m[0][3] * m[2][2] +
77 m[2][0] * m[0][2] * m[1][3] -
78 m[2][0] * m[0][3] * m[1][2],
79 ],
80 // row 2
81 [
82 // 2,0:
83 m[1][0] * m[2][1] * m[3][3] -
84 m[1][0] * m[2][3] * m[3][1] -
85 m[2][0] * m[1][1] * m[3][3] +
86 m[2][0] * m[1][3] * m[3][1] +
87 m[3][0] * m[1][1] * m[2][3] -
88 m[3][0] * m[1][3] * m[2][1],
89 // 2,1:
90 -m[0][0] * m[2][1] * m[3][3] +
91 m[0][0] * m[2][3] * m[3][1] +
92 m[2][0] * m[0][1] * m[3][3] -
93 m[2][0] * m[0][3] * m[3][1] -
94 m[3][0] * m[0][1] * m[2][3] +
95 m[3][0] * m[0][3] * m[2][1],
96 // 2,2:
97 m[0][0] * m[1][1] * m[3][3] -
98 m[0][0] * m[1][3] * m[3][1] -
99 m[1][0] * m[0][1] * m[3][3] +
100 m[1][0] * m[0][3] * m[3][1] +
101 m[3][0] * m[0][1] * m[1][3] -
102 m[3][0] * m[0][3] * m[1][1],
103 // 2,3:
104 -m[0][0] * m[1][1] * m[2][3] +
105 m[0][0] * m[1][3] * m[2][1] +
106 m[1][0] * m[0][1] * m[2][3] -
107 m[1][0] * m[0][3] * m[2][1] -
108 m[2][0] * m[0][1] * m[1][3] +
109 m[2][0] * m[0][3] * m[1][1],
110 ],
111 // row 3
112 [
113 // 3,0:
114 -m[1][0] * m[2][1] * m[3][2] +
115 m[1][0] * m[2][2] * m[3][1] +
116 m[2][0] * m[1][1] * m[3][2] -
117 m[2][0] * m[1][2] * m[3][1] -
118 m[3][0] * m[1][1] * m[2][2] +
119 m[3][0] * m[1][2] * m[2][1],
120 // 3,1:
121 m[0][0] * m[2][1] * m[3][2] -
122 m[0][0] * m[2][2] * m[3][1] -
123 m[2][0] * m[0][1] * m[3][2] +
124 m[2][0] * m[0][2] * m[3][1] +
125 m[3][0] * m[0][1] * m[2][2] -
126 m[3][0] * m[0][2] * m[2][1],
127 // 3,2:
128 -m[0][0] * m[1][1] * m[3][2] +
129 m[0][0] * m[1][2] * m[3][1] +
130 m[1][0] * m[0][1] * m[3][2] -
131 m[1][0] * m[0][2] * m[3][1] -
132 m[3][0] * m[0][1] * m[1][2] +
133 m[3][0] * m[0][2] * m[1][1],
134 // 3,3:
135 m[0][0] * m[1][1] * m[2][2] -
136 m[0][0] * m[1][2] * m[2][1] -
137 m[1][0] * m[0][1] * m[2][2] +
138 m[1][0] * m[0][2] * m[2][1] +
139 m[2][0] * m[0][1] * m[1][2] -
140 m[2][0] * m[0][2] * m[1][1],
141 ],
142 ];
143
144 let det = m[0][0] * inv[0][0] + m[0][1] * inv[1][0] + m[0][2] * inv[2][0] + m[0][3] * inv[3][0];
145 if det == 0. {
146 return None;
147 }
148
149 let det_inv = 1. / det;
150
151 for row in &mut inv {
152 for elem in row.iter_mut() {
153 *elem *= det_inv;
154 }
155 }
156
157 Some(Matrix4x4(inv))
158 }
159
simd_inv4x4(m: Matrix4x4) -> Option<Matrix4x4>160 pub fn simd_inv4x4(m: Matrix4x4) -> Option<Matrix4x4> {
161 let m = m.0;
162 let m_0 = f32x4::from_array(m[0]);
163 let m_1 = f32x4::from_array(m[1]);
164 let m_2 = f32x4::from_array(m[2]);
165 let m_3 = f32x4::from_array(m[3]);
166
167 const SHUFFLE01: [Which; 4] = [First(0), First(1), Second(0), Second(1)];
168 const SHUFFLE02: [Which; 4] = [First(0), First(2), Second(0), Second(2)];
169 const SHUFFLE13: [Which; 4] = [First(1), First(3), Second(1), Second(3)];
170 const SHUFFLE23: [Which; 4] = [First(2), First(3), Second(2), Second(3)];
171
172 let tmp = simd_swizzle!(m_0, m_1, SHUFFLE01);
173 let row1 = simd_swizzle!(m_2, m_3, SHUFFLE01);
174
175 let row0 = simd_swizzle!(tmp, row1, SHUFFLE02);
176 let row1 = simd_swizzle!(row1, tmp, SHUFFLE13);
177
178 let tmp = simd_swizzle!(m_0, m_1, SHUFFLE23);
179 let row3 = simd_swizzle!(m_2, m_3, SHUFFLE23);
180 let row2 = simd_swizzle!(tmp, row3, SHUFFLE02);
181 let row3 = simd_swizzle!(row3, tmp, SHUFFLE13);
182
183 let tmp = (row2 * row3).reverse().rotate_lanes_right::<2>();
184 let minor0 = row1 * tmp;
185 let minor1 = row0 * tmp;
186 let tmp = tmp.rotate_lanes_right::<2>();
187 let minor0 = (row1 * tmp) - minor0;
188 let minor1 = (row0 * tmp) - minor1;
189 let minor1 = minor1.rotate_lanes_right::<2>();
190
191 let tmp = (row1 * row2).reverse().rotate_lanes_right::<2>();
192 let minor0 = (row3 * tmp) + minor0;
193 let minor3 = row0 * tmp;
194 let tmp = tmp.rotate_lanes_right::<2>();
195
196 let minor0 = minor0 - row3 * tmp;
197 let minor3 = row0 * tmp - minor3;
198 let minor3 = minor3.rotate_lanes_right::<2>();
199
200 let tmp = (row3 * row1.rotate_lanes_right::<2>())
201 .reverse()
202 .rotate_lanes_right::<2>();
203 let row2 = row2.rotate_lanes_right::<2>();
204 let minor0 = row2 * tmp + minor0;
205 let minor2 = row0 * tmp;
206 let tmp = tmp.rotate_lanes_right::<2>();
207 let minor0 = minor0 - row2 * tmp;
208 let minor2 = row0 * tmp - minor2;
209 let minor2 = minor2.rotate_lanes_right::<2>();
210
211 let tmp = (row0 * row1).reverse().rotate_lanes_right::<2>();
212 let minor2 = minor2 + row3 * tmp;
213 let minor3 = row2 * tmp - minor3;
214 let tmp = tmp.rotate_lanes_right::<2>();
215 let minor2 = row3 * tmp - minor2;
216 let minor3 = minor3 - row2 * tmp;
217
218 let tmp = (row0 * row3).reverse().rotate_lanes_right::<2>();
219 let minor1 = minor1 - row2 * tmp;
220 let minor2 = row1 * tmp + minor2;
221 let tmp = tmp.rotate_lanes_right::<2>();
222 let minor1 = row2 * tmp + minor1;
223 let minor2 = minor2 - row1 * tmp;
224
225 let tmp = (row0 * row2).reverse().rotate_lanes_right::<2>();
226 let minor1 = row3 * tmp + minor1;
227 let minor3 = minor3 - row1 * tmp;
228 let tmp = tmp.rotate_lanes_right::<2>();
229 let minor1 = minor1 - row3 * tmp;
230 let minor3 = row1 * tmp + minor3;
231
232 let det = row0 * minor0;
233 let det = det.rotate_lanes_right::<2>() + det;
234 let det = det.reverse().rotate_lanes_right::<2>() + det;
235
236 if det.horizontal_sum() == 0. {
237 return None;
238 }
239 // calculate the reciprocal
240 let tmp = f32x4::splat(1.0) / det;
241 let det = tmp + tmp - det * tmp * tmp;
242
243 let res0 = minor0 * det;
244 let res1 = minor1 * det;
245 let res2 = minor2 * det;
246 let res3 = minor3 * det;
247
248 let mut m = m;
249
250 m[0] = res0.to_array();
251 m[1] = res1.to_array();
252 m[2] = res2.to_array();
253 m[3] = res3.to_array();
254
255 Some(Matrix4x4(m))
256 }
257
258 #[cfg(test)]
259 #[rustfmt::skip]
260 mod tests {
261 use super::*;
262
263 #[test]
test()264 fn test() {
265 let tests: &[(Matrix4x4, Option<Matrix4x4>)] = &[
266 // Identity:
267 (Matrix4x4([
268 [1., 0., 0., 0.],
269 [0., 1., 0., 0.],
270 [0., 0., 1., 0.],
271 [0., 0., 0., 1.],
272 ]),
273 Some(Matrix4x4([
274 [1., 0., 0., 0.],
275 [0., 1., 0., 0.],
276 [0., 0., 1., 0.],
277 [0., 0., 0., 1.],
278 ]))
279 ),
280 // None:
281 (Matrix4x4([
282 [1., 2., 3., 4.],
283 [12., 11., 10., 9.],
284 [5., 6., 7., 8.],
285 [16., 15., 14., 13.],
286 ]),
287 None
288 ),
289 // Other:
290 (Matrix4x4([
291 [1., 1., 1., 0.],
292 [0., 3., 1., 2.],
293 [2., 3., 1., 0.],
294 [1., 0., 2., 1.],
295 ]),
296 Some(Matrix4x4([
297 [-3., -0.5, 1.5, 1.0],
298 [ 1., 0.25, -0.25, -0.5],
299 [ 3., 0.25, -1.25, -0.5],
300 [-3., 0.0, 1.0, 1.0],
301 ]))
302 ),
303
304
305 ];
306
307 for &(input, output) in tests {
308 assert_eq!(scalar_inv4x4(input), output);
309 assert_eq!(simd_inv4x4(input), output);
310 }
311 }
312 }
313
main()314 fn main() {
315 // Empty main to make cargo happy
316 }
317