1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements.  See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership.  The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License.  You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied.  See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17 
18 //! InList expression
19 
20 use std::any::Any;
21 use std::sync::Arc;
22 
23 use arrow::array::GenericStringArray;
24 use arrow::array::{
25     ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array,
26     Int64Array, Int8Array, StringOffsetSizeTrait, UInt16Array, UInt32Array, UInt64Array,
27     UInt8Array,
28 };
29 use arrow::{
30     datatypes::{DataType, Schema},
31     record_batch::RecordBatch,
32 };
33 
34 use crate::error::Result;
35 use crate::physical_plan::{ColumnarValue, PhysicalExpr};
36 use crate::scalar::ScalarValue;
37 
38 /// InList
39 #[derive(Debug)]
40 pub struct InListExpr {
41     expr: Arc<dyn PhysicalExpr>,
42     list: Vec<Arc<dyn PhysicalExpr>>,
43     negated: bool,
44 }
45 
46 macro_rules! make_contains {
47     ($ARRAY:expr, $LIST_VALUES:expr, $NEGATED:expr, $SCALAR_VALUE:ident, $ARRAY_TYPE:ident) => {{
48         let array = $ARRAY.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap();
49 
50         let mut contains_null = false;
51         let values = $LIST_VALUES
52             .iter()
53             .flat_map(|expr| match expr {
54                 ColumnarValue::Scalar(s) => match s {
55                     ScalarValue::$SCALAR_VALUE(Some(v)) => Some(*v),
56                     ScalarValue::$SCALAR_VALUE(None) => {
57                         contains_null = true;
58                         None
59                     }
60                     ScalarValue::Utf8(None) => {
61                         contains_null = true;
62                         None
63                     }
64                     datatype => unimplemented!("Unexpected type {} for InList", datatype),
65                 },
66                 ColumnarValue::Array(_) => {
67                     unimplemented!("InList does not yet support nested columns.")
68                 }
69             })
70             .collect::<Vec<_>>();
71 
72         Ok(ColumnarValue::Array(Arc::new(
73             array
74                 .iter()
75                 .map(|x| {
76                     let contains = x.map(|x| values.contains(&x));
77                     match contains {
78                         Some(true) => {
79                             if $NEGATED {
80                                 Some(false)
81                             } else {
82                                 Some(true)
83                             }
84                         }
85                         Some(false) => {
86                             if contains_null {
87                                 None
88                             } else if $NEGATED {
89                                 Some(true)
90                             } else {
91                                 Some(false)
92                             }
93                         }
94                         None => None,
95                     }
96                 })
97                 .collect::<BooleanArray>(),
98         )))
99     }};
100 }
101 
102 impl InListExpr {
103     /// Create a new InList expression
new( expr: Arc<dyn PhysicalExpr>, list: Vec<Arc<dyn PhysicalExpr>>, negated: bool, ) -> Self104     pub fn new(
105         expr: Arc<dyn PhysicalExpr>,
106         list: Vec<Arc<dyn PhysicalExpr>>,
107         negated: bool,
108     ) -> Self {
109         Self {
110             expr,
111             list,
112             negated,
113         }
114     }
115 
116     /// Input expression
expr(&self) -> &Arc<dyn PhysicalExpr>117     pub fn expr(&self) -> &Arc<dyn PhysicalExpr> {
118         &self.expr
119     }
120 
121     /// List to search in
list(&self) -> &[Arc<dyn PhysicalExpr>]122     pub fn list(&self) -> &[Arc<dyn PhysicalExpr>] {
123         &self.list
124     }
125 
126     /// Is this negated e.g. NOT IN LIST
negated(&self) -> bool127     pub fn negated(&self) -> bool {
128         self.negated
129     }
130 
131     /// Compare for specific utf8 types
132     #[allow(clippy::unnecessary_wraps)]
compare_utf8<T: StringOffsetSizeTrait>( &self, array: ArrayRef, list_values: Vec<ColumnarValue>, negated: bool, ) -> Result<ColumnarValue>133     fn compare_utf8<T: StringOffsetSizeTrait>(
134         &self,
135         array: ArrayRef,
136         list_values: Vec<ColumnarValue>,
137         negated: bool,
138     ) -> Result<ColumnarValue> {
139         let array = array
140             .as_any()
141             .downcast_ref::<GenericStringArray<T>>()
142             .unwrap();
143 
144         let mut contains_null = false;
145         let values = list_values
146             .iter()
147             .flat_map(|expr| match expr {
148                 ColumnarValue::Scalar(s) => match s {
149                     ScalarValue::Utf8(Some(v)) => Some(v.as_str()),
150                     ScalarValue::Utf8(None) => {
151                         contains_null = true;
152                         None
153                     }
154                     ScalarValue::LargeUtf8(Some(v)) => Some(v.as_str()),
155                     ScalarValue::LargeUtf8(None) => {
156                         contains_null = true;
157                         None
158                     }
159                     datatype => unimplemented!("Unexpected type {} for InList", datatype),
160                 },
161                 ColumnarValue::Array(_) => {
162                     unimplemented!("InList does not yet support nested columns.")
163                 }
164             })
165             .collect::<Vec<&str>>();
166 
167         Ok(ColumnarValue::Array(Arc::new(
168             array
169                 .iter()
170                 .map(|x| {
171                     let contains = x.map(|x| values.contains(&x));
172                     match contains {
173                         Some(true) => {
174                             if negated {
175                                 Some(false)
176                             } else {
177                                 Some(true)
178                             }
179                         }
180                         Some(false) => {
181                             if contains_null {
182                                 None
183                             } else if negated {
184                                 Some(true)
185                             } else {
186                                 Some(false)
187                             }
188                         }
189                         None => None,
190                     }
191                 })
192                 .collect::<BooleanArray>(),
193         )))
194     }
195 }
196 
197 impl std::fmt::Display for InListExpr {
fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result198     fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
199         if self.negated {
200             write!(f, "{} NOT IN ({:?})", self.expr, self.list)
201         } else {
202             write!(f, "{} IN ({:?})", self.expr, self.list)
203         }
204     }
205 }
206 
207 impl PhysicalExpr for InListExpr {
208     /// Return a reference to Any that can be used for downcasting
as_any(&self) -> &dyn Any209     fn as_any(&self) -> &dyn Any {
210         self
211     }
212 
data_type(&self, _input_schema: &Schema) -> Result<DataType>213     fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
214         Ok(DataType::Boolean)
215     }
216 
nullable(&self, input_schema: &Schema) -> Result<bool>217     fn nullable(&self, input_schema: &Schema) -> Result<bool> {
218         self.expr.nullable(input_schema)
219     }
220 
evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue>221     fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
222         let value = self.expr.evaluate(batch)?;
223         let value_data_type = value.data_type();
224         let list_values = self
225             .list
226             .iter()
227             .map(|expr| expr.evaluate(batch))
228             .collect::<Result<Vec<_>>>()?;
229 
230         let array = match value {
231             ColumnarValue::Array(array) => array,
232             ColumnarValue::Scalar(scalar) => scalar.to_array(),
233         };
234 
235         match value_data_type {
236             DataType::Float32 => {
237                 make_contains!(array, list_values, self.negated, Float32, Float32Array)
238             }
239             DataType::Float64 => {
240                 make_contains!(array, list_values, self.negated, Float64, Float64Array)
241             }
242             DataType::Int16 => {
243                 make_contains!(array, list_values, self.negated, Int16, Int16Array)
244             }
245             DataType::Int32 => {
246                 make_contains!(array, list_values, self.negated, Int32, Int32Array)
247             }
248             DataType::Int64 => {
249                 make_contains!(array, list_values, self.negated, Int64, Int64Array)
250             }
251             DataType::Int8 => {
252                 make_contains!(array, list_values, self.negated, Int8, Int8Array)
253             }
254             DataType::UInt16 => {
255                 make_contains!(array, list_values, self.negated, UInt16, UInt16Array)
256             }
257             DataType::UInt32 => {
258                 make_contains!(array, list_values, self.negated, UInt32, UInt32Array)
259             }
260             DataType::UInt64 => {
261                 make_contains!(array, list_values, self.negated, UInt64, UInt64Array)
262             }
263             DataType::UInt8 => {
264                 make_contains!(array, list_values, self.negated, UInt8, UInt8Array)
265             }
266             DataType::Boolean => {
267                 make_contains!(array, list_values, self.negated, Boolean, BooleanArray)
268             }
269             DataType::Utf8 => self.compare_utf8::<i32>(array, list_values, self.negated),
270             DataType::LargeUtf8 => {
271                 self.compare_utf8::<i64>(array, list_values, self.negated)
272             }
273             datatype => {
274                 unimplemented!("InList does not support datatype {:?}.", datatype)
275             }
276         }
277     }
278 }
279 
280 /// Creates a unary expression InList
in_list( expr: Arc<dyn PhysicalExpr>, list: Vec<Arc<dyn PhysicalExpr>>, negated: &bool, ) -> Result<Arc<dyn PhysicalExpr>>281 pub fn in_list(
282     expr: Arc<dyn PhysicalExpr>,
283     list: Vec<Arc<dyn PhysicalExpr>>,
284     negated: &bool,
285 ) -> Result<Arc<dyn PhysicalExpr>> {
286     Ok(Arc::new(InListExpr::new(expr, list, *negated)))
287 }
288 
289 #[cfg(test)]
290 mod tests {
291     use arrow::{array::StringArray, datatypes::Field};
292 
293     use super::*;
294     use crate::error::Result;
295     use crate::physical_plan::expressions::{col, lit};
296 
297     // applies the in_list expr to an input batch and list
298     macro_rules! in_list {
299         ($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr) => {{
300             let expr = in_list(col("a"), $LIST, $NEGATED).unwrap();
301             let result = expr.evaluate(&$BATCH)?.into_array($BATCH.num_rows());
302             let result = result
303                 .as_any()
304                 .downcast_ref::<BooleanArray>()
305                 .expect("failed to downcast to BooleanArray");
306             let expected = &BooleanArray::from($EXPECTED);
307             assert_eq!(expected, result);
308         }};
309     }
310 
311     #[test]
in_list_utf8() -> Result<()>312     fn in_list_utf8() -> Result<()> {
313         let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
314         let a = StringArray::from(vec![Some("a"), Some("d"), None]);
315         let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
316 
317         // expression: "a in ("a", "b")"
318         let list = vec![
319             lit(ScalarValue::Utf8(Some("a".to_string()))),
320             lit(ScalarValue::Utf8(Some("b".to_string()))),
321         ];
322         in_list!(batch, list, &false, vec![Some(true), Some(false), None]);
323 
324         // expression: "a not in ("a", "b")"
325         let list = vec![
326             lit(ScalarValue::Utf8(Some("a".to_string()))),
327             lit(ScalarValue::Utf8(Some("b".to_string()))),
328         ];
329         in_list!(batch, list, &true, vec![Some(false), Some(true), None]);
330 
331         // expression: "a not in ("a", "b")"
332         let list = vec![
333             lit(ScalarValue::Utf8(Some("a".to_string()))),
334             lit(ScalarValue::Utf8(Some("b".to_string()))),
335             lit(ScalarValue::Utf8(None)),
336         ];
337         in_list!(batch, list, &false, vec![Some(true), None, None]);
338 
339         // expression: "a not in ("a", "b")"
340         let list = vec![
341             lit(ScalarValue::Utf8(Some("a".to_string()))),
342             lit(ScalarValue::Utf8(Some("b".to_string()))),
343             lit(ScalarValue::Utf8(None)),
344         ];
345         in_list!(batch, list, &true, vec![Some(false), None, None]);
346 
347         Ok(())
348     }
349 
350     #[test]
in_list_int64() -> Result<()>351     fn in_list_int64() -> Result<()> {
352         let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]);
353         let a = Int64Array::from(vec![Some(0), Some(2), None]);
354         let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
355 
356         // expression: "a in (0, 1)"
357         let list = vec![
358             lit(ScalarValue::Int64(Some(0))),
359             lit(ScalarValue::Int64(Some(1))),
360         ];
361         in_list!(batch, list, &false, vec![Some(true), Some(false), None]);
362 
363         // expression: "a not in (0, 1)"
364         let list = vec![
365             lit(ScalarValue::Int64(Some(0))),
366             lit(ScalarValue::Int64(Some(1))),
367         ];
368         in_list!(batch, list, &true, vec![Some(false), Some(true), None]);
369 
370         // expression: "a in (0, 1, NULL)"
371         let list = vec![
372             lit(ScalarValue::Int64(Some(0))),
373             lit(ScalarValue::Int64(Some(1))),
374             lit(ScalarValue::Utf8(None)),
375         ];
376         in_list!(batch, list, &false, vec![Some(true), None, None]);
377 
378         // expression: "a not in (0, 1, NULL)"
379         let list = vec![
380             lit(ScalarValue::Int64(Some(0))),
381             lit(ScalarValue::Int64(Some(1))),
382             lit(ScalarValue::Utf8(None)),
383         ];
384         in_list!(batch, list, &true, vec![Some(false), None, None]);
385 
386         Ok(())
387     }
388 
389     #[test]
in_list_float64() -> Result<()>390     fn in_list_float64() -> Result<()> {
391         let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]);
392         let a = Float64Array::from(vec![Some(0.0), Some(0.2), None]);
393         let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
394 
395         // expression: "a in (0.0, 0.2)"
396         let list = vec![
397             lit(ScalarValue::Float64(Some(0.0))),
398             lit(ScalarValue::Float64(Some(0.1))),
399         ];
400         in_list!(batch, list, &false, vec![Some(true), Some(false), None]);
401 
402         // expression: "a not in (0.0, 0.2)"
403         let list = vec![
404             lit(ScalarValue::Float64(Some(0.0))),
405             lit(ScalarValue::Float64(Some(0.1))),
406         ];
407         in_list!(batch, list, &true, vec![Some(false), Some(true), None]);
408 
409         // expression: "a in (0.0, 0.2, NULL)"
410         let list = vec![
411             lit(ScalarValue::Float64(Some(0.0))),
412             lit(ScalarValue::Float64(Some(0.1))),
413             lit(ScalarValue::Utf8(None)),
414         ];
415         in_list!(batch, list, &false, vec![Some(true), None, None]);
416 
417         // expression: "a not in (0.0, 0.2, NULL)"
418         let list = vec![
419             lit(ScalarValue::Float64(Some(0.0))),
420             lit(ScalarValue::Float64(Some(0.1))),
421             lit(ScalarValue::Utf8(None)),
422         ];
423         in_list!(batch, list, &true, vec![Some(false), None, None]);
424 
425         Ok(())
426     }
427 
428     #[test]
in_list_bool() -> Result<()>429     fn in_list_bool() -> Result<()> {
430         let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]);
431         let a = BooleanArray::from(vec![Some(true), None]);
432         let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
433 
434         // expression: "a in (true)"
435         let list = vec![lit(ScalarValue::Boolean(Some(true)))];
436         in_list!(batch, list, &false, vec![Some(true), None]);
437 
438         // expression: "a not in (true)"
439         let list = vec![lit(ScalarValue::Boolean(Some(true)))];
440         in_list!(batch, list, &true, vec![Some(false), None]);
441 
442         // expression: "a in (true, NULL)"
443         let list = vec![
444             lit(ScalarValue::Boolean(Some(true))),
445             lit(ScalarValue::Utf8(None)),
446         ];
447         in_list!(batch, list, &false, vec![Some(true), None]);
448 
449         // expression: "a not in (true, NULL)"
450         let list = vec![
451             lit(ScalarValue::Boolean(Some(true))),
452             lit(ScalarValue::Utf8(None)),
453         ];
454         in_list!(batch, list, &true, vec![Some(false), None]);
455 
456         Ok(())
457     }
458 }
459