1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 use std::{any::Any, sync::Arc};
19
20 use arrow::array::*;
21 use arrow::compute::kernels::arithmetic::{
22 add, divide, divide_scalar, multiply, subtract,
23 };
24 use arrow::compute::kernels::boolean::{and, or};
25 use arrow::compute::kernels::comparison::{eq, gt, gt_eq, lt, lt_eq, neq};
26 use arrow::compute::kernels::comparison::{
27 eq_scalar, gt_eq_scalar, gt_scalar, lt_eq_scalar, lt_scalar, neq_scalar,
28 };
29 use arrow::compute::kernels::comparison::{
30 eq_utf8, gt_eq_utf8, gt_utf8, like_utf8, like_utf8_scalar, lt_eq_utf8, lt_utf8,
31 neq_utf8, nlike_utf8, nlike_utf8_scalar,
32 };
33 use arrow::compute::kernels::comparison::{
34 eq_utf8_scalar, gt_eq_utf8_scalar, gt_utf8_scalar, lt_eq_utf8_scalar, lt_utf8_scalar,
35 neq_utf8_scalar,
36 };
37 use arrow::datatypes::{DataType, Schema, TimeUnit};
38 use arrow::record_batch::RecordBatch;
39
40 use crate::error::{DataFusionError, Result};
41 use crate::logical_plan::Operator;
42 use crate::physical_plan::expressions::cast;
43 use crate::physical_plan::{ColumnarValue, PhysicalExpr};
44 use crate::scalar::ScalarValue;
45
46 use super::coercion::{eq_coercion, numerical_coercion, order_coercion, string_coercion};
47
48 /// Binary expression
49 #[derive(Debug)]
50 pub struct BinaryExpr {
51 left: Arc<dyn PhysicalExpr>,
52 op: Operator,
53 right: Arc<dyn PhysicalExpr>,
54 }
55
56 impl BinaryExpr {
57 /// Create new binary expression
new( left: Arc<dyn PhysicalExpr>, op: Operator, right: Arc<dyn PhysicalExpr>, ) -> Self58 pub fn new(
59 left: Arc<dyn PhysicalExpr>,
60 op: Operator,
61 right: Arc<dyn PhysicalExpr>,
62 ) -> Self {
63 Self { left, op, right }
64 }
65
66 /// Get the left side of the binary expression
left(&self) -> &Arc<dyn PhysicalExpr>67 pub fn left(&self) -> &Arc<dyn PhysicalExpr> {
68 &self.left
69 }
70
71 /// Get the right side of the binary expression
right(&self) -> &Arc<dyn PhysicalExpr>72 pub fn right(&self) -> &Arc<dyn PhysicalExpr> {
73 &self.right
74 }
75
76 /// Get the operator for this binary expression
op(&self) -> &Operator77 pub fn op(&self) -> &Operator {
78 &self.op
79 }
80 }
81
82 impl std::fmt::Display for BinaryExpr {
fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result83 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
84 write!(f, "{} {} {}", self.left, self.op, self.right)
85 }
86 }
87
88 /// Invoke a compute kernel on a pair of binary data arrays
89 macro_rules! compute_utf8_op {
90 ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
91 let ll = $LEFT
92 .as_any()
93 .downcast_ref::<$DT>()
94 .expect("compute_op failed to downcast array");
95 let rr = $RIGHT
96 .as_any()
97 .downcast_ref::<$DT>()
98 .expect("compute_op failed to downcast array");
99 Ok(Arc::new(paste::expr! {[<$OP _utf8>]}(&ll, &rr)?))
100 }};
101 }
102
103 /// Invoke a compute kernel on a data array and a scalar value
104 macro_rules! compute_utf8_op_scalar {
105 ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
106 let ll = $LEFT
107 .as_any()
108 .downcast_ref::<$DT>()
109 .expect("compute_op failed to downcast array");
110 if let ScalarValue::Utf8(Some(string_value)) = $RIGHT {
111 Ok(Arc::new(paste::expr! {[<$OP _utf8_scalar>]}(
112 &ll,
113 &string_value,
114 )?))
115 } else {
116 Err(DataFusionError::Internal(format!(
117 "compute_utf8_op_scalar failed to cast literal value {}",
118 $RIGHT
119 )))
120 }
121 }};
122 }
123
124 /// Invoke a compute kernel on a data array and a scalar value
125 macro_rules! compute_op_scalar {
126 ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
127 use std::convert::TryInto;
128 let ll = $LEFT
129 .as_any()
130 .downcast_ref::<$DT>()
131 .expect("compute_op failed to downcast array");
132 // generate the scalar function name, such as lt_scalar, from the $OP parameter
133 // (which could have a value of lt) and the suffix _scalar
134 Ok(Arc::new(paste::expr! {[<$OP _scalar>]}(
135 &ll,
136 $RIGHT.try_into()?,
137 )?))
138 }};
139 }
140
141 /// Invoke a compute kernel on array(s)
142 macro_rules! compute_op {
143 // invoke binary operator
144 ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
145 let ll = $LEFT
146 .as_any()
147 .downcast_ref::<$DT>()
148 .expect("compute_op failed to downcast array");
149 let rr = $RIGHT
150 .as_any()
151 .downcast_ref::<$DT>()
152 .expect("compute_op failed to downcast array");
153 Ok(Arc::new($OP(&ll, &rr)?))
154 }};
155 // invoke unary operator
156 ($OPERAND:expr, $OP:ident, $DT:ident) => {{
157 let operand = $OPERAND
158 .as_any()
159 .downcast_ref::<$DT>()
160 .expect("compute_op failed to downcast array");
161 Ok(Arc::new($OP(&operand)?))
162 }};
163 }
164
165 macro_rules! binary_string_array_op_scalar {
166 ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
167 let result: Result<Arc<dyn Array>> = match $LEFT.data_type() {
168 DataType::Utf8 => compute_utf8_op_scalar!($LEFT, $RIGHT, $OP, StringArray),
169 other => Err(DataFusionError::Internal(format!(
170 "Data type {:?} not supported for scalar operation on string array",
171 other
172 ))),
173 };
174 Some(result)
175 }};
176 }
177
178 macro_rules! binary_string_array_op {
179 ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
180 match $LEFT.data_type() {
181 DataType::Utf8 => compute_utf8_op!($LEFT, $RIGHT, $OP, StringArray),
182 other => Err(DataFusionError::Internal(format!(
183 "Data type {:?} not supported for binary operation on string arrays",
184 other
185 ))),
186 }
187 }};
188 }
189
190 /// Invoke a compute kernel on a pair of arrays
191 /// The binary_primitive_array_op macro only evaluates for primitive types
192 /// like integers and floats.
193 macro_rules! binary_primitive_array_op {
194 ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
195 match $LEFT.data_type() {
196 DataType::Int8 => compute_op!($LEFT, $RIGHT, $OP, Int8Array),
197 DataType::Int16 => compute_op!($LEFT, $RIGHT, $OP, Int16Array),
198 DataType::Int32 => compute_op!($LEFT, $RIGHT, $OP, Int32Array),
199 DataType::Int64 => compute_op!($LEFT, $RIGHT, $OP, Int64Array),
200 DataType::UInt8 => compute_op!($LEFT, $RIGHT, $OP, UInt8Array),
201 DataType::UInt16 => compute_op!($LEFT, $RIGHT, $OP, UInt16Array),
202 DataType::UInt32 => compute_op!($LEFT, $RIGHT, $OP, UInt32Array),
203 DataType::UInt64 => compute_op!($LEFT, $RIGHT, $OP, UInt64Array),
204 DataType::Float32 => compute_op!($LEFT, $RIGHT, $OP, Float32Array),
205 DataType::Float64 => compute_op!($LEFT, $RIGHT, $OP, Float64Array),
206 other => Err(DataFusionError::Internal(format!(
207 "Data type {:?} not supported for binary operation on primitive arrays",
208 other
209 ))),
210 }
211 }};
212 }
213
214 /// Invoke a compute kernel on an array and a scalar
215 /// The binary_primitive_array_op_scalar macro only evaluates for primitive
216 /// types like integers and floats.
217 macro_rules! binary_primitive_array_op_scalar {
218 ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
219 let result: Result<Arc<dyn Array>> = match $LEFT.data_type() {
220 DataType::Int8 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int8Array),
221 DataType::Int16 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int16Array),
222 DataType::Int32 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int32Array),
223 DataType::Int64 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int64Array),
224 DataType::UInt8 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt8Array),
225 DataType::UInt16 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt16Array),
226 DataType::UInt32 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt32Array),
227 DataType::UInt64 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt64Array),
228 DataType::Float32 => compute_op_scalar!($LEFT, $RIGHT, $OP, Float32Array),
229 DataType::Float64 => compute_op_scalar!($LEFT, $RIGHT, $OP, Float64Array),
230 other => Err(DataFusionError::Internal(format!(
231 "Data type {:?} not supported for scalar operation on primitive array",
232 other
233 ))),
234 };
235 Some(result)
236 }};
237 }
238
239 /// The binary_array_op_scalar macro includes types that extend beyond the primitive,
240 /// such as Utf8 strings.
241 #[macro_export]
242 macro_rules! binary_array_op_scalar {
243 ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
244 let result: Result<Arc<dyn Array>> = match $LEFT.data_type() {
245 DataType::Int8 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int8Array),
246 DataType::Int16 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int16Array),
247 DataType::Int32 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int32Array),
248 DataType::Int64 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int64Array),
249 DataType::UInt8 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt8Array),
250 DataType::UInt16 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt16Array),
251 DataType::UInt32 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt32Array),
252 DataType::UInt64 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt64Array),
253 DataType::Float32 => compute_op_scalar!($LEFT, $RIGHT, $OP, Float32Array),
254 DataType::Float64 => compute_op_scalar!($LEFT, $RIGHT, $OP, Float64Array),
255 DataType::Utf8 => compute_utf8_op_scalar!($LEFT, $RIGHT, $OP, StringArray),
256 DataType::Timestamp(TimeUnit::Nanosecond, None) => {
257 compute_op_scalar!($LEFT, $RIGHT, $OP, TimestampNanosecondArray)
258 }
259 DataType::Date32 => {
260 compute_op_scalar!($LEFT, $RIGHT, $OP, Date32Array)
261 }
262 other => Err(DataFusionError::Internal(format!(
263 "Data type {:?} not supported for scalar operation on dyn array",
264 other
265 ))),
266 };
267 Some(result)
268 }};
269 }
270
271 /// The binary_array_op macro includes types that extend beyond the primitive,
272 /// such as Utf8 strings.
273 #[macro_export]
274 macro_rules! binary_array_op {
275 ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
276 match $LEFT.data_type() {
277 DataType::Int8 => compute_op!($LEFT, $RIGHT, $OP, Int8Array),
278 DataType::Int16 => compute_op!($LEFT, $RIGHT, $OP, Int16Array),
279 DataType::Int32 => compute_op!($LEFT, $RIGHT, $OP, Int32Array),
280 DataType::Int64 => compute_op!($LEFT, $RIGHT, $OP, Int64Array),
281 DataType::UInt8 => compute_op!($LEFT, $RIGHT, $OP, UInt8Array),
282 DataType::UInt16 => compute_op!($LEFT, $RIGHT, $OP, UInt16Array),
283 DataType::UInt32 => compute_op!($LEFT, $RIGHT, $OP, UInt32Array),
284 DataType::UInt64 => compute_op!($LEFT, $RIGHT, $OP, UInt64Array),
285 DataType::Float32 => compute_op!($LEFT, $RIGHT, $OP, Float32Array),
286 DataType::Float64 => compute_op!($LEFT, $RIGHT, $OP, Float64Array),
287 DataType::Utf8 => compute_utf8_op!($LEFT, $RIGHT, $OP, StringArray),
288 DataType::Timestamp(TimeUnit::Nanosecond, None) => {
289 compute_op!($LEFT, $RIGHT, $OP, TimestampNanosecondArray)
290 }
291 DataType::Date32 => {
292 compute_op!($LEFT, $RIGHT, $OP, Date32Array)
293 }
294 DataType::Date64 => {
295 compute_op!($LEFT, $RIGHT, $OP, Date64Array)
296 }
297 other => Err(DataFusionError::Internal(format!(
298 "Data type {:?} not supported for binary operation on dyn arrays",
299 other
300 ))),
301 }
302 }};
303 }
304
305 /// Invoke a boolean kernel on a pair of arrays
306 macro_rules! boolean_op {
307 ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
308 let ll = $LEFT
309 .as_any()
310 .downcast_ref::<BooleanArray>()
311 .expect("boolean_op failed to downcast array");
312 let rr = $RIGHT
313 .as_any()
314 .downcast_ref::<BooleanArray>()
315 .expect("boolean_op failed to downcast array");
316 Ok(Arc::new($OP(&ll, &rr)?))
317 }};
318 }
319
320 /// Coercion rules for all binary operators. Returns the output type
321 /// of applying `op` to an argument of `lhs_type` and `rhs_type`.
common_binary_type( lhs_type: &DataType, op: &Operator, rhs_type: &DataType, ) -> Result<DataType>322 fn common_binary_type(
323 lhs_type: &DataType,
324 op: &Operator,
325 rhs_type: &DataType,
326 ) -> Result<DataType> {
327 // This result MUST be compatible with `binary_coerce`
328 let result = match op {
329 Operator::And | Operator::Or => match (lhs_type, rhs_type) {
330 // logical binary boolean operators can only be evaluated in bools
331 (DataType::Boolean, DataType::Boolean) => Some(DataType::Boolean),
332 _ => None,
333 },
334 // logical equality operators have their own rules, and always return a boolean
335 Operator::Eq | Operator::NotEq => eq_coercion(lhs_type, rhs_type),
336 // "like" operators operate on strings and always return a boolean
337 Operator::Like | Operator::NotLike => string_coercion(lhs_type, rhs_type),
338 // order-comparison operators have their own rules
339 Operator::Lt | Operator::Gt | Operator::GtEq | Operator::LtEq => {
340 order_coercion(lhs_type, rhs_type)
341 }
342 // for math expressions, the final value of the coercion is also the return type
343 // because coercion favours higher information types
344 Operator::Plus | Operator::Minus | Operator::Divide | Operator::Multiply => {
345 numerical_coercion(lhs_type, rhs_type)
346 }
347 Operator::Modulus => {
348 return Err(DataFusionError::NotImplemented(
349 "Modulus operator is still not supported".to_string(),
350 ))
351 }
352 };
353
354 // re-write the error message of failed coercions to include the operator's information
355 match result {
356 None => Err(DataFusionError::Plan(
357 format!(
358 "'{:?} {} {:?}' can't be evaluated because there isn't a common type to coerce the types to",
359 lhs_type, op, rhs_type
360 ),
361 )),
362 Some(t) => Ok(t)
363 }
364 }
365
366 /// Returns the return type of a binary operator or an error when the binary operator cannot
367 /// perform the computation between the argument's types, even after type coercion.
368 ///
369 /// This function makes some assumptions about the underlying available computations.
binary_operator_data_type( lhs_type: &DataType, op: &Operator, rhs_type: &DataType, ) -> Result<DataType>370 pub fn binary_operator_data_type(
371 lhs_type: &DataType,
372 op: &Operator,
373 rhs_type: &DataType,
374 ) -> Result<DataType> {
375 // validate that it is possible to perform the operation on incoming types.
376 // (or the return datatype cannot be infered)
377 let common_type = common_binary_type(lhs_type, op, rhs_type)?;
378
379 match op {
380 // operators that return a boolean
381 Operator::Eq
382 | Operator::NotEq
383 | Operator::And
384 | Operator::Or
385 | Operator::Like
386 | Operator::NotLike
387 | Operator::Lt
388 | Operator::Gt
389 | Operator::GtEq
390 | Operator::LtEq => Ok(DataType::Boolean),
391 // math operations return the same value as the common coerced type
392 Operator::Plus | Operator::Minus | Operator::Divide | Operator::Multiply => {
393 Ok(common_type)
394 }
395 Operator::Modulus => Err(DataFusionError::NotImplemented(
396 "Modulus operator is still not supported".to_string(),
397 )),
398 }
399 }
400
401 impl PhysicalExpr for BinaryExpr {
402 /// Return a reference to Any that can be used for downcasting
as_any(&self) -> &dyn Any403 fn as_any(&self) -> &dyn Any {
404 self
405 }
406
data_type(&self, input_schema: &Schema) -> Result<DataType>407 fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
408 binary_operator_data_type(
409 &self.left.data_type(input_schema)?,
410 &self.op,
411 &self.right.data_type(input_schema)?,
412 )
413 }
414
nullable(&self, input_schema: &Schema) -> Result<bool>415 fn nullable(&self, input_schema: &Schema) -> Result<bool> {
416 Ok(self.left.nullable(input_schema)? || self.right.nullable(input_schema)?)
417 }
418
evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue>419 fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
420 let left_value = self.left.evaluate(batch)?;
421 let right_value = self.right.evaluate(batch)?;
422 let left_data_type = left_value.data_type();
423 let right_data_type = right_value.data_type();
424
425 if left_data_type != right_data_type {
426 return Err(DataFusionError::Internal(format!(
427 "Cannot evaluate binary expression {:?} with types {:?} and {:?}",
428 self.op, left_data_type, right_data_type
429 )));
430 }
431
432 let scalar_result = match (&left_value, &right_value) {
433 (ColumnarValue::Array(array), ColumnarValue::Scalar(scalar)) => {
434 // if left is array and right is literal - use scalar operations
435 match &self.op {
436 Operator::Lt => binary_array_op_scalar!(array, scalar.clone(), lt),
437 Operator::LtEq => {
438 binary_array_op_scalar!(array, scalar.clone(), lt_eq)
439 }
440 Operator::Gt => binary_array_op_scalar!(array, scalar.clone(), gt),
441 Operator::GtEq => {
442 binary_array_op_scalar!(array, scalar.clone(), gt_eq)
443 }
444 Operator::Eq => binary_array_op_scalar!(array, scalar.clone(), eq),
445 Operator::NotEq => {
446 binary_array_op_scalar!(array, scalar.clone(), neq)
447 }
448 Operator::Like => {
449 binary_string_array_op_scalar!(array, scalar.clone(), like)
450 }
451 Operator::NotLike => {
452 binary_string_array_op_scalar!(array, scalar.clone(), nlike)
453 }
454 Operator::Divide => {
455 binary_primitive_array_op_scalar!(array, scalar.clone(), divide)
456 }
457 // if scalar operation is not supported - fallback to array implementation
458 _ => None,
459 }
460 }
461 (ColumnarValue::Scalar(scalar), ColumnarValue::Array(array)) => {
462 // if right is literal and left is array - reverse operator and parameters
463 match &self.op {
464 Operator::Lt => binary_array_op_scalar!(array, scalar.clone(), gt),
465 Operator::LtEq => {
466 binary_array_op_scalar!(array, scalar.clone(), gt_eq)
467 }
468 Operator::Gt => binary_array_op_scalar!(array, scalar.clone(), lt),
469 Operator::GtEq => {
470 binary_array_op_scalar!(array, scalar.clone(), lt_eq)
471 }
472 Operator::Eq => binary_array_op_scalar!(array, scalar.clone(), eq),
473 Operator::NotEq => {
474 binary_array_op_scalar!(array, scalar.clone(), neq)
475 }
476 // if scalar operation is not supported - fallback to array implementation
477 _ => None,
478 }
479 }
480 (_, _) => None,
481 };
482
483 if let Some(result) = scalar_result {
484 return result.map(|a| ColumnarValue::Array(a));
485 }
486
487 // if both arrays or both literals - extract arrays and continue execution
488 let (left, right) = (
489 left_value.into_array(batch.num_rows()),
490 right_value.into_array(batch.num_rows()),
491 );
492
493 let result: Result<ArrayRef> = match &self.op {
494 Operator::Like => binary_string_array_op!(left, right, like),
495 Operator::NotLike => binary_string_array_op!(left, right, nlike),
496 Operator::Lt => binary_array_op!(left, right, lt),
497 Operator::LtEq => binary_array_op!(left, right, lt_eq),
498 Operator::Gt => binary_array_op!(left, right, gt),
499 Operator::GtEq => binary_array_op!(left, right, gt_eq),
500 Operator::Eq => binary_array_op!(left, right, eq),
501 Operator::NotEq => binary_array_op!(left, right, neq),
502 Operator::Plus => binary_primitive_array_op!(left, right, add),
503 Operator::Minus => binary_primitive_array_op!(left, right, subtract),
504 Operator::Multiply => binary_primitive_array_op!(left, right, multiply),
505 Operator::Divide => binary_primitive_array_op!(left, right, divide),
506 Operator::And => {
507 if left_data_type == DataType::Boolean {
508 boolean_op!(left, right, and)
509 } else {
510 return Err(DataFusionError::Internal(format!(
511 "Cannot evaluate binary expression {:?} with types {:?} and {:?}",
512 self.op,
513 left.data_type(),
514 right.data_type()
515 )));
516 }
517 }
518 Operator::Or => {
519 if left_data_type == DataType::Boolean {
520 boolean_op!(left, right, or)
521 } else {
522 return Err(DataFusionError::Internal(format!(
523 "Cannot evaluate binary expression {:?} with types {:?} and {:?}",
524 self.op, left_data_type, right_data_type
525 )));
526 }
527 }
528 Operator::Modulus => Err(DataFusionError::NotImplemented(
529 "Modulus operator is still not supported".to_string(),
530 )),
531 };
532 result.map(|a| ColumnarValue::Array(a))
533 }
534 }
535
536 /// return two physical expressions that are optionally coerced to a
537 /// common type that the binary operator supports.
binary_cast( lhs: Arc<dyn PhysicalExpr>, op: &Operator, rhs: Arc<dyn PhysicalExpr>, input_schema: &Schema, ) -> Result<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)>538 fn binary_cast(
539 lhs: Arc<dyn PhysicalExpr>,
540 op: &Operator,
541 rhs: Arc<dyn PhysicalExpr>,
542 input_schema: &Schema,
543 ) -> Result<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)> {
544 let lhs_type = &lhs.data_type(input_schema)?;
545 let rhs_type = &rhs.data_type(input_schema)?;
546
547 let cast_type = common_binary_type(lhs_type, op, rhs_type)?;
548
549 Ok((
550 cast(lhs, input_schema, cast_type.clone())?,
551 cast(rhs, input_schema, cast_type)?,
552 ))
553 }
554
555 /// Create a binary expression whose arguments are correctly coerced.
556 /// This function errors if it is not possible to coerce the arguments
557 /// to computational types supported by the operator.
binary( lhs: Arc<dyn PhysicalExpr>, op: Operator, rhs: Arc<dyn PhysicalExpr>, input_schema: &Schema, ) -> Result<Arc<dyn PhysicalExpr>>558 pub fn binary(
559 lhs: Arc<dyn PhysicalExpr>,
560 op: Operator,
561 rhs: Arc<dyn PhysicalExpr>,
562 input_schema: &Schema,
563 ) -> Result<Arc<dyn PhysicalExpr>> {
564 let (l, r) = binary_cast(lhs, &op, rhs, input_schema)?;
565 Ok(Arc::new(BinaryExpr::new(l, op, r)))
566 }
567
568 #[cfg(test)]
569 mod tests {
570 use arrow::datatypes::{ArrowNumericType, Field, Int32Type, SchemaRef};
571 use arrow::util::display::array_value_to_string;
572
573 use super::*;
574 use crate::error::Result;
575 use crate::physical_plan::expressions::col;
576
577 // Create a binary expression without coercion. Used here when we do not want to coerce the expressions
578 // to valid types. Usage can result in an execution (after plan) error.
binary_simple( l: Arc<dyn PhysicalExpr>, op: Operator, r: Arc<dyn PhysicalExpr>, ) -> Arc<dyn PhysicalExpr>579 fn binary_simple(
580 l: Arc<dyn PhysicalExpr>,
581 op: Operator,
582 r: Arc<dyn PhysicalExpr>,
583 ) -> Arc<dyn PhysicalExpr> {
584 Arc::new(BinaryExpr::new(l, op, r))
585 }
586
587 #[test]
binary_comparison() -> Result<()>588 fn binary_comparison() -> Result<()> {
589 let schema = Schema::new(vec![
590 Field::new("a", DataType::Int32, false),
591 Field::new("b", DataType::Int32, false),
592 ]);
593 let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
594 let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
595 let batch =
596 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])?;
597
598 // expression: "a < b"
599 let lt = binary_simple(col("a"), Operator::Lt, col("b"));
600 let result = lt.evaluate(&batch)?.into_array(batch.num_rows());
601 assert_eq!(result.len(), 5);
602
603 let expected = vec![false, false, true, true, true];
604 let result = result
605 .as_any()
606 .downcast_ref::<BooleanArray>()
607 .expect("failed to downcast to BooleanArray");
608 for (i, &expected_item) in expected.iter().enumerate().take(5) {
609 assert_eq!(result.value(i), expected_item);
610 }
611
612 Ok(())
613 }
614
615 #[test]
binary_nested() -> Result<()>616 fn binary_nested() -> Result<()> {
617 let schema = Schema::new(vec![
618 Field::new("a", DataType::Int32, false),
619 Field::new("b", DataType::Int32, false),
620 ]);
621 let a = Int32Array::from(vec![2, 4, 6, 8, 10]);
622 let b = Int32Array::from(vec![2, 5, 4, 8, 8]);
623 let batch =
624 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])?;
625
626 // expression: "a < b OR a == b"
627 let expr = binary_simple(
628 binary_simple(col("a"), Operator::Lt, col("b")),
629 Operator::Or,
630 binary_simple(col("a"), Operator::Eq, col("b")),
631 );
632 assert_eq!("a < b OR a = b", format!("{}", expr));
633
634 let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
635 assert_eq!(result.len(), 5);
636
637 let expected = vec![true, true, false, true, false];
638 let result = result
639 .as_any()
640 .downcast_ref::<BooleanArray>()
641 .expect("failed to downcast to BooleanArray");
642 for (i, &expected_item) in expected.iter().enumerate().take(5) {
643 assert_eq!(result.value(i), expected_item);
644 }
645
646 Ok(())
647 }
648
649 // runs an end-to-end test of physical type coercion:
650 // 1. construct a record batch with two columns of type A and B
651 // (*_ARRAY is the Rust Arrow array type, and *_TYPE is the DataType of the elements)
652 // 2. construct a physical expression of A OP B
653 // 3. evaluate the expression
654 // 4. verify that the resulting expression is of type C
655 // 5. verify that the results of evaluation are $VEC
656 macro_rules! test_coercion {
657 ($A_ARRAY:ident, $A_TYPE:expr, $A_VEC:expr, $B_ARRAY:ident, $B_TYPE:expr, $B_VEC:expr, $OP:expr, $C_ARRAY:ident, $C_TYPE:expr, $VEC:expr) => {{
658 let schema = Schema::new(vec![
659 Field::new("a", $A_TYPE, false),
660 Field::new("b", $B_TYPE, false),
661 ]);
662 let a = $A_ARRAY::from($A_VEC);
663 let b = $B_ARRAY::from($B_VEC);
664 let batch = RecordBatch::try_new(
665 Arc::new(schema.clone()),
666 vec![Arc::new(a), Arc::new(b)],
667 )?;
668
669 // verify that we can construct the expression
670 let expression = binary(col("a"), $OP, col("b"), &schema)?;
671
672 // verify that the expression's type is correct
673 assert_eq!(expression.data_type(&schema)?, $C_TYPE);
674
675 // compute
676 let result = expression.evaluate(&batch)?.into_array(batch.num_rows());
677
678 // verify that the array's data_type is correct
679 assert_eq!(*result.data_type(), $C_TYPE);
680
681 // verify that the data itself is downcastable
682 let result = result
683 .as_any()
684 .downcast_ref::<$C_ARRAY>()
685 .expect("failed to downcast");
686 // verify that the result itself is correct
687 for (i, x) in $VEC.iter().enumerate() {
688 assert_eq!(result.value(i), *x);
689 }
690 }};
691 }
692
693 #[test]
test_type_coersion() -> Result<()>694 fn test_type_coersion() -> Result<()> {
695 test_coercion!(
696 Int32Array,
697 DataType::Int32,
698 vec![1i32, 2i32],
699 UInt32Array,
700 DataType::UInt32,
701 vec![1u32, 2u32],
702 Operator::Plus,
703 Int32Array,
704 DataType::Int32,
705 vec![2i32, 4i32]
706 );
707 test_coercion!(
708 Int32Array,
709 DataType::Int32,
710 vec![1i32],
711 UInt16Array,
712 DataType::UInt16,
713 vec![1u16],
714 Operator::Plus,
715 Int32Array,
716 DataType::Int32,
717 vec![2i32]
718 );
719 test_coercion!(
720 Float32Array,
721 DataType::Float32,
722 vec![1f32],
723 UInt16Array,
724 DataType::UInt16,
725 vec![1u16],
726 Operator::Plus,
727 Float32Array,
728 DataType::Float32,
729 vec![2f32]
730 );
731 test_coercion!(
732 Float32Array,
733 DataType::Float32,
734 vec![2f32],
735 UInt16Array,
736 DataType::UInt16,
737 vec![1u16],
738 Operator::Multiply,
739 Float32Array,
740 DataType::Float32,
741 vec![2f32]
742 );
743 test_coercion!(
744 StringArray,
745 DataType::Utf8,
746 vec!["hello world", "world"],
747 StringArray,
748 DataType::Utf8,
749 vec!["%hello%", "%hello%"],
750 Operator::Like,
751 BooleanArray,
752 DataType::Boolean,
753 vec![true, false]
754 );
755 test_coercion!(
756 StringArray,
757 DataType::Utf8,
758 vec!["1994-12-13", "1995-01-26"],
759 Date32Array,
760 DataType::Date32,
761 vec![9112, 9156],
762 Operator::Eq,
763 BooleanArray,
764 DataType::Boolean,
765 vec![true, true]
766 );
767 test_coercion!(
768 StringArray,
769 DataType::Utf8,
770 vec!["1994-12-13", "1995-01-26"],
771 Date32Array,
772 DataType::Date32,
773 vec![9113, 9154],
774 Operator::Lt,
775 BooleanArray,
776 DataType::Boolean,
777 vec![true, false]
778 );
779 test_coercion!(
780 StringArray,
781 DataType::Utf8,
782 vec!["1994-12-13T12:34:56", "1995-01-26T01:23:45"],
783 Date64Array,
784 DataType::Date64,
785 vec![787322096000, 791083425000],
786 Operator::Eq,
787 BooleanArray,
788 DataType::Boolean,
789 vec![true, true]
790 );
791 test_coercion!(
792 StringArray,
793 DataType::Utf8,
794 vec!["1994-12-13T12:34:56", "1995-01-26T01:23:45"],
795 Date64Array,
796 DataType::Date64,
797 vec![787322096001, 791083424999],
798 Operator::Lt,
799 BooleanArray,
800 DataType::Boolean,
801 vec![true, false]
802 );
803 Ok(())
804 }
805
806 // Note it would be nice to use the same test_coercion macro as
807 // above, but sadly the type of the values of the dictionary are
808 // not encoded in the rust type of the DictionaryArray. Thus there
809 // is no way at the time of this writing to create a dictionary
810 // array using the `From` trait
811 #[test]
test_dictionary_type_to_array_coersion() -> Result<()>812 fn test_dictionary_type_to_array_coersion() -> Result<()> {
813 // Test string a string dictionary
814 let dict_type =
815 DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8));
816 let string_type = DataType::Utf8;
817
818 // build dictionary
819 let keys_builder = PrimitiveBuilder::<Int32Type>::new(10);
820 let values_builder = arrow::array::StringBuilder::new(10);
821 let mut dict_builder = StringDictionaryBuilder::new(keys_builder, values_builder);
822
823 dict_builder.append("one")?;
824 dict_builder.append_null()?;
825 dict_builder.append("three")?;
826 dict_builder.append("four")?;
827 let dict_array = dict_builder.finish();
828
829 let str_array =
830 StringArray::from(vec![Some("not one"), Some("two"), None, Some("four")]);
831
832 let schema = Arc::new(Schema::new(vec![
833 Field::new("dict", dict_type, true),
834 Field::new("str", string_type, true),
835 ]));
836
837 let batch = RecordBatch::try_new(
838 schema.clone(),
839 vec![Arc::new(dict_array), Arc::new(str_array)],
840 )?;
841
842 let expected = "false\n\n\ntrue";
843
844 // Test 1: dict = str
845
846 // verify that we can construct the expression
847 let expression = binary(col("dict"), Operator::Eq, col("str"), &schema)?;
848 assert_eq!(expression.data_type(&schema)?, DataType::Boolean);
849
850 // evaluate and verify the result type matched
851 let result = expression.evaluate(&batch)?.into_array(batch.num_rows());
852 assert_eq!(result.data_type(), &DataType::Boolean);
853
854 // verify that the result itself is correct
855 assert_eq!(expected, array_to_string(&result)?);
856
857 // Test 2: now test the other direction
858 // str = dict
859
860 // verify that we can construct the expression
861 let expression = binary(col("str"), Operator::Eq, col("dict"), &schema)?;
862 assert_eq!(expression.data_type(&schema)?, DataType::Boolean);
863
864 // evaluate and verify the result type matched
865 let result = expression.evaluate(&batch)?.into_array(batch.num_rows());
866 assert_eq!(result.data_type(), &DataType::Boolean);
867
868 // verify that the result itself is correct
869 assert_eq!(expected, array_to_string(&result)?);
870
871 Ok(())
872 }
873
874 // Convert the array to a newline delimited string of pretty printed values
array_to_string(array: &ArrayRef) -> Result<String>875 fn array_to_string(array: &ArrayRef) -> Result<String> {
876 let s = (0..array.len())
877 .map(|i| array_value_to_string(array, i))
878 .collect::<std::result::Result<Vec<_>, arrow::error::ArrowError>>()?
879 .join("\n");
880 Ok(s)
881 }
882
883 #[test]
plus_op() -> Result<()>884 fn plus_op() -> Result<()> {
885 let schema = Schema::new(vec![
886 Field::new("a", DataType::Int32, false),
887 Field::new("b", DataType::Int32, false),
888 ]);
889 let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
890 let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
891
892 apply_arithmetic::<Int32Type>(
893 Arc::new(schema),
894 vec![Arc::new(a), Arc::new(b)],
895 Operator::Plus,
896 Int32Array::from(vec![2, 4, 7, 12, 21]),
897 )?;
898
899 Ok(())
900 }
901
902 #[test]
minus_op() -> Result<()>903 fn minus_op() -> Result<()> {
904 let schema = Arc::new(Schema::new(vec![
905 Field::new("a", DataType::Int32, false),
906 Field::new("b", DataType::Int32, false),
907 ]));
908 let a = Arc::new(Int32Array::from(vec![1, 2, 4, 8, 16]));
909 let b = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
910
911 apply_arithmetic::<Int32Type>(
912 schema.clone(),
913 vec![a.clone(), b.clone()],
914 Operator::Minus,
915 Int32Array::from(vec![0, 0, 1, 4, 11]),
916 )?;
917
918 // should handle have negative values in result (for signed)
919 apply_arithmetic::<Int32Type>(
920 schema,
921 vec![b, a],
922 Operator::Minus,
923 Int32Array::from(vec![0, 0, -1, -4, -11]),
924 )?;
925
926 Ok(())
927 }
928
929 #[test]
multiply_op() -> Result<()>930 fn multiply_op() -> Result<()> {
931 let schema = Arc::new(Schema::new(vec![
932 Field::new("a", DataType::Int32, false),
933 Field::new("b", DataType::Int32, false),
934 ]));
935 let a = Arc::new(Int32Array::from(vec![4, 8, 16, 32, 64]));
936 let b = Arc::new(Int32Array::from(vec![2, 4, 8, 16, 32]));
937
938 apply_arithmetic::<Int32Type>(
939 schema,
940 vec![a, b],
941 Operator::Multiply,
942 Int32Array::from(vec![8, 32, 128, 512, 2048]),
943 )?;
944
945 Ok(())
946 }
947
948 #[test]
divide_op() -> Result<()>949 fn divide_op() -> Result<()> {
950 let schema = Arc::new(Schema::new(vec![
951 Field::new("a", DataType::Int32, false),
952 Field::new("b", DataType::Int32, false),
953 ]));
954 let a = Arc::new(Int32Array::from(vec![8, 32, 128, 512, 2048]));
955 let b = Arc::new(Int32Array::from(vec![2, 4, 8, 16, 32]));
956
957 apply_arithmetic::<Int32Type>(
958 schema,
959 vec![a, b],
960 Operator::Divide,
961 Int32Array::from(vec![4, 8, 16, 32, 64]),
962 )?;
963
964 Ok(())
965 }
966
apply_arithmetic<T: ArrowNumericType>( schema: SchemaRef, data: Vec<ArrayRef>, op: Operator, expected: PrimitiveArray<T>, ) -> Result<()>967 fn apply_arithmetic<T: ArrowNumericType>(
968 schema: SchemaRef,
969 data: Vec<ArrayRef>,
970 op: Operator,
971 expected: PrimitiveArray<T>,
972 ) -> Result<()> {
973 let arithmetic_op = binary_simple(col("a"), op, col("b"));
974 let batch = RecordBatch::try_new(schema, data)?;
975 let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows());
976
977 assert_eq!(result.as_ref(), &expected);
978 Ok(())
979 }
980
981 #[test]
test_coersion_error() -> Result<()>982 fn test_coersion_error() -> Result<()> {
983 let expr =
984 common_binary_type(&DataType::Float32, &Operator::Plus, &DataType::Utf8);
985
986 if let Err(DataFusionError::Plan(e)) = expr {
987 assert_eq!(e, "'Float32 + Utf8' can't be evaluated because there isn't a common type to coerce the types to");
988 Ok(())
989 } else {
990 Err(DataFusionError::Internal(
991 "Coercion should have returned an DataFusionError::Internal".to_string(),
992 ))
993 }
994 }
995 }
996