1 #![allow(dead_code, unused_imports)]
2 
3 macro_rules! convert_fn {
4     (fn $name:ident($var:ident : $vartype:ty) -> $restype:ty {
5             if feature("f16c") { $f16c:expr }
6             else { $fallback:expr }}) => {
7         #[inline]
8         pub(crate) fn $name($var: $vartype) -> $restype {
9             // Use CPU feature detection if using std
10             #[cfg(all(
11                 feature = "use-intrinsics",
12                 feature = "std",
13                 any(target_arch = "x86", target_arch = "x86_64"),
14                 not(target_feature = "f16c")
15             ))]
16             {
17                 if is_x86_feature_detected!("f16c") {
18                     $f16c
19                 } else {
20                     $fallback
21                 }
22             }
23             // Use intrinsics directly when a compile target or using no_std
24             #[cfg(all(
25                 feature = "use-intrinsics",
26                 any(target_arch = "x86", target_arch = "x86_64"),
27                 target_feature = "f16c"
28             ))]
29             {
30                 $f16c
31             }
32             // Fallback to software
33             #[cfg(any(
34                 not(feature = "use-intrinsics"),
35                 not(any(target_arch = "x86", target_arch = "x86_64")),
36                 all(not(feature = "std"), not(target_feature = "f16c"))
37             ))]
38             {
39                 $fallback
40             }
41         }
42     };
43 }
44 
45 convert_fn! {
46     fn f32_to_f16(f: f32) -> u16 {
47         if feature("f16c") {
48             unsafe { x86::f32_to_f16_x86_f16c(f) }
49         } else {
50             f32_to_f16_fallback(f)
51         }
52     }
53 }
54 
55 convert_fn! {
56     fn f64_to_f16(f: f64) -> u16 {
57         if feature("f16c") {
58             unsafe { x86::f32_to_f16_x86_f16c(f as f32) }
59         } else {
60             f64_to_f16_fallback(f)
61         }
62     }
63 }
64 
65 convert_fn! {
66     fn f16_to_f32(i: u16) -> f32 {
67         if feature("f16c") {
68             unsafe { x86::f16_to_f32_x86_f16c(i) }
69         } else {
70             f16_to_f32_fallback(i)
71         }
72     }
73 }
74 
75 convert_fn! {
76     fn f16_to_f64(i: u16) -> f64 {
77         if feature("f16c") {
78             unsafe { x86::f16_to_f32_x86_f16c(i) as f64 }
79         } else {
80             f16_to_f64_fallback(i)
81         }
82     }
83 }
84 
85 // TODO: While SIMD versions are faster, further improvements can be made by doing runtime feature
86 // detection once at beginning of convert slice method, rather than per chunk
87 
88 convert_fn! {
89     fn f32x4_to_f16x4(f: &[f32]) -> [u16; 4] {
90         if feature("f16c") {
91             unsafe { x86::f32x4_to_f16x4_x86_f16c(f) }
92         } else {
93             f32x4_to_f16x4_fallback(f)
94         }
95     }
96 }
97 
98 convert_fn! {
99     fn f16x4_to_f32x4(i: &[u16]) -> [f32; 4] {
100         if feature("f16c") {
101             unsafe { x86::f16x4_to_f32x4_x86_f16c(i) }
102         } else {
103             f16x4_to_f32x4_fallback(i)
104         }
105     }
106 }
107 
108 convert_fn! {
109     fn f64x4_to_f16x4(f: &[f64]) -> [u16; 4] {
110         if feature("f16c") {
111             unsafe { x86::f64x4_to_f16x4_x86_f16c(f) }
112         } else {
113             f64x4_to_f16x4_fallback(f)
114         }
115     }
116 }
117 
118 convert_fn! {
119     fn f16x4_to_f64x4(i: &[u16]) -> [f64; 4] {
120         if feature("f16c") {
121             unsafe { x86::f16x4_to_f64x4_x86_f16c(i) }
122         } else {
123             f16x4_to_f64x4_fallback(i)
124         }
125     }
126 }
127 
128 /////////////// Fallbacks ////////////////
129 
130 // In the below functions, round to nearest, with ties to even.
131 // Let us call the most significant bit that will be shifted out the round_bit.
132 //
133 // Round up if either
134 //  a) Removed part > tie.
135 //     (mantissa & round_bit) != 0 && (mantissa & (round_bit - 1)) != 0
136 //  b) Removed part == tie, and retained part is odd.
137 //     (mantissa & round_bit) != 0 && (mantissa & (2 * round_bit)) != 0
138 // (If removed part == tie and retained part is even, do not round up.)
139 // These two conditions can be combined into one:
140 //     (mantissa & round_bit) != 0 && (mantissa & ((round_bit - 1) | (2 * round_bit))) != 0
141 // which can be simplified into
142 //     (mantissa & round_bit) != 0 && (mantissa & (3 * round_bit - 1)) != 0
143 
f32_to_f16_fallback(value: f32) -> u16144 fn f32_to_f16_fallback(value: f32) -> u16 {
145     // Convert to raw bytes
146     let x = value.to_bits();
147 
148     // Extract IEEE754 components
149     let sign = x & 0x8000_0000u32;
150     let exp = x & 0x7F80_0000u32;
151     let man = x & 0x007F_FFFFu32;
152 
153     // Check for all exponent bits being set, which is Infinity or NaN
154     if exp == 0x7F80_0000u32 {
155         // Set mantissa MSB for NaN (and also keep shifted mantissa bits)
156         let nan_bit = if man == 0 { 0 } else { 0x0200u32 };
157         return ((sign >> 16) | 0x7C00u32 | nan_bit | (man >> 13)) as u16;
158     }
159 
160     // The number is normalized, start assembling half precision version
161     let half_sign = sign >> 16;
162     // Unbias the exponent, then bias for half precision
163     let unbiased_exp = ((exp >> 23) as i32) - 127;
164     let half_exp = unbiased_exp + 15;
165 
166     // Check for exponent overflow, return +infinity
167     if half_exp >= 0x1F {
168         return (half_sign | 0x7C00u32) as u16;
169     }
170 
171     // Check for underflow
172     if half_exp <= 0 {
173         // Check mantissa for what we can do
174         if 14 - half_exp > 24 {
175             // No rounding possibility, so this is a full underflow, return signed zero
176             return half_sign as u16;
177         }
178         // Don't forget about hidden leading mantissa bit when assembling mantissa
179         let man = man | 0x0080_0000u32;
180         let mut half_man = man >> (14 - half_exp);
181         // Check for rounding (see comment above functions)
182         let round_bit = 1 << (13 - half_exp);
183         if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 {
184             half_man += 1;
185         }
186         // No exponent for subnormals
187         return (half_sign | half_man) as u16;
188     }
189 
190     // Rebias the exponent
191     let half_exp = (half_exp as u32) << 10;
192     let half_man = man >> 13;
193     // Check for rounding (see comment above functions)
194     let round_bit = 0x0000_1000u32;
195     if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 {
196         // Round it
197         ((half_sign | half_exp | half_man) + 1) as u16
198     } else {
199         (half_sign | half_exp | half_man) as u16
200     }
201 }
202 
f64_to_f16_fallback(value: f64) -> u16203 fn f64_to_f16_fallback(value: f64) -> u16 {
204     // Convert to raw bytes, truncating the last 32-bits of mantissa; that precision will always
205     // be lost on half-precision.
206     let val = value.to_bits();
207     let x = (val >> 32) as u32;
208 
209     // Extract IEEE754 components
210     let sign = x & 0x8000_0000u32;
211     let exp = x & 0x7FF0_0000u32;
212     let man = x & 0x000F_FFFFu32;
213 
214     // Check for all exponent bits being set, which is Infinity or NaN
215     if exp == 0x7FF0_0000u32 {
216         // Set mantissa MSB for NaN (and also keep shifted mantissa bits).
217         // We also have to check the last 32 bits.
218         let nan_bit = if man == 0 && (val as u32 == 0) {
219             0
220         } else {
221             0x0200u32
222         };
223         return ((sign >> 16) | 0x7C00u32 | nan_bit | (man >> 10)) as u16;
224     }
225 
226     // The number is normalized, start assembling half precision version
227     let half_sign = sign >> 16;
228     // Unbias the exponent, then bias for half precision
229     let unbiased_exp = ((exp >> 20) as i64) - 1023;
230     let half_exp = unbiased_exp + 15;
231 
232     // Check for exponent overflow, return +infinity
233     if half_exp >= 0x1F {
234         return (half_sign | 0x7C00u32) as u16;
235     }
236 
237     // Check for underflow
238     if half_exp <= 0 {
239         // Check mantissa for what we can do
240         if 10 - half_exp > 21 {
241             // No rounding possibility, so this is a full underflow, return signed zero
242             return half_sign as u16;
243         }
244         // Don't forget about hidden leading mantissa bit when assembling mantissa
245         let man = man | 0x0010_0000u32;
246         let mut half_man = man >> (11 - half_exp);
247         // Check for rounding (see comment above functions)
248         let round_bit = 1 << (10 - half_exp);
249         if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 {
250             half_man += 1;
251         }
252         // No exponent for subnormals
253         return (half_sign | half_man) as u16;
254     }
255 
256     // Rebias the exponent
257     let half_exp = (half_exp as u32) << 10;
258     let half_man = man >> 10;
259     // Check for rounding (see comment above functions)
260     let round_bit = 0x0000_0200u32;
261     if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 {
262         // Round it
263         ((half_sign | half_exp | half_man) + 1) as u16
264     } else {
265         (half_sign | half_exp | half_man) as u16
266     }
267 }
268 
f16_to_f32_fallback(i: u16) -> f32269 fn f16_to_f32_fallback(i: u16) -> f32 {
270     // Check for signed zero
271     if i & 0x7FFFu16 == 0 {
272         return f32::from_bits((i as u32) << 16);
273     }
274 
275     let half_sign = (i & 0x8000u16) as u32;
276     let half_exp = (i & 0x7C00u16) as u32;
277     let half_man = (i & 0x03FFu16) as u32;
278 
279     // Check for an infinity or NaN when all exponent bits set
280     if half_exp == 0x7C00u32 {
281         // Check for signed infinity if mantissa is zero
282         if half_man == 0 {
283             return f32::from_bits((half_sign << 16) | 0x7F80_0000u32);
284         } else {
285             // NaN, keep current mantissa but also set most significiant mantissa bit
286             return f32::from_bits((half_sign << 16) | 0x7FC0_0000u32 | (half_man << 13));
287         }
288     }
289 
290     // Calculate single-precision components with adjusted exponent
291     let sign = half_sign << 16;
292     // Unbias exponent
293     let unbiased_exp = ((half_exp as i32) >> 10) - 15;
294 
295     // Check for subnormals, which will be normalized by adjusting exponent
296     if half_exp == 0 {
297         // Calculate how much to adjust the exponent by
298         let e = (half_man as u16).leading_zeros() - 6;
299 
300         // Rebias and adjust exponent
301         let exp = (127 - 15 - e) << 23;
302         let man = (half_man << (14 + e)) & 0x7F_FF_FFu32;
303         return f32::from_bits(sign | exp | man);
304     }
305 
306     // Rebias exponent for a normalized normal
307     let exp = ((unbiased_exp + 127) as u32) << 23;
308     let man = (half_man & 0x03FFu32) << 13;
309     f32::from_bits(sign | exp | man)
310 }
311 
f16_to_f64_fallback(i: u16) -> f64312 fn f16_to_f64_fallback(i: u16) -> f64 {
313     // Check for signed zero
314     if i & 0x7FFFu16 == 0 {
315         return f64::from_bits((i as u64) << 48);
316     }
317 
318     let half_sign = (i & 0x8000u16) as u64;
319     let half_exp = (i & 0x7C00u16) as u64;
320     let half_man = (i & 0x03FFu16) as u64;
321 
322     // Check for an infinity or NaN when all exponent bits set
323     if half_exp == 0x7C00u64 {
324         // Check for signed infinity if mantissa is zero
325         if half_man == 0 {
326             return f64::from_bits((half_sign << 48) | 0x7FF0_0000_0000_0000u64);
327         } else {
328             // NaN, keep current mantissa but also set most significiant mantissa bit
329             return f64::from_bits((half_sign << 48) | 0x7FF8_0000_0000_0000u64 | (half_man << 42));
330         }
331     }
332 
333     // Calculate double-precision components with adjusted exponent
334     let sign = half_sign << 48;
335     // Unbias exponent
336     let unbiased_exp = ((half_exp as i64) >> 10) - 15;
337 
338     // Check for subnormals, which will be normalized by adjusting exponent
339     if half_exp == 0 {
340         // Calculate how much to adjust the exponent by
341         let e = (half_man as u16).leading_zeros() - 6;
342 
343         // Rebias and adjust exponent
344         let exp = ((1023 - 15 - e) as u64) << 52;
345         let man = (half_man << (43 + e)) & 0xF_FFFF_FFFF_FFFFu64;
346         return f64::from_bits(sign | exp | man);
347     }
348 
349     // Rebias exponent for a normalized normal
350     let exp = ((unbiased_exp + 1023) as u64) << 52;
351     let man = (half_man & 0x03FFu64) << 42;
352     f64::from_bits(sign | exp | man)
353 }
354 
355 #[inline]
f16x4_to_f32x4_fallback(v: &[u16]) -> [f32; 4]356 fn f16x4_to_f32x4_fallback(v: &[u16]) -> [f32; 4] {
357     debug_assert!(v.len() >= 4);
358 
359     [
360         f16_to_f32_fallback(v[0]),
361         f16_to_f32_fallback(v[1]),
362         f16_to_f32_fallback(v[2]),
363         f16_to_f32_fallback(v[3]),
364     ]
365 }
366 
367 #[inline]
f32x4_to_f16x4_fallback(v: &[f32]) -> [u16; 4]368 fn f32x4_to_f16x4_fallback(v: &[f32]) -> [u16; 4] {
369     debug_assert!(v.len() >= 4);
370 
371     [
372         f32_to_f16_fallback(v[0]),
373         f32_to_f16_fallback(v[1]),
374         f32_to_f16_fallback(v[2]),
375         f32_to_f16_fallback(v[3]),
376     ]
377 }
378 
379 #[inline]
f16x4_to_f64x4_fallback(v: &[u16]) -> [f64; 4]380 fn f16x4_to_f64x4_fallback(v: &[u16]) -> [f64; 4] {
381     debug_assert!(v.len() >= 4);
382 
383     [
384         f16_to_f64_fallback(v[0]),
385         f16_to_f64_fallback(v[1]),
386         f16_to_f64_fallback(v[2]),
387         f16_to_f64_fallback(v[3]),
388     ]
389 }
390 
391 #[inline]
f64x4_to_f16x4_fallback(v: &[f64]) -> [u16; 4]392 fn f64x4_to_f16x4_fallback(v: &[f64]) -> [u16; 4] {
393     debug_assert!(v.len() >= 4);
394 
395     [
396         f64_to_f16_fallback(v[0]),
397         f64_to_f16_fallback(v[1]),
398         f64_to_f16_fallback(v[2]),
399         f64_to_f16_fallback(v[3]),
400     ]
401 }
402 
403 /////////////// x86/x86_64 f16c ////////////////
404 #[cfg(all(
405     feature = "use-intrinsics",
406     any(target_arch = "x86", target_arch = "x86_64")
407 ))]
408 mod x86 {
409     use core::{mem::MaybeUninit, ptr};
410 
411     #[cfg(target_arch = "x86")]
412     use core::arch::x86::{__m128, __m128i, _mm_cvtph_ps, _mm_cvtps_ph, _MM_FROUND_TO_NEAREST_INT};
413     #[cfg(target_arch = "x86_64")]
414     use core::arch::x86_64::{
415         __m128, __m128i, _mm_cvtph_ps, _mm_cvtps_ph, _MM_FROUND_TO_NEAREST_INT,
416     };
417 
418     #[target_feature(enable = "f16c")]
419     #[inline]
f16_to_f32_x86_f16c(i: u16) -> f32420     pub(super) unsafe fn f16_to_f32_x86_f16c(i: u16) -> f32 {
421         let mut vec = MaybeUninit::<__m128i>::zeroed();
422         vec.as_mut_ptr().cast::<u16>().write(i);
423         let retval = _mm_cvtph_ps(vec.assume_init());
424         *(&retval as *const __m128).cast()
425     }
426 
427     #[target_feature(enable = "f16c")]
428     #[inline]
f32_to_f16_x86_f16c(f: f32) -> u16429     pub(super) unsafe fn f32_to_f16_x86_f16c(f: f32) -> u16 {
430         let mut vec = MaybeUninit::<__m128>::zeroed();
431         vec.as_mut_ptr().cast::<f32>().write(f);
432         let retval = _mm_cvtps_ph(vec.assume_init(), _MM_FROUND_TO_NEAREST_INT);
433         *(&retval as *const __m128i).cast()
434     }
435 
436     #[target_feature(enable = "f16c")]
437     #[inline]
f16x4_to_f32x4_x86_f16c(v: &[u16]) -> [f32; 4]438     pub(super) unsafe fn f16x4_to_f32x4_x86_f16c(v: &[u16]) -> [f32; 4] {
439         debug_assert!(v.len() >= 4);
440 
441         let mut vec = MaybeUninit::<__m128i>::zeroed();
442         ptr::copy_nonoverlapping(v.as_ptr(), vec.as_mut_ptr().cast(), 4);
443         let retval = _mm_cvtph_ps(vec.assume_init());
444         *(&retval as *const __m128).cast()
445     }
446 
447     #[target_feature(enable = "f16c")]
448     #[inline]
f32x4_to_f16x4_x86_f16c(v: &[f32]) -> [u16; 4]449     pub(super) unsafe fn f32x4_to_f16x4_x86_f16c(v: &[f32]) -> [u16; 4] {
450         debug_assert!(v.len() >= 4);
451 
452         let mut vec = MaybeUninit::<__m128>::uninit();
453         ptr::copy_nonoverlapping(v.as_ptr(), vec.as_mut_ptr().cast(), 4);
454         let retval = _mm_cvtps_ph(vec.assume_init(), _MM_FROUND_TO_NEAREST_INT);
455         *(&retval as *const __m128i).cast()
456     }
457 
458     #[target_feature(enable = "f16c")]
459     #[inline]
f16x4_to_f64x4_x86_f16c(v: &[u16]) -> [f64; 4]460     pub(super) unsafe fn f16x4_to_f64x4_x86_f16c(v: &[u16]) -> [f64; 4] {
461         debug_assert!(v.len() >= 4);
462 
463         let mut vec = MaybeUninit::<__m128i>::zeroed();
464         ptr::copy_nonoverlapping(v.as_ptr(), vec.as_mut_ptr().cast(), 4);
465         let retval = _mm_cvtph_ps(vec.assume_init());
466         let array = *(&retval as *const __m128).cast::<[f32; 4]>();
467         // Let compiler vectorize this regular cast for now.
468         // TODO: investigate auto-detecting sse2/avx convert features
469         [
470             array[0] as f64,
471             array[1] as f64,
472             array[2] as f64,
473             array[3] as f64,
474         ]
475     }
476 
477     #[target_feature(enable = "f16c")]
478     #[inline]
f64x4_to_f16x4_x86_f16c(v: &[f64]) -> [u16; 4]479     pub(super) unsafe fn f64x4_to_f16x4_x86_f16c(v: &[f64]) -> [u16; 4] {
480         debug_assert!(v.len() >= 4);
481 
482         // Let compiler vectorize this regular cast for now.
483         // TODO: investigate auto-detecting sse2/avx convert features
484         let v = [v[0] as f32, v[1] as f32, v[2] as f32, v[3] as f32];
485 
486         let mut vec = MaybeUninit::<__m128>::uninit();
487         ptr::copy_nonoverlapping(v.as_ptr(), vec.as_mut_ptr().cast(), 4);
488         let retval = _mm_cvtps_ph(vec.assume_init(), _MM_FROUND_TO_NEAREST_INT);
489         *(&retval as *const __m128i).cast()
490     }
491 }
492