1 // This is adapted from `fallback.rs` from rust-memchr. It's modified to return
2 // the 'inverse' query of memchr, e.g. finding the first byte not in the provided
3 // set. This is simple for the 1-byte case.
4 
5 use core::cmp;
6 use core::usize;
7 
8 #[cfg(target_pointer_width = "32")]
9 const USIZE_BYTES: usize = 4;
10 
11 #[cfg(target_pointer_width = "64")]
12 const USIZE_BYTES: usize = 8;
13 
14 // The number of bytes to loop at in one iteration of memchr/memrchr.
15 const LOOP_SIZE: usize = 2 * USIZE_BYTES;
16 
17 /// Repeat the given byte into a word size number. That is, every 8 bits
18 /// is equivalent to the given byte. For example, if `b` is `\x4E` or
19 /// `01001110` in binary, then the returned value on a 32-bit system would be:
20 /// `01001110_01001110_01001110_01001110`.
21 #[inline(always)]
repeat_byte(b: u8) -> usize22 fn repeat_byte(b: u8) -> usize {
23     (b as usize) * (usize::MAX / 255)
24 }
25 
inv_memchr(n1: u8, haystack: &[u8]) -> Option<usize>26 pub fn inv_memchr(n1: u8, haystack: &[u8]) -> Option<usize> {
27     let vn1 = repeat_byte(n1);
28     let confirm = |byte| byte != n1;
29     let loop_size = cmp::min(LOOP_SIZE, haystack.len());
30     let align = USIZE_BYTES - 1;
31     let start_ptr = haystack.as_ptr();
32     let end_ptr = haystack[haystack.len()..].as_ptr();
33     let mut ptr = start_ptr;
34 
35     unsafe {
36         if haystack.len() < USIZE_BYTES {
37             return forward_search(start_ptr, end_ptr, ptr, confirm);
38         }
39 
40         let chunk = read_unaligned_usize(ptr);
41         if (chunk ^ vn1) != 0 {
42             return forward_search(start_ptr, end_ptr, ptr, confirm);
43         }
44 
45         ptr = ptr.add(USIZE_BYTES - (start_ptr as usize & align));
46         debug_assert!(ptr > start_ptr);
47         debug_assert!(end_ptr.sub(USIZE_BYTES) >= start_ptr);
48         while loop_size == LOOP_SIZE && ptr <= end_ptr.sub(loop_size) {
49             debug_assert_eq!(0, (ptr as usize) % USIZE_BYTES);
50 
51             let a = *(ptr as *const usize);
52             let b = *(ptr.add(USIZE_BYTES) as *const usize);
53             let eqa = (a ^ vn1) != 0;
54             let eqb = (b ^ vn1) != 0;
55             if eqa || eqb {
56                 break;
57             }
58             ptr = ptr.add(LOOP_SIZE);
59         }
60         forward_search(start_ptr, end_ptr, ptr, confirm)
61     }
62 }
63 
64 /// Return the last index not matching the byte `x` in `text`.
inv_memrchr(n1: u8, haystack: &[u8]) -> Option<usize>65 pub fn inv_memrchr(n1: u8, haystack: &[u8]) -> Option<usize> {
66     let vn1 = repeat_byte(n1);
67     let confirm = |byte| byte != n1;
68     let loop_size = cmp::min(LOOP_SIZE, haystack.len());
69     let align = USIZE_BYTES - 1;
70     let start_ptr = haystack.as_ptr();
71     let end_ptr = haystack[haystack.len()..].as_ptr();
72     let mut ptr = end_ptr;
73 
74     unsafe {
75         if haystack.len() < USIZE_BYTES {
76             return reverse_search(start_ptr, end_ptr, ptr, confirm);
77         }
78 
79         let chunk = read_unaligned_usize(ptr.sub(USIZE_BYTES));
80         if (chunk ^ vn1) != 0 {
81             return reverse_search(start_ptr, end_ptr, ptr, confirm);
82         }
83 
84         ptr = (end_ptr as usize & !align) as *const u8;
85         debug_assert!(start_ptr <= ptr && ptr <= end_ptr);
86         while loop_size == LOOP_SIZE && ptr >= start_ptr.add(loop_size) {
87             debug_assert_eq!(0, (ptr as usize) % USIZE_BYTES);
88 
89             let a = *(ptr.sub(2 * USIZE_BYTES) as *const usize);
90             let b = *(ptr.sub(1 * USIZE_BYTES) as *const usize);
91             let eqa = (a ^ vn1) != 0;
92             let eqb = (b ^ vn1) != 0;
93             if eqa || eqb {
94                 break;
95             }
96             ptr = ptr.sub(loop_size);
97         }
98         reverse_search(start_ptr, end_ptr, ptr, confirm)
99     }
100 }
101 
102 #[inline(always)]
forward_search<F: Fn(u8) -> bool>( start_ptr: *const u8, end_ptr: *const u8, mut ptr: *const u8, confirm: F, ) -> Option<usize>103 unsafe fn forward_search<F: Fn(u8) -> bool>(
104     start_ptr: *const u8,
105     end_ptr: *const u8,
106     mut ptr: *const u8,
107     confirm: F,
108 ) -> Option<usize> {
109     debug_assert!(start_ptr <= ptr);
110     debug_assert!(ptr <= end_ptr);
111 
112     while ptr < end_ptr {
113         if confirm(*ptr) {
114             return Some(sub(ptr, start_ptr));
115         }
116         ptr = ptr.offset(1);
117     }
118     None
119 }
120 
121 #[inline(always)]
reverse_search<F: Fn(u8) -> bool>( start_ptr: *const u8, end_ptr: *const u8, mut ptr: *const u8, confirm: F, ) -> Option<usize>122 unsafe fn reverse_search<F: Fn(u8) -> bool>(
123     start_ptr: *const u8,
124     end_ptr: *const u8,
125     mut ptr: *const u8,
126     confirm: F,
127 ) -> Option<usize> {
128     debug_assert!(start_ptr <= ptr);
129     debug_assert!(ptr <= end_ptr);
130 
131     while ptr > start_ptr {
132         ptr = ptr.offset(-1);
133         if confirm(*ptr) {
134             return Some(sub(ptr, start_ptr));
135         }
136     }
137     None
138 }
139 
read_unaligned_usize(ptr: *const u8) -> usize140 unsafe fn read_unaligned_usize(ptr: *const u8) -> usize {
141     (ptr as *const usize).read_unaligned()
142 }
143 
144 /// Subtract `b` from `a` and return the difference. `a` should be greater than
145 /// or equal to `b`.
sub(a: *const u8, b: *const u8) -> usize146 fn sub(a: *const u8, b: *const u8) -> usize {
147     debug_assert!(a >= b);
148     (a as usize) - (b as usize)
149 }
150 
151 /// Safe wrapper around `forward_search`
152 #[inline]
forward_search_bytes<F: Fn(u8) -> bool>( s: &[u8], confirm: F, ) -> Option<usize>153 pub(crate) fn forward_search_bytes<F: Fn(u8) -> bool>(
154     s: &[u8],
155     confirm: F,
156 ) -> Option<usize> {
157     unsafe {
158         let start = s.as_ptr();
159         let end = start.add(s.len());
160         forward_search(start, end, start, confirm)
161     }
162 }
163 
164 /// Safe wrapper around `reverse_search`
165 #[inline]
reverse_search_bytes<F: Fn(u8) -> bool>( s: &[u8], confirm: F, ) -> Option<usize>166 pub(crate) fn reverse_search_bytes<F: Fn(u8) -> bool>(
167     s: &[u8],
168     confirm: F,
169 ) -> Option<usize> {
170     unsafe {
171         let start = s.as_ptr();
172         let end = start.add(s.len());
173         reverse_search(start, end, end, confirm)
174     }
175 }
176 
177 #[cfg(test)]
178 mod tests {
179     use super::{inv_memchr, inv_memrchr};
180     // search string, search byte, inv_memchr result, inv_memrchr result.
181     // these are expanded into a much larger set of tests in build_tests
182     const TESTS: &[(&[u8], u8, usize, usize)] = &[
183         (b"z", b'a', 0, 0),
184         (b"zz", b'a', 0, 1),
185         (b"aza", b'a', 1, 1),
186         (b"zaz", b'a', 0, 2),
187         (b"zza", b'a', 0, 1),
188         (b"zaa", b'a', 0, 0),
189         (b"zzz", b'a', 0, 2),
190     ];
191 
192     type TestCase = (Vec<u8>, u8, Option<(usize, usize)>);
193 
build_tests() -> Vec<TestCase>194     fn build_tests() -> Vec<TestCase> {
195         let mut result = vec![];
196         for &(search, byte, fwd_pos, rev_pos) in TESTS {
197             result.push((search.to_vec(), byte, Some((fwd_pos, rev_pos))));
198             for i in 1..515 {
199                 // add a bunch of copies of the search byte to the end.
200                 let mut suffixed: Vec<u8> = search.into();
201                 suffixed.extend(std::iter::repeat(byte).take(i));
202                 result.push((suffixed, byte, Some((fwd_pos, rev_pos))));
203 
204                 // add a bunch of copies of the search byte to the start.
205                 let mut prefixed: Vec<u8> =
206                     std::iter::repeat(byte).take(i).collect();
207                 prefixed.extend(search);
208                 result.push((
209                     prefixed,
210                     byte,
211                     Some((fwd_pos + i, rev_pos + i)),
212                 ));
213 
214                 // add a bunch of copies of the search byte to both ends.
215                 let mut surrounded: Vec<u8> =
216                     std::iter::repeat(byte).take(i).collect();
217                 surrounded.extend(search);
218                 surrounded.extend(std::iter::repeat(byte).take(i));
219                 result.push((
220                     surrounded,
221                     byte,
222                     Some((fwd_pos + i, rev_pos + i)),
223                 ));
224             }
225         }
226 
227         // build non-matching tests for several sizes
228         for i in 0..515 {
229             result.push((
230                 std::iter::repeat(b'\0').take(i).collect(),
231                 b'\0',
232                 None,
233             ));
234         }
235 
236         result
237     }
238 
239     #[test]
test_inv_memchr()240     fn test_inv_memchr() {
241         use {ByteSlice, B};
242         for (search, byte, matching) in build_tests() {
243             assert_eq!(
244                 inv_memchr(byte, &search),
245                 matching.map(|m| m.0),
246                 "inv_memchr when searching for {:?} in {:?}",
247                 byte as char,
248                 // better printing
249                 B(&search).as_bstr(),
250             );
251             assert_eq!(
252                 inv_memrchr(byte, &search),
253                 matching.map(|m| m.1),
254                 "inv_memrchr when searching for {:?} in {:?}",
255                 byte as char,
256                 // better printing
257                 B(&search).as_bstr(),
258             );
259             // Test a rather large number off offsets for potential alignment issues
260             for offset in 1..130 {
261                 if offset >= search.len() {
262                     break;
263                 }
264                 // If this would cause us to shift the results off the end, skip
265                 // it so that we don't have to recompute them.
266                 if let Some((f, r)) = matching {
267                     if offset > f || offset > r {
268                         break;
269                     }
270                 }
271                 let realigned = &search[offset..];
272 
273                 let forward_pos = matching.map(|m| m.0 - offset);
274                 let reverse_pos = matching.map(|m| m.1 - offset);
275 
276                 assert_eq!(
277                     inv_memchr(byte, &realigned),
278                     forward_pos,
279                     "inv_memchr when searching (realigned by {}) for {:?} in {:?}",
280                     offset,
281                     byte as char,
282                     realigned.as_bstr(),
283                 );
284                 assert_eq!(
285                     inv_memrchr(byte, &realigned),
286                     reverse_pos,
287                     "inv_memrchr when searching (realigned by {}) for {:?} in {:?}",
288                     offset,
289                     byte as char,
290                     realigned.as_bstr(),
291                 );
292             }
293         }
294     }
295 }
296