1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements.  See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership.  The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License.  You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied.  See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17 
18 use std::{any::Any, sync::Arc};
19 
20 use arrow::array::*;
21 use arrow::compute::kernels::arithmetic::{
22     add, divide, divide_scalar, multiply, subtract,
23 };
24 use arrow::compute::kernels::boolean::{and, or};
25 use arrow::compute::kernels::comparison::{eq, gt, gt_eq, lt, lt_eq, neq};
26 use arrow::compute::kernels::comparison::{
27     eq_scalar, gt_eq_scalar, gt_scalar, lt_eq_scalar, lt_scalar, neq_scalar,
28 };
29 use arrow::compute::kernels::comparison::{
30     eq_utf8, gt_eq_utf8, gt_utf8, like_utf8, like_utf8_scalar, lt_eq_utf8, lt_utf8,
31     neq_utf8, nlike_utf8, nlike_utf8_scalar,
32 };
33 use arrow::compute::kernels::comparison::{
34     eq_utf8_scalar, gt_eq_utf8_scalar, gt_utf8_scalar, lt_eq_utf8_scalar, lt_utf8_scalar,
35     neq_utf8_scalar,
36 };
37 use arrow::datatypes::{DataType, Schema, TimeUnit};
38 use arrow::record_batch::RecordBatch;
39 
40 use crate::error::{DataFusionError, Result};
41 use crate::logical_plan::Operator;
42 use crate::physical_plan::expressions::cast;
43 use crate::physical_plan::{ColumnarValue, PhysicalExpr};
44 use crate::scalar::ScalarValue;
45 
46 use super::coercion::{eq_coercion, numerical_coercion, order_coercion, string_coercion};
47 
48 /// Binary expression
49 #[derive(Debug)]
50 pub struct BinaryExpr {
51     left: Arc<dyn PhysicalExpr>,
52     op: Operator,
53     right: Arc<dyn PhysicalExpr>,
54 }
55 
56 impl BinaryExpr {
57     /// Create new binary expression
new( left: Arc<dyn PhysicalExpr>, op: Operator, right: Arc<dyn PhysicalExpr>, ) -> Self58     pub fn new(
59         left: Arc<dyn PhysicalExpr>,
60         op: Operator,
61         right: Arc<dyn PhysicalExpr>,
62     ) -> Self {
63         Self { left, op, right }
64     }
65 
66     /// Get the left side of the binary expression
left(&self) -> &Arc<dyn PhysicalExpr>67     pub fn left(&self) -> &Arc<dyn PhysicalExpr> {
68         &self.left
69     }
70 
71     /// Get the right side of the binary expression
right(&self) -> &Arc<dyn PhysicalExpr>72     pub fn right(&self) -> &Arc<dyn PhysicalExpr> {
73         &self.right
74     }
75 
76     /// Get the operator for this binary expression
op(&self) -> &Operator77     pub fn op(&self) -> &Operator {
78         &self.op
79     }
80 }
81 
82 impl std::fmt::Display for BinaryExpr {
fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result83     fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
84         write!(f, "{} {} {}", self.left, self.op, self.right)
85     }
86 }
87 
88 /// Invoke a compute kernel on a pair of binary data arrays
89 macro_rules! compute_utf8_op {
90     ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
91         let ll = $LEFT
92             .as_any()
93             .downcast_ref::<$DT>()
94             .expect("compute_op failed to downcast array");
95         let rr = $RIGHT
96             .as_any()
97             .downcast_ref::<$DT>()
98             .expect("compute_op failed to downcast array");
99         Ok(Arc::new(paste::expr! {[<$OP _utf8>]}(&ll, &rr)?))
100     }};
101 }
102 
103 /// Invoke a compute kernel on a data array and a scalar value
104 macro_rules! compute_utf8_op_scalar {
105     ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
106         let ll = $LEFT
107             .as_any()
108             .downcast_ref::<$DT>()
109             .expect("compute_op failed to downcast array");
110         if let ScalarValue::Utf8(Some(string_value)) = $RIGHT {
111             Ok(Arc::new(paste::expr! {[<$OP _utf8_scalar>]}(
112                 &ll,
113                 &string_value,
114             )?))
115         } else {
116             Err(DataFusionError::Internal(format!(
117                 "compute_utf8_op_scalar failed to cast literal value {}",
118                 $RIGHT
119             )))
120         }
121     }};
122 }
123 
124 /// Invoke a compute kernel on a data array and a scalar value
125 macro_rules! compute_op_scalar {
126     ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
127         use std::convert::TryInto;
128         let ll = $LEFT
129             .as_any()
130             .downcast_ref::<$DT>()
131             .expect("compute_op failed to downcast array");
132         // generate the scalar function name, such as lt_scalar, from the $OP parameter
133         // (which could have a value of lt) and the suffix _scalar
134         Ok(Arc::new(paste::expr! {[<$OP _scalar>]}(
135             &ll,
136             $RIGHT.try_into()?,
137         )?))
138     }};
139 }
140 
141 /// Invoke a compute kernel on array(s)
142 macro_rules! compute_op {
143     // invoke binary operator
144     ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
145         let ll = $LEFT
146             .as_any()
147             .downcast_ref::<$DT>()
148             .expect("compute_op failed to downcast array");
149         let rr = $RIGHT
150             .as_any()
151             .downcast_ref::<$DT>()
152             .expect("compute_op failed to downcast array");
153         Ok(Arc::new($OP(&ll, &rr)?))
154     }};
155     // invoke unary operator
156     ($OPERAND:expr, $OP:ident, $DT:ident) => {{
157         let operand = $OPERAND
158             .as_any()
159             .downcast_ref::<$DT>()
160             .expect("compute_op failed to downcast array");
161         Ok(Arc::new($OP(&operand)?))
162     }};
163 }
164 
165 macro_rules! binary_string_array_op_scalar {
166     ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
167         let result: Result<Arc<dyn Array>> = match $LEFT.data_type() {
168             DataType::Utf8 => compute_utf8_op_scalar!($LEFT, $RIGHT, $OP, StringArray),
169             other => Err(DataFusionError::Internal(format!(
170                 "Data type {:?} not supported for scalar operation on string array",
171                 other
172             ))),
173         };
174         Some(result)
175     }};
176 }
177 
178 macro_rules! binary_string_array_op {
179     ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
180         match $LEFT.data_type() {
181             DataType::Utf8 => compute_utf8_op!($LEFT, $RIGHT, $OP, StringArray),
182             other => Err(DataFusionError::Internal(format!(
183                 "Data type {:?} not supported for binary operation on string arrays",
184                 other
185             ))),
186         }
187     }};
188 }
189 
190 /// Invoke a compute kernel on a pair of arrays
191 /// The binary_primitive_array_op macro only evaluates for primitive types
192 /// like integers and floats.
193 macro_rules! binary_primitive_array_op {
194     ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
195         match $LEFT.data_type() {
196             DataType::Int8 => compute_op!($LEFT, $RIGHT, $OP, Int8Array),
197             DataType::Int16 => compute_op!($LEFT, $RIGHT, $OP, Int16Array),
198             DataType::Int32 => compute_op!($LEFT, $RIGHT, $OP, Int32Array),
199             DataType::Int64 => compute_op!($LEFT, $RIGHT, $OP, Int64Array),
200             DataType::UInt8 => compute_op!($LEFT, $RIGHT, $OP, UInt8Array),
201             DataType::UInt16 => compute_op!($LEFT, $RIGHT, $OP, UInt16Array),
202             DataType::UInt32 => compute_op!($LEFT, $RIGHT, $OP, UInt32Array),
203             DataType::UInt64 => compute_op!($LEFT, $RIGHT, $OP, UInt64Array),
204             DataType::Float32 => compute_op!($LEFT, $RIGHT, $OP, Float32Array),
205             DataType::Float64 => compute_op!($LEFT, $RIGHT, $OP, Float64Array),
206             other => Err(DataFusionError::Internal(format!(
207                 "Data type {:?} not supported for binary operation on primitive arrays",
208                 other
209             ))),
210         }
211     }};
212 }
213 
214 /// Invoke a compute kernel on an array and a scalar
215 /// The binary_primitive_array_op_scalar macro only evaluates for primitive
216 /// types like integers and floats.
217 macro_rules! binary_primitive_array_op_scalar {
218     ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
219         let result: Result<Arc<dyn Array>> = match $LEFT.data_type() {
220             DataType::Int8 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int8Array),
221             DataType::Int16 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int16Array),
222             DataType::Int32 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int32Array),
223             DataType::Int64 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int64Array),
224             DataType::UInt8 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt8Array),
225             DataType::UInt16 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt16Array),
226             DataType::UInt32 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt32Array),
227             DataType::UInt64 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt64Array),
228             DataType::Float32 => compute_op_scalar!($LEFT, $RIGHT, $OP, Float32Array),
229             DataType::Float64 => compute_op_scalar!($LEFT, $RIGHT, $OP, Float64Array),
230             other => Err(DataFusionError::Internal(format!(
231                 "Data type {:?} not supported for scalar operation on primitive array",
232                 other
233             ))),
234         };
235         Some(result)
236     }};
237 }
238 
239 /// The binary_array_op_scalar macro includes types that extend beyond the primitive,
240 /// such as Utf8 strings.
241 #[macro_export]
242 macro_rules! binary_array_op_scalar {
243     ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
244         let result: Result<Arc<dyn Array>> = match $LEFT.data_type() {
245             DataType::Int8 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int8Array),
246             DataType::Int16 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int16Array),
247             DataType::Int32 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int32Array),
248             DataType::Int64 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int64Array),
249             DataType::UInt8 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt8Array),
250             DataType::UInt16 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt16Array),
251             DataType::UInt32 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt32Array),
252             DataType::UInt64 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt64Array),
253             DataType::Float32 => compute_op_scalar!($LEFT, $RIGHT, $OP, Float32Array),
254             DataType::Float64 => compute_op_scalar!($LEFT, $RIGHT, $OP, Float64Array),
255             DataType::Utf8 => compute_utf8_op_scalar!($LEFT, $RIGHT, $OP, StringArray),
256             DataType::Timestamp(TimeUnit::Nanosecond, None) => {
257                 compute_op_scalar!($LEFT, $RIGHT, $OP, TimestampNanosecondArray)
258             }
259             DataType::Date32 => {
260                 compute_op_scalar!($LEFT, $RIGHT, $OP, Date32Array)
261             }
262             other => Err(DataFusionError::Internal(format!(
263                 "Data type {:?} not supported for scalar operation on dyn array",
264                 other
265             ))),
266         };
267         Some(result)
268     }};
269 }
270 
271 /// The binary_array_op macro includes types that extend beyond the primitive,
272 /// such as Utf8 strings.
273 #[macro_export]
274 macro_rules! binary_array_op {
275     ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
276         match $LEFT.data_type() {
277             DataType::Int8 => compute_op!($LEFT, $RIGHT, $OP, Int8Array),
278             DataType::Int16 => compute_op!($LEFT, $RIGHT, $OP, Int16Array),
279             DataType::Int32 => compute_op!($LEFT, $RIGHT, $OP, Int32Array),
280             DataType::Int64 => compute_op!($LEFT, $RIGHT, $OP, Int64Array),
281             DataType::UInt8 => compute_op!($LEFT, $RIGHT, $OP, UInt8Array),
282             DataType::UInt16 => compute_op!($LEFT, $RIGHT, $OP, UInt16Array),
283             DataType::UInt32 => compute_op!($LEFT, $RIGHT, $OP, UInt32Array),
284             DataType::UInt64 => compute_op!($LEFT, $RIGHT, $OP, UInt64Array),
285             DataType::Float32 => compute_op!($LEFT, $RIGHT, $OP, Float32Array),
286             DataType::Float64 => compute_op!($LEFT, $RIGHT, $OP, Float64Array),
287             DataType::Utf8 => compute_utf8_op!($LEFT, $RIGHT, $OP, StringArray),
288             DataType::Timestamp(TimeUnit::Nanosecond, None) => {
289                 compute_op!($LEFT, $RIGHT, $OP, TimestampNanosecondArray)
290             }
291             DataType::Date32 => {
292                 compute_op!($LEFT, $RIGHT, $OP, Date32Array)
293             }
294             DataType::Date64 => {
295                 compute_op!($LEFT, $RIGHT, $OP, Date64Array)
296             }
297             other => Err(DataFusionError::Internal(format!(
298                 "Data type {:?} not supported for binary operation on dyn arrays",
299                 other
300             ))),
301         }
302     }};
303 }
304 
305 /// Invoke a boolean kernel on a pair of arrays
306 macro_rules! boolean_op {
307     ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
308         let ll = $LEFT
309             .as_any()
310             .downcast_ref::<BooleanArray>()
311             .expect("boolean_op failed to downcast array");
312         let rr = $RIGHT
313             .as_any()
314             .downcast_ref::<BooleanArray>()
315             .expect("boolean_op failed to downcast array");
316         Ok(Arc::new($OP(&ll, &rr)?))
317     }};
318 }
319 
320 /// Coercion rules for all binary operators. Returns the output type
321 /// of applying `op` to an argument of `lhs_type` and `rhs_type`.
common_binary_type( lhs_type: &DataType, op: &Operator, rhs_type: &DataType, ) -> Result<DataType>322 fn common_binary_type(
323     lhs_type: &DataType,
324     op: &Operator,
325     rhs_type: &DataType,
326 ) -> Result<DataType> {
327     // This result MUST be compatible with `binary_coerce`
328     let result = match op {
329         Operator::And | Operator::Or => match (lhs_type, rhs_type) {
330             // logical binary boolean operators can only be evaluated in bools
331             (DataType::Boolean, DataType::Boolean) => Some(DataType::Boolean),
332             _ => None,
333         },
334         // logical equality operators have their own rules, and always return a boolean
335         Operator::Eq | Operator::NotEq => eq_coercion(lhs_type, rhs_type),
336         // "like" operators operate on strings and always return a boolean
337         Operator::Like | Operator::NotLike => string_coercion(lhs_type, rhs_type),
338         // order-comparison operators have their own rules
339         Operator::Lt | Operator::Gt | Operator::GtEq | Operator::LtEq => {
340             order_coercion(lhs_type, rhs_type)
341         }
342         // for math expressions, the final value of the coercion is also the return type
343         // because coercion favours higher information types
344         Operator::Plus | Operator::Minus | Operator::Divide | Operator::Multiply => {
345             numerical_coercion(lhs_type, rhs_type)
346         }
347         Operator::Modulus => {
348             return Err(DataFusionError::NotImplemented(
349                 "Modulus operator is still not supported".to_string(),
350             ))
351         }
352     };
353 
354     // re-write the error message of failed coercions to include the operator's information
355     match result {
356         None => Err(DataFusionError::Plan(
357             format!(
358                 "'{:?} {} {:?}' can't be evaluated because there isn't a common type to coerce the types to",
359                 lhs_type, op, rhs_type
360             ),
361         )),
362         Some(t) => Ok(t)
363     }
364 }
365 
366 /// Returns the return type of a binary operator or an error when the binary operator cannot
367 /// perform the computation between the argument's types, even after type coercion.
368 ///
369 /// This function makes some assumptions about the underlying available computations.
binary_operator_data_type( lhs_type: &DataType, op: &Operator, rhs_type: &DataType, ) -> Result<DataType>370 pub fn binary_operator_data_type(
371     lhs_type: &DataType,
372     op: &Operator,
373     rhs_type: &DataType,
374 ) -> Result<DataType> {
375     // validate that it is possible to perform the operation on incoming types.
376     // (or the return datatype cannot be infered)
377     let common_type = common_binary_type(lhs_type, op, rhs_type)?;
378 
379     match op {
380         // operators that return a boolean
381         Operator::Eq
382         | Operator::NotEq
383         | Operator::And
384         | Operator::Or
385         | Operator::Like
386         | Operator::NotLike
387         | Operator::Lt
388         | Operator::Gt
389         | Operator::GtEq
390         | Operator::LtEq => Ok(DataType::Boolean),
391         // math operations return the same value as the common coerced type
392         Operator::Plus | Operator::Minus | Operator::Divide | Operator::Multiply => {
393             Ok(common_type)
394         }
395         Operator::Modulus => Err(DataFusionError::NotImplemented(
396             "Modulus operator is still not supported".to_string(),
397         )),
398     }
399 }
400 
401 impl PhysicalExpr for BinaryExpr {
402     /// Return a reference to Any that can be used for downcasting
as_any(&self) -> &dyn Any403     fn as_any(&self) -> &dyn Any {
404         self
405     }
406 
data_type(&self, input_schema: &Schema) -> Result<DataType>407     fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
408         binary_operator_data_type(
409             &self.left.data_type(input_schema)?,
410             &self.op,
411             &self.right.data_type(input_schema)?,
412         )
413     }
414 
nullable(&self, input_schema: &Schema) -> Result<bool>415     fn nullable(&self, input_schema: &Schema) -> Result<bool> {
416         Ok(self.left.nullable(input_schema)? || self.right.nullable(input_schema)?)
417     }
418 
evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue>419     fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
420         let left_value = self.left.evaluate(batch)?;
421         let right_value = self.right.evaluate(batch)?;
422         let left_data_type = left_value.data_type();
423         let right_data_type = right_value.data_type();
424 
425         if left_data_type != right_data_type {
426             return Err(DataFusionError::Internal(format!(
427                 "Cannot evaluate binary expression {:?} with types {:?} and {:?}",
428                 self.op, left_data_type, right_data_type
429             )));
430         }
431 
432         let scalar_result = match (&left_value, &right_value) {
433             (ColumnarValue::Array(array), ColumnarValue::Scalar(scalar)) => {
434                 // if left is array and right is literal - use scalar operations
435                 match &self.op {
436                     Operator::Lt => binary_array_op_scalar!(array, scalar.clone(), lt),
437                     Operator::LtEq => {
438                         binary_array_op_scalar!(array, scalar.clone(), lt_eq)
439                     }
440                     Operator::Gt => binary_array_op_scalar!(array, scalar.clone(), gt),
441                     Operator::GtEq => {
442                         binary_array_op_scalar!(array, scalar.clone(), gt_eq)
443                     }
444                     Operator::Eq => binary_array_op_scalar!(array, scalar.clone(), eq),
445                     Operator::NotEq => {
446                         binary_array_op_scalar!(array, scalar.clone(), neq)
447                     }
448                     Operator::Like => {
449                         binary_string_array_op_scalar!(array, scalar.clone(), like)
450                     }
451                     Operator::NotLike => {
452                         binary_string_array_op_scalar!(array, scalar.clone(), nlike)
453                     }
454                     Operator::Divide => {
455                         binary_primitive_array_op_scalar!(array, scalar.clone(), divide)
456                     }
457                     // if scalar operation is not supported - fallback to array implementation
458                     _ => None,
459                 }
460             }
461             (ColumnarValue::Scalar(scalar), ColumnarValue::Array(array)) => {
462                 // if right is literal and left is array - reverse operator and parameters
463                 match &self.op {
464                     Operator::Lt => binary_array_op_scalar!(array, scalar.clone(), gt),
465                     Operator::LtEq => {
466                         binary_array_op_scalar!(array, scalar.clone(), gt_eq)
467                     }
468                     Operator::Gt => binary_array_op_scalar!(array, scalar.clone(), lt),
469                     Operator::GtEq => {
470                         binary_array_op_scalar!(array, scalar.clone(), lt_eq)
471                     }
472                     Operator::Eq => binary_array_op_scalar!(array, scalar.clone(), eq),
473                     Operator::NotEq => {
474                         binary_array_op_scalar!(array, scalar.clone(), neq)
475                     }
476                     // if scalar operation is not supported - fallback to array implementation
477                     _ => None,
478                 }
479             }
480             (_, _) => None,
481         };
482 
483         if let Some(result) = scalar_result {
484             return result.map(|a| ColumnarValue::Array(a));
485         }
486 
487         // if both arrays or both literals - extract arrays and continue execution
488         let (left, right) = (
489             left_value.into_array(batch.num_rows()),
490             right_value.into_array(batch.num_rows()),
491         );
492 
493         let result: Result<ArrayRef> = match &self.op {
494             Operator::Like => binary_string_array_op!(left, right, like),
495             Operator::NotLike => binary_string_array_op!(left, right, nlike),
496             Operator::Lt => binary_array_op!(left, right, lt),
497             Operator::LtEq => binary_array_op!(left, right, lt_eq),
498             Operator::Gt => binary_array_op!(left, right, gt),
499             Operator::GtEq => binary_array_op!(left, right, gt_eq),
500             Operator::Eq => binary_array_op!(left, right, eq),
501             Operator::NotEq => binary_array_op!(left, right, neq),
502             Operator::Plus => binary_primitive_array_op!(left, right, add),
503             Operator::Minus => binary_primitive_array_op!(left, right, subtract),
504             Operator::Multiply => binary_primitive_array_op!(left, right, multiply),
505             Operator::Divide => binary_primitive_array_op!(left, right, divide),
506             Operator::And => {
507                 if left_data_type == DataType::Boolean {
508                     boolean_op!(left, right, and)
509                 } else {
510                     return Err(DataFusionError::Internal(format!(
511                         "Cannot evaluate binary expression {:?} with types {:?} and {:?}",
512                         self.op,
513                         left.data_type(),
514                         right.data_type()
515                     )));
516                 }
517             }
518             Operator::Or => {
519                 if left_data_type == DataType::Boolean {
520                     boolean_op!(left, right, or)
521                 } else {
522                     return Err(DataFusionError::Internal(format!(
523                         "Cannot evaluate binary expression {:?} with types {:?} and {:?}",
524                         self.op, left_data_type, right_data_type
525                     )));
526                 }
527             }
528             Operator::Modulus => Err(DataFusionError::NotImplemented(
529                 "Modulus operator is still not supported".to_string(),
530             )),
531         };
532         result.map(|a| ColumnarValue::Array(a))
533     }
534 }
535 
536 /// return two physical expressions that are optionally coerced to a
537 /// common type that the binary operator supports.
binary_cast( lhs: Arc<dyn PhysicalExpr>, op: &Operator, rhs: Arc<dyn PhysicalExpr>, input_schema: &Schema, ) -> Result<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)>538 fn binary_cast(
539     lhs: Arc<dyn PhysicalExpr>,
540     op: &Operator,
541     rhs: Arc<dyn PhysicalExpr>,
542     input_schema: &Schema,
543 ) -> Result<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)> {
544     let lhs_type = &lhs.data_type(input_schema)?;
545     let rhs_type = &rhs.data_type(input_schema)?;
546 
547     let cast_type = common_binary_type(lhs_type, op, rhs_type)?;
548 
549     Ok((
550         cast(lhs, input_schema, cast_type.clone())?,
551         cast(rhs, input_schema, cast_type)?,
552     ))
553 }
554 
555 /// Create a binary expression whose arguments are correctly coerced.
556 /// This function errors if it is not possible to coerce the arguments
557 /// to computational types supported by the operator.
binary( lhs: Arc<dyn PhysicalExpr>, op: Operator, rhs: Arc<dyn PhysicalExpr>, input_schema: &Schema, ) -> Result<Arc<dyn PhysicalExpr>>558 pub fn binary(
559     lhs: Arc<dyn PhysicalExpr>,
560     op: Operator,
561     rhs: Arc<dyn PhysicalExpr>,
562     input_schema: &Schema,
563 ) -> Result<Arc<dyn PhysicalExpr>> {
564     let (l, r) = binary_cast(lhs, &op, rhs, input_schema)?;
565     Ok(Arc::new(BinaryExpr::new(l, op, r)))
566 }
567 
568 #[cfg(test)]
569 mod tests {
570     use arrow::datatypes::{ArrowNumericType, Field, Int32Type, SchemaRef};
571     use arrow::util::display::array_value_to_string;
572 
573     use super::*;
574     use crate::error::Result;
575     use crate::physical_plan::expressions::col;
576 
577     // Create a binary expression without coercion. Used here when we do not want to coerce the expressions
578     // to valid types. Usage can result in an execution (after plan) error.
binary_simple( l: Arc<dyn PhysicalExpr>, op: Operator, r: Arc<dyn PhysicalExpr>, ) -> Arc<dyn PhysicalExpr>579     fn binary_simple(
580         l: Arc<dyn PhysicalExpr>,
581         op: Operator,
582         r: Arc<dyn PhysicalExpr>,
583     ) -> Arc<dyn PhysicalExpr> {
584         Arc::new(BinaryExpr::new(l, op, r))
585     }
586 
587     #[test]
binary_comparison() -> Result<()>588     fn binary_comparison() -> Result<()> {
589         let schema = Schema::new(vec![
590             Field::new("a", DataType::Int32, false),
591             Field::new("b", DataType::Int32, false),
592         ]);
593         let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
594         let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
595         let batch =
596             RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])?;
597 
598         // expression: "a < b"
599         let lt = binary_simple(col("a"), Operator::Lt, col("b"));
600         let result = lt.evaluate(&batch)?.into_array(batch.num_rows());
601         assert_eq!(result.len(), 5);
602 
603         let expected = vec![false, false, true, true, true];
604         let result = result
605             .as_any()
606             .downcast_ref::<BooleanArray>()
607             .expect("failed to downcast to BooleanArray");
608         for (i, &expected_item) in expected.iter().enumerate().take(5) {
609             assert_eq!(result.value(i), expected_item);
610         }
611 
612         Ok(())
613     }
614 
615     #[test]
binary_nested() -> Result<()>616     fn binary_nested() -> Result<()> {
617         let schema = Schema::new(vec![
618             Field::new("a", DataType::Int32, false),
619             Field::new("b", DataType::Int32, false),
620         ]);
621         let a = Int32Array::from(vec![2, 4, 6, 8, 10]);
622         let b = Int32Array::from(vec![2, 5, 4, 8, 8]);
623         let batch =
624             RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])?;
625 
626         // expression: "a < b OR a == b"
627         let expr = binary_simple(
628             binary_simple(col("a"), Operator::Lt, col("b")),
629             Operator::Or,
630             binary_simple(col("a"), Operator::Eq, col("b")),
631         );
632         assert_eq!("a < b OR a = b", format!("{}", expr));
633 
634         let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
635         assert_eq!(result.len(), 5);
636 
637         let expected = vec![true, true, false, true, false];
638         let result = result
639             .as_any()
640             .downcast_ref::<BooleanArray>()
641             .expect("failed to downcast to BooleanArray");
642         for (i, &expected_item) in expected.iter().enumerate().take(5) {
643             assert_eq!(result.value(i), expected_item);
644         }
645 
646         Ok(())
647     }
648 
649     // runs an end-to-end test of physical type coercion:
650     // 1. construct a record batch with two columns of type A and B
651     //  (*_ARRAY is the Rust Arrow array type, and *_TYPE is the DataType of the elements)
652     // 2. construct a physical expression of A OP B
653     // 3. evaluate the expression
654     // 4. verify that the resulting expression is of type C
655     // 5. verify that the results of evaluation are $VEC
656     macro_rules! test_coercion {
657         ($A_ARRAY:ident, $A_TYPE:expr, $A_VEC:expr, $B_ARRAY:ident, $B_TYPE:expr, $B_VEC:expr, $OP:expr, $C_ARRAY:ident, $C_TYPE:expr, $VEC:expr) => {{
658             let schema = Schema::new(vec![
659                 Field::new("a", $A_TYPE, false),
660                 Field::new("b", $B_TYPE, false),
661             ]);
662             let a = $A_ARRAY::from($A_VEC);
663             let b = $B_ARRAY::from($B_VEC);
664             let batch = RecordBatch::try_new(
665                 Arc::new(schema.clone()),
666                 vec![Arc::new(a), Arc::new(b)],
667             )?;
668 
669             // verify that we can construct the expression
670             let expression = binary(col("a"), $OP, col("b"), &schema)?;
671 
672             // verify that the expression's type is correct
673             assert_eq!(expression.data_type(&schema)?, $C_TYPE);
674 
675             // compute
676             let result = expression.evaluate(&batch)?.into_array(batch.num_rows());
677 
678             // verify that the array's data_type is correct
679             assert_eq!(*result.data_type(), $C_TYPE);
680 
681             // verify that the data itself is downcastable
682             let result = result
683                 .as_any()
684                 .downcast_ref::<$C_ARRAY>()
685                 .expect("failed to downcast");
686             // verify that the result itself is correct
687             for (i, x) in $VEC.iter().enumerate() {
688                 assert_eq!(result.value(i), *x);
689             }
690         }};
691     }
692 
693     #[test]
test_type_coersion() -> Result<()>694     fn test_type_coersion() -> Result<()> {
695         test_coercion!(
696             Int32Array,
697             DataType::Int32,
698             vec![1i32, 2i32],
699             UInt32Array,
700             DataType::UInt32,
701             vec![1u32, 2u32],
702             Operator::Plus,
703             Int32Array,
704             DataType::Int32,
705             vec![2i32, 4i32]
706         );
707         test_coercion!(
708             Int32Array,
709             DataType::Int32,
710             vec![1i32],
711             UInt16Array,
712             DataType::UInt16,
713             vec![1u16],
714             Operator::Plus,
715             Int32Array,
716             DataType::Int32,
717             vec![2i32]
718         );
719         test_coercion!(
720             Float32Array,
721             DataType::Float32,
722             vec![1f32],
723             UInt16Array,
724             DataType::UInt16,
725             vec![1u16],
726             Operator::Plus,
727             Float32Array,
728             DataType::Float32,
729             vec![2f32]
730         );
731         test_coercion!(
732             Float32Array,
733             DataType::Float32,
734             vec![2f32],
735             UInt16Array,
736             DataType::UInt16,
737             vec![1u16],
738             Operator::Multiply,
739             Float32Array,
740             DataType::Float32,
741             vec![2f32]
742         );
743         test_coercion!(
744             StringArray,
745             DataType::Utf8,
746             vec!["hello world", "world"],
747             StringArray,
748             DataType::Utf8,
749             vec!["%hello%", "%hello%"],
750             Operator::Like,
751             BooleanArray,
752             DataType::Boolean,
753             vec![true, false]
754         );
755         test_coercion!(
756             StringArray,
757             DataType::Utf8,
758             vec!["1994-12-13", "1995-01-26"],
759             Date32Array,
760             DataType::Date32,
761             vec![9112, 9156],
762             Operator::Eq,
763             BooleanArray,
764             DataType::Boolean,
765             vec![true, true]
766         );
767         test_coercion!(
768             StringArray,
769             DataType::Utf8,
770             vec!["1994-12-13", "1995-01-26"],
771             Date32Array,
772             DataType::Date32,
773             vec![9113, 9154],
774             Operator::Lt,
775             BooleanArray,
776             DataType::Boolean,
777             vec![true, false]
778         );
779         test_coercion!(
780             StringArray,
781             DataType::Utf8,
782             vec!["1994-12-13T12:34:56", "1995-01-26T01:23:45"],
783             Date64Array,
784             DataType::Date64,
785             vec![787322096000, 791083425000],
786             Operator::Eq,
787             BooleanArray,
788             DataType::Boolean,
789             vec![true, true]
790         );
791         test_coercion!(
792             StringArray,
793             DataType::Utf8,
794             vec!["1994-12-13T12:34:56", "1995-01-26T01:23:45"],
795             Date64Array,
796             DataType::Date64,
797             vec![787322096001, 791083424999],
798             Operator::Lt,
799             BooleanArray,
800             DataType::Boolean,
801             vec![true, false]
802         );
803         Ok(())
804     }
805 
806     // Note it would be nice to use the same test_coercion macro as
807     // above, but sadly the type of the values of the dictionary are
808     // not encoded in the rust type of the DictionaryArray. Thus there
809     // is no way at the time of this writing to create a dictionary
810     // array using the `From` trait
811     #[test]
test_dictionary_type_to_array_coersion() -> Result<()>812     fn test_dictionary_type_to_array_coersion() -> Result<()> {
813         // Test string  a string dictionary
814         let dict_type =
815             DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8));
816         let string_type = DataType::Utf8;
817 
818         // build dictionary
819         let keys_builder = PrimitiveBuilder::<Int32Type>::new(10);
820         let values_builder = arrow::array::StringBuilder::new(10);
821         let mut dict_builder = StringDictionaryBuilder::new(keys_builder, values_builder);
822 
823         dict_builder.append("one")?;
824         dict_builder.append_null()?;
825         dict_builder.append("three")?;
826         dict_builder.append("four")?;
827         let dict_array = dict_builder.finish();
828 
829         let str_array =
830             StringArray::from(vec![Some("not one"), Some("two"), None, Some("four")]);
831 
832         let schema = Arc::new(Schema::new(vec![
833             Field::new("dict", dict_type, true),
834             Field::new("str", string_type, true),
835         ]));
836 
837         let batch = RecordBatch::try_new(
838             schema.clone(),
839             vec![Arc::new(dict_array), Arc::new(str_array)],
840         )?;
841 
842         let expected = "false\n\n\ntrue";
843 
844         // Test 1: dict = str
845 
846         // verify that we can construct the expression
847         let expression = binary(col("dict"), Operator::Eq, col("str"), &schema)?;
848         assert_eq!(expression.data_type(&schema)?, DataType::Boolean);
849 
850         // evaluate and verify the result type matched
851         let result = expression.evaluate(&batch)?.into_array(batch.num_rows());
852         assert_eq!(result.data_type(), &DataType::Boolean);
853 
854         // verify that the result itself is correct
855         assert_eq!(expected, array_to_string(&result)?);
856 
857         // Test 2: now test the other direction
858         // str = dict
859 
860         // verify that we can construct the expression
861         let expression = binary(col("str"), Operator::Eq, col("dict"), &schema)?;
862         assert_eq!(expression.data_type(&schema)?, DataType::Boolean);
863 
864         // evaluate and verify the result type matched
865         let result = expression.evaluate(&batch)?.into_array(batch.num_rows());
866         assert_eq!(result.data_type(), &DataType::Boolean);
867 
868         // verify that the result itself is correct
869         assert_eq!(expected, array_to_string(&result)?);
870 
871         Ok(())
872     }
873 
874     // Convert the array to a newline delimited string of pretty printed values
array_to_string(array: &ArrayRef) -> Result<String>875     fn array_to_string(array: &ArrayRef) -> Result<String> {
876         let s = (0..array.len())
877             .map(|i| array_value_to_string(array, i))
878             .collect::<std::result::Result<Vec<_>, arrow::error::ArrowError>>()?
879             .join("\n");
880         Ok(s)
881     }
882 
883     #[test]
plus_op() -> Result<()>884     fn plus_op() -> Result<()> {
885         let schema = Schema::new(vec![
886             Field::new("a", DataType::Int32, false),
887             Field::new("b", DataType::Int32, false),
888         ]);
889         let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
890         let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
891 
892         apply_arithmetic::<Int32Type>(
893             Arc::new(schema),
894             vec![Arc::new(a), Arc::new(b)],
895             Operator::Plus,
896             Int32Array::from(vec![2, 4, 7, 12, 21]),
897         )?;
898 
899         Ok(())
900     }
901 
902     #[test]
minus_op() -> Result<()>903     fn minus_op() -> Result<()> {
904         let schema = Arc::new(Schema::new(vec![
905             Field::new("a", DataType::Int32, false),
906             Field::new("b", DataType::Int32, false),
907         ]));
908         let a = Arc::new(Int32Array::from(vec![1, 2, 4, 8, 16]));
909         let b = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
910 
911         apply_arithmetic::<Int32Type>(
912             schema.clone(),
913             vec![a.clone(), b.clone()],
914             Operator::Minus,
915             Int32Array::from(vec![0, 0, 1, 4, 11]),
916         )?;
917 
918         // should handle have negative values in result (for signed)
919         apply_arithmetic::<Int32Type>(
920             schema,
921             vec![b, a],
922             Operator::Minus,
923             Int32Array::from(vec![0, 0, -1, -4, -11]),
924         )?;
925 
926         Ok(())
927     }
928 
929     #[test]
multiply_op() -> Result<()>930     fn multiply_op() -> Result<()> {
931         let schema = Arc::new(Schema::new(vec![
932             Field::new("a", DataType::Int32, false),
933             Field::new("b", DataType::Int32, false),
934         ]));
935         let a = Arc::new(Int32Array::from(vec![4, 8, 16, 32, 64]));
936         let b = Arc::new(Int32Array::from(vec![2, 4, 8, 16, 32]));
937 
938         apply_arithmetic::<Int32Type>(
939             schema,
940             vec![a, b],
941             Operator::Multiply,
942             Int32Array::from(vec![8, 32, 128, 512, 2048]),
943         )?;
944 
945         Ok(())
946     }
947 
948     #[test]
divide_op() -> Result<()>949     fn divide_op() -> Result<()> {
950         let schema = Arc::new(Schema::new(vec![
951             Field::new("a", DataType::Int32, false),
952             Field::new("b", DataType::Int32, false),
953         ]));
954         let a = Arc::new(Int32Array::from(vec![8, 32, 128, 512, 2048]));
955         let b = Arc::new(Int32Array::from(vec![2, 4, 8, 16, 32]));
956 
957         apply_arithmetic::<Int32Type>(
958             schema,
959             vec![a, b],
960             Operator::Divide,
961             Int32Array::from(vec![4, 8, 16, 32, 64]),
962         )?;
963 
964         Ok(())
965     }
966 
apply_arithmetic<T: ArrowNumericType>( schema: SchemaRef, data: Vec<ArrayRef>, op: Operator, expected: PrimitiveArray<T>, ) -> Result<()>967     fn apply_arithmetic<T: ArrowNumericType>(
968         schema: SchemaRef,
969         data: Vec<ArrayRef>,
970         op: Operator,
971         expected: PrimitiveArray<T>,
972     ) -> Result<()> {
973         let arithmetic_op = binary_simple(col("a"), op, col("b"));
974         let batch = RecordBatch::try_new(schema, data)?;
975         let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows());
976 
977         assert_eq!(result.as_ref(), &expected);
978         Ok(())
979     }
980 
981     #[test]
test_coersion_error() -> Result<()>982     fn test_coersion_error() -> Result<()> {
983         let expr =
984             common_binary_type(&DataType::Float32, &Operator::Plus, &DataType::Utf8);
985 
986         if let Err(DataFusionError::Plan(e)) = expr {
987             assert_eq!(e, "'Float32 + Utf8' can't be evaluated because there isn't a common type to coerce the types to");
988             Ok(())
989         } else {
990             Err(DataFusionError::Internal(
991                 "Coercion should have returned an DataFusionError::Internal".to_string(),
992             ))
993         }
994     }
995 }
996