1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements.  See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership.  The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License.  You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied.  See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17 
18 //! Defines miscellaneous array kernels.
19 
20 use crate::error::Result;
21 use crate::record_batch::RecordBatch;
22 use crate::{array::*, util::bit_chunk_iterator::BitChunkIterator};
23 use std::{iter::Enumerate, sync::Arc};
24 
25 /// Function that can filter arbitrary arrays
26 pub type Filter<'a> = Box<Fn(&ArrayData) -> ArrayData + 'a>;
27 
28 /// Internal state of [SlicesIterator]
29 #[derive(Debug, PartialEq)]
30 enum State {
31     // it is iterating over bits of a mask (`u64`, steps of size of 1 slot)
32     Bits(u64),
33     // it is iterating over chunks (steps of size of 64 slots)
34     Chunks,
35     // it is iterating over the remainding bits (steps of size of 1 slot)
36     Remainder,
37     // nothing more to iterate.
38     Finish,
39 }
40 
41 /// An iterator of `(usize, usize)` each representing an interval `[start,end[` whose
42 /// slots of a [BooleanArray] are true. Each interval corresponds to a contiguous region of memory to be
43 /// "taken" from an array to be filtered.
44 #[derive(Debug)]
45 pub(crate) struct SlicesIterator<'a> {
46     iter: Enumerate<BitChunkIterator<'a>>,
47     state: State,
48     filter_count: usize,
49     remainder_mask: u64,
50     remainder_len: usize,
51     chunk_len: usize,
52     len: usize,
53     start: usize,
54     on_region: bool,
55     current_chunk: usize,
56     current_bit: usize,
57 }
58 
59 impl<'a> SlicesIterator<'a> {
new(filter: &'a BooleanArray) -> Self60     pub(crate) fn new(filter: &'a BooleanArray) -> Self {
61         let values = &filter.data_ref().buffers()[0];
62 
63         // this operation is performed before iteration
64         // because it is fast and allows reserving all the needed memory
65         let filter_count = values.count_set_bits_offset(filter.offset(), filter.len());
66 
67         let chunks = values.bit_chunks(filter.offset(), filter.len());
68 
69         Self {
70             iter: chunks.iter().enumerate(),
71             state: State::Chunks,
72             filter_count,
73             remainder_len: chunks.remainder_len(),
74             chunk_len: chunks.chunk_len(),
75             remainder_mask: chunks.remainder_bits(),
76             len: 0,
77             start: 0,
78             on_region: false,
79             current_chunk: 0,
80             current_bit: 0,
81         }
82     }
83 
84     #[inline]
current_start(&self) -> usize85     fn current_start(&self) -> usize {
86         self.current_chunk * 64 + self.current_bit
87     }
88 
89     #[inline]
iterate_bits(&mut self, mask: u64, max: usize) -> Option<(usize, usize)>90     fn iterate_bits(&mut self, mask: u64, max: usize) -> Option<(usize, usize)> {
91         while self.current_bit < max {
92             if (mask & (1 << self.current_bit)) != 0 {
93                 if !self.on_region {
94                     self.start = self.current_start();
95                     self.on_region = true;
96                 }
97                 self.len += 1;
98             } else if self.on_region {
99                 let result = (self.start, self.start + self.len);
100                 self.len = 0;
101                 self.on_region = false;
102                 self.current_bit += 1;
103                 return Some(result);
104             }
105             self.current_bit += 1;
106         }
107         self.current_bit = 0;
108         None
109     }
110 
111     /// iterates over chunks.
112     #[inline]
iterate_chunks(&mut self) -> Option<(usize, usize)>113     fn iterate_chunks(&mut self) -> Option<(usize, usize)> {
114         while let Some((i, mask)) = self.iter.next() {
115             self.current_chunk = i;
116             if mask == 0 {
117                 if self.on_region {
118                     let result = (self.start, self.start + self.len);
119                     self.len = 0;
120                     self.on_region = false;
121                     return Some(result);
122                 }
123             } else if mask == 18446744073709551615u64 {
124                 // = !0u64
125                 if !self.on_region {
126                     self.start = self.current_start();
127                     self.on_region = true;
128                 }
129                 self.len += 64;
130             } else {
131                 // there is a chunk that has a non-trivial mask => iterate over bits.
132                 self.state = State::Bits(mask);
133                 return None;
134             }
135         }
136         // no more chunks => start iterating over the remainder
137         self.current_chunk = self.chunk_len;
138         self.state = State::Remainder;
139         None
140     }
141 }
142 
143 impl<'a> Iterator for SlicesIterator<'a> {
144     type Item = (usize, usize);
145 
next(&mut self) -> Option<Self::Item>146     fn next(&mut self) -> Option<Self::Item> {
147         match self.state {
148             State::Chunks => {
149                 match self.iterate_chunks() {
150                     None => {
151                         // iterating over chunks does not yield any new slice => continue to the next
152                         self.current_bit = 0;
153                         self.next()
154                     }
155                     other => other,
156                 }
157             }
158             State::Bits(mask) => {
159                 match self.iterate_bits(mask, 64) {
160                     None => {
161                         // iterating over bits does not yield any new slice => change back
162                         // to chunks and continue to the next
163                         self.state = State::Chunks;
164                         self.next()
165                     }
166                     other => other,
167                 }
168             }
169             State::Remainder => {
170                 match self.iterate_bits(self.remainder_mask, self.remainder_len) {
171                     None => {
172                         self.state = State::Finish;
173                         if self.on_region {
174                             Some((self.start, self.start + self.len))
175                         } else {
176                             None
177                         }
178                     }
179                     other => other,
180                 }
181             }
182             State::Finish => None,
183         }
184     }
185 }
186 
187 /// Returns a prepared function optimized to filter multiple arrays.
188 /// Creating this function requires time, but using it is faster than [filter] when the
189 /// same filter needs to be applied to multiple arrays (e.g. a multi-column `RecordBatch`).
190 /// WARNING: the nulls of `filter` are ignored and the value on its slot is considered.
191 /// Therefore, it is considered undefined behavior to pass `filter` with null values.
build_filter(filter: &BooleanArray) -> Result<Filter>192 pub fn build_filter(filter: &BooleanArray) -> Result<Filter> {
193     let iter = SlicesIterator::new(filter);
194     let filter_count = iter.filter_count;
195     let chunks = iter.collect::<Vec<_>>();
196 
197     Ok(Box::new(move |array: &ArrayData| {
198         let mut mutable = MutableArrayData::new(vec![array], false, filter_count);
199         chunks
200             .iter()
201             .for_each(|(start, end)| mutable.extend(0, *start, *end));
202         mutable.freeze()
203     }))
204 }
205 
206 /// Filters an [Array], returning elements matching the filter (i.e. where the values are true).
207 /// WARNING: the nulls of `filter` are ignored and the value on its slot is considered.
208 /// Therefore, it is considered undefined behavior to pass `filter` with null values.
209 /// # Example
210 /// ```rust
211 /// # use arrow::array::{Int32Array, BooleanArray};
212 /// # use arrow::error::Result;
213 /// # use arrow::compute::kernels::filter::filter;
214 /// # fn main() -> Result<()> {
215 /// let array = Int32Array::from(vec![5, 6, 7, 8, 9]);
216 /// let filter_array = BooleanArray::from(vec![true, false, false, true, false]);
217 /// let c = filter(&array, &filter_array)?;
218 /// let c = c.as_any().downcast_ref::<Int32Array>().unwrap();
219 /// assert_eq!(c, &Int32Array::from(vec![5, 8]));
220 /// # Ok(())
221 /// # }
222 /// ```
filter(array: &Array, filter: &BooleanArray) -> Result<ArrayRef>223 pub fn filter(array: &Array, filter: &BooleanArray) -> Result<ArrayRef> {
224     let iter = SlicesIterator::new(filter);
225 
226     let mut mutable =
227         MutableArrayData::new(vec![array.data_ref()], false, iter.filter_count);
228     iter.for_each(|(start, end)| mutable.extend(0, start, end));
229     let data = mutable.freeze();
230     Ok(make_array(Arc::new(data)))
231 }
232 
233 /// Returns a new [RecordBatch] with arrays containing only values matching the filter.
234 /// WARNING: the nulls of `filter` are ignored and the value on its slot is considered.
235 /// Therefore, it is considered undefined behavior to pass `filter` with null values.
filter_record_batch( record_batch: &RecordBatch, filter: &BooleanArray, ) -> Result<RecordBatch>236 pub fn filter_record_batch(
237     record_batch: &RecordBatch,
238     filter: &BooleanArray,
239 ) -> Result<RecordBatch> {
240     let filter = build_filter(filter)?;
241     let filtered_arrays = record_batch
242         .columns()
243         .iter()
244         .map(|a| make_array(Arc::new(filter(&a.data()))))
245         .collect();
246     RecordBatch::try_new(record_batch.schema(), filtered_arrays)
247 }
248 
249 #[cfg(test)]
250 mod tests {
251     use super::*;
252     use crate::{
253         buffer::Buffer,
254         datatypes::{DataType, Field},
255     };
256 
257     macro_rules! def_temporal_test {
258         ($test:ident, $array_type: ident, $data: expr) => {
259             #[test]
260             fn $test() {
261                 let a = $data;
262                 let b = BooleanArray::from(vec![true, false, true, false]);
263                 let c = filter(&a, &b).unwrap();
264                 let d = c.as_ref().as_any().downcast_ref::<$array_type>().unwrap();
265                 assert_eq!(2, d.len());
266                 assert_eq!(1, d.value(0));
267                 assert_eq!(3, d.value(1));
268             }
269         };
270     }
271 
272     def_temporal_test!(
273         test_filter_date32,
274         Date32Array,
275         Date32Array::from(vec![1, 2, 3, 4])
276     );
277     def_temporal_test!(
278         test_filter_date64,
279         Date64Array,
280         Date64Array::from(vec![1, 2, 3, 4])
281     );
282     def_temporal_test!(
283         test_filter_time32_second,
284         Time32SecondArray,
285         Time32SecondArray::from(vec![1, 2, 3, 4])
286     );
287     def_temporal_test!(
288         test_filter_time32_millisecond,
289         Time32MillisecondArray,
290         Time32MillisecondArray::from(vec![1, 2, 3, 4])
291     );
292     def_temporal_test!(
293         test_filter_time64_microsecond,
294         Time64MicrosecondArray,
295         Time64MicrosecondArray::from(vec![1, 2, 3, 4])
296     );
297     def_temporal_test!(
298         test_filter_time64_nanosecond,
299         Time64NanosecondArray,
300         Time64NanosecondArray::from(vec![1, 2, 3, 4])
301     );
302     def_temporal_test!(
303         test_filter_duration_second,
304         DurationSecondArray,
305         DurationSecondArray::from(vec![1, 2, 3, 4])
306     );
307     def_temporal_test!(
308         test_filter_duration_millisecond,
309         DurationMillisecondArray,
310         DurationMillisecondArray::from(vec![1, 2, 3, 4])
311     );
312     def_temporal_test!(
313         test_filter_duration_microsecond,
314         DurationMicrosecondArray,
315         DurationMicrosecondArray::from(vec![1, 2, 3, 4])
316     );
317     def_temporal_test!(
318         test_filter_duration_nanosecond,
319         DurationNanosecondArray,
320         DurationNanosecondArray::from(vec![1, 2, 3, 4])
321     );
322     def_temporal_test!(
323         test_filter_timestamp_second,
324         TimestampSecondArray,
325         TimestampSecondArray::from_vec(vec![1, 2, 3, 4], None)
326     );
327     def_temporal_test!(
328         test_filter_timestamp_millisecond,
329         TimestampMillisecondArray,
330         TimestampMillisecondArray::from_vec(vec![1, 2, 3, 4], None)
331     );
332     def_temporal_test!(
333         test_filter_timestamp_microsecond,
334         TimestampMicrosecondArray,
335         TimestampMicrosecondArray::from_vec(vec![1, 2, 3, 4], None)
336     );
337     def_temporal_test!(
338         test_filter_timestamp_nanosecond,
339         TimestampNanosecondArray,
340         TimestampNanosecondArray::from_vec(vec![1, 2, 3, 4], None)
341     );
342 
343     #[test]
test_filter_array_slice()344     fn test_filter_array_slice() {
345         let a_slice = Int32Array::from(vec![5, 6, 7, 8, 9]).slice(1, 4);
346         let a = a_slice.as_ref();
347         let b = BooleanArray::from(vec![true, false, false, true]);
348         // filtering with sliced filter array is not currently supported
349         // let b_slice = BooleanArray::from(vec![true, false, false, true, false]).slice(1, 4);
350         // let b = b_slice.as_any().downcast_ref().unwrap();
351         let c = filter(a, &b).unwrap();
352         let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
353         assert_eq!(2, d.len());
354         assert_eq!(6, d.value(0));
355         assert_eq!(9, d.value(1));
356     }
357 
358     #[test]
test_filter_array_low_density()359     fn test_filter_array_low_density() {
360         // this test exercises the all 0's branch of the filter algorithm
361         let mut data_values = (1..=65).collect::<Vec<i32>>();
362         let mut filter_values =
363             (1..=65).map(|i| matches!(i % 65, 0)).collect::<Vec<bool>>();
364         // set up two more values after the batch
365         data_values.extend_from_slice(&[66, 67]);
366         filter_values.extend_from_slice(&[false, true]);
367         let a = Int32Array::from(data_values);
368         let b = BooleanArray::from(filter_values);
369         let c = filter(&a, &b).unwrap();
370         let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
371         assert_eq!(2, d.len());
372         assert_eq!(65, d.value(0));
373         assert_eq!(67, d.value(1));
374     }
375 
376     #[test]
test_filter_array_high_density()377     fn test_filter_array_high_density() {
378         // this test exercises the all 1's branch of the filter algorithm
379         let mut data_values = (1..=65).map(Some).collect::<Vec<_>>();
380         let mut filter_values = (1..=65)
381             .map(|i| !matches!(i % 65, 0))
382             .collect::<Vec<bool>>();
383         // set second data value to null
384         data_values[1] = None;
385         // set up two more values after the batch
386         data_values.extend_from_slice(&[Some(66), None, Some(67), None]);
387         filter_values.extend_from_slice(&[false, true, true, true]);
388         let a = Int32Array::from(data_values);
389         let b = BooleanArray::from(filter_values);
390         let c = filter(&a, &b).unwrap();
391         let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
392         assert_eq!(67, d.len());
393         assert_eq!(3, d.null_count());
394         assert_eq!(1, d.value(0));
395         assert_eq!(true, d.is_null(1));
396         assert_eq!(64, d.value(63));
397         assert_eq!(true, d.is_null(64));
398         assert_eq!(67, d.value(65));
399     }
400 
401     #[test]
test_filter_string_array_simple()402     fn test_filter_string_array_simple() {
403         let a = StringArray::from(vec!["hello", " ", "world", "!"]);
404         let b = BooleanArray::from(vec![true, false, true, false]);
405         let c = filter(&a, &b).unwrap();
406         let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
407         assert_eq!(2, d.len());
408         assert_eq!("hello", d.value(0));
409         assert_eq!("world", d.value(1));
410     }
411 
412     #[test]
test_filter_primative_array_with_null()413     fn test_filter_primative_array_with_null() {
414         let a = Int32Array::from(vec![Some(5), None]);
415         let b = BooleanArray::from(vec![false, true]);
416         let c = filter(&a, &b).unwrap();
417         let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
418         assert_eq!(1, d.len());
419         assert_eq!(true, d.is_null(0));
420     }
421 
422     #[test]
test_filter_string_array_with_null()423     fn test_filter_string_array_with_null() {
424         let a = StringArray::from(vec![Some("hello"), None, Some("world"), None]);
425         let b = BooleanArray::from(vec![true, false, false, true]);
426         let c = filter(&a, &b).unwrap();
427         let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
428         assert_eq!(2, d.len());
429         assert_eq!("hello", d.value(0));
430         assert_eq!(false, d.is_null(0));
431         assert_eq!(true, d.is_null(1));
432     }
433 
434     #[test]
test_filter_binary_array_with_null()435     fn test_filter_binary_array_with_null() {
436         let data: Vec<Option<&[u8]>> = vec![Some(b"hello"), None, Some(b"world"), None];
437         let a = BinaryArray::from(data);
438         let b = BooleanArray::from(vec![true, false, false, true]);
439         let c = filter(&a, &b).unwrap();
440         let d = c.as_ref().as_any().downcast_ref::<BinaryArray>().unwrap();
441         assert_eq!(2, d.len());
442         assert_eq!(b"hello", d.value(0));
443         assert_eq!(false, d.is_null(0));
444         assert_eq!(true, d.is_null(1));
445     }
446 
447     #[test]
test_filter_array_slice_with_null()448     fn test_filter_array_slice_with_null() {
449         let a_slice =
450             Int32Array::from(vec![Some(5), None, Some(7), Some(8), Some(9)]).slice(1, 4);
451         let a = a_slice.as_ref();
452         let b = BooleanArray::from(vec![true, false, false, true]);
453         // filtering with sliced filter array is not currently supported
454         // let b_slice = BooleanArray::from(vec![true, false, false, true, false]).slice(1, 4);
455         // let b = b_slice.as_any().downcast_ref().unwrap();
456         let c = filter(a, &b).unwrap();
457         let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
458         assert_eq!(2, d.len());
459         assert_eq!(true, d.is_null(0));
460         assert_eq!(false, d.is_null(1));
461         assert_eq!(9, d.value(1));
462     }
463 
464     #[test]
test_filter_dictionary_array()465     fn test_filter_dictionary_array() {
466         let values = vec![Some("hello"), None, Some("world"), Some("!")];
467         let a: Int8DictionaryArray = values.iter().copied().collect();
468         let b = BooleanArray::from(vec![false, true, true, false]);
469         let c = filter(&a, &b).unwrap();
470         let d = c
471             .as_ref()
472             .as_any()
473             .downcast_ref::<Int8DictionaryArray>()
474             .unwrap();
475         let value_array = d.values();
476         let values = value_array.as_any().downcast_ref::<StringArray>().unwrap();
477         // values are cloned in the filtered dictionary array
478         assert_eq!(3, values.len());
479         // but keys are filtered
480         assert_eq!(2, d.len());
481         assert_eq!(true, d.is_null(0));
482         assert_eq!("world", values.value(d.keys().value(1) as usize));
483     }
484 
485     #[test]
test_filter_string_array_with_negated_boolean_array()486     fn test_filter_string_array_with_negated_boolean_array() {
487         let a = StringArray::from(vec!["hello", " ", "world", "!"]);
488         let mut bb = BooleanBuilder::new(2);
489         bb.append_value(false).unwrap();
490         bb.append_value(true).unwrap();
491         bb.append_value(false).unwrap();
492         bb.append_value(true).unwrap();
493         let b = bb.finish();
494         let b = crate::compute::not(&b).unwrap();
495 
496         let c = filter(&a, &b).unwrap();
497         let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
498         assert_eq!(2, d.len());
499         assert_eq!("hello", d.value(0));
500         assert_eq!("world", d.value(1));
501     }
502 
503     #[test]
test_filter_list_array()504     fn test_filter_list_array() {
505         let value_data = ArrayData::builder(DataType::Int32)
506             .len(8)
507             .add_buffer(Buffer::from_slice_ref(&[0, 1, 2, 3, 4, 5, 6, 7]))
508             .build();
509 
510         let value_offsets = Buffer::from_slice_ref(&[0i64, 3, 6, 8, 8]);
511 
512         let list_data_type =
513             DataType::LargeList(Box::new(Field::new("item", DataType::Int32, false)));
514         let list_data = ArrayData::builder(list_data_type)
515             .len(4)
516             .add_buffer(value_offsets)
517             .add_child_data(value_data)
518             .null_bit_buffer(Buffer::from([0b00000111]))
519             .build();
520 
521         //  a = [[0, 1, 2], [3, 4, 5], [6, 7], null]
522         let a = LargeListArray::from(list_data);
523         let b = BooleanArray::from(vec![false, true, false, true]);
524         let result = filter(&a, &b).unwrap();
525 
526         // expected: [[3, 4, 5], null]
527         let value_data = ArrayData::builder(DataType::Int32)
528             .len(3)
529             .add_buffer(Buffer::from_slice_ref(&[3, 4, 5]))
530             .build();
531 
532         let value_offsets = Buffer::from_slice_ref(&[0i64, 3, 3]);
533 
534         let list_data_type =
535             DataType::LargeList(Box::new(Field::new("item", DataType::Int32, false)));
536         let expected = ArrayData::builder(list_data_type)
537             .len(2)
538             .add_buffer(value_offsets)
539             .add_child_data(value_data)
540             .null_bit_buffer(Buffer::from([0b00000001]))
541             .build();
542 
543         assert_eq!(&make_array(expected), &result);
544     }
545 
546     #[test]
test_slice_iterator_bits()547     fn test_slice_iterator_bits() {
548         let filter_values = (0..64).map(|i| i == 1).collect::<Vec<bool>>();
549         let filter = BooleanArray::from(filter_values);
550 
551         let iter = SlicesIterator::new(&filter);
552         let filter_count = iter.filter_count;
553         let chunks = iter.collect::<Vec<_>>();
554 
555         assert_eq!(chunks, vec![(1, 2)]);
556         assert_eq!(filter_count, 1);
557     }
558 
559     #[test]
test_slice_iterator_bits1()560     fn test_slice_iterator_bits1() {
561         let filter_values = (0..64).map(|i| i != 1).collect::<Vec<bool>>();
562         let filter = BooleanArray::from(filter_values);
563 
564         let iter = SlicesIterator::new(&filter);
565         let filter_count = iter.filter_count;
566         let chunks = iter.collect::<Vec<_>>();
567 
568         assert_eq!(chunks, vec![(0, 1), (2, 64)]);
569         assert_eq!(filter_count, 64 - 1);
570     }
571 
572     #[test]
test_slice_iterator_chunk_and_bits()573     fn test_slice_iterator_chunk_and_bits() {
574         let filter_values = (0..130).map(|i| i % 62 != 0).collect::<Vec<bool>>();
575         let filter = BooleanArray::from(filter_values);
576 
577         let iter = SlicesIterator::new(&filter);
578         let filter_count = iter.filter_count;
579         let chunks = iter.collect::<Vec<_>>();
580 
581         assert_eq!(chunks, vec![(1, 62), (63, 124), (125, 130)]);
582         assert_eq!(filter_count, 61 + 61 + 5);
583     }
584 }
585