1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 //! InList expression
19
20 use std::any::Any;
21 use std::sync::Arc;
22
23 use arrow::array::GenericStringArray;
24 use arrow::array::{
25 ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array,
26 Int64Array, Int8Array, StringOffsetSizeTrait, UInt16Array, UInt32Array, UInt64Array,
27 UInt8Array,
28 };
29 use arrow::{
30 datatypes::{DataType, Schema},
31 record_batch::RecordBatch,
32 };
33
34 use crate::error::Result;
35 use crate::physical_plan::{ColumnarValue, PhysicalExpr};
36 use crate::scalar::ScalarValue;
37
38 /// InList
39 #[derive(Debug)]
40 pub struct InListExpr {
41 expr: Arc<dyn PhysicalExpr>,
42 list: Vec<Arc<dyn PhysicalExpr>>,
43 negated: bool,
44 }
45
46 macro_rules! make_contains {
47 ($ARRAY:expr, $LIST_VALUES:expr, $NEGATED:expr, $SCALAR_VALUE:ident, $ARRAY_TYPE:ident) => {{
48 let array = $ARRAY.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap();
49
50 let mut contains_null = false;
51 let values = $LIST_VALUES
52 .iter()
53 .flat_map(|expr| match expr {
54 ColumnarValue::Scalar(s) => match s {
55 ScalarValue::$SCALAR_VALUE(Some(v)) => Some(*v),
56 ScalarValue::$SCALAR_VALUE(None) => {
57 contains_null = true;
58 None
59 }
60 ScalarValue::Utf8(None) => {
61 contains_null = true;
62 None
63 }
64 datatype => unimplemented!("Unexpected type {} for InList", datatype),
65 },
66 ColumnarValue::Array(_) => {
67 unimplemented!("InList does not yet support nested columns.")
68 }
69 })
70 .collect::<Vec<_>>();
71
72 Ok(ColumnarValue::Array(Arc::new(
73 array
74 .iter()
75 .map(|x| {
76 let contains = x.map(|x| values.contains(&x));
77 match contains {
78 Some(true) => {
79 if $NEGATED {
80 Some(false)
81 } else {
82 Some(true)
83 }
84 }
85 Some(false) => {
86 if contains_null {
87 None
88 } else if $NEGATED {
89 Some(true)
90 } else {
91 Some(false)
92 }
93 }
94 None => None,
95 }
96 })
97 .collect::<BooleanArray>(),
98 )))
99 }};
100 }
101
102 impl InListExpr {
103 /// Create a new InList expression
new( expr: Arc<dyn PhysicalExpr>, list: Vec<Arc<dyn PhysicalExpr>>, negated: bool, ) -> Self104 pub fn new(
105 expr: Arc<dyn PhysicalExpr>,
106 list: Vec<Arc<dyn PhysicalExpr>>,
107 negated: bool,
108 ) -> Self {
109 Self {
110 expr,
111 list,
112 negated,
113 }
114 }
115
116 /// Input expression
expr(&self) -> &Arc<dyn PhysicalExpr>117 pub fn expr(&self) -> &Arc<dyn PhysicalExpr> {
118 &self.expr
119 }
120
121 /// List to search in
list(&self) -> &[Arc<dyn PhysicalExpr>]122 pub fn list(&self) -> &[Arc<dyn PhysicalExpr>] {
123 &self.list
124 }
125
126 /// Is this negated e.g. NOT IN LIST
negated(&self) -> bool127 pub fn negated(&self) -> bool {
128 self.negated
129 }
130
131 /// Compare for specific utf8 types
132 #[allow(clippy::unnecessary_wraps)]
compare_utf8<T: StringOffsetSizeTrait>( &self, array: ArrayRef, list_values: Vec<ColumnarValue>, negated: bool, ) -> Result<ColumnarValue>133 fn compare_utf8<T: StringOffsetSizeTrait>(
134 &self,
135 array: ArrayRef,
136 list_values: Vec<ColumnarValue>,
137 negated: bool,
138 ) -> Result<ColumnarValue> {
139 let array = array
140 .as_any()
141 .downcast_ref::<GenericStringArray<T>>()
142 .unwrap();
143
144 let mut contains_null = false;
145 let values = list_values
146 .iter()
147 .flat_map(|expr| match expr {
148 ColumnarValue::Scalar(s) => match s {
149 ScalarValue::Utf8(Some(v)) => Some(v.as_str()),
150 ScalarValue::Utf8(None) => {
151 contains_null = true;
152 None
153 }
154 ScalarValue::LargeUtf8(Some(v)) => Some(v.as_str()),
155 ScalarValue::LargeUtf8(None) => {
156 contains_null = true;
157 None
158 }
159 datatype => unimplemented!("Unexpected type {} for InList", datatype),
160 },
161 ColumnarValue::Array(_) => {
162 unimplemented!("InList does not yet support nested columns.")
163 }
164 })
165 .collect::<Vec<&str>>();
166
167 Ok(ColumnarValue::Array(Arc::new(
168 array
169 .iter()
170 .map(|x| {
171 let contains = x.map(|x| values.contains(&x));
172 match contains {
173 Some(true) => {
174 if negated {
175 Some(false)
176 } else {
177 Some(true)
178 }
179 }
180 Some(false) => {
181 if contains_null {
182 None
183 } else if negated {
184 Some(true)
185 } else {
186 Some(false)
187 }
188 }
189 None => None,
190 }
191 })
192 .collect::<BooleanArray>(),
193 )))
194 }
195 }
196
197 impl std::fmt::Display for InListExpr {
fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result198 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
199 if self.negated {
200 write!(f, "{} NOT IN ({:?})", self.expr, self.list)
201 } else {
202 write!(f, "{} IN ({:?})", self.expr, self.list)
203 }
204 }
205 }
206
207 impl PhysicalExpr for InListExpr {
208 /// Return a reference to Any that can be used for downcasting
as_any(&self) -> &dyn Any209 fn as_any(&self) -> &dyn Any {
210 self
211 }
212
data_type(&self, _input_schema: &Schema) -> Result<DataType>213 fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
214 Ok(DataType::Boolean)
215 }
216
nullable(&self, input_schema: &Schema) -> Result<bool>217 fn nullable(&self, input_schema: &Schema) -> Result<bool> {
218 self.expr.nullable(input_schema)
219 }
220
evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue>221 fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
222 let value = self.expr.evaluate(batch)?;
223 let value_data_type = value.data_type();
224 let list_values = self
225 .list
226 .iter()
227 .map(|expr| expr.evaluate(batch))
228 .collect::<Result<Vec<_>>>()?;
229
230 let array = match value {
231 ColumnarValue::Array(array) => array,
232 ColumnarValue::Scalar(scalar) => scalar.to_array(),
233 };
234
235 match value_data_type {
236 DataType::Float32 => {
237 make_contains!(array, list_values, self.negated, Float32, Float32Array)
238 }
239 DataType::Float64 => {
240 make_contains!(array, list_values, self.negated, Float64, Float64Array)
241 }
242 DataType::Int16 => {
243 make_contains!(array, list_values, self.negated, Int16, Int16Array)
244 }
245 DataType::Int32 => {
246 make_contains!(array, list_values, self.negated, Int32, Int32Array)
247 }
248 DataType::Int64 => {
249 make_contains!(array, list_values, self.negated, Int64, Int64Array)
250 }
251 DataType::Int8 => {
252 make_contains!(array, list_values, self.negated, Int8, Int8Array)
253 }
254 DataType::UInt16 => {
255 make_contains!(array, list_values, self.negated, UInt16, UInt16Array)
256 }
257 DataType::UInt32 => {
258 make_contains!(array, list_values, self.negated, UInt32, UInt32Array)
259 }
260 DataType::UInt64 => {
261 make_contains!(array, list_values, self.negated, UInt64, UInt64Array)
262 }
263 DataType::UInt8 => {
264 make_contains!(array, list_values, self.negated, UInt8, UInt8Array)
265 }
266 DataType::Boolean => {
267 make_contains!(array, list_values, self.negated, Boolean, BooleanArray)
268 }
269 DataType::Utf8 => self.compare_utf8::<i32>(array, list_values, self.negated),
270 DataType::LargeUtf8 => {
271 self.compare_utf8::<i64>(array, list_values, self.negated)
272 }
273 datatype => {
274 unimplemented!("InList does not support datatype {:?}.", datatype)
275 }
276 }
277 }
278 }
279
280 /// Creates a unary expression InList
in_list( expr: Arc<dyn PhysicalExpr>, list: Vec<Arc<dyn PhysicalExpr>>, negated: &bool, ) -> Result<Arc<dyn PhysicalExpr>>281 pub fn in_list(
282 expr: Arc<dyn PhysicalExpr>,
283 list: Vec<Arc<dyn PhysicalExpr>>,
284 negated: &bool,
285 ) -> Result<Arc<dyn PhysicalExpr>> {
286 Ok(Arc::new(InListExpr::new(expr, list, *negated)))
287 }
288
289 #[cfg(test)]
290 mod tests {
291 use arrow::{array::StringArray, datatypes::Field};
292
293 use super::*;
294 use crate::error::Result;
295 use crate::physical_plan::expressions::{col, lit};
296
297 // applies the in_list expr to an input batch and list
298 macro_rules! in_list {
299 ($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr) => {{
300 let expr = in_list(col("a"), $LIST, $NEGATED).unwrap();
301 let result = expr.evaluate(&$BATCH)?.into_array($BATCH.num_rows());
302 let result = result
303 .as_any()
304 .downcast_ref::<BooleanArray>()
305 .expect("failed to downcast to BooleanArray");
306 let expected = &BooleanArray::from($EXPECTED);
307 assert_eq!(expected, result);
308 }};
309 }
310
311 #[test]
in_list_utf8() -> Result<()>312 fn in_list_utf8() -> Result<()> {
313 let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
314 let a = StringArray::from(vec![Some("a"), Some("d"), None]);
315 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
316
317 // expression: "a in ("a", "b")"
318 let list = vec![
319 lit(ScalarValue::Utf8(Some("a".to_string()))),
320 lit(ScalarValue::Utf8(Some("b".to_string()))),
321 ];
322 in_list!(batch, list, &false, vec![Some(true), Some(false), None]);
323
324 // expression: "a not in ("a", "b")"
325 let list = vec![
326 lit(ScalarValue::Utf8(Some("a".to_string()))),
327 lit(ScalarValue::Utf8(Some("b".to_string()))),
328 ];
329 in_list!(batch, list, &true, vec![Some(false), Some(true), None]);
330
331 // expression: "a not in ("a", "b")"
332 let list = vec![
333 lit(ScalarValue::Utf8(Some("a".to_string()))),
334 lit(ScalarValue::Utf8(Some("b".to_string()))),
335 lit(ScalarValue::Utf8(None)),
336 ];
337 in_list!(batch, list, &false, vec![Some(true), None, None]);
338
339 // expression: "a not in ("a", "b")"
340 let list = vec![
341 lit(ScalarValue::Utf8(Some("a".to_string()))),
342 lit(ScalarValue::Utf8(Some("b".to_string()))),
343 lit(ScalarValue::Utf8(None)),
344 ];
345 in_list!(batch, list, &true, vec![Some(false), None, None]);
346
347 Ok(())
348 }
349
350 #[test]
in_list_int64() -> Result<()>351 fn in_list_int64() -> Result<()> {
352 let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]);
353 let a = Int64Array::from(vec![Some(0), Some(2), None]);
354 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
355
356 // expression: "a in (0, 1)"
357 let list = vec![
358 lit(ScalarValue::Int64(Some(0))),
359 lit(ScalarValue::Int64(Some(1))),
360 ];
361 in_list!(batch, list, &false, vec![Some(true), Some(false), None]);
362
363 // expression: "a not in (0, 1)"
364 let list = vec![
365 lit(ScalarValue::Int64(Some(0))),
366 lit(ScalarValue::Int64(Some(1))),
367 ];
368 in_list!(batch, list, &true, vec![Some(false), Some(true), None]);
369
370 // expression: "a in (0, 1, NULL)"
371 let list = vec![
372 lit(ScalarValue::Int64(Some(0))),
373 lit(ScalarValue::Int64(Some(1))),
374 lit(ScalarValue::Utf8(None)),
375 ];
376 in_list!(batch, list, &false, vec![Some(true), None, None]);
377
378 // expression: "a not in (0, 1, NULL)"
379 let list = vec![
380 lit(ScalarValue::Int64(Some(0))),
381 lit(ScalarValue::Int64(Some(1))),
382 lit(ScalarValue::Utf8(None)),
383 ];
384 in_list!(batch, list, &true, vec![Some(false), None, None]);
385
386 Ok(())
387 }
388
389 #[test]
in_list_float64() -> Result<()>390 fn in_list_float64() -> Result<()> {
391 let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]);
392 let a = Float64Array::from(vec![Some(0.0), Some(0.2), None]);
393 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
394
395 // expression: "a in (0.0, 0.2)"
396 let list = vec![
397 lit(ScalarValue::Float64(Some(0.0))),
398 lit(ScalarValue::Float64(Some(0.1))),
399 ];
400 in_list!(batch, list, &false, vec![Some(true), Some(false), None]);
401
402 // expression: "a not in (0.0, 0.2)"
403 let list = vec![
404 lit(ScalarValue::Float64(Some(0.0))),
405 lit(ScalarValue::Float64(Some(0.1))),
406 ];
407 in_list!(batch, list, &true, vec![Some(false), Some(true), None]);
408
409 // expression: "a in (0.0, 0.2, NULL)"
410 let list = vec![
411 lit(ScalarValue::Float64(Some(0.0))),
412 lit(ScalarValue::Float64(Some(0.1))),
413 lit(ScalarValue::Utf8(None)),
414 ];
415 in_list!(batch, list, &false, vec![Some(true), None, None]);
416
417 // expression: "a not in (0.0, 0.2, NULL)"
418 let list = vec![
419 lit(ScalarValue::Float64(Some(0.0))),
420 lit(ScalarValue::Float64(Some(0.1))),
421 lit(ScalarValue::Utf8(None)),
422 ];
423 in_list!(batch, list, &true, vec![Some(false), None, None]);
424
425 Ok(())
426 }
427
428 #[test]
in_list_bool() -> Result<()>429 fn in_list_bool() -> Result<()> {
430 let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]);
431 let a = BooleanArray::from(vec![Some(true), None]);
432 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
433
434 // expression: "a in (true)"
435 let list = vec![lit(ScalarValue::Boolean(Some(true)))];
436 in_list!(batch, list, &false, vec![Some(true), None]);
437
438 // expression: "a not in (true)"
439 let list = vec![lit(ScalarValue::Boolean(Some(true)))];
440 in_list!(batch, list, &true, vec![Some(false), None]);
441
442 // expression: "a in (true, NULL)"
443 let list = vec![
444 lit(ScalarValue::Boolean(Some(true))),
445 lit(ScalarValue::Utf8(None)),
446 ];
447 in_list!(batch, list, &false, vec![Some(true), None]);
448
449 // expression: "a not in (true, NULL)"
450 let list = vec![
451 lit(ScalarValue::Boolean(Some(true))),
452 lit(ScalarValue::Utf8(None)),
453 ];
454 in_list!(batch, list, &true, vec![Some(false), None]);
455
456 Ok(())
457 }
458 }
459