1 //! SIMD byte scanning logic.
2 //!
3 //! This module provides functions that allow walking through byteslices, calling
4 //! provided callback functions on special bytes and their indices using SIMD.
5 //! The byteset is defined in `compute_lookup`.
6 //!
7 //! The idea is to load in a chunk of 16 bytes and perform a lookup into a set of
8 //! bytes on all the bytes in this chunk simultaneously. We produce a 16 bit bitmask
9 //! from this and call the callback on every index corresponding to a 1 in this mask
10 //! before moving on to the next chunk. This allows us to move quickly when there
11 //! are no or few matches.
12 //!
13 //! The table lookup is inspired by this [great overview]. However, since all of the
14 //! bytes we're interested in are ASCII, we don't quite need the full generality of
15 //! the universal algorithm and are hence able to skip a few instructions.
16 //!
17 //! [great overview]: http://0x80.pl/articles/simd-byte-lookup.html
18 
19 use crate::parse::{LookupTable, LoopInstruction, Options};
20 use core::arch::x86_64::*;
21 
22 pub(crate) const VECTOR_SIZE: usize = std::mem::size_of::<__m128i>();
23 
24 /// Generates a lookup table containing the bitmaps for our
25 /// special marker bytes. This is effectively a 128 element 2d bitvector,
26 /// that can be indexed by a four bit row index (the lower nibble)
27 /// and a three bit column index (upper nibble).
compute_lookup(options: &Options) -> [u8; 16]28 pub(crate) fn compute_lookup(options: &Options) -> [u8; 16] {
29     let mut lookup = [0u8; 16];
30     let standard_bytes = [
31         b'\n', b'\r', b'*', b'_', b'&', b'\\', b'[', b']', b'<', b'!', b'`',
32     ];
33 
34     for &byte in &standard_bytes {
35         add_lookup_byte(&mut lookup, byte);
36     }
37     if options.contains(Options::ENABLE_TABLES) {
38         add_lookup_byte(&mut lookup, b'|');
39     }
40     if options.contains(Options::ENABLE_STRIKETHROUGH) {
41         add_lookup_byte(&mut lookup, b'~');
42     }
43     if options.contains(Options::ENABLE_SMART_PUNCTUATION) {
44         for &byte in &[b'.', b'-', b'"', b'\''] {
45             add_lookup_byte(&mut lookup, byte);
46         }
47     }
48 
49     lookup
50 }
51 
add_lookup_byte(lookup: &mut [u8; 16], byte: u8)52 fn add_lookup_byte(lookup: &mut [u8; 16], byte: u8) {
53     lookup[(byte & 0x0f) as usize] |= 1 << (byte >> 4);
54 }
55 
56 /// Computes a bit mask for the given byteslice starting from the given index,
57 /// where the 16 least significant bits indicate (by value of 1) whether or not
58 /// there is a special character at that byte position. The least significant bit
59 /// corresponds to `bytes[ix]` and the most significant bit corresponds to
60 /// `bytes[ix + 15]`.
61 /// It is only safe to call this function when `bytes.len() >= ix + VECTOR_SIZE`.
62 #[target_feature(enable = "ssse3")]
63 #[inline]
compute_mask(lut: &[u8; 16], bytes: &[u8], ix: usize) -> i3264 unsafe fn compute_mask(lut: &[u8; 16], bytes: &[u8], ix: usize) -> i32 {
65     debug_assert!(bytes.len() >= ix + VECTOR_SIZE);
66 
67     let bitmap = _mm_loadu_si128(lut.as_ptr() as *const __m128i);
68     // Small lookup table to compute single bit bitshifts
69     // for 16 bytes at once.
70     let bitmask_lookup =
71         _mm_setr_epi8(1, 2, 4, 8, 16, 32, 64, -128, -1, -1, -1, -1, -1, -1, -1, -1);
72 
73     // Load input from memory.
74     let raw_ptr = bytes.as_ptr().add(ix) as *const __m128i;
75     let input = _mm_loadu_si128(raw_ptr);
76     // Compute the bitmap using the bottom nibble as an index
77     // into the lookup table. Note that non-ascii bytes will have
78     // their most significant bit set and will map to lookup[0].
79     let bitset = _mm_shuffle_epi8(bitmap, input);
80     // Compute the high nibbles of the input using a 16-bit rightshift of four
81     // and a mask to prevent most-significant bit issues.
82     let higher_nibbles = _mm_and_si128(_mm_srli_epi16(input, 4), _mm_set1_epi8(0x0f));
83     // Create a bitmask for the bitmap by perform a left shift of the value
84     // of the higher nibble. Bytes with their most significant set are mapped
85     // to -1 (all ones).
86     let bitmask = _mm_shuffle_epi8(bitmask_lookup, higher_nibbles);
87     // Test the bit of the bitmap by AND'ing the bitmap and the mask together.
88     let tmp = _mm_and_si128(bitset, bitmask);
89     // Check whether the result was not null. NEQ is not a SIMD intrinsic,
90     // but comparing to the bitmask is logically equivalent. This also prevents us
91     // from matching any non-ASCII bytes since none of the bitmaps were all ones
92     // (-1).
93     let result = _mm_cmpeq_epi8(tmp, bitmask);
94 
95     // Return the resulting bitmask.
96     _mm_movemask_epi8(result)
97 }
98 
99 /// Calls callback on byte indices and their value.
100 /// Breaks when callback returns LoopInstruction::BreakAtWith(ix, val). And skips the
101 /// number of bytes in callback return value otherwise.
102 /// Returns the final index and a possible break value.
iterate_special_bytes<F, T>( lut: &LookupTable, bytes: &[u8], ix: usize, callback: F, ) -> (usize, Option<T>) where F: FnMut(usize, u8) -> LoopInstruction<Option<T>>,103 pub(crate) fn iterate_special_bytes<F, T>(
104     lut: &LookupTable,
105     bytes: &[u8],
106     ix: usize,
107     callback: F,
108 ) -> (usize, Option<T>)
109 where
110     F: FnMut(usize, u8) -> LoopInstruction<Option<T>>,
111 {
112     if is_x86_feature_detected!("ssse3") && bytes.len() >= VECTOR_SIZE {
113         unsafe { simd_iterate_special_bytes(&lut.simd, bytes, ix, callback) }
114     } else {
115         crate::parse::scalar_iterate_special_bytes(&lut.scalar, bytes, ix, callback)
116     }
117 }
118 
119 /// Calls the callback function for every 1 in the given bitmask with
120 /// the index `offset + ix`, where `ix` is the position of the 1 in the mask.
121 /// Returns `Ok(ix)` to continue from index `ix`, `Err((end_ix, opt_val)` to break with
122 /// final index `end_ix` and optional value `opt_val`.
process_mask<F, T>( mut mask: i32, bytes: &[u8], mut offset: usize, callback: &mut F, ) -> Result<usize, (usize, Option<T>)> where F: FnMut(usize, u8) -> LoopInstruction<Option<T>>,123 unsafe fn process_mask<F, T>(
124     mut mask: i32,
125     bytes: &[u8],
126     mut offset: usize,
127     callback: &mut F,
128 ) -> Result<usize, (usize, Option<T>)>
129 where
130     F: FnMut(usize, u8) -> LoopInstruction<Option<T>>,
131 {
132     while mask != 0 {
133         let mask_ix = mask.trailing_zeros() as usize;
134         offset += mask_ix;
135         match callback(offset, *bytes.get_unchecked(offset)) {
136             LoopInstruction::ContinueAndSkip(skip) => {
137                 offset += skip + 1;
138                 mask >>= skip + 1 + mask_ix;
139             }
140             LoopInstruction::BreakAtWith(ix, val) => return Err((ix, val)),
141         }
142     }
143     Ok(offset)
144 }
145 
146 #[target_feature(enable = "ssse3")]
147 /// Important: only call this function when `bytes.len() >= 16`. Doing
148 /// so otherwise may exhibit undefined behaviour.
simd_iterate_special_bytes<F, T>( lut: &[u8; 16], bytes: &[u8], mut ix: usize, mut callback: F, ) -> (usize, Option<T>) where F: FnMut(usize, u8) -> LoopInstruction<Option<T>>,149 unsafe fn simd_iterate_special_bytes<F, T>(
150     lut: &[u8; 16],
151     bytes: &[u8],
152     mut ix: usize,
153     mut callback: F,
154 ) -> (usize, Option<T>)
155 where
156     F: FnMut(usize, u8) -> LoopInstruction<Option<T>>,
157 {
158     debug_assert!(bytes.len() >= VECTOR_SIZE);
159     let upperbound = bytes.len() - VECTOR_SIZE;
160 
161     while ix < upperbound {
162         let mask = compute_mask(lut, bytes, ix);
163         let block_start = ix;
164         ix = match process_mask(mask, bytes, ix, &mut callback) {
165             Ok(ix) => std::cmp::max(ix, VECTOR_SIZE + block_start),
166             Err((end_ix, val)) => return (end_ix, val),
167         };
168     }
169 
170     if bytes.len() > ix {
171         // shift off the bytes at start we have already scanned
172         let mask = compute_mask(lut, bytes, upperbound) >> ix - upperbound;
173         if let Err((end_ix, val)) = process_mask(mask, bytes, ix, &mut callback) {
174             return (end_ix, val);
175         }
176     }
177 
178     (bytes.len(), None)
179 }
180 
181 #[cfg(test)]
182 mod simd_test {
183     use super::{iterate_special_bytes, LoopInstruction};
184     use crate::Options;
185 
check_expected_indices(bytes: &[u8], expected: &[usize], skip: usize)186     fn check_expected_indices(bytes: &[u8], expected: &[usize], skip: usize) {
187         let mut opts = Options::empty();
188         opts.insert(Options::ENABLE_TABLES);
189         opts.insert(Options::ENABLE_FOOTNOTES);
190         opts.insert(Options::ENABLE_STRIKETHROUGH);
191         opts.insert(Options::ENABLE_TASKLISTS);
192 
193         let lut = crate::parse::create_lut(&opts);
194         let mut indices = vec![];
195 
196         iterate_special_bytes::<_, i32>(&lut, bytes, 0, |ix, _byte_ty| {
197             indices.push(ix);
198             LoopInstruction::ContinueAndSkip(skip)
199         });
200 
201         assert_eq!(&indices[..], expected);
202     }
203 
204     #[test]
simple_no_match()205     fn simple_no_match() {
206         check_expected_indices("abcdef0123456789".as_bytes(), &[], 0);
207     }
208 
209     #[test]
simple_match()210     fn simple_match() {
211         check_expected_indices("*bcd&f0123456789".as_bytes(), &[0, 4], 0);
212     }
213 
214     #[test]
single_open_fish()215     fn single_open_fish() {
216         check_expected_indices("<".as_bytes(), &[0], 0);
217     }
218 
219     #[test]
long_match()220     fn long_match() {
221         check_expected_indices("0123456789abcde~*bcd&f0".as_bytes(), &[15, 16, 20], 0);
222     }
223 
224     #[test]
border_skip()225     fn border_skip() {
226         check_expected_indices("0123456789abcde~~~~d&f0".as_bytes(), &[15, 20], 3);
227     }
228 
229     #[test]
exhaustive_search()230     fn exhaustive_search() {
231         let chars = [
232             b'\n', b'\r', b'*', b'_', b'~', b'|', b'&', b'\\', b'[', b']', b'<', b'!', b'`',
233         ];
234 
235         for &c in &chars {
236             for i in 0u8..=255 {
237                 if !chars.contains(&i) {
238                     // full match
239                     let mut buf = [i; 18];
240                     buf[3] = c;
241                     buf[6] = c;
242 
243                     check_expected_indices(&buf[..], &[3, 6], 0);
244                 }
245             }
246         }
247     }
248 }
249