1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 //! Defines miscellaneous array kernels.
19
20 use crate::error::Result;
21 use crate::record_batch::RecordBatch;
22 use crate::{array::*, util::bit_chunk_iterator::BitChunkIterator};
23 use std::{iter::Enumerate, sync::Arc};
24
25 /// Function that can filter arbitrary arrays
26 pub type Filter<'a> = Box<Fn(&ArrayData) -> ArrayData + 'a>;
27
28 /// Internal state of [SlicesIterator]
29 #[derive(Debug, PartialEq)]
30 enum State {
31 // it is iterating over bits of a mask (`u64`, steps of size of 1 slot)
32 Bits(u64),
33 // it is iterating over chunks (steps of size of 64 slots)
34 Chunks,
35 // it is iterating over the remainding bits (steps of size of 1 slot)
36 Remainder,
37 // nothing more to iterate.
38 Finish,
39 }
40
41 /// An iterator of `(usize, usize)` each representing an interval `[start,end[` whose
42 /// slots of a [BooleanArray] are true. Each interval corresponds to a contiguous region of memory to be
43 /// "taken" from an array to be filtered.
44 #[derive(Debug)]
45 pub(crate) struct SlicesIterator<'a> {
46 iter: Enumerate<BitChunkIterator<'a>>,
47 state: State,
48 filter_count: usize,
49 remainder_mask: u64,
50 remainder_len: usize,
51 chunk_len: usize,
52 len: usize,
53 start: usize,
54 on_region: bool,
55 current_chunk: usize,
56 current_bit: usize,
57 }
58
59 impl<'a> SlicesIterator<'a> {
new(filter: &'a BooleanArray) -> Self60 pub(crate) fn new(filter: &'a BooleanArray) -> Self {
61 let values = &filter.data_ref().buffers()[0];
62
63 // this operation is performed before iteration
64 // because it is fast and allows reserving all the needed memory
65 let filter_count = values.count_set_bits_offset(filter.offset(), filter.len());
66
67 let chunks = values.bit_chunks(filter.offset(), filter.len());
68
69 Self {
70 iter: chunks.iter().enumerate(),
71 state: State::Chunks,
72 filter_count,
73 remainder_len: chunks.remainder_len(),
74 chunk_len: chunks.chunk_len(),
75 remainder_mask: chunks.remainder_bits(),
76 len: 0,
77 start: 0,
78 on_region: false,
79 current_chunk: 0,
80 current_bit: 0,
81 }
82 }
83
84 #[inline]
current_start(&self) -> usize85 fn current_start(&self) -> usize {
86 self.current_chunk * 64 + self.current_bit
87 }
88
89 #[inline]
iterate_bits(&mut self, mask: u64, max: usize) -> Option<(usize, usize)>90 fn iterate_bits(&mut self, mask: u64, max: usize) -> Option<(usize, usize)> {
91 while self.current_bit < max {
92 if (mask & (1 << self.current_bit)) != 0 {
93 if !self.on_region {
94 self.start = self.current_start();
95 self.on_region = true;
96 }
97 self.len += 1;
98 } else if self.on_region {
99 let result = (self.start, self.start + self.len);
100 self.len = 0;
101 self.on_region = false;
102 self.current_bit += 1;
103 return Some(result);
104 }
105 self.current_bit += 1;
106 }
107 self.current_bit = 0;
108 None
109 }
110
111 /// iterates over chunks.
112 #[inline]
iterate_chunks(&mut self) -> Option<(usize, usize)>113 fn iterate_chunks(&mut self) -> Option<(usize, usize)> {
114 while let Some((i, mask)) = self.iter.next() {
115 self.current_chunk = i;
116 if mask == 0 {
117 if self.on_region {
118 let result = (self.start, self.start + self.len);
119 self.len = 0;
120 self.on_region = false;
121 return Some(result);
122 }
123 } else if mask == 18446744073709551615u64 {
124 // = !0u64
125 if !self.on_region {
126 self.start = self.current_start();
127 self.on_region = true;
128 }
129 self.len += 64;
130 } else {
131 // there is a chunk that has a non-trivial mask => iterate over bits.
132 self.state = State::Bits(mask);
133 return None;
134 }
135 }
136 // no more chunks => start iterating over the remainder
137 self.current_chunk = self.chunk_len;
138 self.state = State::Remainder;
139 None
140 }
141 }
142
143 impl<'a> Iterator for SlicesIterator<'a> {
144 type Item = (usize, usize);
145
next(&mut self) -> Option<Self::Item>146 fn next(&mut self) -> Option<Self::Item> {
147 match self.state {
148 State::Chunks => {
149 match self.iterate_chunks() {
150 None => {
151 // iterating over chunks does not yield any new slice => continue to the next
152 self.current_bit = 0;
153 self.next()
154 }
155 other => other,
156 }
157 }
158 State::Bits(mask) => {
159 match self.iterate_bits(mask, 64) {
160 None => {
161 // iterating over bits does not yield any new slice => change back
162 // to chunks and continue to the next
163 self.state = State::Chunks;
164 self.next()
165 }
166 other => other,
167 }
168 }
169 State::Remainder => {
170 match self.iterate_bits(self.remainder_mask, self.remainder_len) {
171 None => {
172 self.state = State::Finish;
173 if self.on_region {
174 Some((self.start, self.start + self.len))
175 } else {
176 None
177 }
178 }
179 other => other,
180 }
181 }
182 State::Finish => None,
183 }
184 }
185 }
186
187 /// Returns a prepared function optimized to filter multiple arrays.
188 /// Creating this function requires time, but using it is faster than [filter] when the
189 /// same filter needs to be applied to multiple arrays (e.g. a multi-column `RecordBatch`).
190 /// WARNING: the nulls of `filter` are ignored and the value on its slot is considered.
191 /// Therefore, it is considered undefined behavior to pass `filter` with null values.
build_filter(filter: &BooleanArray) -> Result<Filter>192 pub fn build_filter(filter: &BooleanArray) -> Result<Filter> {
193 let iter = SlicesIterator::new(filter);
194 let filter_count = iter.filter_count;
195 let chunks = iter.collect::<Vec<_>>();
196
197 Ok(Box::new(move |array: &ArrayData| {
198 let mut mutable = MutableArrayData::new(vec![array], false, filter_count);
199 chunks
200 .iter()
201 .for_each(|(start, end)| mutable.extend(0, *start, *end));
202 mutable.freeze()
203 }))
204 }
205
206 /// Filters an [Array], returning elements matching the filter (i.e. where the values are true).
207 /// WARNING: the nulls of `filter` are ignored and the value on its slot is considered.
208 /// Therefore, it is considered undefined behavior to pass `filter` with null values.
209 /// # Example
210 /// ```rust
211 /// # use arrow::array::{Int32Array, BooleanArray};
212 /// # use arrow::error::Result;
213 /// # use arrow::compute::kernels::filter::filter;
214 /// # fn main() -> Result<()> {
215 /// let array = Int32Array::from(vec![5, 6, 7, 8, 9]);
216 /// let filter_array = BooleanArray::from(vec![true, false, false, true, false]);
217 /// let c = filter(&array, &filter_array)?;
218 /// let c = c.as_any().downcast_ref::<Int32Array>().unwrap();
219 /// assert_eq!(c, &Int32Array::from(vec![5, 8]));
220 /// # Ok(())
221 /// # }
222 /// ```
filter(array: &Array, filter: &BooleanArray) -> Result<ArrayRef>223 pub fn filter(array: &Array, filter: &BooleanArray) -> Result<ArrayRef> {
224 let iter = SlicesIterator::new(filter);
225
226 let mut mutable =
227 MutableArrayData::new(vec![array.data_ref()], false, iter.filter_count);
228 iter.for_each(|(start, end)| mutable.extend(0, start, end));
229 let data = mutable.freeze();
230 Ok(make_array(Arc::new(data)))
231 }
232
233 /// Returns a new [RecordBatch] with arrays containing only values matching the filter.
234 /// WARNING: the nulls of `filter` are ignored and the value on its slot is considered.
235 /// Therefore, it is considered undefined behavior to pass `filter` with null values.
filter_record_batch( record_batch: &RecordBatch, filter: &BooleanArray, ) -> Result<RecordBatch>236 pub fn filter_record_batch(
237 record_batch: &RecordBatch,
238 filter: &BooleanArray,
239 ) -> Result<RecordBatch> {
240 let filter = build_filter(filter)?;
241 let filtered_arrays = record_batch
242 .columns()
243 .iter()
244 .map(|a| make_array(Arc::new(filter(&a.data()))))
245 .collect();
246 RecordBatch::try_new(record_batch.schema(), filtered_arrays)
247 }
248
249 #[cfg(test)]
250 mod tests {
251 use super::*;
252 use crate::{
253 buffer::Buffer,
254 datatypes::{DataType, Field},
255 };
256
257 macro_rules! def_temporal_test {
258 ($test:ident, $array_type: ident, $data: expr) => {
259 #[test]
260 fn $test() {
261 let a = $data;
262 let b = BooleanArray::from(vec![true, false, true, false]);
263 let c = filter(&a, &b).unwrap();
264 let d = c.as_ref().as_any().downcast_ref::<$array_type>().unwrap();
265 assert_eq!(2, d.len());
266 assert_eq!(1, d.value(0));
267 assert_eq!(3, d.value(1));
268 }
269 };
270 }
271
272 def_temporal_test!(
273 test_filter_date32,
274 Date32Array,
275 Date32Array::from(vec![1, 2, 3, 4])
276 );
277 def_temporal_test!(
278 test_filter_date64,
279 Date64Array,
280 Date64Array::from(vec![1, 2, 3, 4])
281 );
282 def_temporal_test!(
283 test_filter_time32_second,
284 Time32SecondArray,
285 Time32SecondArray::from(vec![1, 2, 3, 4])
286 );
287 def_temporal_test!(
288 test_filter_time32_millisecond,
289 Time32MillisecondArray,
290 Time32MillisecondArray::from(vec![1, 2, 3, 4])
291 );
292 def_temporal_test!(
293 test_filter_time64_microsecond,
294 Time64MicrosecondArray,
295 Time64MicrosecondArray::from(vec![1, 2, 3, 4])
296 );
297 def_temporal_test!(
298 test_filter_time64_nanosecond,
299 Time64NanosecondArray,
300 Time64NanosecondArray::from(vec![1, 2, 3, 4])
301 );
302 def_temporal_test!(
303 test_filter_duration_second,
304 DurationSecondArray,
305 DurationSecondArray::from(vec![1, 2, 3, 4])
306 );
307 def_temporal_test!(
308 test_filter_duration_millisecond,
309 DurationMillisecondArray,
310 DurationMillisecondArray::from(vec![1, 2, 3, 4])
311 );
312 def_temporal_test!(
313 test_filter_duration_microsecond,
314 DurationMicrosecondArray,
315 DurationMicrosecondArray::from(vec![1, 2, 3, 4])
316 );
317 def_temporal_test!(
318 test_filter_duration_nanosecond,
319 DurationNanosecondArray,
320 DurationNanosecondArray::from(vec![1, 2, 3, 4])
321 );
322 def_temporal_test!(
323 test_filter_timestamp_second,
324 TimestampSecondArray,
325 TimestampSecondArray::from_vec(vec![1, 2, 3, 4], None)
326 );
327 def_temporal_test!(
328 test_filter_timestamp_millisecond,
329 TimestampMillisecondArray,
330 TimestampMillisecondArray::from_vec(vec![1, 2, 3, 4], None)
331 );
332 def_temporal_test!(
333 test_filter_timestamp_microsecond,
334 TimestampMicrosecondArray,
335 TimestampMicrosecondArray::from_vec(vec![1, 2, 3, 4], None)
336 );
337 def_temporal_test!(
338 test_filter_timestamp_nanosecond,
339 TimestampNanosecondArray,
340 TimestampNanosecondArray::from_vec(vec![1, 2, 3, 4], None)
341 );
342
343 #[test]
test_filter_array_slice()344 fn test_filter_array_slice() {
345 let a_slice = Int32Array::from(vec![5, 6, 7, 8, 9]).slice(1, 4);
346 let a = a_slice.as_ref();
347 let b = BooleanArray::from(vec![true, false, false, true]);
348 // filtering with sliced filter array is not currently supported
349 // let b_slice = BooleanArray::from(vec![true, false, false, true, false]).slice(1, 4);
350 // let b = b_slice.as_any().downcast_ref().unwrap();
351 let c = filter(a, &b).unwrap();
352 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
353 assert_eq!(2, d.len());
354 assert_eq!(6, d.value(0));
355 assert_eq!(9, d.value(1));
356 }
357
358 #[test]
test_filter_array_low_density()359 fn test_filter_array_low_density() {
360 // this test exercises the all 0's branch of the filter algorithm
361 let mut data_values = (1..=65).collect::<Vec<i32>>();
362 let mut filter_values =
363 (1..=65).map(|i| matches!(i % 65, 0)).collect::<Vec<bool>>();
364 // set up two more values after the batch
365 data_values.extend_from_slice(&[66, 67]);
366 filter_values.extend_from_slice(&[false, true]);
367 let a = Int32Array::from(data_values);
368 let b = BooleanArray::from(filter_values);
369 let c = filter(&a, &b).unwrap();
370 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
371 assert_eq!(2, d.len());
372 assert_eq!(65, d.value(0));
373 assert_eq!(67, d.value(1));
374 }
375
376 #[test]
test_filter_array_high_density()377 fn test_filter_array_high_density() {
378 // this test exercises the all 1's branch of the filter algorithm
379 let mut data_values = (1..=65).map(Some).collect::<Vec<_>>();
380 let mut filter_values = (1..=65)
381 .map(|i| !matches!(i % 65, 0))
382 .collect::<Vec<bool>>();
383 // set second data value to null
384 data_values[1] = None;
385 // set up two more values after the batch
386 data_values.extend_from_slice(&[Some(66), None, Some(67), None]);
387 filter_values.extend_from_slice(&[false, true, true, true]);
388 let a = Int32Array::from(data_values);
389 let b = BooleanArray::from(filter_values);
390 let c = filter(&a, &b).unwrap();
391 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
392 assert_eq!(67, d.len());
393 assert_eq!(3, d.null_count());
394 assert_eq!(1, d.value(0));
395 assert_eq!(true, d.is_null(1));
396 assert_eq!(64, d.value(63));
397 assert_eq!(true, d.is_null(64));
398 assert_eq!(67, d.value(65));
399 }
400
401 #[test]
test_filter_string_array_simple()402 fn test_filter_string_array_simple() {
403 let a = StringArray::from(vec!["hello", " ", "world", "!"]);
404 let b = BooleanArray::from(vec![true, false, true, false]);
405 let c = filter(&a, &b).unwrap();
406 let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
407 assert_eq!(2, d.len());
408 assert_eq!("hello", d.value(0));
409 assert_eq!("world", d.value(1));
410 }
411
412 #[test]
test_filter_primative_array_with_null()413 fn test_filter_primative_array_with_null() {
414 let a = Int32Array::from(vec![Some(5), None]);
415 let b = BooleanArray::from(vec![false, true]);
416 let c = filter(&a, &b).unwrap();
417 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
418 assert_eq!(1, d.len());
419 assert_eq!(true, d.is_null(0));
420 }
421
422 #[test]
test_filter_string_array_with_null()423 fn test_filter_string_array_with_null() {
424 let a = StringArray::from(vec![Some("hello"), None, Some("world"), None]);
425 let b = BooleanArray::from(vec![true, false, false, true]);
426 let c = filter(&a, &b).unwrap();
427 let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
428 assert_eq!(2, d.len());
429 assert_eq!("hello", d.value(0));
430 assert_eq!(false, d.is_null(0));
431 assert_eq!(true, d.is_null(1));
432 }
433
434 #[test]
test_filter_binary_array_with_null()435 fn test_filter_binary_array_with_null() {
436 let data: Vec<Option<&[u8]>> = vec![Some(b"hello"), None, Some(b"world"), None];
437 let a = BinaryArray::from(data);
438 let b = BooleanArray::from(vec![true, false, false, true]);
439 let c = filter(&a, &b).unwrap();
440 let d = c.as_ref().as_any().downcast_ref::<BinaryArray>().unwrap();
441 assert_eq!(2, d.len());
442 assert_eq!(b"hello", d.value(0));
443 assert_eq!(false, d.is_null(0));
444 assert_eq!(true, d.is_null(1));
445 }
446
447 #[test]
test_filter_array_slice_with_null()448 fn test_filter_array_slice_with_null() {
449 let a_slice =
450 Int32Array::from(vec![Some(5), None, Some(7), Some(8), Some(9)]).slice(1, 4);
451 let a = a_slice.as_ref();
452 let b = BooleanArray::from(vec![true, false, false, true]);
453 // filtering with sliced filter array is not currently supported
454 // let b_slice = BooleanArray::from(vec![true, false, false, true, false]).slice(1, 4);
455 // let b = b_slice.as_any().downcast_ref().unwrap();
456 let c = filter(a, &b).unwrap();
457 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
458 assert_eq!(2, d.len());
459 assert_eq!(true, d.is_null(0));
460 assert_eq!(false, d.is_null(1));
461 assert_eq!(9, d.value(1));
462 }
463
464 #[test]
test_filter_dictionary_array()465 fn test_filter_dictionary_array() {
466 let values = vec![Some("hello"), None, Some("world"), Some("!")];
467 let a: Int8DictionaryArray = values.iter().copied().collect();
468 let b = BooleanArray::from(vec![false, true, true, false]);
469 let c = filter(&a, &b).unwrap();
470 let d = c
471 .as_ref()
472 .as_any()
473 .downcast_ref::<Int8DictionaryArray>()
474 .unwrap();
475 let value_array = d.values();
476 let values = value_array.as_any().downcast_ref::<StringArray>().unwrap();
477 // values are cloned in the filtered dictionary array
478 assert_eq!(3, values.len());
479 // but keys are filtered
480 assert_eq!(2, d.len());
481 assert_eq!(true, d.is_null(0));
482 assert_eq!("world", values.value(d.keys().value(1) as usize));
483 }
484
485 #[test]
test_filter_string_array_with_negated_boolean_array()486 fn test_filter_string_array_with_negated_boolean_array() {
487 let a = StringArray::from(vec!["hello", " ", "world", "!"]);
488 let mut bb = BooleanBuilder::new(2);
489 bb.append_value(false).unwrap();
490 bb.append_value(true).unwrap();
491 bb.append_value(false).unwrap();
492 bb.append_value(true).unwrap();
493 let b = bb.finish();
494 let b = crate::compute::not(&b).unwrap();
495
496 let c = filter(&a, &b).unwrap();
497 let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
498 assert_eq!(2, d.len());
499 assert_eq!("hello", d.value(0));
500 assert_eq!("world", d.value(1));
501 }
502
503 #[test]
test_filter_list_array()504 fn test_filter_list_array() {
505 let value_data = ArrayData::builder(DataType::Int32)
506 .len(8)
507 .add_buffer(Buffer::from_slice_ref(&[0, 1, 2, 3, 4, 5, 6, 7]))
508 .build();
509
510 let value_offsets = Buffer::from_slice_ref(&[0i64, 3, 6, 8, 8]);
511
512 let list_data_type =
513 DataType::LargeList(Box::new(Field::new("item", DataType::Int32, false)));
514 let list_data = ArrayData::builder(list_data_type)
515 .len(4)
516 .add_buffer(value_offsets)
517 .add_child_data(value_data)
518 .null_bit_buffer(Buffer::from([0b00000111]))
519 .build();
520
521 // a = [[0, 1, 2], [3, 4, 5], [6, 7], null]
522 let a = LargeListArray::from(list_data);
523 let b = BooleanArray::from(vec![false, true, false, true]);
524 let result = filter(&a, &b).unwrap();
525
526 // expected: [[3, 4, 5], null]
527 let value_data = ArrayData::builder(DataType::Int32)
528 .len(3)
529 .add_buffer(Buffer::from_slice_ref(&[3, 4, 5]))
530 .build();
531
532 let value_offsets = Buffer::from_slice_ref(&[0i64, 3, 3]);
533
534 let list_data_type =
535 DataType::LargeList(Box::new(Field::new("item", DataType::Int32, false)));
536 let expected = ArrayData::builder(list_data_type)
537 .len(2)
538 .add_buffer(value_offsets)
539 .add_child_data(value_data)
540 .null_bit_buffer(Buffer::from([0b00000001]))
541 .build();
542
543 assert_eq!(&make_array(expected), &result);
544 }
545
546 #[test]
test_slice_iterator_bits()547 fn test_slice_iterator_bits() {
548 let filter_values = (0..64).map(|i| i == 1).collect::<Vec<bool>>();
549 let filter = BooleanArray::from(filter_values);
550
551 let iter = SlicesIterator::new(&filter);
552 let filter_count = iter.filter_count;
553 let chunks = iter.collect::<Vec<_>>();
554
555 assert_eq!(chunks, vec![(1, 2)]);
556 assert_eq!(filter_count, 1);
557 }
558
559 #[test]
test_slice_iterator_bits1()560 fn test_slice_iterator_bits1() {
561 let filter_values = (0..64).map(|i| i != 1).collect::<Vec<bool>>();
562 let filter = BooleanArray::from(filter_values);
563
564 let iter = SlicesIterator::new(&filter);
565 let filter_count = iter.filter_count;
566 let chunks = iter.collect::<Vec<_>>();
567
568 assert_eq!(chunks, vec![(0, 1), (2, 64)]);
569 assert_eq!(filter_count, 64 - 1);
570 }
571
572 #[test]
test_slice_iterator_chunk_and_bits()573 fn test_slice_iterator_chunk_and_bits() {
574 let filter_values = (0..130).map(|i| i % 62 != 0).collect::<Vec<bool>>();
575 let filter = BooleanArray::from(filter_values);
576
577 let iter = SlicesIterator::new(&filter);
578 let filter_count = iter.filter_count;
579 let chunks = iter.collect::<Vec<_>>();
580
581 assert_eq!(chunks, vec![(1, 62), (63, 124), (125, 130)]);
582 assert_eq!(filter_count, 61 + 61 + 5);
583 }
584 }
585