1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements.  See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership.  The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License.  You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied.  See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17 
18 //! Defines physical expressions that can evaluated at runtime during query execution
19 
20 use std::cell::RefCell;
21 use std::rc::Rc;
22 use std::sync::Arc;
23 
24 use crate::error::{ExecutionError, Result};
25 use crate::execution::physical_plan::common::get_scalar_value;
26 use crate::execution::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr};
27 use crate::logicalplan::{Operator, ScalarValue};
28 use arrow::array::{
29     ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array,
30     Int64Array, Int8Array, StringArray, TimestampNanosecondArray, UInt16Array,
31     UInt32Array, UInt64Array, UInt8Array,
32 };
33 use arrow::array::{
34     Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder,
35     Int8Builder, StringBuilder, UInt16Builder, UInt32Builder, UInt64Builder,
36     UInt8Builder,
37 };
38 use arrow::compute;
39 use arrow::compute::kernels::arithmetic::{add, divide, multiply, subtract};
40 use arrow::compute::kernels::boolean::{and, or};
41 use arrow::compute::kernels::cast::cast;
42 use arrow::compute::kernels::comparison::{eq, gt, gt_eq, lt, lt_eq, neq};
43 use arrow::compute::kernels::comparison::{
44     eq_utf8, gt_eq_utf8, gt_utf8, like_utf8, lt_eq_utf8, lt_utf8, neq_utf8, nlike_utf8,
45 };
46 use arrow::datatypes::{DataType, Schema, TimeUnit};
47 use arrow::record_batch::RecordBatch;
48 
49 /// Represents an aliased expression
50 pub struct Alias {
51     expr: Arc<dyn PhysicalExpr>,
52     alias: String,
53 }
54 
55 impl Alias {
56     /// Create a new aliased expression
new(expr: Arc<dyn PhysicalExpr>, alias: &str) -> Self57     pub fn new(expr: Arc<dyn PhysicalExpr>, alias: &str) -> Self {
58         Self {
59             expr: expr.clone(),
60             alias: alias.to_owned(),
61         }
62     }
63 }
64 
65 impl PhysicalExpr for Alias {
name(&self) -> String66     fn name(&self) -> String {
67         self.alias.clone()
68     }
69 
data_type(&self, input_schema: &Schema) -> Result<DataType>70     fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
71         self.expr.data_type(input_schema)
72     }
73 
evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef>74     fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
75         self.expr.evaluate(batch)
76     }
77 }
78 
79 /// Represents the column at a given index in a RecordBatch
80 pub struct Column {
81     index: usize,
82     name: String,
83 }
84 
85 impl Column {
86     /// Create a new column expression
new(index: usize, name: &str) -> Self87     pub fn new(index: usize, name: &str) -> Self {
88         Self {
89             index,
90             name: name.to_owned(),
91         }
92     }
93 }
94 
95 impl PhysicalExpr for Column {
96     /// Get the name to use in a schema to represent the result of this expression
name(&self) -> String97     fn name(&self) -> String {
98         self.name.clone()
99     }
100 
101     /// Get the data type of this expression, given the schema of the input
data_type(&self, input_schema: &Schema) -> Result<DataType>102     fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
103         Ok(input_schema.field(self.index).data_type().clone())
104     }
105 
106     /// Evaluate the expression
evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef>107     fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
108         Ok(batch.column(self.index).clone())
109     }
110 }
111 
112 /// Create a column expression
col(i: usize, schema: &Schema) -> Arc<dyn PhysicalExpr>113 pub fn col(i: usize, schema: &Schema) -> Arc<dyn PhysicalExpr> {
114     Arc::new(Column::new(i, &schema.field(i).name()))
115 }
116 
117 /// SUM aggregate expression
118 pub struct Sum {
119     expr: Arc<dyn PhysicalExpr>,
120 }
121 
122 impl Sum {
123     /// Create a new SUM aggregate function
new(expr: Arc<dyn PhysicalExpr>) -> Self124     pub fn new(expr: Arc<dyn PhysicalExpr>) -> Self {
125         Self { expr }
126     }
127 }
128 
129 impl AggregateExpr for Sum {
name(&self) -> String130     fn name(&self) -> String {
131         "SUM".to_string()
132     }
133 
data_type(&self, input_schema: &Schema) -> Result<DataType>134     fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
135         match self.expr.data_type(input_schema)? {
136             DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
137                 Ok(DataType::Int64)
138             }
139             DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
140                 Ok(DataType::UInt64)
141             }
142             DataType::Float32 => Ok(DataType::Float32),
143             DataType::Float64 => Ok(DataType::Float64),
144             other => Err(ExecutionError::General(format!(
145                 "SUM does not support {:?}",
146                 other
147             ))),
148         }
149     }
150 
evaluate_input(&self, batch: &RecordBatch) -> Result<ArrayRef>151     fn evaluate_input(&self, batch: &RecordBatch) -> Result<ArrayRef> {
152         self.expr.evaluate(batch)
153     }
154 
create_accumulator(&self) -> Rc<RefCell<dyn Accumulator>>155     fn create_accumulator(&self) -> Rc<RefCell<dyn Accumulator>> {
156         Rc::new(RefCell::new(SumAccumulator { sum: None }))
157     }
158 
create_reducer(&self, column_index: usize) -> Arc<dyn AggregateExpr>159     fn create_reducer(&self, column_index: usize) -> Arc<dyn AggregateExpr> {
160         Arc::new(Sum::new(Arc::new(Column::new(column_index, &self.name()))))
161     }
162 }
163 
164 macro_rules! sum_accumulate {
165     ($SELF:ident, $VALUE:expr, $ARRAY_TYPE:ident, $SCALAR_VARIANT:ident, $TY:ty) => {{
166         $SELF.sum = match $SELF.sum {
167             Some(ScalarValue::$SCALAR_VARIANT(n)) => {
168                 Some(ScalarValue::$SCALAR_VARIANT(n + $VALUE as $TY))
169             }
170             Some(_) => {
171                 return Err(ExecutionError::InternalError(
172                     "Unexpected ScalarValue variant".to_string(),
173                 ))
174             }
175             None => Some(ScalarValue::$SCALAR_VARIANT($VALUE as $TY)),
176         };
177     }};
178 }
179 
180 struct SumAccumulator {
181     sum: Option<ScalarValue>,
182 }
183 
184 impl Accumulator for SumAccumulator {
accumulate_scalar(&mut self, value: Option<ScalarValue>) -> Result<()>185     fn accumulate_scalar(&mut self, value: Option<ScalarValue>) -> Result<()> {
186         if let Some(value) = value {
187             match value {
188                 ScalarValue::Int8(value) => {
189                     sum_accumulate!(self, value, Int8Array, Int64, i64);
190                 }
191                 ScalarValue::Int16(value) => {
192                     sum_accumulate!(self, value, Int16Array, Int64, i64);
193                 }
194                 ScalarValue::Int32(value) => {
195                     sum_accumulate!(self, value, Int32Array, Int64, i64);
196                 }
197                 ScalarValue::Int64(value) => {
198                     sum_accumulate!(self, value, Int64Array, Int64, i64);
199                 }
200                 ScalarValue::UInt8(value) => {
201                     sum_accumulate!(self, value, UInt8Array, UInt64, u64);
202                 }
203                 ScalarValue::UInt16(value) => {
204                     sum_accumulate!(self, value, UInt16Array, UInt64, u64);
205                 }
206                 ScalarValue::UInt32(value) => {
207                     sum_accumulate!(self, value, UInt32Array, UInt64, u64);
208                 }
209                 ScalarValue::UInt64(value) => {
210                     sum_accumulate!(self, value, UInt64Array, UInt64, u64);
211                 }
212                 ScalarValue::Float32(value) => {
213                     sum_accumulate!(self, value, Float32Array, Float32, f32);
214                 }
215                 ScalarValue::Float64(value) => {
216                     sum_accumulate!(self, value, Float64Array, Float64, f64);
217                 }
218                 other => {
219                     return Err(ExecutionError::General(format!(
220                         "SUM does not support {:?}",
221                         other
222                     )))
223                 }
224             }
225         }
226         Ok(())
227     }
228 
accumulate_batch(&mut self, array: &ArrayRef) -> Result<()>229     fn accumulate_batch(&mut self, array: &ArrayRef) -> Result<()> {
230         let sum = match array.data_type() {
231             DataType::UInt8 => {
232                 match compute::sum(array.as_any().downcast_ref::<UInt8Array>().unwrap()) {
233                     Some(n) => Ok(Some(ScalarValue::UInt8(n))),
234                     None => Ok(None),
235                 }
236             }
237             DataType::UInt16 => {
238                 match compute::sum(array.as_any().downcast_ref::<UInt16Array>().unwrap())
239                 {
240                     Some(n) => Ok(Some(ScalarValue::UInt16(n))),
241                     None => Ok(None),
242                 }
243             }
244             DataType::UInt32 => {
245                 match compute::sum(array.as_any().downcast_ref::<UInt32Array>().unwrap())
246                 {
247                     Some(n) => Ok(Some(ScalarValue::UInt32(n))),
248                     None => Ok(None),
249                 }
250             }
251             DataType::UInt64 => {
252                 match compute::sum(array.as_any().downcast_ref::<UInt64Array>().unwrap())
253                 {
254                     Some(n) => Ok(Some(ScalarValue::UInt64(n))),
255                     None => Ok(None),
256                 }
257             }
258             DataType::Int8 => {
259                 match compute::sum(array.as_any().downcast_ref::<Int8Array>().unwrap()) {
260                     Some(n) => Ok(Some(ScalarValue::Int8(n))),
261                     None => Ok(None),
262                 }
263             }
264             DataType::Int16 => {
265                 match compute::sum(array.as_any().downcast_ref::<Int16Array>().unwrap()) {
266                     Some(n) => Ok(Some(ScalarValue::Int16(n))),
267                     None => Ok(None),
268                 }
269             }
270             DataType::Int32 => {
271                 match compute::sum(array.as_any().downcast_ref::<Int32Array>().unwrap()) {
272                     Some(n) => Ok(Some(ScalarValue::Int32(n))),
273                     None => Ok(None),
274                 }
275             }
276             DataType::Int64 => {
277                 match compute::sum(array.as_any().downcast_ref::<Int64Array>().unwrap()) {
278                     Some(n) => Ok(Some(ScalarValue::Int64(n))),
279                     None => Ok(None),
280                 }
281             }
282             DataType::Float32 => {
283                 match compute::sum(array.as_any().downcast_ref::<Float32Array>().unwrap())
284                 {
285                     Some(n) => Ok(Some(ScalarValue::Float32(n))),
286                     None => Ok(None),
287                 }
288             }
289             DataType::Float64 => {
290                 match compute::sum(array.as_any().downcast_ref::<Float64Array>().unwrap())
291                 {
292                     Some(n) => Ok(Some(ScalarValue::Float64(n))),
293                     None => Ok(None),
294                 }
295             }
296             _ => Err(ExecutionError::ExecutionError(
297                 "Unsupported data type for SUM".to_string(),
298             )),
299         }?;
300         self.accumulate_scalar(sum)
301     }
302 
get_value(&self) -> Result<Option<ScalarValue>>303     fn get_value(&self) -> Result<Option<ScalarValue>> {
304         Ok(self.sum.clone())
305     }
306 }
307 
308 /// Create a sum expression
sum(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn AggregateExpr>309 pub fn sum(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn AggregateExpr> {
310     Arc::new(Sum::new(expr))
311 }
312 
313 /// AVG aggregate expression
314 pub struct Avg {
315     expr: Arc<dyn PhysicalExpr>,
316 }
317 
318 impl Avg {
319     /// Create a new AVG aggregate function
new(expr: Arc<dyn PhysicalExpr>) -> Self320     pub fn new(expr: Arc<dyn PhysicalExpr>) -> Self {
321         Self { expr }
322     }
323 }
324 
325 impl AggregateExpr for Avg {
name(&self) -> String326     fn name(&self) -> String {
327         "AVG".to_string()
328     }
329 
data_type(&self, input_schema: &Schema) -> Result<DataType>330     fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
331         match self.expr.data_type(input_schema)? {
332             DataType::Int8
333             | DataType::Int16
334             | DataType::Int32
335             | DataType::Int64
336             | DataType::UInt8
337             | DataType::UInt16
338             | DataType::UInt32
339             | DataType::UInt64
340             | DataType::Float32
341             | DataType::Float64 => Ok(DataType::Float64),
342             other => Err(ExecutionError::General(format!(
343                 "AVG does not support {:?}",
344                 other
345             ))),
346         }
347     }
348 
evaluate_input(&self, batch: &RecordBatch) -> Result<ArrayRef>349     fn evaluate_input(&self, batch: &RecordBatch) -> Result<ArrayRef> {
350         self.expr.evaluate(batch)
351     }
352 
create_accumulator(&self) -> Rc<RefCell<dyn Accumulator>>353     fn create_accumulator(&self) -> Rc<RefCell<dyn Accumulator>> {
354         Rc::new(RefCell::new(AvgAccumulator {
355             sum: None,
356             count: None,
357         }))
358     }
359 
create_reducer(&self, column_index: usize) -> Arc<dyn AggregateExpr>360     fn create_reducer(&self, column_index: usize) -> Arc<dyn AggregateExpr> {
361         Arc::new(Avg::new(Arc::new(Column::new(column_index, &self.name()))))
362     }
363 }
364 
365 macro_rules! avg_accumulate {
366     ($SELF:ident, $VALUE:expr, $ARRAY_TYPE:ident) => {{
367         match ($SELF.sum, $SELF.count) {
368             (Some(sum), Some(count)) => {
369                 $SELF.sum = Some(sum + $VALUE as f64);
370                 $SELF.count = Some(count + 1);
371             }
372             _ => {
373                 $SELF.sum = Some($VALUE as f64);
374                 $SELF.count = Some(1);
375             }
376         };
377     }};
378 }
379 struct AvgAccumulator {
380     sum: Option<f64>,
381     count: Option<i64>,
382 }
383 
384 impl Accumulator for AvgAccumulator {
accumulate_scalar(&mut self, value: Option<ScalarValue>) -> Result<()>385     fn accumulate_scalar(&mut self, value: Option<ScalarValue>) -> Result<()> {
386         if let Some(value) = value {
387             match value {
388                 ScalarValue::Int8(value) => avg_accumulate!(self, value, Int8Array),
389                 ScalarValue::Int16(value) => avg_accumulate!(self, value, Int16Array),
390                 ScalarValue::Int32(value) => avg_accumulate!(self, value, Int32Array),
391                 ScalarValue::Int64(value) => avg_accumulate!(self, value, Int64Array),
392                 ScalarValue::UInt8(value) => avg_accumulate!(self, value, UInt8Array),
393                 ScalarValue::UInt16(value) => avg_accumulate!(self, value, UInt16Array),
394                 ScalarValue::UInt32(value) => avg_accumulate!(self, value, UInt32Array),
395                 ScalarValue::UInt64(value) => avg_accumulate!(self, value, UInt64Array),
396                 ScalarValue::Float32(value) => avg_accumulate!(self, value, Float32Array),
397                 ScalarValue::Float64(value) => avg_accumulate!(self, value, Float64Array),
398                 other => {
399                     return Err(ExecutionError::General(format!(
400                         "AVG does not support {:?}",
401                         other
402                     )))
403                 }
404             }
405         }
406         Ok(())
407     }
408 
accumulate_batch(&mut self, array: &ArrayRef) -> Result<()>409     fn accumulate_batch(&mut self, array: &ArrayRef) -> Result<()> {
410         for row in 0..array.len() {
411             self.accumulate_scalar(get_scalar_value(array, row)?)?;
412         }
413         Ok(())
414     }
415 
get_value(&self) -> Result<Option<ScalarValue>>416     fn get_value(&self) -> Result<Option<ScalarValue>> {
417         match (self.sum, self.count) {
418             (Some(sum), Some(count)) => {
419                 Ok(Some(ScalarValue::Float64(sum / count as f64)))
420             }
421             _ => Ok(None),
422         }
423     }
424 }
425 
426 /// Create a avg expression
avg(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn AggregateExpr>427 pub fn avg(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn AggregateExpr> {
428     Arc::new(Avg::new(expr))
429 }
430 
431 /// MAX aggregate expression
432 pub struct Max {
433     expr: Arc<dyn PhysicalExpr>,
434 }
435 
436 impl Max {
437     /// Create a new MAX aggregate function
new(expr: Arc<dyn PhysicalExpr>) -> Self438     pub fn new(expr: Arc<dyn PhysicalExpr>) -> Self {
439         Self { expr }
440     }
441 }
442 
443 impl AggregateExpr for Max {
name(&self) -> String444     fn name(&self) -> String {
445         "MAX".to_string()
446     }
447 
data_type(&self, input_schema: &Schema) -> Result<DataType>448     fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
449         match self.expr.data_type(input_schema)? {
450             DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
451                 Ok(DataType::Int64)
452             }
453             DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
454                 Ok(DataType::UInt64)
455             }
456             DataType::Float32 => Ok(DataType::Float32),
457             DataType::Float64 => Ok(DataType::Float64),
458             other => Err(ExecutionError::General(format!(
459                 "MAX does not support {:?}",
460                 other
461             ))),
462         }
463     }
464 
evaluate_input(&self, batch: &RecordBatch) -> Result<ArrayRef>465     fn evaluate_input(&self, batch: &RecordBatch) -> Result<ArrayRef> {
466         self.expr.evaluate(batch)
467     }
468 
create_accumulator(&self) -> Rc<RefCell<dyn Accumulator>>469     fn create_accumulator(&self) -> Rc<RefCell<dyn Accumulator>> {
470         Rc::new(RefCell::new(MaxAccumulator { max: None }))
471     }
472 
create_reducer(&self, column_index: usize) -> Arc<dyn AggregateExpr>473     fn create_reducer(&self, column_index: usize) -> Arc<dyn AggregateExpr> {
474         Arc::new(Max::new(Arc::new(Column::new(column_index, &self.name()))))
475     }
476 }
477 
478 macro_rules! max_accumulate {
479     ($SELF:ident, $VALUE:expr, $ARRAY_TYPE:ident, $SCALAR_VARIANT:ident, $TY:ty) => {{
480         $SELF.max = match $SELF.max {
481             Some(ScalarValue::$SCALAR_VARIANT(n)) => {
482                 if n > ($VALUE as $TY) {
483                     Some(ScalarValue::$SCALAR_VARIANT(n))
484                 } else {
485                     Some(ScalarValue::$SCALAR_VARIANT($VALUE as $TY))
486                 }
487             }
488             Some(_) => {
489                 return Err(ExecutionError::InternalError(
490                     "Unexpected ScalarValue variant".to_string(),
491                 ))
492             }
493             None => Some(ScalarValue::$SCALAR_VARIANT($VALUE as $TY)),
494         };
495     }};
496 }
497 struct MaxAccumulator {
498     max: Option<ScalarValue>,
499 }
500 
501 impl Accumulator for MaxAccumulator {
accumulate_scalar(&mut self, value: Option<ScalarValue>) -> Result<()>502     fn accumulate_scalar(&mut self, value: Option<ScalarValue>) -> Result<()> {
503         if let Some(value) = value {
504             match value {
505                 ScalarValue::Int8(value) => {
506                     max_accumulate!(self, value, Int8Array, Int64, i64);
507                 }
508                 ScalarValue::Int16(value) => {
509                     max_accumulate!(self, value, Int16Array, Int64, i64)
510                 }
511                 ScalarValue::Int32(value) => {
512                     max_accumulate!(self, value, Int32Array, Int64, i64)
513                 }
514                 ScalarValue::Int64(value) => {
515                     max_accumulate!(self, value, Int64Array, Int64, i64)
516                 }
517                 ScalarValue::UInt8(value) => {
518                     max_accumulate!(self, value, UInt8Array, UInt64, u64)
519                 }
520                 ScalarValue::UInt16(value) => {
521                     max_accumulate!(self, value, UInt16Array, UInt64, u64)
522                 }
523                 ScalarValue::UInt32(value) => {
524                     max_accumulate!(self, value, UInt32Array, UInt64, u64)
525                 }
526                 ScalarValue::UInt64(value) => {
527                     max_accumulate!(self, value, UInt64Array, UInt64, u64)
528                 }
529                 ScalarValue::Float32(value) => {
530                     max_accumulate!(self, value, Float32Array, Float32, f32)
531                 }
532                 ScalarValue::Float64(value) => {
533                     max_accumulate!(self, value, Float64Array, Float64, f64)
534                 }
535                 other => {
536                     return Err(ExecutionError::General(format!(
537                         "MAX does not support {:?}",
538                         other
539                     )))
540                 }
541             }
542         }
543         Ok(())
544     }
545 
accumulate_batch(&mut self, array: &ArrayRef) -> Result<()>546     fn accumulate_batch(&mut self, array: &ArrayRef) -> Result<()> {
547         let max = match array.data_type() {
548             DataType::UInt8 => {
549                 match compute::max(array.as_any().downcast_ref::<UInt8Array>().unwrap()) {
550                     Some(n) => Ok(Some(ScalarValue::UInt8(n))),
551                     None => Ok(None),
552                 }
553             }
554             DataType::UInt16 => {
555                 match compute::max(array.as_any().downcast_ref::<UInt16Array>().unwrap())
556                 {
557                     Some(n) => Ok(Some(ScalarValue::UInt16(n))),
558                     None => Ok(None),
559                 }
560             }
561             DataType::UInt32 => {
562                 match compute::max(array.as_any().downcast_ref::<UInt32Array>().unwrap())
563                 {
564                     Some(n) => Ok(Some(ScalarValue::UInt32(n))),
565                     None => Ok(None),
566                 }
567             }
568             DataType::UInt64 => {
569                 match compute::max(array.as_any().downcast_ref::<UInt64Array>().unwrap())
570                 {
571                     Some(n) => Ok(Some(ScalarValue::UInt64(n))),
572                     None => Ok(None),
573                 }
574             }
575             DataType::Int8 => {
576                 match compute::max(array.as_any().downcast_ref::<Int8Array>().unwrap()) {
577                     Some(n) => Ok(Some(ScalarValue::Int8(n))),
578                     None => Ok(None),
579                 }
580             }
581             DataType::Int16 => {
582                 match compute::max(array.as_any().downcast_ref::<Int16Array>().unwrap()) {
583                     Some(n) => Ok(Some(ScalarValue::Int16(n))),
584                     None => Ok(None),
585                 }
586             }
587             DataType::Int32 => {
588                 match compute::max(array.as_any().downcast_ref::<Int32Array>().unwrap()) {
589                     Some(n) => Ok(Some(ScalarValue::Int32(n))),
590                     None => Ok(None),
591                 }
592             }
593             DataType::Int64 => {
594                 match compute::max(array.as_any().downcast_ref::<Int64Array>().unwrap()) {
595                     Some(n) => Ok(Some(ScalarValue::Int64(n))),
596                     None => Ok(None),
597                 }
598             }
599             DataType::Float32 => {
600                 match compute::max(array.as_any().downcast_ref::<Float32Array>().unwrap())
601                 {
602                     Some(n) => Ok(Some(ScalarValue::Float32(n))),
603                     None => Ok(None),
604                 }
605             }
606             DataType::Float64 => {
607                 match compute::max(array.as_any().downcast_ref::<Float64Array>().unwrap())
608                 {
609                     Some(n) => Ok(Some(ScalarValue::Float64(n))),
610                     None => Ok(None),
611                 }
612             }
613             _ => Err(ExecutionError::ExecutionError(
614                 "Unsupported data type for MAX".to_string(),
615             )),
616         }?;
617         self.accumulate_scalar(max)
618     }
619 
get_value(&self) -> Result<Option<ScalarValue>>620     fn get_value(&self) -> Result<Option<ScalarValue>> {
621         Ok(self.max.clone())
622     }
623 }
624 
625 /// Create a max expression
max(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn AggregateExpr>626 pub fn max(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn AggregateExpr> {
627     Arc::new(Max::new(expr))
628 }
629 
630 /// MIN aggregate expression
631 pub struct Min {
632     expr: Arc<dyn PhysicalExpr>,
633 }
634 
635 impl Min {
636     /// Create a new MIN aggregate function
new(expr: Arc<dyn PhysicalExpr>) -> Self637     pub fn new(expr: Arc<dyn PhysicalExpr>) -> Self {
638         Self { expr }
639     }
640 }
641 
642 impl AggregateExpr for Min {
name(&self) -> String643     fn name(&self) -> String {
644         "MIN".to_string()
645     }
646 
data_type(&self, input_schema: &Schema) -> Result<DataType>647     fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
648         match self.expr.data_type(input_schema)? {
649             DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
650                 Ok(DataType::Int64)
651             }
652             DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
653                 Ok(DataType::UInt64)
654             }
655             DataType::Float32 => Ok(DataType::Float32),
656             DataType::Float64 => Ok(DataType::Float64),
657             other => Err(ExecutionError::General(format!(
658                 "MIN does not support {:?}",
659                 other
660             ))),
661         }
662     }
663 
evaluate_input(&self, batch: &RecordBatch) -> Result<ArrayRef>664     fn evaluate_input(&self, batch: &RecordBatch) -> Result<ArrayRef> {
665         self.expr.evaluate(batch)
666     }
667 
create_accumulator(&self) -> Rc<RefCell<dyn Accumulator>>668     fn create_accumulator(&self) -> Rc<RefCell<dyn Accumulator>> {
669         Rc::new(RefCell::new(MinAccumulator { min: None }))
670     }
671 
create_reducer(&self, column_index: usize) -> Arc<dyn AggregateExpr>672     fn create_reducer(&self, column_index: usize) -> Arc<dyn AggregateExpr> {
673         Arc::new(Min::new(Arc::new(Column::new(column_index, &self.name()))))
674     }
675 }
676 
677 macro_rules! min_accumulate {
678     ($SELF:ident, $VALUE:expr, $ARRAY_TYPE:ident, $SCALAR_VARIANT:ident, $TY:ty) => {{
679         $SELF.min = match $SELF.min {
680             Some(ScalarValue::$SCALAR_VARIANT(n)) => {
681                 if n < ($VALUE as $TY) {
682                     Some(ScalarValue::$SCALAR_VARIANT(n))
683                 } else {
684                     Some(ScalarValue::$SCALAR_VARIANT($VALUE as $TY))
685                 }
686             }
687             Some(_) => {
688                 return Err(ExecutionError::InternalError(
689                     "Unexpected ScalarValue variant".to_string(),
690                 ))
691             }
692             None => Some(ScalarValue::$SCALAR_VARIANT($VALUE as $TY)),
693         };
694     }};
695 }
696 struct MinAccumulator {
697     min: Option<ScalarValue>,
698 }
699 
700 impl Accumulator for MinAccumulator {
accumulate_scalar(&mut self, value: Option<ScalarValue>) -> Result<()>701     fn accumulate_scalar(&mut self, value: Option<ScalarValue>) -> Result<()> {
702         if let Some(value) = value {
703             match value {
704                 ScalarValue::Int8(value) => {
705                     min_accumulate!(self, value, Int8Array, Int64, i64);
706                 }
707                 ScalarValue::Int16(value) => {
708                     min_accumulate!(self, value, Int16Array, Int64, i64)
709                 }
710                 ScalarValue::Int32(value) => {
711                     min_accumulate!(self, value, Int32Array, Int64, i64)
712                 }
713                 ScalarValue::Int64(value) => {
714                     min_accumulate!(self, value, Int64Array, Int64, i64)
715                 }
716                 ScalarValue::UInt8(value) => {
717                     min_accumulate!(self, value, UInt8Array, UInt64, u64)
718                 }
719                 ScalarValue::UInt16(value) => {
720                     min_accumulate!(self, value, UInt16Array, UInt64, u64)
721                 }
722                 ScalarValue::UInt32(value) => {
723                     min_accumulate!(self, value, UInt32Array, UInt64, u64)
724                 }
725                 ScalarValue::UInt64(value) => {
726                     min_accumulate!(self, value, UInt64Array, UInt64, u64)
727                 }
728                 ScalarValue::Float32(value) => {
729                     min_accumulate!(self, value, Float32Array, Float32, f32)
730                 }
731                 ScalarValue::Float64(value) => {
732                     min_accumulate!(self, value, Float64Array, Float64, f64)
733                 }
734                 other => {
735                     return Err(ExecutionError::General(format!(
736                         "MIN does not support {:?}",
737                         other
738                     )))
739                 }
740             }
741         }
742         Ok(())
743     }
744 
accumulate_batch(&mut self, array: &ArrayRef) -> Result<()>745     fn accumulate_batch(&mut self, array: &ArrayRef) -> Result<()> {
746         let min = match array.data_type() {
747             DataType::UInt8 => {
748                 match compute::min(array.as_any().downcast_ref::<UInt8Array>().unwrap()) {
749                     Some(n) => Ok(Some(ScalarValue::UInt8(n))),
750                     None => Ok(None),
751                 }
752             }
753             DataType::UInt16 => {
754                 match compute::min(array.as_any().downcast_ref::<UInt16Array>().unwrap())
755                 {
756                     Some(n) => Ok(Some(ScalarValue::UInt16(n))),
757                     None => Ok(None),
758                 }
759             }
760             DataType::UInt32 => {
761                 match compute::min(array.as_any().downcast_ref::<UInt32Array>().unwrap())
762                 {
763                     Some(n) => Ok(Some(ScalarValue::UInt32(n))),
764                     None => Ok(None),
765                 }
766             }
767             DataType::UInt64 => {
768                 match compute::min(array.as_any().downcast_ref::<UInt64Array>().unwrap())
769                 {
770                     Some(n) => Ok(Some(ScalarValue::UInt64(n))),
771                     None => Ok(None),
772                 }
773             }
774             DataType::Int8 => {
775                 match compute::min(array.as_any().downcast_ref::<Int8Array>().unwrap()) {
776                     Some(n) => Ok(Some(ScalarValue::Int8(n))),
777                     None => Ok(None),
778                 }
779             }
780             DataType::Int16 => {
781                 match compute::min(array.as_any().downcast_ref::<Int16Array>().unwrap()) {
782                     Some(n) => Ok(Some(ScalarValue::Int16(n))),
783                     None => Ok(None),
784                 }
785             }
786             DataType::Int32 => {
787                 match compute::min(array.as_any().downcast_ref::<Int32Array>().unwrap()) {
788                     Some(n) => Ok(Some(ScalarValue::Int32(n))),
789                     None => Ok(None),
790                 }
791             }
792             DataType::Int64 => {
793                 match compute::min(array.as_any().downcast_ref::<Int64Array>().unwrap()) {
794                     Some(n) => Ok(Some(ScalarValue::Int64(n))),
795                     None => Ok(None),
796                 }
797             }
798             DataType::Float32 => {
799                 match compute::min(array.as_any().downcast_ref::<Float32Array>().unwrap())
800                 {
801                     Some(n) => Ok(Some(ScalarValue::Float32(n))),
802                     None => Ok(None),
803                 }
804             }
805             DataType::Float64 => {
806                 match compute::min(array.as_any().downcast_ref::<Float64Array>().unwrap())
807                 {
808                     Some(n) => Ok(Some(ScalarValue::Float64(n))),
809                     None => Ok(None),
810                 }
811             }
812             _ => Err(ExecutionError::ExecutionError(
813                 "Unsupported data type for MIN".to_string(),
814             )),
815         }?;
816         self.accumulate_scalar(min)
817     }
818 
get_value(&self) -> Result<Option<ScalarValue>>819     fn get_value(&self) -> Result<Option<ScalarValue>> {
820         Ok(self.min.clone())
821     }
822 }
823 
824 /// Create a min expression
min(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn AggregateExpr>825 pub fn min(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn AggregateExpr> {
826     Arc::new(Min::new(expr))
827 }
828 
829 /// COUNT aggregate expression
830 /// Returns the amount of non-null values of the given expression.
831 pub struct Count {
832     expr: Arc<dyn PhysicalExpr>,
833 }
834 
835 impl Count {
836     /// Create a new COUNT aggregate function.
new(expr: Arc<dyn PhysicalExpr>) -> Self837     pub fn new(expr: Arc<dyn PhysicalExpr>) -> Self {
838         Self { expr: expr }
839     }
840 }
841 
842 impl AggregateExpr for Count {
name(&self) -> String843     fn name(&self) -> String {
844         "COUNT".to_string()
845     }
846 
data_type(&self, _input_schema: &Schema) -> Result<DataType>847     fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
848         Ok(DataType::UInt64)
849     }
850 
evaluate_input(&self, batch: &RecordBatch) -> Result<ArrayRef>851     fn evaluate_input(&self, batch: &RecordBatch) -> Result<ArrayRef> {
852         self.expr.evaluate(batch)
853     }
854 
create_accumulator(&self) -> Rc<RefCell<dyn Accumulator>>855     fn create_accumulator(&self) -> Rc<RefCell<dyn Accumulator>> {
856         Rc::new(RefCell::new(CountAccumulator { count: 0 }))
857     }
858 
create_reducer(&self, column_index: usize) -> Arc<dyn AggregateExpr>859     fn create_reducer(&self, column_index: usize) -> Arc<dyn AggregateExpr> {
860         Arc::new(Sum::new(Arc::new(Column::new(column_index, &self.name()))))
861     }
862 }
863 
864 struct CountAccumulator {
865     count: u64,
866 }
867 
868 impl Accumulator for CountAccumulator {
accumulate_scalar(&mut self, value: Option<ScalarValue>) -> Result<()>869     fn accumulate_scalar(&mut self, value: Option<ScalarValue>) -> Result<()> {
870         if value.is_some() {
871             self.count += 1;
872         }
873         Ok(())
874     }
875 
accumulate_batch(&mut self, array: &ArrayRef) -> Result<()>876     fn accumulate_batch(&mut self, array: &ArrayRef) -> Result<()> {
877         self.count += array.len() as u64 - array.null_count() as u64;
878         Ok(())
879     }
880 
get_value(&self) -> Result<Option<ScalarValue>>881     fn get_value(&self) -> Result<Option<ScalarValue>> {
882         Ok(Some(ScalarValue::UInt64(self.count)))
883     }
884 }
885 
886 /// Create a count expression
count(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn AggregateExpr>887 pub fn count(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn AggregateExpr> {
888     Arc::new(Count::new(expr))
889 }
890 
891 /// Invoke a compute kernel on a pair of binary data arrays
892 macro_rules! compute_utf8_op {
893     ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
894         let ll = $LEFT
895             .as_any()
896             .downcast_ref::<$DT>()
897             .expect("compute_op failed to downcast array");
898         let rr = $RIGHT
899             .as_any()
900             .downcast_ref::<$DT>()
901             .expect("compute_op failed to downcast array");
902         Ok(Arc::new(paste::expr! {[<$OP _utf8>]}(&ll, &rr)?))
903     }};
904 }
905 
906 /// Invoke a compute kernel on a pair of arrays
907 macro_rules! compute_op {
908     ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
909         let ll = $LEFT
910             .as_any()
911             .downcast_ref::<$DT>()
912             .expect("compute_op failed to downcast array");
913         let rr = $RIGHT
914             .as_any()
915             .downcast_ref::<$DT>()
916             .expect("compute_op failed to downcast array");
917         Ok(Arc::new($OP(&ll, &rr)?))
918     }};
919 }
920 
921 macro_rules! binary_string_array_op {
922     ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
923         match $LEFT.data_type() {
924             DataType::Utf8 => compute_utf8_op!($LEFT, $RIGHT, $OP, StringArray),
925             other => Err(ExecutionError::General(format!(
926                 "Unsupported data type {:?}",
927                 other
928             ))),
929         }
930     }};
931 }
932 
933 /// Invoke a compute kernel on a pair of arrays
934 /// The binary_primitive_array_op macro only evaluates for primitive types
935 /// like integers and floats.
936 macro_rules! binary_primitive_array_op {
937     ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
938         match $LEFT.data_type() {
939             DataType::Int8 => compute_op!($LEFT, $RIGHT, $OP, Int8Array),
940             DataType::Int16 => compute_op!($LEFT, $RIGHT, $OP, Int16Array),
941             DataType::Int32 => compute_op!($LEFT, $RIGHT, $OP, Int32Array),
942             DataType::Int64 => compute_op!($LEFT, $RIGHT, $OP, Int64Array),
943             DataType::UInt8 => compute_op!($LEFT, $RIGHT, $OP, UInt8Array),
944             DataType::UInt16 => compute_op!($LEFT, $RIGHT, $OP, UInt16Array),
945             DataType::UInt32 => compute_op!($LEFT, $RIGHT, $OP, UInt32Array),
946             DataType::UInt64 => compute_op!($LEFT, $RIGHT, $OP, UInt64Array),
947             DataType::Float32 => compute_op!($LEFT, $RIGHT, $OP, Float32Array),
948             DataType::Float64 => compute_op!($LEFT, $RIGHT, $OP, Float64Array),
949             other => Err(ExecutionError::General(format!(
950                 "Unsupported data type {:?}",
951                 other
952             ))),
953         }
954     }};
955 }
956 
957 /// The binary_array_op macro includes types that extend beyond the primitive,
958 /// such as Utf8 strings.
959 macro_rules! binary_array_op {
960     ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
961         match $LEFT.data_type() {
962             DataType::Int8 => compute_op!($LEFT, $RIGHT, $OP, Int8Array),
963             DataType::Int16 => compute_op!($LEFT, $RIGHT, $OP, Int16Array),
964             DataType::Int32 => compute_op!($LEFT, $RIGHT, $OP, Int32Array),
965             DataType::Int64 => compute_op!($LEFT, $RIGHT, $OP, Int64Array),
966             DataType::UInt8 => compute_op!($LEFT, $RIGHT, $OP, UInt8Array),
967             DataType::UInt16 => compute_op!($LEFT, $RIGHT, $OP, UInt16Array),
968             DataType::UInt32 => compute_op!($LEFT, $RIGHT, $OP, UInt32Array),
969             DataType::UInt64 => compute_op!($LEFT, $RIGHT, $OP, UInt64Array),
970             DataType::Float32 => compute_op!($LEFT, $RIGHT, $OP, Float32Array),
971             DataType::Float64 => compute_op!($LEFT, $RIGHT, $OP, Float64Array),
972             DataType::Utf8 => compute_utf8_op!($LEFT, $RIGHT, $OP, StringArray),
973             DataType::Timestamp(TimeUnit::Nanosecond, None) => {
974                 compute_op!($LEFT, $RIGHT, $OP, TimestampNanosecondArray)
975             }
976             other => Err(ExecutionError::General(format!(
977                 "Unsupported data type {:?}",
978                 other
979             ))),
980         }
981     }};
982 }
983 
984 /// Invoke a boolean kernel on a pair of arrays
985 macro_rules! boolean_op {
986     ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
987         let ll = $LEFT
988             .as_any()
989             .downcast_ref::<BooleanArray>()
990             .expect("boolean_op failed to downcast array");
991         let rr = $RIGHT
992             .as_any()
993             .downcast_ref::<BooleanArray>()
994             .expect("boolean_op failed to downcast array");
995         Ok(Arc::new($OP(&ll, &rr)?))
996     }};
997 }
998 /// Binary expression
999 pub struct BinaryExpr {
1000     left: Arc<dyn PhysicalExpr>,
1001     op: Operator,
1002     right: Arc<dyn PhysicalExpr>,
1003 }
1004 
1005 impl BinaryExpr {
1006     /// Create new binary expression
new( left: Arc<dyn PhysicalExpr>, op: Operator, right: Arc<dyn PhysicalExpr>, ) -> Self1007     pub fn new(
1008         left: Arc<dyn PhysicalExpr>,
1009         op: Operator,
1010         right: Arc<dyn PhysicalExpr>,
1011     ) -> Self {
1012         Self { left, op, right }
1013     }
1014 }
1015 
1016 impl PhysicalExpr for BinaryExpr {
name(&self) -> String1017     fn name(&self) -> String {
1018         format!("{:?}", self.op)
1019     }
1020 
data_type(&self, input_schema: &Schema) -> Result<DataType>1021     fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
1022         self.left.data_type(input_schema)
1023     }
1024 
evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef>1025     fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
1026         let left = self.left.evaluate(batch)?;
1027         let right = self.right.evaluate(batch)?;
1028         if left.data_type() != right.data_type() {
1029             return Err(ExecutionError::General(format!(
1030                 "Cannot evaluate binary expression {:?} with types {:?} and {:?}",
1031                 self.op,
1032                 left.data_type(),
1033                 right.data_type()
1034             )));
1035         }
1036         match &self.op {
1037             Operator::Like => binary_string_array_op!(left, right, like),
1038             Operator::NotLike => binary_string_array_op!(left, right, nlike),
1039             Operator::Lt => binary_array_op!(left, right, lt),
1040             Operator::LtEq => binary_array_op!(left, right, lt_eq),
1041             Operator::Gt => binary_array_op!(left, right, gt),
1042             Operator::GtEq => binary_array_op!(left, right, gt_eq),
1043             Operator::Eq => binary_array_op!(left, right, eq),
1044             Operator::NotEq => binary_array_op!(left, right, neq),
1045             Operator::Plus => binary_primitive_array_op!(left, right, add),
1046             Operator::Minus => binary_primitive_array_op!(left, right, subtract),
1047             Operator::Multiply => binary_primitive_array_op!(left, right, multiply),
1048             Operator::Divide => binary_primitive_array_op!(left, right, divide),
1049             Operator::And => {
1050                 if left.data_type() == &DataType::Boolean {
1051                     boolean_op!(left, right, and)
1052                 } else {
1053                     return Err(ExecutionError::General(format!(
1054                         "Cannot evaluate binary expression {:?} with types {:?} and {:?}",
1055                         self.op,
1056                         left.data_type(),
1057                         right.data_type()
1058                     )));
1059                 }
1060             }
1061             Operator::Or => {
1062                 if left.data_type() == &DataType::Boolean {
1063                     boolean_op!(left, right, or)
1064                 } else {
1065                     return Err(ExecutionError::General(format!(
1066                         "Cannot evaluate binary expression {:?} with types {:?} and {:?}",
1067                         self.op,
1068                         left.data_type(),
1069                         right.data_type()
1070                     )));
1071                 }
1072             }
1073             _ => Err(ExecutionError::General("Unsupported operator".to_string())),
1074         }
1075     }
1076 }
1077 
1078 /// Create a binary expression
binary( l: Arc<dyn PhysicalExpr>, op: Operator, r: Arc<dyn PhysicalExpr>, ) -> Arc<dyn PhysicalExpr>1079 pub fn binary(
1080     l: Arc<dyn PhysicalExpr>,
1081     op: Operator,
1082     r: Arc<dyn PhysicalExpr>,
1083 ) -> Arc<dyn PhysicalExpr> {
1084     Arc::new(BinaryExpr::new(l, op, r))
1085 }
1086 
1087 /// Not expression
1088 pub struct NotExpr {
1089     arg: Arc<dyn PhysicalExpr>,
1090 }
1091 
1092 impl NotExpr {
1093     /// Create new not expression
new(arg: Arc<dyn PhysicalExpr>) -> Self1094     pub fn new(arg: Arc<dyn PhysicalExpr>) -> Self {
1095         Self { arg }
1096     }
1097 }
1098 
1099 impl PhysicalExpr for NotExpr {
name(&self) -> String1100     fn name(&self) -> String {
1101         "NOT".to_string()
1102     }
1103 
data_type(&self, _input_schema: &Schema) -> Result<DataType>1104     fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
1105         return Ok(DataType::Boolean);
1106     }
1107 
evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef>1108     fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
1109         let arg = self.arg.evaluate(batch)?;
1110         if arg.data_type() != &DataType::Boolean {
1111             return Err(ExecutionError::General(format!(
1112                 "Cannot evaluate \"not\" expression with type {:?}",
1113                 arg.data_type(),
1114             )));
1115         }
1116         let arg = arg
1117             .as_any()
1118             .downcast_ref::<BooleanArray>()
1119             .expect("boolean_op failed to downcast array");
1120         return Ok(Arc::new(arrow::compute::kernels::boolean::not(arg)?));
1121     }
1122 }
1123 
1124 /// Create a unary expression
not(arg: Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr>1125 pub fn not(arg: Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr> {
1126     Arc::new(NotExpr::new(arg))
1127 }
1128 
1129 /// CAST expression casts an expression to a specific data type
1130 pub struct CastExpr {
1131     /// The expression to cast
1132     expr: Arc<dyn PhysicalExpr>,
1133     /// The data type to cast to
1134     cast_type: DataType,
1135 }
1136 
1137 /// Determine if a DataType is numeric or not
is_numeric(dt: &DataType) -> bool1138 fn is_numeric(dt: &DataType) -> bool {
1139     match dt {
1140         DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => true,
1141         DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => true,
1142         DataType::Float16 | DataType::Float32 | DataType::Float64 => true,
1143         _ => false,
1144     }
1145 }
1146 
1147 impl CastExpr {
1148     /// Create a CAST expression
try_new( expr: Arc<dyn PhysicalExpr>, input_schema: &Schema, cast_type: DataType, ) -> Result<Self>1149     pub fn try_new(
1150         expr: Arc<dyn PhysicalExpr>,
1151         input_schema: &Schema,
1152         cast_type: DataType,
1153     ) -> Result<Self> {
1154         let expr_type = expr.data_type(input_schema)?;
1155         // numbers can be cast to numbers and strings
1156         if is_numeric(&expr_type)
1157             && (is_numeric(&cast_type) || cast_type == DataType::Utf8)
1158         {
1159             Ok(Self { expr, cast_type })
1160         } else if expr_type == DataType::Binary && cast_type == DataType::Utf8 {
1161             Ok(Self { expr, cast_type })
1162         } else if is_numeric(&expr_type)
1163             && cast_type == DataType::Timestamp(TimeUnit::Nanosecond, None)
1164         {
1165             Ok(Self { expr, cast_type })
1166         } else {
1167             Err(ExecutionError::General(format!(
1168                 "Invalid CAST from {:?} to {:?}",
1169                 expr_type, cast_type
1170             )))
1171         }
1172     }
1173 }
1174 
1175 impl PhysicalExpr for CastExpr {
name(&self) -> String1176     fn name(&self) -> String {
1177         "CAST".to_string()
1178     }
1179 
data_type(&self, _input_schema: &Schema) -> Result<DataType>1180     fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
1181         Ok(self.cast_type.clone())
1182     }
1183 
evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef>1184     fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
1185         let value = self.expr.evaluate(batch)?;
1186         Ok(cast(&value, &self.cast_type)?)
1187     }
1188 }
1189 
1190 /// Represents a non-null literal value
1191 pub struct Literal {
1192     value: ScalarValue,
1193 }
1194 
1195 impl Literal {
1196     /// Create a literal value expression
new(value: ScalarValue) -> Self1197     pub fn new(value: ScalarValue) -> Self {
1198         Self { value }
1199     }
1200 }
1201 
1202 /// Build array containing the same literal value repeated. This is necessary because the Arrow
1203 /// memory model does not have the concept of a scalar value currently.
1204 macro_rules! build_literal_array {
1205     ($BATCH:ident, $BUILDER:ident, $VALUE:expr) => {{
1206         let mut builder = $BUILDER::new($BATCH.num_rows());
1207         for _ in 0..$BATCH.num_rows() {
1208             builder.append_value($VALUE)?;
1209         }
1210         Ok(Arc::new(builder.finish()))
1211     }};
1212 }
1213 
1214 impl PhysicalExpr for Literal {
name(&self) -> String1215     fn name(&self) -> String {
1216         "lit".to_string()
1217     }
1218 
data_type(&self, _input_schema: &Schema) -> Result<DataType>1219     fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
1220         Ok(self.value.get_datatype())
1221     }
1222 
evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef>1223     fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
1224         match &self.value {
1225             ScalarValue::Int8(value) => build_literal_array!(batch, Int8Builder, *value),
1226             ScalarValue::Int16(value) => {
1227                 build_literal_array!(batch, Int16Builder, *value)
1228             }
1229             ScalarValue::Int32(value) => {
1230                 build_literal_array!(batch, Int32Builder, *value)
1231             }
1232             ScalarValue::Int64(value) => {
1233                 build_literal_array!(batch, Int64Builder, *value)
1234             }
1235             ScalarValue::UInt8(value) => {
1236                 build_literal_array!(batch, UInt8Builder, *value)
1237             }
1238             ScalarValue::UInt16(value) => {
1239                 build_literal_array!(batch, UInt16Builder, *value)
1240             }
1241             ScalarValue::UInt32(value) => {
1242                 build_literal_array!(batch, UInt32Builder, *value)
1243             }
1244             ScalarValue::UInt64(value) => {
1245                 build_literal_array!(batch, UInt64Builder, *value)
1246             }
1247             ScalarValue::Float32(value) => {
1248                 build_literal_array!(batch, Float32Builder, *value)
1249             }
1250             ScalarValue::Float64(value) => {
1251                 build_literal_array!(batch, Float64Builder, *value)
1252             }
1253             ScalarValue::Utf8(value) => build_literal_array!(batch, StringBuilder, value),
1254             other => Err(ExecutionError::General(format!(
1255                 "Unsupported literal type {:?}",
1256                 other
1257             ))),
1258         }
1259     }
1260 }
1261 
1262 /// Create a literal expression
lit(value: ScalarValue) -> Arc<dyn PhysicalExpr>1263 pub fn lit(value: ScalarValue) -> Arc<dyn PhysicalExpr> {
1264     Arc::new(Literal::new(value))
1265 }
1266 
1267 #[cfg(test)]
1268 mod tests {
1269     use super::*;
1270     use crate::error::Result;
1271     use crate::execution::physical_plan::common::get_scalar_value;
1272     use arrow::array::{PrimitiveArray, StringArray, Time64NanosecondArray};
1273     use arrow::datatypes::*;
1274 
1275     #[test]
binary_comparison() -> Result<()>1276     fn binary_comparison() -> Result<()> {
1277         let schema = Schema::new(vec![
1278             Field::new("a", DataType::Int32, false),
1279             Field::new("b", DataType::Int32, false),
1280         ]);
1281         let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1282         let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
1283         let batch = RecordBatch::try_new(
1284             Arc::new(schema.clone()),
1285             vec![Arc::new(a), Arc::new(b)],
1286         )?;
1287 
1288         // expression: "a < b"
1289         let lt = binary(col(0, &schema), Operator::Lt, col(1, &schema));
1290         let result = lt.evaluate(&batch)?;
1291         assert_eq!(result.len(), 5);
1292 
1293         let expected = vec![false, false, true, true, true];
1294         let result = result
1295             .as_any()
1296             .downcast_ref::<BooleanArray>()
1297             .expect("failed to downcast to BooleanArray");
1298         for i in 0..5 {
1299             assert_eq!(result.value(i), expected[i]);
1300         }
1301 
1302         Ok(())
1303     }
1304 
1305     #[test]
binary_nested() -> Result<()>1306     fn binary_nested() -> Result<()> {
1307         let schema = Schema::new(vec![
1308             Field::new("a", DataType::Int32, false),
1309             Field::new("b", DataType::Int32, false),
1310         ]);
1311         let a = Int32Array::from(vec![2, 4, 6, 8, 10]);
1312         let b = Int32Array::from(vec![2, 5, 4, 8, 8]);
1313         let batch = RecordBatch::try_new(
1314             Arc::new(schema.clone()),
1315             vec![Arc::new(a), Arc::new(b)],
1316         )?;
1317 
1318         // expression: "a < b OR a == b"
1319         let expr = binary(
1320             binary(col(0, &schema), Operator::Lt, col(1, &schema)),
1321             Operator::Or,
1322             binary(col(0, &schema), Operator::Eq, col(1, &schema)),
1323         );
1324         let result = expr.evaluate(&batch)?;
1325         assert_eq!(result.len(), 5);
1326 
1327         let expected = vec![true, true, false, true, false];
1328         let result = result
1329             .as_any()
1330             .downcast_ref::<BooleanArray>()
1331             .expect("failed to downcast to BooleanArray");
1332         for i in 0..5 {
1333             print!("{}", i);
1334             assert_eq!(result.value(i), expected[i]);
1335         }
1336 
1337         Ok(())
1338     }
1339 
1340     #[test]
literal_i32() -> Result<()>1341     fn literal_i32() -> Result<()> {
1342         // create an arbitrary record bacth
1343         let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1344         let a = Int32Array::from(vec![Some(1), None, Some(3), Some(4), Some(5)]);
1345         let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1346 
1347         // create and evaluate a literal expression
1348         let literal_expr = lit(ScalarValue::Int32(42));
1349         let literal_array = literal_expr.evaluate(&batch)?;
1350         let literal_array = literal_array.as_any().downcast_ref::<Int32Array>().unwrap();
1351 
1352         // note that the contents of the literal array are unrelated to the batch contents except for the length of the array
1353         assert_eq!(literal_array.len(), 5); // 5 rows in the batch
1354         for i in 0..literal_array.len() {
1355             assert_eq!(literal_array.value(i), 42);
1356         }
1357 
1358         Ok(())
1359     }
1360 
1361     #[test]
cast_i32_to_u32() -> Result<()>1362     fn cast_i32_to_u32() -> Result<()> {
1363         let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1364         let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1365         let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1366 
1367         let cast = CastExpr::try_new(col(0, &schema), &schema, DataType::UInt32)?;
1368         let result = cast.evaluate(&batch)?;
1369         assert_eq!(result.len(), 5);
1370 
1371         let result = result
1372             .as_any()
1373             .downcast_ref::<UInt32Array>()
1374             .expect("failed to downcast to UInt32Array");
1375         assert_eq!(result.value(0), 1_u32);
1376 
1377         Ok(())
1378     }
1379 
1380     #[test]
cast_i32_to_utf8() -> Result<()>1381     fn cast_i32_to_utf8() -> Result<()> {
1382         let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1383         let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1384         let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1385 
1386         let cast = CastExpr::try_new(col(0, &schema), &schema, DataType::Utf8)?;
1387         let result = cast.evaluate(&batch)?;
1388         assert_eq!(result.len(), 5);
1389 
1390         let result = result
1391             .as_any()
1392             .downcast_ref::<StringArray>()
1393             .expect("failed to downcast to StringArray");
1394         assert_eq!(result.value(0), "1");
1395 
1396         Ok(())
1397     }
1398 
1399     #[test]
cast_i64_to_timestamp_nanoseconds() -> Result<()>1400     fn cast_i64_to_timestamp_nanoseconds() -> Result<()> {
1401         let schema = Schema::new(vec![Field::new("a", DataType::Int64, false)]);
1402         let a = Int64Array::from(vec![1, 2, 3, 4, 5]);
1403         let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1404 
1405         let cast = CastExpr::try_new(
1406             col(0, &schema),
1407             &schema,
1408             DataType::Timestamp(TimeUnit::Nanosecond, None),
1409         )?;
1410         let result = cast.evaluate(&batch)?;
1411         assert_eq!(result.len(), 5);
1412         let expected_result = Time64NanosecondArray::from(vec![1, 2, 3, 4]);
1413         let result = result
1414             .as_any()
1415             .downcast_ref::<TimestampNanosecondArray>()
1416             .expect("failed to downcast to TimestampNanosecondArray");
1417         assert_eq!(result.value(0), expected_result.value(0));
1418 
1419         Ok(())
1420     }
1421 
1422     #[test]
invalid_cast() -> Result<()>1423     fn invalid_cast() -> Result<()> {
1424         let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]);
1425         match CastExpr::try_new(col(0, &schema), &schema, DataType::Int32) {
1426             Err(ExecutionError::General(ref str)) => {
1427                 assert_eq!(str, "Invalid CAST from Utf8 to Int32");
1428                 Ok(())
1429             }
1430             _ => panic!(),
1431         }
1432     }
1433 
1434     #[test]
sum_contract() -> Result<()>1435     fn sum_contract() -> Result<()> {
1436         let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1437 
1438         let sum = sum(col(0, &schema));
1439         assert_eq!("SUM".to_string(), sum.name());
1440         assert_eq!(DataType::Int64, sum.data_type(&schema)?);
1441 
1442         let combiner = sum.create_reducer(0);
1443         assert_eq!("SUM".to_string(), combiner.name());
1444         assert_eq!(DataType::Int64, combiner.data_type(&schema)?);
1445 
1446         Ok(())
1447     }
1448 
1449     #[test]
max_contract() -> Result<()>1450     fn max_contract() -> Result<()> {
1451         let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1452 
1453         let max = max(col(0, &schema));
1454         assert_eq!("MAX".to_string(), max.name());
1455         assert_eq!(DataType::Int64, max.data_type(&schema)?);
1456 
1457         let combiner = max.create_reducer(0);
1458         assert_eq!("MAX".to_string(), combiner.name());
1459         assert_eq!(DataType::Int64, combiner.data_type(&schema)?);
1460 
1461         Ok(())
1462     }
1463 
1464     #[test]
min_contract() -> Result<()>1465     fn min_contract() -> Result<()> {
1466         let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1467 
1468         let min = min(col(0, &schema));
1469         assert_eq!("MIN".to_string(), min.name());
1470         assert_eq!(DataType::Int64, min.data_type(&schema)?);
1471 
1472         let combiner = min.create_reducer(0);
1473         assert_eq!("MIN".to_string(), combiner.name());
1474         assert_eq!(DataType::Int64, combiner.data_type(&schema)?);
1475 
1476         Ok(())
1477     }
1478     #[test]
avg_contract() -> Result<()>1479     fn avg_contract() -> Result<()> {
1480         let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1481 
1482         let avg = avg(col(0, &schema));
1483         assert_eq!("AVG".to_string(), avg.name());
1484         assert_eq!(DataType::Float64, avg.data_type(&schema)?);
1485 
1486         let combiner = avg.create_reducer(0);
1487         assert_eq!("AVG".to_string(), combiner.name());
1488         assert_eq!(DataType::Float64, combiner.data_type(&schema)?);
1489 
1490         Ok(())
1491     }
1492 
1493     #[test]
sum_i32() -> Result<()>1494     fn sum_i32() -> Result<()> {
1495         let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1496 
1497         let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1498         let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1499 
1500         assert_eq!(do_sum(&batch)?, Some(ScalarValue::Int64(15)));
1501 
1502         Ok(())
1503     }
1504 
1505     #[test]
avg_i32() -> Result<()>1506     fn avg_i32() -> Result<()> {
1507         let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1508 
1509         let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1510         let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1511 
1512         assert_eq!(do_avg(&batch)?, Some(ScalarValue::Float64(3_f64)));
1513 
1514         Ok(())
1515     }
1516 
1517     #[test]
max_i32() -> Result<()>1518     fn max_i32() -> Result<()> {
1519         let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1520 
1521         let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1522         let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1523 
1524         assert_eq!(do_max(&batch)?, Some(ScalarValue::Int64(5)));
1525 
1526         Ok(())
1527     }
1528 
1529     #[test]
min_i32() -> Result<()>1530     fn min_i32() -> Result<()> {
1531         let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1532 
1533         let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1534         let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1535 
1536         assert_eq!(do_min(&batch)?, Some(ScalarValue::Int64(1)));
1537 
1538         Ok(())
1539     }
1540 
1541     #[test]
sum_i32_with_nulls() -> Result<()>1542     fn sum_i32_with_nulls() -> Result<()> {
1543         let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1544 
1545         let a = Int32Array::from(vec![Some(1), None, Some(3), Some(4), Some(5)]);
1546         let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1547 
1548         assert_eq!(do_sum(&batch)?, Some(ScalarValue::Int64(13)));
1549 
1550         Ok(())
1551     }
1552 
1553     #[test]
avg_i32_with_nulls() -> Result<()>1554     fn avg_i32_with_nulls() -> Result<()> {
1555         let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1556 
1557         let a = Int32Array::from(vec![Some(1), None, Some(3), Some(4), Some(5)]);
1558         let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1559 
1560         assert_eq!(do_avg(&batch)?, Some(ScalarValue::Float64(3.25)));
1561 
1562         Ok(())
1563     }
1564 
1565     #[test]
max_i32_with_nulls() -> Result<()>1566     fn max_i32_with_nulls() -> Result<()> {
1567         let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1568 
1569         let a = Int32Array::from(vec![Some(1), None, Some(3), Some(4), Some(5)]);
1570         let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1571 
1572         assert_eq!(do_max(&batch)?, Some(ScalarValue::Int64(5)));
1573 
1574         Ok(())
1575     }
1576 
1577     #[test]
min_i32_with_nulls() -> Result<()>1578     fn min_i32_with_nulls() -> Result<()> {
1579         let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1580 
1581         let a = Int32Array::from(vec![Some(1), None, Some(3), Some(4), Some(5)]);
1582         let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1583 
1584         assert_eq!(do_min(&batch)?, Some(ScalarValue::Int64(1)));
1585 
1586         Ok(())
1587     }
1588 
1589     #[test]
sum_i32_all_nulls() -> Result<()>1590     fn sum_i32_all_nulls() -> Result<()> {
1591         let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1592 
1593         let a = Int32Array::from(vec![None, None]);
1594         let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1595 
1596         assert_eq!(do_sum(&batch)?, None);
1597 
1598         Ok(())
1599     }
1600 
1601     #[test]
max_i32_all_nulls() -> Result<()>1602     fn max_i32_all_nulls() -> Result<()> {
1603         let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1604 
1605         let a = Int32Array::from(vec![None, None]);
1606         let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1607 
1608         assert_eq!(do_max(&batch)?, None);
1609 
1610         Ok(())
1611     }
1612 
1613     #[test]
min_i32_all_nulls() -> Result<()>1614     fn min_i32_all_nulls() -> Result<()> {
1615         let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1616 
1617         let a = Int32Array::from(vec![None, None]);
1618         let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1619 
1620         assert_eq!(do_min(&batch)?, None);
1621 
1622         Ok(())
1623     }
1624 
1625     #[test]
avg_i32_all_nulls() -> Result<()>1626     fn avg_i32_all_nulls() -> Result<()> {
1627         let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1628 
1629         let a = Int32Array::from(vec![None, None]);
1630         let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1631 
1632         assert_eq!(do_avg(&batch)?, None);
1633 
1634         Ok(())
1635     }
1636 
1637     #[test]
sum_u32() -> Result<()>1638     fn sum_u32() -> Result<()> {
1639         let schema = Schema::new(vec![Field::new("a", DataType::UInt32, false)]);
1640 
1641         let a = UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32]);
1642         let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1643 
1644         assert_eq!(do_sum(&batch)?, Some(ScalarValue::UInt64(15_u64)));
1645 
1646         Ok(())
1647     }
1648 
1649     #[test]
avg_u32() -> Result<()>1650     fn avg_u32() -> Result<()> {
1651         let schema = Schema::new(vec![Field::new("a", DataType::UInt32, false)]);
1652 
1653         let a = UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32]);
1654         let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1655 
1656         assert_eq!(do_avg(&batch)?, Some(ScalarValue::Float64(3_f64)));
1657 
1658         Ok(())
1659     }
1660 
1661     #[test]
max_u32() -> Result<()>1662     fn max_u32() -> Result<()> {
1663         let schema = Schema::new(vec![Field::new("a", DataType::UInt32, false)]);
1664 
1665         let a = UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32]);
1666         let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1667 
1668         assert_eq!(do_max(&batch)?, Some(ScalarValue::UInt64(5_u64)));
1669 
1670         Ok(())
1671     }
1672 
1673     #[test]
min_u32() -> Result<()>1674     fn min_u32() -> Result<()> {
1675         let schema = Schema::new(vec![Field::new("a", DataType::UInt32, false)]);
1676 
1677         let a = UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32]);
1678         let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1679 
1680         assert_eq!(do_min(&batch)?, Some(ScalarValue::UInt64(1_u64)));
1681 
1682         Ok(())
1683     }
1684 
1685     #[test]
sum_f32() -> Result<()>1686     fn sum_f32() -> Result<()> {
1687         let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]);
1688 
1689         let a = Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32]);
1690         let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1691 
1692         assert_eq!(do_sum(&batch)?, Some(ScalarValue::Float32(15_f32)));
1693 
1694         Ok(())
1695     }
1696 
1697     #[test]
avg_f32() -> Result<()>1698     fn avg_f32() -> Result<()> {
1699         let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]);
1700 
1701         let a = Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32]);
1702         let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1703 
1704         assert_eq!(do_avg(&batch)?, Some(ScalarValue::Float64(3_f64)));
1705 
1706         Ok(())
1707     }
1708 
1709     #[test]
max_f32() -> Result<()>1710     fn max_f32() -> Result<()> {
1711         let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]);
1712 
1713         let a = Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32]);
1714         let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1715 
1716         assert_eq!(do_max(&batch)?, Some(ScalarValue::Float32(5_f32)));
1717 
1718         Ok(())
1719     }
1720 
1721     #[test]
min_f32() -> Result<()>1722     fn min_f32() -> Result<()> {
1723         let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]);
1724 
1725         let a = Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32]);
1726         let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1727 
1728         assert_eq!(do_min(&batch)?, Some(ScalarValue::Float32(1_f32)));
1729 
1730         Ok(())
1731     }
1732 
1733     #[test]
sum_f64() -> Result<()>1734     fn sum_f64() -> Result<()> {
1735         let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]);
1736 
1737         let a = Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]);
1738         let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1739 
1740         assert_eq!(do_sum(&batch)?, Some(ScalarValue::Float64(15_f64)));
1741 
1742         Ok(())
1743     }
1744 
1745     #[test]
avg_f64() -> Result<()>1746     fn avg_f64() -> Result<()> {
1747         let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]);
1748 
1749         let a = Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]);
1750         let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1751 
1752         assert_eq!(do_avg(&batch)?, Some(ScalarValue::Float64(3_f64)));
1753 
1754         Ok(())
1755     }
1756 
1757     #[test]
max_f64() -> Result<()>1758     fn max_f64() -> Result<()> {
1759         let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]);
1760 
1761         let a = Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]);
1762         let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1763 
1764         assert_eq!(do_max(&batch)?, Some(ScalarValue::Float64(5_f64)));
1765 
1766         Ok(())
1767     }
1768 
1769     #[test]
min_f64() -> Result<()>1770     fn min_f64() -> Result<()> {
1771         let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]);
1772 
1773         let a = Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]);
1774         let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1775 
1776         assert_eq!(do_min(&batch)?, Some(ScalarValue::Float64(1_f64)));
1777 
1778         Ok(())
1779     }
1780 
1781     #[test]
count_elements() -> Result<()>1782     fn count_elements() -> Result<()> {
1783         let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1784         let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1785         let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1786         assert_eq!(do_count(&batch)?, Some(ScalarValue::UInt64(5)));
1787         Ok(())
1788     }
1789 
1790     #[test]
count_with_nulls() -> Result<()>1791     fn count_with_nulls() -> Result<()> {
1792         let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1793         let a = Int32Array::from(vec![Some(1), Some(2), None, None, Some(3), None]);
1794         let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1795         assert_eq!(do_count(&batch)?, Some(ScalarValue::UInt64(3)));
1796         Ok(())
1797     }
1798 
1799     #[test]
count_all_nulls() -> Result<()>1800     fn count_all_nulls() -> Result<()> {
1801         let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]);
1802         let a = BooleanArray::from(vec![None, None, None, None, None, None, None, None]);
1803         let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1804         assert_eq!(do_count(&batch)?, Some(ScalarValue::UInt64(0)));
1805         Ok(())
1806     }
1807 
1808     #[test]
count_empty() -> Result<()>1809     fn count_empty() -> Result<()> {
1810         let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]);
1811         let a = BooleanArray::from(Vec::<bool>::new());
1812         let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1813         assert_eq!(do_count(&batch)?, Some(ScalarValue::UInt64(0)));
1814         Ok(())
1815     }
1816 
do_sum(batch: &RecordBatch) -> Result<Option<ScalarValue>>1817     fn do_sum(batch: &RecordBatch) -> Result<Option<ScalarValue>> {
1818         let sum = sum(col(0, &batch.schema()));
1819         let accum = sum.create_accumulator();
1820         let input = sum.evaluate_input(batch)?;
1821         let mut accum = accum.borrow_mut();
1822         for i in 0..batch.num_rows() {
1823             accum.accumulate_scalar(get_scalar_value(&input, i)?)?;
1824         }
1825         accum.get_value()
1826     }
1827 
do_max(batch: &RecordBatch) -> Result<Option<ScalarValue>>1828     fn do_max(batch: &RecordBatch) -> Result<Option<ScalarValue>> {
1829         let max = max(col(0, &batch.schema()));
1830         let accum = max.create_accumulator();
1831         let input = max.evaluate_input(batch)?;
1832         let mut accum = accum.borrow_mut();
1833         for i in 0..batch.num_rows() {
1834             accum.accumulate_scalar(get_scalar_value(&input, i)?)?;
1835         }
1836         accum.get_value()
1837     }
1838 
do_min(batch: &RecordBatch) -> Result<Option<ScalarValue>>1839     fn do_min(batch: &RecordBatch) -> Result<Option<ScalarValue>> {
1840         let min = min(col(0, &batch.schema()));
1841         let accum = min.create_accumulator();
1842         let input = min.evaluate_input(batch)?;
1843         let mut accum = accum.borrow_mut();
1844         for i in 0..batch.num_rows() {
1845             accum.accumulate_scalar(get_scalar_value(&input, i)?)?;
1846         }
1847         accum.get_value()
1848     }
1849 
do_count(batch: &RecordBatch) -> Result<Option<ScalarValue>>1850     fn do_count(batch: &RecordBatch) -> Result<Option<ScalarValue>> {
1851         let count = count(col(0, &batch.schema()));
1852         let accum = count.create_accumulator();
1853         let input = count.evaluate_input(batch)?;
1854         let mut accum = accum.borrow_mut();
1855         for i in 0..batch.num_rows() {
1856             accum.accumulate_scalar(get_scalar_value(&input, i)?)?;
1857         }
1858         accum.get_value()
1859     }
1860 
do_avg(batch: &RecordBatch) -> Result<Option<ScalarValue>>1861     fn do_avg(batch: &RecordBatch) -> Result<Option<ScalarValue>> {
1862         let avg = avg(col(0, &batch.schema()));
1863         let accum = avg.create_accumulator();
1864         let input = avg.evaluate_input(batch)?;
1865         let mut accum = accum.borrow_mut();
1866         for i in 0..batch.num_rows() {
1867             accum.accumulate_scalar(get_scalar_value(&input, i)?)?;
1868         }
1869         accum.get_value()
1870     }
1871 
1872     #[test]
plus_op() -> Result<()>1873     fn plus_op() -> Result<()> {
1874         let schema = Schema::new(vec![
1875             Field::new("a", DataType::Int32, false),
1876             Field::new("b", DataType::Int32, false),
1877         ]);
1878         let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1879         let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
1880 
1881         apply_arithmetic::<Int32Type>(
1882             Arc::new(schema),
1883             vec![Arc::new(a), Arc::new(b)],
1884             Operator::Plus,
1885             Int32Array::from(vec![2, 4, 7, 12, 21]),
1886         )?;
1887 
1888         Ok(())
1889     }
1890 
1891     #[test]
minus_op() -> Result<()>1892     fn minus_op() -> Result<()> {
1893         let schema = Arc::new(Schema::new(vec![
1894             Field::new("a", DataType::Int32, false),
1895             Field::new("b", DataType::Int32, false),
1896         ]));
1897         let a = Arc::new(Int32Array::from(vec![1, 2, 4, 8, 16]));
1898         let b = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
1899 
1900         apply_arithmetic::<Int32Type>(
1901             schema.clone(),
1902             vec![a.clone(), b.clone()],
1903             Operator::Minus,
1904             Int32Array::from(vec![0, 0, 1, 4, 11]),
1905         )?;
1906 
1907         // should handle have negative values in result (for signed)
1908         apply_arithmetic::<Int32Type>(
1909             schema.clone(),
1910             vec![b.clone(), a.clone()],
1911             Operator::Minus,
1912             Int32Array::from(vec![0, 0, -1, -4, -11]),
1913         )?;
1914 
1915         Ok(())
1916     }
1917 
1918     #[test]
multiply_op() -> Result<()>1919     fn multiply_op() -> Result<()> {
1920         let schema = Arc::new(Schema::new(vec![
1921             Field::new("a", DataType::Int32, false),
1922             Field::new("b", DataType::Int32, false),
1923         ]));
1924         let a = Arc::new(Int32Array::from(vec![4, 8, 16, 32, 64]));
1925         let b = Arc::new(Int32Array::from(vec![2, 4, 8, 16, 32]));
1926 
1927         apply_arithmetic::<Int32Type>(
1928             schema,
1929             vec![a, b],
1930             Operator::Multiply,
1931             Int32Array::from(vec![8, 32, 128, 512, 2048]),
1932         )?;
1933 
1934         Ok(())
1935     }
1936 
1937     #[test]
divide_op() -> Result<()>1938     fn divide_op() -> Result<()> {
1939         let schema = Arc::new(Schema::new(vec![
1940             Field::new("a", DataType::Int32, false),
1941             Field::new("b", DataType::Int32, false),
1942         ]));
1943         let a = Arc::new(Int32Array::from(vec![8, 32, 128, 512, 2048]));
1944         let b = Arc::new(Int32Array::from(vec![2, 4, 8, 16, 32]));
1945 
1946         apply_arithmetic::<Int32Type>(
1947             schema,
1948             vec![a, b],
1949             Operator::Divide,
1950             Int32Array::from(vec![4, 8, 16, 32, 64]),
1951         )?;
1952 
1953         Ok(())
1954     }
1955 
apply_arithmetic<T: ArrowNumericType>( schema: Arc<Schema>, data: Vec<ArrayRef>, op: Operator, expected: PrimitiveArray<T>, ) -> Result<()>1956     fn apply_arithmetic<T: ArrowNumericType>(
1957         schema: Arc<Schema>,
1958         data: Vec<ArrayRef>,
1959         op: Operator,
1960         expected: PrimitiveArray<T>,
1961     ) -> Result<()> {
1962         let arithmetic_op = binary(col(0, schema.as_ref()), op, col(1, schema.as_ref()));
1963         let batch = RecordBatch::try_new(schema, data)?;
1964         let result = arithmetic_op.evaluate(&batch)?;
1965 
1966         assert_array_eq::<T>(expected, result);
1967 
1968         Ok(())
1969     }
1970 
assert_array_eq<T: ArrowNumericType>( expected: PrimitiveArray<T>, actual: ArrayRef, )1971     fn assert_array_eq<T: ArrowNumericType>(
1972         expected: PrimitiveArray<T>,
1973         actual: ArrayRef,
1974     ) {
1975         let actual = actual
1976             .as_any()
1977             .downcast_ref::<PrimitiveArray<T>>()
1978             .expect("Actual array should unwrap to type of expected array");
1979 
1980         for i in 0..expected.len() {
1981             assert_eq!(expected.value(i), actual.value(i));
1982         }
1983     }
1984 
1985     #[test]
neg_op() -> Result<()>1986     fn neg_op() -> Result<()> {
1987         let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]);
1988         let a = BooleanArray::from(vec![true, false]);
1989         let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1990 
1991         // expression: "!a"
1992         let lt = not(col(0, &schema));
1993         let result = lt.evaluate(&batch)?;
1994         assert_eq!(result.len(), 2);
1995 
1996         let expected = vec![false, true];
1997         let result = result
1998             .as_any()
1999             .downcast_ref::<BooleanArray>()
2000             .expect("failed to downcast to BooleanArray");
2001         for i in 0..2 {
2002             assert_eq!(result.value(i), expected[i]);
2003         }
2004 
2005         Ok(())
2006     }
2007 }
2008