1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 //! Defines physical expressions that can evaluated at runtime during query execution
19
20 use std::cell::RefCell;
21 use std::rc::Rc;
22 use std::sync::Arc;
23
24 use crate::error::{ExecutionError, Result};
25 use crate::execution::physical_plan::common::get_scalar_value;
26 use crate::execution::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr};
27 use crate::logicalplan::{Operator, ScalarValue};
28 use arrow::array::{
29 ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array,
30 Int64Array, Int8Array, StringArray, TimestampNanosecondArray, UInt16Array,
31 UInt32Array, UInt64Array, UInt8Array,
32 };
33 use arrow::array::{
34 Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder,
35 Int8Builder, StringBuilder, UInt16Builder, UInt32Builder, UInt64Builder,
36 UInt8Builder,
37 };
38 use arrow::compute;
39 use arrow::compute::kernels::arithmetic::{add, divide, multiply, subtract};
40 use arrow::compute::kernels::boolean::{and, or};
41 use arrow::compute::kernels::cast::cast;
42 use arrow::compute::kernels::comparison::{eq, gt, gt_eq, lt, lt_eq, neq};
43 use arrow::compute::kernels::comparison::{
44 eq_utf8, gt_eq_utf8, gt_utf8, like_utf8, lt_eq_utf8, lt_utf8, neq_utf8, nlike_utf8,
45 };
46 use arrow::datatypes::{DataType, Schema, TimeUnit};
47 use arrow::record_batch::RecordBatch;
48
49 /// Represents an aliased expression
50 pub struct Alias {
51 expr: Arc<dyn PhysicalExpr>,
52 alias: String,
53 }
54
55 impl Alias {
56 /// Create a new aliased expression
new(expr: Arc<dyn PhysicalExpr>, alias: &str) -> Self57 pub fn new(expr: Arc<dyn PhysicalExpr>, alias: &str) -> Self {
58 Self {
59 expr: expr.clone(),
60 alias: alias.to_owned(),
61 }
62 }
63 }
64
65 impl PhysicalExpr for Alias {
name(&self) -> String66 fn name(&self) -> String {
67 self.alias.clone()
68 }
69
data_type(&self, input_schema: &Schema) -> Result<DataType>70 fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
71 self.expr.data_type(input_schema)
72 }
73
evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef>74 fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
75 self.expr.evaluate(batch)
76 }
77 }
78
79 /// Represents the column at a given index in a RecordBatch
80 pub struct Column {
81 index: usize,
82 name: String,
83 }
84
85 impl Column {
86 /// Create a new column expression
new(index: usize, name: &str) -> Self87 pub fn new(index: usize, name: &str) -> Self {
88 Self {
89 index,
90 name: name.to_owned(),
91 }
92 }
93 }
94
95 impl PhysicalExpr for Column {
96 /// Get the name to use in a schema to represent the result of this expression
name(&self) -> String97 fn name(&self) -> String {
98 self.name.clone()
99 }
100
101 /// Get the data type of this expression, given the schema of the input
data_type(&self, input_schema: &Schema) -> Result<DataType>102 fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
103 Ok(input_schema.field(self.index).data_type().clone())
104 }
105
106 /// Evaluate the expression
evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef>107 fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
108 Ok(batch.column(self.index).clone())
109 }
110 }
111
112 /// Create a column expression
col(i: usize, schema: &Schema) -> Arc<dyn PhysicalExpr>113 pub fn col(i: usize, schema: &Schema) -> Arc<dyn PhysicalExpr> {
114 Arc::new(Column::new(i, &schema.field(i).name()))
115 }
116
117 /// SUM aggregate expression
118 pub struct Sum {
119 expr: Arc<dyn PhysicalExpr>,
120 }
121
122 impl Sum {
123 /// Create a new SUM aggregate function
new(expr: Arc<dyn PhysicalExpr>) -> Self124 pub fn new(expr: Arc<dyn PhysicalExpr>) -> Self {
125 Self { expr }
126 }
127 }
128
129 impl AggregateExpr for Sum {
name(&self) -> String130 fn name(&self) -> String {
131 "SUM".to_string()
132 }
133
data_type(&self, input_schema: &Schema) -> Result<DataType>134 fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
135 match self.expr.data_type(input_schema)? {
136 DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
137 Ok(DataType::Int64)
138 }
139 DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
140 Ok(DataType::UInt64)
141 }
142 DataType::Float32 => Ok(DataType::Float32),
143 DataType::Float64 => Ok(DataType::Float64),
144 other => Err(ExecutionError::General(format!(
145 "SUM does not support {:?}",
146 other
147 ))),
148 }
149 }
150
evaluate_input(&self, batch: &RecordBatch) -> Result<ArrayRef>151 fn evaluate_input(&self, batch: &RecordBatch) -> Result<ArrayRef> {
152 self.expr.evaluate(batch)
153 }
154
create_accumulator(&self) -> Rc<RefCell<dyn Accumulator>>155 fn create_accumulator(&self) -> Rc<RefCell<dyn Accumulator>> {
156 Rc::new(RefCell::new(SumAccumulator { sum: None }))
157 }
158
create_reducer(&self, column_index: usize) -> Arc<dyn AggregateExpr>159 fn create_reducer(&self, column_index: usize) -> Arc<dyn AggregateExpr> {
160 Arc::new(Sum::new(Arc::new(Column::new(column_index, &self.name()))))
161 }
162 }
163
164 macro_rules! sum_accumulate {
165 ($SELF:ident, $VALUE:expr, $ARRAY_TYPE:ident, $SCALAR_VARIANT:ident, $TY:ty) => {{
166 $SELF.sum = match $SELF.sum {
167 Some(ScalarValue::$SCALAR_VARIANT(n)) => {
168 Some(ScalarValue::$SCALAR_VARIANT(n + $VALUE as $TY))
169 }
170 Some(_) => {
171 return Err(ExecutionError::InternalError(
172 "Unexpected ScalarValue variant".to_string(),
173 ))
174 }
175 None => Some(ScalarValue::$SCALAR_VARIANT($VALUE as $TY)),
176 };
177 }};
178 }
179
180 struct SumAccumulator {
181 sum: Option<ScalarValue>,
182 }
183
184 impl Accumulator for SumAccumulator {
accumulate_scalar(&mut self, value: Option<ScalarValue>) -> Result<()>185 fn accumulate_scalar(&mut self, value: Option<ScalarValue>) -> Result<()> {
186 if let Some(value) = value {
187 match value {
188 ScalarValue::Int8(value) => {
189 sum_accumulate!(self, value, Int8Array, Int64, i64);
190 }
191 ScalarValue::Int16(value) => {
192 sum_accumulate!(self, value, Int16Array, Int64, i64);
193 }
194 ScalarValue::Int32(value) => {
195 sum_accumulate!(self, value, Int32Array, Int64, i64);
196 }
197 ScalarValue::Int64(value) => {
198 sum_accumulate!(self, value, Int64Array, Int64, i64);
199 }
200 ScalarValue::UInt8(value) => {
201 sum_accumulate!(self, value, UInt8Array, UInt64, u64);
202 }
203 ScalarValue::UInt16(value) => {
204 sum_accumulate!(self, value, UInt16Array, UInt64, u64);
205 }
206 ScalarValue::UInt32(value) => {
207 sum_accumulate!(self, value, UInt32Array, UInt64, u64);
208 }
209 ScalarValue::UInt64(value) => {
210 sum_accumulate!(self, value, UInt64Array, UInt64, u64);
211 }
212 ScalarValue::Float32(value) => {
213 sum_accumulate!(self, value, Float32Array, Float32, f32);
214 }
215 ScalarValue::Float64(value) => {
216 sum_accumulate!(self, value, Float64Array, Float64, f64);
217 }
218 other => {
219 return Err(ExecutionError::General(format!(
220 "SUM does not support {:?}",
221 other
222 )))
223 }
224 }
225 }
226 Ok(())
227 }
228
accumulate_batch(&mut self, array: &ArrayRef) -> Result<()>229 fn accumulate_batch(&mut self, array: &ArrayRef) -> Result<()> {
230 let sum = match array.data_type() {
231 DataType::UInt8 => {
232 match compute::sum(array.as_any().downcast_ref::<UInt8Array>().unwrap()) {
233 Some(n) => Ok(Some(ScalarValue::UInt8(n))),
234 None => Ok(None),
235 }
236 }
237 DataType::UInt16 => {
238 match compute::sum(array.as_any().downcast_ref::<UInt16Array>().unwrap())
239 {
240 Some(n) => Ok(Some(ScalarValue::UInt16(n))),
241 None => Ok(None),
242 }
243 }
244 DataType::UInt32 => {
245 match compute::sum(array.as_any().downcast_ref::<UInt32Array>().unwrap())
246 {
247 Some(n) => Ok(Some(ScalarValue::UInt32(n))),
248 None => Ok(None),
249 }
250 }
251 DataType::UInt64 => {
252 match compute::sum(array.as_any().downcast_ref::<UInt64Array>().unwrap())
253 {
254 Some(n) => Ok(Some(ScalarValue::UInt64(n))),
255 None => Ok(None),
256 }
257 }
258 DataType::Int8 => {
259 match compute::sum(array.as_any().downcast_ref::<Int8Array>().unwrap()) {
260 Some(n) => Ok(Some(ScalarValue::Int8(n))),
261 None => Ok(None),
262 }
263 }
264 DataType::Int16 => {
265 match compute::sum(array.as_any().downcast_ref::<Int16Array>().unwrap()) {
266 Some(n) => Ok(Some(ScalarValue::Int16(n))),
267 None => Ok(None),
268 }
269 }
270 DataType::Int32 => {
271 match compute::sum(array.as_any().downcast_ref::<Int32Array>().unwrap()) {
272 Some(n) => Ok(Some(ScalarValue::Int32(n))),
273 None => Ok(None),
274 }
275 }
276 DataType::Int64 => {
277 match compute::sum(array.as_any().downcast_ref::<Int64Array>().unwrap()) {
278 Some(n) => Ok(Some(ScalarValue::Int64(n))),
279 None => Ok(None),
280 }
281 }
282 DataType::Float32 => {
283 match compute::sum(array.as_any().downcast_ref::<Float32Array>().unwrap())
284 {
285 Some(n) => Ok(Some(ScalarValue::Float32(n))),
286 None => Ok(None),
287 }
288 }
289 DataType::Float64 => {
290 match compute::sum(array.as_any().downcast_ref::<Float64Array>().unwrap())
291 {
292 Some(n) => Ok(Some(ScalarValue::Float64(n))),
293 None => Ok(None),
294 }
295 }
296 _ => Err(ExecutionError::ExecutionError(
297 "Unsupported data type for SUM".to_string(),
298 )),
299 }?;
300 self.accumulate_scalar(sum)
301 }
302
get_value(&self) -> Result<Option<ScalarValue>>303 fn get_value(&self) -> Result<Option<ScalarValue>> {
304 Ok(self.sum.clone())
305 }
306 }
307
308 /// Create a sum expression
sum(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn AggregateExpr>309 pub fn sum(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn AggregateExpr> {
310 Arc::new(Sum::new(expr))
311 }
312
313 /// AVG aggregate expression
314 pub struct Avg {
315 expr: Arc<dyn PhysicalExpr>,
316 }
317
318 impl Avg {
319 /// Create a new AVG aggregate function
new(expr: Arc<dyn PhysicalExpr>) -> Self320 pub fn new(expr: Arc<dyn PhysicalExpr>) -> Self {
321 Self { expr }
322 }
323 }
324
325 impl AggregateExpr for Avg {
name(&self) -> String326 fn name(&self) -> String {
327 "AVG".to_string()
328 }
329
data_type(&self, input_schema: &Schema) -> Result<DataType>330 fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
331 match self.expr.data_type(input_schema)? {
332 DataType::Int8
333 | DataType::Int16
334 | DataType::Int32
335 | DataType::Int64
336 | DataType::UInt8
337 | DataType::UInt16
338 | DataType::UInt32
339 | DataType::UInt64
340 | DataType::Float32
341 | DataType::Float64 => Ok(DataType::Float64),
342 other => Err(ExecutionError::General(format!(
343 "AVG does not support {:?}",
344 other
345 ))),
346 }
347 }
348
evaluate_input(&self, batch: &RecordBatch) -> Result<ArrayRef>349 fn evaluate_input(&self, batch: &RecordBatch) -> Result<ArrayRef> {
350 self.expr.evaluate(batch)
351 }
352
create_accumulator(&self) -> Rc<RefCell<dyn Accumulator>>353 fn create_accumulator(&self) -> Rc<RefCell<dyn Accumulator>> {
354 Rc::new(RefCell::new(AvgAccumulator {
355 sum: None,
356 count: None,
357 }))
358 }
359
create_reducer(&self, column_index: usize) -> Arc<dyn AggregateExpr>360 fn create_reducer(&self, column_index: usize) -> Arc<dyn AggregateExpr> {
361 Arc::new(Avg::new(Arc::new(Column::new(column_index, &self.name()))))
362 }
363 }
364
365 macro_rules! avg_accumulate {
366 ($SELF:ident, $VALUE:expr, $ARRAY_TYPE:ident) => {{
367 match ($SELF.sum, $SELF.count) {
368 (Some(sum), Some(count)) => {
369 $SELF.sum = Some(sum + $VALUE as f64);
370 $SELF.count = Some(count + 1);
371 }
372 _ => {
373 $SELF.sum = Some($VALUE as f64);
374 $SELF.count = Some(1);
375 }
376 };
377 }};
378 }
379 struct AvgAccumulator {
380 sum: Option<f64>,
381 count: Option<i64>,
382 }
383
384 impl Accumulator for AvgAccumulator {
accumulate_scalar(&mut self, value: Option<ScalarValue>) -> Result<()>385 fn accumulate_scalar(&mut self, value: Option<ScalarValue>) -> Result<()> {
386 if let Some(value) = value {
387 match value {
388 ScalarValue::Int8(value) => avg_accumulate!(self, value, Int8Array),
389 ScalarValue::Int16(value) => avg_accumulate!(self, value, Int16Array),
390 ScalarValue::Int32(value) => avg_accumulate!(self, value, Int32Array),
391 ScalarValue::Int64(value) => avg_accumulate!(self, value, Int64Array),
392 ScalarValue::UInt8(value) => avg_accumulate!(self, value, UInt8Array),
393 ScalarValue::UInt16(value) => avg_accumulate!(self, value, UInt16Array),
394 ScalarValue::UInt32(value) => avg_accumulate!(self, value, UInt32Array),
395 ScalarValue::UInt64(value) => avg_accumulate!(self, value, UInt64Array),
396 ScalarValue::Float32(value) => avg_accumulate!(self, value, Float32Array),
397 ScalarValue::Float64(value) => avg_accumulate!(self, value, Float64Array),
398 other => {
399 return Err(ExecutionError::General(format!(
400 "AVG does not support {:?}",
401 other
402 )))
403 }
404 }
405 }
406 Ok(())
407 }
408
accumulate_batch(&mut self, array: &ArrayRef) -> Result<()>409 fn accumulate_batch(&mut self, array: &ArrayRef) -> Result<()> {
410 for row in 0..array.len() {
411 self.accumulate_scalar(get_scalar_value(array, row)?)?;
412 }
413 Ok(())
414 }
415
get_value(&self) -> Result<Option<ScalarValue>>416 fn get_value(&self) -> Result<Option<ScalarValue>> {
417 match (self.sum, self.count) {
418 (Some(sum), Some(count)) => {
419 Ok(Some(ScalarValue::Float64(sum / count as f64)))
420 }
421 _ => Ok(None),
422 }
423 }
424 }
425
426 /// Create a avg expression
avg(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn AggregateExpr>427 pub fn avg(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn AggregateExpr> {
428 Arc::new(Avg::new(expr))
429 }
430
431 /// MAX aggregate expression
432 pub struct Max {
433 expr: Arc<dyn PhysicalExpr>,
434 }
435
436 impl Max {
437 /// Create a new MAX aggregate function
new(expr: Arc<dyn PhysicalExpr>) -> Self438 pub fn new(expr: Arc<dyn PhysicalExpr>) -> Self {
439 Self { expr }
440 }
441 }
442
443 impl AggregateExpr for Max {
name(&self) -> String444 fn name(&self) -> String {
445 "MAX".to_string()
446 }
447
data_type(&self, input_schema: &Schema) -> Result<DataType>448 fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
449 match self.expr.data_type(input_schema)? {
450 DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
451 Ok(DataType::Int64)
452 }
453 DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
454 Ok(DataType::UInt64)
455 }
456 DataType::Float32 => Ok(DataType::Float32),
457 DataType::Float64 => Ok(DataType::Float64),
458 other => Err(ExecutionError::General(format!(
459 "MAX does not support {:?}",
460 other
461 ))),
462 }
463 }
464
evaluate_input(&self, batch: &RecordBatch) -> Result<ArrayRef>465 fn evaluate_input(&self, batch: &RecordBatch) -> Result<ArrayRef> {
466 self.expr.evaluate(batch)
467 }
468
create_accumulator(&self) -> Rc<RefCell<dyn Accumulator>>469 fn create_accumulator(&self) -> Rc<RefCell<dyn Accumulator>> {
470 Rc::new(RefCell::new(MaxAccumulator { max: None }))
471 }
472
create_reducer(&self, column_index: usize) -> Arc<dyn AggregateExpr>473 fn create_reducer(&self, column_index: usize) -> Arc<dyn AggregateExpr> {
474 Arc::new(Max::new(Arc::new(Column::new(column_index, &self.name()))))
475 }
476 }
477
478 macro_rules! max_accumulate {
479 ($SELF:ident, $VALUE:expr, $ARRAY_TYPE:ident, $SCALAR_VARIANT:ident, $TY:ty) => {{
480 $SELF.max = match $SELF.max {
481 Some(ScalarValue::$SCALAR_VARIANT(n)) => {
482 if n > ($VALUE as $TY) {
483 Some(ScalarValue::$SCALAR_VARIANT(n))
484 } else {
485 Some(ScalarValue::$SCALAR_VARIANT($VALUE as $TY))
486 }
487 }
488 Some(_) => {
489 return Err(ExecutionError::InternalError(
490 "Unexpected ScalarValue variant".to_string(),
491 ))
492 }
493 None => Some(ScalarValue::$SCALAR_VARIANT($VALUE as $TY)),
494 };
495 }};
496 }
497 struct MaxAccumulator {
498 max: Option<ScalarValue>,
499 }
500
501 impl Accumulator for MaxAccumulator {
accumulate_scalar(&mut self, value: Option<ScalarValue>) -> Result<()>502 fn accumulate_scalar(&mut self, value: Option<ScalarValue>) -> Result<()> {
503 if let Some(value) = value {
504 match value {
505 ScalarValue::Int8(value) => {
506 max_accumulate!(self, value, Int8Array, Int64, i64);
507 }
508 ScalarValue::Int16(value) => {
509 max_accumulate!(self, value, Int16Array, Int64, i64)
510 }
511 ScalarValue::Int32(value) => {
512 max_accumulate!(self, value, Int32Array, Int64, i64)
513 }
514 ScalarValue::Int64(value) => {
515 max_accumulate!(self, value, Int64Array, Int64, i64)
516 }
517 ScalarValue::UInt8(value) => {
518 max_accumulate!(self, value, UInt8Array, UInt64, u64)
519 }
520 ScalarValue::UInt16(value) => {
521 max_accumulate!(self, value, UInt16Array, UInt64, u64)
522 }
523 ScalarValue::UInt32(value) => {
524 max_accumulate!(self, value, UInt32Array, UInt64, u64)
525 }
526 ScalarValue::UInt64(value) => {
527 max_accumulate!(self, value, UInt64Array, UInt64, u64)
528 }
529 ScalarValue::Float32(value) => {
530 max_accumulate!(self, value, Float32Array, Float32, f32)
531 }
532 ScalarValue::Float64(value) => {
533 max_accumulate!(self, value, Float64Array, Float64, f64)
534 }
535 other => {
536 return Err(ExecutionError::General(format!(
537 "MAX does not support {:?}",
538 other
539 )))
540 }
541 }
542 }
543 Ok(())
544 }
545
accumulate_batch(&mut self, array: &ArrayRef) -> Result<()>546 fn accumulate_batch(&mut self, array: &ArrayRef) -> Result<()> {
547 let max = match array.data_type() {
548 DataType::UInt8 => {
549 match compute::max(array.as_any().downcast_ref::<UInt8Array>().unwrap()) {
550 Some(n) => Ok(Some(ScalarValue::UInt8(n))),
551 None => Ok(None),
552 }
553 }
554 DataType::UInt16 => {
555 match compute::max(array.as_any().downcast_ref::<UInt16Array>().unwrap())
556 {
557 Some(n) => Ok(Some(ScalarValue::UInt16(n))),
558 None => Ok(None),
559 }
560 }
561 DataType::UInt32 => {
562 match compute::max(array.as_any().downcast_ref::<UInt32Array>().unwrap())
563 {
564 Some(n) => Ok(Some(ScalarValue::UInt32(n))),
565 None => Ok(None),
566 }
567 }
568 DataType::UInt64 => {
569 match compute::max(array.as_any().downcast_ref::<UInt64Array>().unwrap())
570 {
571 Some(n) => Ok(Some(ScalarValue::UInt64(n))),
572 None => Ok(None),
573 }
574 }
575 DataType::Int8 => {
576 match compute::max(array.as_any().downcast_ref::<Int8Array>().unwrap()) {
577 Some(n) => Ok(Some(ScalarValue::Int8(n))),
578 None => Ok(None),
579 }
580 }
581 DataType::Int16 => {
582 match compute::max(array.as_any().downcast_ref::<Int16Array>().unwrap()) {
583 Some(n) => Ok(Some(ScalarValue::Int16(n))),
584 None => Ok(None),
585 }
586 }
587 DataType::Int32 => {
588 match compute::max(array.as_any().downcast_ref::<Int32Array>().unwrap()) {
589 Some(n) => Ok(Some(ScalarValue::Int32(n))),
590 None => Ok(None),
591 }
592 }
593 DataType::Int64 => {
594 match compute::max(array.as_any().downcast_ref::<Int64Array>().unwrap()) {
595 Some(n) => Ok(Some(ScalarValue::Int64(n))),
596 None => Ok(None),
597 }
598 }
599 DataType::Float32 => {
600 match compute::max(array.as_any().downcast_ref::<Float32Array>().unwrap())
601 {
602 Some(n) => Ok(Some(ScalarValue::Float32(n))),
603 None => Ok(None),
604 }
605 }
606 DataType::Float64 => {
607 match compute::max(array.as_any().downcast_ref::<Float64Array>().unwrap())
608 {
609 Some(n) => Ok(Some(ScalarValue::Float64(n))),
610 None => Ok(None),
611 }
612 }
613 _ => Err(ExecutionError::ExecutionError(
614 "Unsupported data type for MAX".to_string(),
615 )),
616 }?;
617 self.accumulate_scalar(max)
618 }
619
get_value(&self) -> Result<Option<ScalarValue>>620 fn get_value(&self) -> Result<Option<ScalarValue>> {
621 Ok(self.max.clone())
622 }
623 }
624
625 /// Create a max expression
max(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn AggregateExpr>626 pub fn max(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn AggregateExpr> {
627 Arc::new(Max::new(expr))
628 }
629
630 /// MIN aggregate expression
631 pub struct Min {
632 expr: Arc<dyn PhysicalExpr>,
633 }
634
635 impl Min {
636 /// Create a new MIN aggregate function
new(expr: Arc<dyn PhysicalExpr>) -> Self637 pub fn new(expr: Arc<dyn PhysicalExpr>) -> Self {
638 Self { expr }
639 }
640 }
641
642 impl AggregateExpr for Min {
name(&self) -> String643 fn name(&self) -> String {
644 "MIN".to_string()
645 }
646
data_type(&self, input_schema: &Schema) -> Result<DataType>647 fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
648 match self.expr.data_type(input_schema)? {
649 DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
650 Ok(DataType::Int64)
651 }
652 DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
653 Ok(DataType::UInt64)
654 }
655 DataType::Float32 => Ok(DataType::Float32),
656 DataType::Float64 => Ok(DataType::Float64),
657 other => Err(ExecutionError::General(format!(
658 "MIN does not support {:?}",
659 other
660 ))),
661 }
662 }
663
evaluate_input(&self, batch: &RecordBatch) -> Result<ArrayRef>664 fn evaluate_input(&self, batch: &RecordBatch) -> Result<ArrayRef> {
665 self.expr.evaluate(batch)
666 }
667
create_accumulator(&self) -> Rc<RefCell<dyn Accumulator>>668 fn create_accumulator(&self) -> Rc<RefCell<dyn Accumulator>> {
669 Rc::new(RefCell::new(MinAccumulator { min: None }))
670 }
671
create_reducer(&self, column_index: usize) -> Arc<dyn AggregateExpr>672 fn create_reducer(&self, column_index: usize) -> Arc<dyn AggregateExpr> {
673 Arc::new(Min::new(Arc::new(Column::new(column_index, &self.name()))))
674 }
675 }
676
677 macro_rules! min_accumulate {
678 ($SELF:ident, $VALUE:expr, $ARRAY_TYPE:ident, $SCALAR_VARIANT:ident, $TY:ty) => {{
679 $SELF.min = match $SELF.min {
680 Some(ScalarValue::$SCALAR_VARIANT(n)) => {
681 if n < ($VALUE as $TY) {
682 Some(ScalarValue::$SCALAR_VARIANT(n))
683 } else {
684 Some(ScalarValue::$SCALAR_VARIANT($VALUE as $TY))
685 }
686 }
687 Some(_) => {
688 return Err(ExecutionError::InternalError(
689 "Unexpected ScalarValue variant".to_string(),
690 ))
691 }
692 None => Some(ScalarValue::$SCALAR_VARIANT($VALUE as $TY)),
693 };
694 }};
695 }
696 struct MinAccumulator {
697 min: Option<ScalarValue>,
698 }
699
700 impl Accumulator for MinAccumulator {
accumulate_scalar(&mut self, value: Option<ScalarValue>) -> Result<()>701 fn accumulate_scalar(&mut self, value: Option<ScalarValue>) -> Result<()> {
702 if let Some(value) = value {
703 match value {
704 ScalarValue::Int8(value) => {
705 min_accumulate!(self, value, Int8Array, Int64, i64);
706 }
707 ScalarValue::Int16(value) => {
708 min_accumulate!(self, value, Int16Array, Int64, i64)
709 }
710 ScalarValue::Int32(value) => {
711 min_accumulate!(self, value, Int32Array, Int64, i64)
712 }
713 ScalarValue::Int64(value) => {
714 min_accumulate!(self, value, Int64Array, Int64, i64)
715 }
716 ScalarValue::UInt8(value) => {
717 min_accumulate!(self, value, UInt8Array, UInt64, u64)
718 }
719 ScalarValue::UInt16(value) => {
720 min_accumulate!(self, value, UInt16Array, UInt64, u64)
721 }
722 ScalarValue::UInt32(value) => {
723 min_accumulate!(self, value, UInt32Array, UInt64, u64)
724 }
725 ScalarValue::UInt64(value) => {
726 min_accumulate!(self, value, UInt64Array, UInt64, u64)
727 }
728 ScalarValue::Float32(value) => {
729 min_accumulate!(self, value, Float32Array, Float32, f32)
730 }
731 ScalarValue::Float64(value) => {
732 min_accumulate!(self, value, Float64Array, Float64, f64)
733 }
734 other => {
735 return Err(ExecutionError::General(format!(
736 "MIN does not support {:?}",
737 other
738 )))
739 }
740 }
741 }
742 Ok(())
743 }
744
accumulate_batch(&mut self, array: &ArrayRef) -> Result<()>745 fn accumulate_batch(&mut self, array: &ArrayRef) -> Result<()> {
746 let min = match array.data_type() {
747 DataType::UInt8 => {
748 match compute::min(array.as_any().downcast_ref::<UInt8Array>().unwrap()) {
749 Some(n) => Ok(Some(ScalarValue::UInt8(n))),
750 None => Ok(None),
751 }
752 }
753 DataType::UInt16 => {
754 match compute::min(array.as_any().downcast_ref::<UInt16Array>().unwrap())
755 {
756 Some(n) => Ok(Some(ScalarValue::UInt16(n))),
757 None => Ok(None),
758 }
759 }
760 DataType::UInt32 => {
761 match compute::min(array.as_any().downcast_ref::<UInt32Array>().unwrap())
762 {
763 Some(n) => Ok(Some(ScalarValue::UInt32(n))),
764 None => Ok(None),
765 }
766 }
767 DataType::UInt64 => {
768 match compute::min(array.as_any().downcast_ref::<UInt64Array>().unwrap())
769 {
770 Some(n) => Ok(Some(ScalarValue::UInt64(n))),
771 None => Ok(None),
772 }
773 }
774 DataType::Int8 => {
775 match compute::min(array.as_any().downcast_ref::<Int8Array>().unwrap()) {
776 Some(n) => Ok(Some(ScalarValue::Int8(n))),
777 None => Ok(None),
778 }
779 }
780 DataType::Int16 => {
781 match compute::min(array.as_any().downcast_ref::<Int16Array>().unwrap()) {
782 Some(n) => Ok(Some(ScalarValue::Int16(n))),
783 None => Ok(None),
784 }
785 }
786 DataType::Int32 => {
787 match compute::min(array.as_any().downcast_ref::<Int32Array>().unwrap()) {
788 Some(n) => Ok(Some(ScalarValue::Int32(n))),
789 None => Ok(None),
790 }
791 }
792 DataType::Int64 => {
793 match compute::min(array.as_any().downcast_ref::<Int64Array>().unwrap()) {
794 Some(n) => Ok(Some(ScalarValue::Int64(n))),
795 None => Ok(None),
796 }
797 }
798 DataType::Float32 => {
799 match compute::min(array.as_any().downcast_ref::<Float32Array>().unwrap())
800 {
801 Some(n) => Ok(Some(ScalarValue::Float32(n))),
802 None => Ok(None),
803 }
804 }
805 DataType::Float64 => {
806 match compute::min(array.as_any().downcast_ref::<Float64Array>().unwrap())
807 {
808 Some(n) => Ok(Some(ScalarValue::Float64(n))),
809 None => Ok(None),
810 }
811 }
812 _ => Err(ExecutionError::ExecutionError(
813 "Unsupported data type for MIN".to_string(),
814 )),
815 }?;
816 self.accumulate_scalar(min)
817 }
818
get_value(&self) -> Result<Option<ScalarValue>>819 fn get_value(&self) -> Result<Option<ScalarValue>> {
820 Ok(self.min.clone())
821 }
822 }
823
824 /// Create a min expression
min(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn AggregateExpr>825 pub fn min(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn AggregateExpr> {
826 Arc::new(Min::new(expr))
827 }
828
829 /// COUNT aggregate expression
830 /// Returns the amount of non-null values of the given expression.
831 pub struct Count {
832 expr: Arc<dyn PhysicalExpr>,
833 }
834
835 impl Count {
836 /// Create a new COUNT aggregate function.
new(expr: Arc<dyn PhysicalExpr>) -> Self837 pub fn new(expr: Arc<dyn PhysicalExpr>) -> Self {
838 Self { expr: expr }
839 }
840 }
841
842 impl AggregateExpr for Count {
name(&self) -> String843 fn name(&self) -> String {
844 "COUNT".to_string()
845 }
846
data_type(&self, _input_schema: &Schema) -> Result<DataType>847 fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
848 Ok(DataType::UInt64)
849 }
850
evaluate_input(&self, batch: &RecordBatch) -> Result<ArrayRef>851 fn evaluate_input(&self, batch: &RecordBatch) -> Result<ArrayRef> {
852 self.expr.evaluate(batch)
853 }
854
create_accumulator(&self) -> Rc<RefCell<dyn Accumulator>>855 fn create_accumulator(&self) -> Rc<RefCell<dyn Accumulator>> {
856 Rc::new(RefCell::new(CountAccumulator { count: 0 }))
857 }
858
create_reducer(&self, column_index: usize) -> Arc<dyn AggregateExpr>859 fn create_reducer(&self, column_index: usize) -> Arc<dyn AggregateExpr> {
860 Arc::new(Sum::new(Arc::new(Column::new(column_index, &self.name()))))
861 }
862 }
863
864 struct CountAccumulator {
865 count: u64,
866 }
867
868 impl Accumulator for CountAccumulator {
accumulate_scalar(&mut self, value: Option<ScalarValue>) -> Result<()>869 fn accumulate_scalar(&mut self, value: Option<ScalarValue>) -> Result<()> {
870 if value.is_some() {
871 self.count += 1;
872 }
873 Ok(())
874 }
875
accumulate_batch(&mut self, array: &ArrayRef) -> Result<()>876 fn accumulate_batch(&mut self, array: &ArrayRef) -> Result<()> {
877 self.count += array.len() as u64 - array.null_count() as u64;
878 Ok(())
879 }
880
get_value(&self) -> Result<Option<ScalarValue>>881 fn get_value(&self) -> Result<Option<ScalarValue>> {
882 Ok(Some(ScalarValue::UInt64(self.count)))
883 }
884 }
885
886 /// Create a count expression
count(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn AggregateExpr>887 pub fn count(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn AggregateExpr> {
888 Arc::new(Count::new(expr))
889 }
890
891 /// Invoke a compute kernel on a pair of binary data arrays
892 macro_rules! compute_utf8_op {
893 ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
894 let ll = $LEFT
895 .as_any()
896 .downcast_ref::<$DT>()
897 .expect("compute_op failed to downcast array");
898 let rr = $RIGHT
899 .as_any()
900 .downcast_ref::<$DT>()
901 .expect("compute_op failed to downcast array");
902 Ok(Arc::new(paste::expr! {[<$OP _utf8>]}(&ll, &rr)?))
903 }};
904 }
905
906 /// Invoke a compute kernel on a pair of arrays
907 macro_rules! compute_op {
908 ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
909 let ll = $LEFT
910 .as_any()
911 .downcast_ref::<$DT>()
912 .expect("compute_op failed to downcast array");
913 let rr = $RIGHT
914 .as_any()
915 .downcast_ref::<$DT>()
916 .expect("compute_op failed to downcast array");
917 Ok(Arc::new($OP(&ll, &rr)?))
918 }};
919 }
920
921 macro_rules! binary_string_array_op {
922 ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
923 match $LEFT.data_type() {
924 DataType::Utf8 => compute_utf8_op!($LEFT, $RIGHT, $OP, StringArray),
925 other => Err(ExecutionError::General(format!(
926 "Unsupported data type {:?}",
927 other
928 ))),
929 }
930 }};
931 }
932
933 /// Invoke a compute kernel on a pair of arrays
934 /// The binary_primitive_array_op macro only evaluates for primitive types
935 /// like integers and floats.
936 macro_rules! binary_primitive_array_op {
937 ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
938 match $LEFT.data_type() {
939 DataType::Int8 => compute_op!($LEFT, $RIGHT, $OP, Int8Array),
940 DataType::Int16 => compute_op!($LEFT, $RIGHT, $OP, Int16Array),
941 DataType::Int32 => compute_op!($LEFT, $RIGHT, $OP, Int32Array),
942 DataType::Int64 => compute_op!($LEFT, $RIGHT, $OP, Int64Array),
943 DataType::UInt8 => compute_op!($LEFT, $RIGHT, $OP, UInt8Array),
944 DataType::UInt16 => compute_op!($LEFT, $RIGHT, $OP, UInt16Array),
945 DataType::UInt32 => compute_op!($LEFT, $RIGHT, $OP, UInt32Array),
946 DataType::UInt64 => compute_op!($LEFT, $RIGHT, $OP, UInt64Array),
947 DataType::Float32 => compute_op!($LEFT, $RIGHT, $OP, Float32Array),
948 DataType::Float64 => compute_op!($LEFT, $RIGHT, $OP, Float64Array),
949 other => Err(ExecutionError::General(format!(
950 "Unsupported data type {:?}",
951 other
952 ))),
953 }
954 }};
955 }
956
957 /// The binary_array_op macro includes types that extend beyond the primitive,
958 /// such as Utf8 strings.
959 macro_rules! binary_array_op {
960 ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
961 match $LEFT.data_type() {
962 DataType::Int8 => compute_op!($LEFT, $RIGHT, $OP, Int8Array),
963 DataType::Int16 => compute_op!($LEFT, $RIGHT, $OP, Int16Array),
964 DataType::Int32 => compute_op!($LEFT, $RIGHT, $OP, Int32Array),
965 DataType::Int64 => compute_op!($LEFT, $RIGHT, $OP, Int64Array),
966 DataType::UInt8 => compute_op!($LEFT, $RIGHT, $OP, UInt8Array),
967 DataType::UInt16 => compute_op!($LEFT, $RIGHT, $OP, UInt16Array),
968 DataType::UInt32 => compute_op!($LEFT, $RIGHT, $OP, UInt32Array),
969 DataType::UInt64 => compute_op!($LEFT, $RIGHT, $OP, UInt64Array),
970 DataType::Float32 => compute_op!($LEFT, $RIGHT, $OP, Float32Array),
971 DataType::Float64 => compute_op!($LEFT, $RIGHT, $OP, Float64Array),
972 DataType::Utf8 => compute_utf8_op!($LEFT, $RIGHT, $OP, StringArray),
973 DataType::Timestamp(TimeUnit::Nanosecond, None) => {
974 compute_op!($LEFT, $RIGHT, $OP, TimestampNanosecondArray)
975 }
976 other => Err(ExecutionError::General(format!(
977 "Unsupported data type {:?}",
978 other
979 ))),
980 }
981 }};
982 }
983
984 /// Invoke a boolean kernel on a pair of arrays
985 macro_rules! boolean_op {
986 ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
987 let ll = $LEFT
988 .as_any()
989 .downcast_ref::<BooleanArray>()
990 .expect("boolean_op failed to downcast array");
991 let rr = $RIGHT
992 .as_any()
993 .downcast_ref::<BooleanArray>()
994 .expect("boolean_op failed to downcast array");
995 Ok(Arc::new($OP(&ll, &rr)?))
996 }};
997 }
998 /// Binary expression
999 pub struct BinaryExpr {
1000 left: Arc<dyn PhysicalExpr>,
1001 op: Operator,
1002 right: Arc<dyn PhysicalExpr>,
1003 }
1004
1005 impl BinaryExpr {
1006 /// Create new binary expression
new( left: Arc<dyn PhysicalExpr>, op: Operator, right: Arc<dyn PhysicalExpr>, ) -> Self1007 pub fn new(
1008 left: Arc<dyn PhysicalExpr>,
1009 op: Operator,
1010 right: Arc<dyn PhysicalExpr>,
1011 ) -> Self {
1012 Self { left, op, right }
1013 }
1014 }
1015
1016 impl PhysicalExpr for BinaryExpr {
name(&self) -> String1017 fn name(&self) -> String {
1018 format!("{:?}", self.op)
1019 }
1020
data_type(&self, input_schema: &Schema) -> Result<DataType>1021 fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
1022 self.left.data_type(input_schema)
1023 }
1024
evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef>1025 fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
1026 let left = self.left.evaluate(batch)?;
1027 let right = self.right.evaluate(batch)?;
1028 if left.data_type() != right.data_type() {
1029 return Err(ExecutionError::General(format!(
1030 "Cannot evaluate binary expression {:?} with types {:?} and {:?}",
1031 self.op,
1032 left.data_type(),
1033 right.data_type()
1034 )));
1035 }
1036 match &self.op {
1037 Operator::Like => binary_string_array_op!(left, right, like),
1038 Operator::NotLike => binary_string_array_op!(left, right, nlike),
1039 Operator::Lt => binary_array_op!(left, right, lt),
1040 Operator::LtEq => binary_array_op!(left, right, lt_eq),
1041 Operator::Gt => binary_array_op!(left, right, gt),
1042 Operator::GtEq => binary_array_op!(left, right, gt_eq),
1043 Operator::Eq => binary_array_op!(left, right, eq),
1044 Operator::NotEq => binary_array_op!(left, right, neq),
1045 Operator::Plus => binary_primitive_array_op!(left, right, add),
1046 Operator::Minus => binary_primitive_array_op!(left, right, subtract),
1047 Operator::Multiply => binary_primitive_array_op!(left, right, multiply),
1048 Operator::Divide => binary_primitive_array_op!(left, right, divide),
1049 Operator::And => {
1050 if left.data_type() == &DataType::Boolean {
1051 boolean_op!(left, right, and)
1052 } else {
1053 return Err(ExecutionError::General(format!(
1054 "Cannot evaluate binary expression {:?} with types {:?} and {:?}",
1055 self.op,
1056 left.data_type(),
1057 right.data_type()
1058 )));
1059 }
1060 }
1061 Operator::Or => {
1062 if left.data_type() == &DataType::Boolean {
1063 boolean_op!(left, right, or)
1064 } else {
1065 return Err(ExecutionError::General(format!(
1066 "Cannot evaluate binary expression {:?} with types {:?} and {:?}",
1067 self.op,
1068 left.data_type(),
1069 right.data_type()
1070 )));
1071 }
1072 }
1073 _ => Err(ExecutionError::General("Unsupported operator".to_string())),
1074 }
1075 }
1076 }
1077
1078 /// Create a binary expression
binary( l: Arc<dyn PhysicalExpr>, op: Operator, r: Arc<dyn PhysicalExpr>, ) -> Arc<dyn PhysicalExpr>1079 pub fn binary(
1080 l: Arc<dyn PhysicalExpr>,
1081 op: Operator,
1082 r: Arc<dyn PhysicalExpr>,
1083 ) -> Arc<dyn PhysicalExpr> {
1084 Arc::new(BinaryExpr::new(l, op, r))
1085 }
1086
1087 /// Not expression
1088 pub struct NotExpr {
1089 arg: Arc<dyn PhysicalExpr>,
1090 }
1091
1092 impl NotExpr {
1093 /// Create new not expression
new(arg: Arc<dyn PhysicalExpr>) -> Self1094 pub fn new(arg: Arc<dyn PhysicalExpr>) -> Self {
1095 Self { arg }
1096 }
1097 }
1098
1099 impl PhysicalExpr for NotExpr {
name(&self) -> String1100 fn name(&self) -> String {
1101 "NOT".to_string()
1102 }
1103
data_type(&self, _input_schema: &Schema) -> Result<DataType>1104 fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
1105 return Ok(DataType::Boolean);
1106 }
1107
evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef>1108 fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
1109 let arg = self.arg.evaluate(batch)?;
1110 if arg.data_type() != &DataType::Boolean {
1111 return Err(ExecutionError::General(format!(
1112 "Cannot evaluate \"not\" expression with type {:?}",
1113 arg.data_type(),
1114 )));
1115 }
1116 let arg = arg
1117 .as_any()
1118 .downcast_ref::<BooleanArray>()
1119 .expect("boolean_op failed to downcast array");
1120 return Ok(Arc::new(arrow::compute::kernels::boolean::not(arg)?));
1121 }
1122 }
1123
1124 /// Create a unary expression
not(arg: Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr>1125 pub fn not(arg: Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr> {
1126 Arc::new(NotExpr::new(arg))
1127 }
1128
1129 /// CAST expression casts an expression to a specific data type
1130 pub struct CastExpr {
1131 /// The expression to cast
1132 expr: Arc<dyn PhysicalExpr>,
1133 /// The data type to cast to
1134 cast_type: DataType,
1135 }
1136
1137 /// Determine if a DataType is numeric or not
is_numeric(dt: &DataType) -> bool1138 fn is_numeric(dt: &DataType) -> bool {
1139 match dt {
1140 DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => true,
1141 DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => true,
1142 DataType::Float16 | DataType::Float32 | DataType::Float64 => true,
1143 _ => false,
1144 }
1145 }
1146
1147 impl CastExpr {
1148 /// Create a CAST expression
try_new( expr: Arc<dyn PhysicalExpr>, input_schema: &Schema, cast_type: DataType, ) -> Result<Self>1149 pub fn try_new(
1150 expr: Arc<dyn PhysicalExpr>,
1151 input_schema: &Schema,
1152 cast_type: DataType,
1153 ) -> Result<Self> {
1154 let expr_type = expr.data_type(input_schema)?;
1155 // numbers can be cast to numbers and strings
1156 if is_numeric(&expr_type)
1157 && (is_numeric(&cast_type) || cast_type == DataType::Utf8)
1158 {
1159 Ok(Self { expr, cast_type })
1160 } else if expr_type == DataType::Binary && cast_type == DataType::Utf8 {
1161 Ok(Self { expr, cast_type })
1162 } else if is_numeric(&expr_type)
1163 && cast_type == DataType::Timestamp(TimeUnit::Nanosecond, None)
1164 {
1165 Ok(Self { expr, cast_type })
1166 } else {
1167 Err(ExecutionError::General(format!(
1168 "Invalid CAST from {:?} to {:?}",
1169 expr_type, cast_type
1170 )))
1171 }
1172 }
1173 }
1174
1175 impl PhysicalExpr for CastExpr {
name(&self) -> String1176 fn name(&self) -> String {
1177 "CAST".to_string()
1178 }
1179
data_type(&self, _input_schema: &Schema) -> Result<DataType>1180 fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
1181 Ok(self.cast_type.clone())
1182 }
1183
evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef>1184 fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
1185 let value = self.expr.evaluate(batch)?;
1186 Ok(cast(&value, &self.cast_type)?)
1187 }
1188 }
1189
1190 /// Represents a non-null literal value
1191 pub struct Literal {
1192 value: ScalarValue,
1193 }
1194
1195 impl Literal {
1196 /// Create a literal value expression
new(value: ScalarValue) -> Self1197 pub fn new(value: ScalarValue) -> Self {
1198 Self { value }
1199 }
1200 }
1201
1202 /// Build array containing the same literal value repeated. This is necessary because the Arrow
1203 /// memory model does not have the concept of a scalar value currently.
1204 macro_rules! build_literal_array {
1205 ($BATCH:ident, $BUILDER:ident, $VALUE:expr) => {{
1206 let mut builder = $BUILDER::new($BATCH.num_rows());
1207 for _ in 0..$BATCH.num_rows() {
1208 builder.append_value($VALUE)?;
1209 }
1210 Ok(Arc::new(builder.finish()))
1211 }};
1212 }
1213
1214 impl PhysicalExpr for Literal {
name(&self) -> String1215 fn name(&self) -> String {
1216 "lit".to_string()
1217 }
1218
data_type(&self, _input_schema: &Schema) -> Result<DataType>1219 fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
1220 Ok(self.value.get_datatype())
1221 }
1222
evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef>1223 fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
1224 match &self.value {
1225 ScalarValue::Int8(value) => build_literal_array!(batch, Int8Builder, *value),
1226 ScalarValue::Int16(value) => {
1227 build_literal_array!(batch, Int16Builder, *value)
1228 }
1229 ScalarValue::Int32(value) => {
1230 build_literal_array!(batch, Int32Builder, *value)
1231 }
1232 ScalarValue::Int64(value) => {
1233 build_literal_array!(batch, Int64Builder, *value)
1234 }
1235 ScalarValue::UInt8(value) => {
1236 build_literal_array!(batch, UInt8Builder, *value)
1237 }
1238 ScalarValue::UInt16(value) => {
1239 build_literal_array!(batch, UInt16Builder, *value)
1240 }
1241 ScalarValue::UInt32(value) => {
1242 build_literal_array!(batch, UInt32Builder, *value)
1243 }
1244 ScalarValue::UInt64(value) => {
1245 build_literal_array!(batch, UInt64Builder, *value)
1246 }
1247 ScalarValue::Float32(value) => {
1248 build_literal_array!(batch, Float32Builder, *value)
1249 }
1250 ScalarValue::Float64(value) => {
1251 build_literal_array!(batch, Float64Builder, *value)
1252 }
1253 ScalarValue::Utf8(value) => build_literal_array!(batch, StringBuilder, value),
1254 other => Err(ExecutionError::General(format!(
1255 "Unsupported literal type {:?}",
1256 other
1257 ))),
1258 }
1259 }
1260 }
1261
1262 /// Create a literal expression
lit(value: ScalarValue) -> Arc<dyn PhysicalExpr>1263 pub fn lit(value: ScalarValue) -> Arc<dyn PhysicalExpr> {
1264 Arc::new(Literal::new(value))
1265 }
1266
1267 #[cfg(test)]
1268 mod tests {
1269 use super::*;
1270 use crate::error::Result;
1271 use crate::execution::physical_plan::common::get_scalar_value;
1272 use arrow::array::{PrimitiveArray, StringArray, Time64NanosecondArray};
1273 use arrow::datatypes::*;
1274
1275 #[test]
binary_comparison() -> Result<()>1276 fn binary_comparison() -> Result<()> {
1277 let schema = Schema::new(vec![
1278 Field::new("a", DataType::Int32, false),
1279 Field::new("b", DataType::Int32, false),
1280 ]);
1281 let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1282 let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
1283 let batch = RecordBatch::try_new(
1284 Arc::new(schema.clone()),
1285 vec![Arc::new(a), Arc::new(b)],
1286 )?;
1287
1288 // expression: "a < b"
1289 let lt = binary(col(0, &schema), Operator::Lt, col(1, &schema));
1290 let result = lt.evaluate(&batch)?;
1291 assert_eq!(result.len(), 5);
1292
1293 let expected = vec![false, false, true, true, true];
1294 let result = result
1295 .as_any()
1296 .downcast_ref::<BooleanArray>()
1297 .expect("failed to downcast to BooleanArray");
1298 for i in 0..5 {
1299 assert_eq!(result.value(i), expected[i]);
1300 }
1301
1302 Ok(())
1303 }
1304
1305 #[test]
binary_nested() -> Result<()>1306 fn binary_nested() -> Result<()> {
1307 let schema = Schema::new(vec![
1308 Field::new("a", DataType::Int32, false),
1309 Field::new("b", DataType::Int32, false),
1310 ]);
1311 let a = Int32Array::from(vec![2, 4, 6, 8, 10]);
1312 let b = Int32Array::from(vec![2, 5, 4, 8, 8]);
1313 let batch = RecordBatch::try_new(
1314 Arc::new(schema.clone()),
1315 vec![Arc::new(a), Arc::new(b)],
1316 )?;
1317
1318 // expression: "a < b OR a == b"
1319 let expr = binary(
1320 binary(col(0, &schema), Operator::Lt, col(1, &schema)),
1321 Operator::Or,
1322 binary(col(0, &schema), Operator::Eq, col(1, &schema)),
1323 );
1324 let result = expr.evaluate(&batch)?;
1325 assert_eq!(result.len(), 5);
1326
1327 let expected = vec![true, true, false, true, false];
1328 let result = result
1329 .as_any()
1330 .downcast_ref::<BooleanArray>()
1331 .expect("failed to downcast to BooleanArray");
1332 for i in 0..5 {
1333 print!("{}", i);
1334 assert_eq!(result.value(i), expected[i]);
1335 }
1336
1337 Ok(())
1338 }
1339
1340 #[test]
literal_i32() -> Result<()>1341 fn literal_i32() -> Result<()> {
1342 // create an arbitrary record bacth
1343 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1344 let a = Int32Array::from(vec![Some(1), None, Some(3), Some(4), Some(5)]);
1345 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1346
1347 // create and evaluate a literal expression
1348 let literal_expr = lit(ScalarValue::Int32(42));
1349 let literal_array = literal_expr.evaluate(&batch)?;
1350 let literal_array = literal_array.as_any().downcast_ref::<Int32Array>().unwrap();
1351
1352 // note that the contents of the literal array are unrelated to the batch contents except for the length of the array
1353 assert_eq!(literal_array.len(), 5); // 5 rows in the batch
1354 for i in 0..literal_array.len() {
1355 assert_eq!(literal_array.value(i), 42);
1356 }
1357
1358 Ok(())
1359 }
1360
1361 #[test]
cast_i32_to_u32() -> Result<()>1362 fn cast_i32_to_u32() -> Result<()> {
1363 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1364 let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1365 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1366
1367 let cast = CastExpr::try_new(col(0, &schema), &schema, DataType::UInt32)?;
1368 let result = cast.evaluate(&batch)?;
1369 assert_eq!(result.len(), 5);
1370
1371 let result = result
1372 .as_any()
1373 .downcast_ref::<UInt32Array>()
1374 .expect("failed to downcast to UInt32Array");
1375 assert_eq!(result.value(0), 1_u32);
1376
1377 Ok(())
1378 }
1379
1380 #[test]
cast_i32_to_utf8() -> Result<()>1381 fn cast_i32_to_utf8() -> Result<()> {
1382 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1383 let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1384 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1385
1386 let cast = CastExpr::try_new(col(0, &schema), &schema, DataType::Utf8)?;
1387 let result = cast.evaluate(&batch)?;
1388 assert_eq!(result.len(), 5);
1389
1390 let result = result
1391 .as_any()
1392 .downcast_ref::<StringArray>()
1393 .expect("failed to downcast to StringArray");
1394 assert_eq!(result.value(0), "1");
1395
1396 Ok(())
1397 }
1398
1399 #[test]
cast_i64_to_timestamp_nanoseconds() -> Result<()>1400 fn cast_i64_to_timestamp_nanoseconds() -> Result<()> {
1401 let schema = Schema::new(vec![Field::new("a", DataType::Int64, false)]);
1402 let a = Int64Array::from(vec![1, 2, 3, 4, 5]);
1403 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1404
1405 let cast = CastExpr::try_new(
1406 col(0, &schema),
1407 &schema,
1408 DataType::Timestamp(TimeUnit::Nanosecond, None),
1409 )?;
1410 let result = cast.evaluate(&batch)?;
1411 assert_eq!(result.len(), 5);
1412 let expected_result = Time64NanosecondArray::from(vec![1, 2, 3, 4]);
1413 let result = result
1414 .as_any()
1415 .downcast_ref::<TimestampNanosecondArray>()
1416 .expect("failed to downcast to TimestampNanosecondArray");
1417 assert_eq!(result.value(0), expected_result.value(0));
1418
1419 Ok(())
1420 }
1421
1422 #[test]
invalid_cast() -> Result<()>1423 fn invalid_cast() -> Result<()> {
1424 let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]);
1425 match CastExpr::try_new(col(0, &schema), &schema, DataType::Int32) {
1426 Err(ExecutionError::General(ref str)) => {
1427 assert_eq!(str, "Invalid CAST from Utf8 to Int32");
1428 Ok(())
1429 }
1430 _ => panic!(),
1431 }
1432 }
1433
1434 #[test]
sum_contract() -> Result<()>1435 fn sum_contract() -> Result<()> {
1436 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1437
1438 let sum = sum(col(0, &schema));
1439 assert_eq!("SUM".to_string(), sum.name());
1440 assert_eq!(DataType::Int64, sum.data_type(&schema)?);
1441
1442 let combiner = sum.create_reducer(0);
1443 assert_eq!("SUM".to_string(), combiner.name());
1444 assert_eq!(DataType::Int64, combiner.data_type(&schema)?);
1445
1446 Ok(())
1447 }
1448
1449 #[test]
max_contract() -> Result<()>1450 fn max_contract() -> Result<()> {
1451 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1452
1453 let max = max(col(0, &schema));
1454 assert_eq!("MAX".to_string(), max.name());
1455 assert_eq!(DataType::Int64, max.data_type(&schema)?);
1456
1457 let combiner = max.create_reducer(0);
1458 assert_eq!("MAX".to_string(), combiner.name());
1459 assert_eq!(DataType::Int64, combiner.data_type(&schema)?);
1460
1461 Ok(())
1462 }
1463
1464 #[test]
min_contract() -> Result<()>1465 fn min_contract() -> Result<()> {
1466 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1467
1468 let min = min(col(0, &schema));
1469 assert_eq!("MIN".to_string(), min.name());
1470 assert_eq!(DataType::Int64, min.data_type(&schema)?);
1471
1472 let combiner = min.create_reducer(0);
1473 assert_eq!("MIN".to_string(), combiner.name());
1474 assert_eq!(DataType::Int64, combiner.data_type(&schema)?);
1475
1476 Ok(())
1477 }
1478 #[test]
avg_contract() -> Result<()>1479 fn avg_contract() -> Result<()> {
1480 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1481
1482 let avg = avg(col(0, &schema));
1483 assert_eq!("AVG".to_string(), avg.name());
1484 assert_eq!(DataType::Float64, avg.data_type(&schema)?);
1485
1486 let combiner = avg.create_reducer(0);
1487 assert_eq!("AVG".to_string(), combiner.name());
1488 assert_eq!(DataType::Float64, combiner.data_type(&schema)?);
1489
1490 Ok(())
1491 }
1492
1493 #[test]
sum_i32() -> Result<()>1494 fn sum_i32() -> Result<()> {
1495 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1496
1497 let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1498 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1499
1500 assert_eq!(do_sum(&batch)?, Some(ScalarValue::Int64(15)));
1501
1502 Ok(())
1503 }
1504
1505 #[test]
avg_i32() -> Result<()>1506 fn avg_i32() -> Result<()> {
1507 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1508
1509 let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1510 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1511
1512 assert_eq!(do_avg(&batch)?, Some(ScalarValue::Float64(3_f64)));
1513
1514 Ok(())
1515 }
1516
1517 #[test]
max_i32() -> Result<()>1518 fn max_i32() -> Result<()> {
1519 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1520
1521 let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1522 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1523
1524 assert_eq!(do_max(&batch)?, Some(ScalarValue::Int64(5)));
1525
1526 Ok(())
1527 }
1528
1529 #[test]
min_i32() -> Result<()>1530 fn min_i32() -> Result<()> {
1531 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1532
1533 let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1534 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1535
1536 assert_eq!(do_min(&batch)?, Some(ScalarValue::Int64(1)));
1537
1538 Ok(())
1539 }
1540
1541 #[test]
sum_i32_with_nulls() -> Result<()>1542 fn sum_i32_with_nulls() -> Result<()> {
1543 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1544
1545 let a = Int32Array::from(vec![Some(1), None, Some(3), Some(4), Some(5)]);
1546 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1547
1548 assert_eq!(do_sum(&batch)?, Some(ScalarValue::Int64(13)));
1549
1550 Ok(())
1551 }
1552
1553 #[test]
avg_i32_with_nulls() -> Result<()>1554 fn avg_i32_with_nulls() -> Result<()> {
1555 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1556
1557 let a = Int32Array::from(vec![Some(1), None, Some(3), Some(4), Some(5)]);
1558 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1559
1560 assert_eq!(do_avg(&batch)?, Some(ScalarValue::Float64(3.25)));
1561
1562 Ok(())
1563 }
1564
1565 #[test]
max_i32_with_nulls() -> Result<()>1566 fn max_i32_with_nulls() -> Result<()> {
1567 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1568
1569 let a = Int32Array::from(vec![Some(1), None, Some(3), Some(4), Some(5)]);
1570 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1571
1572 assert_eq!(do_max(&batch)?, Some(ScalarValue::Int64(5)));
1573
1574 Ok(())
1575 }
1576
1577 #[test]
min_i32_with_nulls() -> Result<()>1578 fn min_i32_with_nulls() -> Result<()> {
1579 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1580
1581 let a = Int32Array::from(vec![Some(1), None, Some(3), Some(4), Some(5)]);
1582 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1583
1584 assert_eq!(do_min(&batch)?, Some(ScalarValue::Int64(1)));
1585
1586 Ok(())
1587 }
1588
1589 #[test]
sum_i32_all_nulls() -> Result<()>1590 fn sum_i32_all_nulls() -> Result<()> {
1591 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1592
1593 let a = Int32Array::from(vec![None, None]);
1594 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1595
1596 assert_eq!(do_sum(&batch)?, None);
1597
1598 Ok(())
1599 }
1600
1601 #[test]
max_i32_all_nulls() -> Result<()>1602 fn max_i32_all_nulls() -> Result<()> {
1603 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1604
1605 let a = Int32Array::from(vec![None, None]);
1606 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1607
1608 assert_eq!(do_max(&batch)?, None);
1609
1610 Ok(())
1611 }
1612
1613 #[test]
min_i32_all_nulls() -> Result<()>1614 fn min_i32_all_nulls() -> Result<()> {
1615 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1616
1617 let a = Int32Array::from(vec![None, None]);
1618 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1619
1620 assert_eq!(do_min(&batch)?, None);
1621
1622 Ok(())
1623 }
1624
1625 #[test]
avg_i32_all_nulls() -> Result<()>1626 fn avg_i32_all_nulls() -> Result<()> {
1627 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1628
1629 let a = Int32Array::from(vec![None, None]);
1630 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1631
1632 assert_eq!(do_avg(&batch)?, None);
1633
1634 Ok(())
1635 }
1636
1637 #[test]
sum_u32() -> Result<()>1638 fn sum_u32() -> Result<()> {
1639 let schema = Schema::new(vec![Field::new("a", DataType::UInt32, false)]);
1640
1641 let a = UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32]);
1642 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1643
1644 assert_eq!(do_sum(&batch)?, Some(ScalarValue::UInt64(15_u64)));
1645
1646 Ok(())
1647 }
1648
1649 #[test]
avg_u32() -> Result<()>1650 fn avg_u32() -> Result<()> {
1651 let schema = Schema::new(vec![Field::new("a", DataType::UInt32, false)]);
1652
1653 let a = UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32]);
1654 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1655
1656 assert_eq!(do_avg(&batch)?, Some(ScalarValue::Float64(3_f64)));
1657
1658 Ok(())
1659 }
1660
1661 #[test]
max_u32() -> Result<()>1662 fn max_u32() -> Result<()> {
1663 let schema = Schema::new(vec![Field::new("a", DataType::UInt32, false)]);
1664
1665 let a = UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32]);
1666 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1667
1668 assert_eq!(do_max(&batch)?, Some(ScalarValue::UInt64(5_u64)));
1669
1670 Ok(())
1671 }
1672
1673 #[test]
min_u32() -> Result<()>1674 fn min_u32() -> Result<()> {
1675 let schema = Schema::new(vec![Field::new("a", DataType::UInt32, false)]);
1676
1677 let a = UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32]);
1678 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1679
1680 assert_eq!(do_min(&batch)?, Some(ScalarValue::UInt64(1_u64)));
1681
1682 Ok(())
1683 }
1684
1685 #[test]
sum_f32() -> Result<()>1686 fn sum_f32() -> Result<()> {
1687 let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]);
1688
1689 let a = Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32]);
1690 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1691
1692 assert_eq!(do_sum(&batch)?, Some(ScalarValue::Float32(15_f32)));
1693
1694 Ok(())
1695 }
1696
1697 #[test]
avg_f32() -> Result<()>1698 fn avg_f32() -> Result<()> {
1699 let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]);
1700
1701 let a = Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32]);
1702 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1703
1704 assert_eq!(do_avg(&batch)?, Some(ScalarValue::Float64(3_f64)));
1705
1706 Ok(())
1707 }
1708
1709 #[test]
max_f32() -> Result<()>1710 fn max_f32() -> Result<()> {
1711 let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]);
1712
1713 let a = Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32]);
1714 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1715
1716 assert_eq!(do_max(&batch)?, Some(ScalarValue::Float32(5_f32)));
1717
1718 Ok(())
1719 }
1720
1721 #[test]
min_f32() -> Result<()>1722 fn min_f32() -> Result<()> {
1723 let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]);
1724
1725 let a = Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32]);
1726 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1727
1728 assert_eq!(do_min(&batch)?, Some(ScalarValue::Float32(1_f32)));
1729
1730 Ok(())
1731 }
1732
1733 #[test]
sum_f64() -> Result<()>1734 fn sum_f64() -> Result<()> {
1735 let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]);
1736
1737 let a = Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]);
1738 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1739
1740 assert_eq!(do_sum(&batch)?, Some(ScalarValue::Float64(15_f64)));
1741
1742 Ok(())
1743 }
1744
1745 #[test]
avg_f64() -> Result<()>1746 fn avg_f64() -> Result<()> {
1747 let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]);
1748
1749 let a = Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]);
1750 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1751
1752 assert_eq!(do_avg(&batch)?, Some(ScalarValue::Float64(3_f64)));
1753
1754 Ok(())
1755 }
1756
1757 #[test]
max_f64() -> Result<()>1758 fn max_f64() -> Result<()> {
1759 let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]);
1760
1761 let a = Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]);
1762 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1763
1764 assert_eq!(do_max(&batch)?, Some(ScalarValue::Float64(5_f64)));
1765
1766 Ok(())
1767 }
1768
1769 #[test]
min_f64() -> Result<()>1770 fn min_f64() -> Result<()> {
1771 let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]);
1772
1773 let a = Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]);
1774 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1775
1776 assert_eq!(do_min(&batch)?, Some(ScalarValue::Float64(1_f64)));
1777
1778 Ok(())
1779 }
1780
1781 #[test]
count_elements() -> Result<()>1782 fn count_elements() -> Result<()> {
1783 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1784 let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1785 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1786 assert_eq!(do_count(&batch)?, Some(ScalarValue::UInt64(5)));
1787 Ok(())
1788 }
1789
1790 #[test]
count_with_nulls() -> Result<()>1791 fn count_with_nulls() -> Result<()> {
1792 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1793 let a = Int32Array::from(vec![Some(1), Some(2), None, None, Some(3), None]);
1794 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1795 assert_eq!(do_count(&batch)?, Some(ScalarValue::UInt64(3)));
1796 Ok(())
1797 }
1798
1799 #[test]
count_all_nulls() -> Result<()>1800 fn count_all_nulls() -> Result<()> {
1801 let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]);
1802 let a = BooleanArray::from(vec![None, None, None, None, None, None, None, None]);
1803 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1804 assert_eq!(do_count(&batch)?, Some(ScalarValue::UInt64(0)));
1805 Ok(())
1806 }
1807
1808 #[test]
count_empty() -> Result<()>1809 fn count_empty() -> Result<()> {
1810 let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]);
1811 let a = BooleanArray::from(Vec::<bool>::new());
1812 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1813 assert_eq!(do_count(&batch)?, Some(ScalarValue::UInt64(0)));
1814 Ok(())
1815 }
1816
do_sum(batch: &RecordBatch) -> Result<Option<ScalarValue>>1817 fn do_sum(batch: &RecordBatch) -> Result<Option<ScalarValue>> {
1818 let sum = sum(col(0, &batch.schema()));
1819 let accum = sum.create_accumulator();
1820 let input = sum.evaluate_input(batch)?;
1821 let mut accum = accum.borrow_mut();
1822 for i in 0..batch.num_rows() {
1823 accum.accumulate_scalar(get_scalar_value(&input, i)?)?;
1824 }
1825 accum.get_value()
1826 }
1827
do_max(batch: &RecordBatch) -> Result<Option<ScalarValue>>1828 fn do_max(batch: &RecordBatch) -> Result<Option<ScalarValue>> {
1829 let max = max(col(0, &batch.schema()));
1830 let accum = max.create_accumulator();
1831 let input = max.evaluate_input(batch)?;
1832 let mut accum = accum.borrow_mut();
1833 for i in 0..batch.num_rows() {
1834 accum.accumulate_scalar(get_scalar_value(&input, i)?)?;
1835 }
1836 accum.get_value()
1837 }
1838
do_min(batch: &RecordBatch) -> Result<Option<ScalarValue>>1839 fn do_min(batch: &RecordBatch) -> Result<Option<ScalarValue>> {
1840 let min = min(col(0, &batch.schema()));
1841 let accum = min.create_accumulator();
1842 let input = min.evaluate_input(batch)?;
1843 let mut accum = accum.borrow_mut();
1844 for i in 0..batch.num_rows() {
1845 accum.accumulate_scalar(get_scalar_value(&input, i)?)?;
1846 }
1847 accum.get_value()
1848 }
1849
do_count(batch: &RecordBatch) -> Result<Option<ScalarValue>>1850 fn do_count(batch: &RecordBatch) -> Result<Option<ScalarValue>> {
1851 let count = count(col(0, &batch.schema()));
1852 let accum = count.create_accumulator();
1853 let input = count.evaluate_input(batch)?;
1854 let mut accum = accum.borrow_mut();
1855 for i in 0..batch.num_rows() {
1856 accum.accumulate_scalar(get_scalar_value(&input, i)?)?;
1857 }
1858 accum.get_value()
1859 }
1860
do_avg(batch: &RecordBatch) -> Result<Option<ScalarValue>>1861 fn do_avg(batch: &RecordBatch) -> Result<Option<ScalarValue>> {
1862 let avg = avg(col(0, &batch.schema()));
1863 let accum = avg.create_accumulator();
1864 let input = avg.evaluate_input(batch)?;
1865 let mut accum = accum.borrow_mut();
1866 for i in 0..batch.num_rows() {
1867 accum.accumulate_scalar(get_scalar_value(&input, i)?)?;
1868 }
1869 accum.get_value()
1870 }
1871
1872 #[test]
plus_op() -> Result<()>1873 fn plus_op() -> Result<()> {
1874 let schema = Schema::new(vec![
1875 Field::new("a", DataType::Int32, false),
1876 Field::new("b", DataType::Int32, false),
1877 ]);
1878 let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1879 let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
1880
1881 apply_arithmetic::<Int32Type>(
1882 Arc::new(schema),
1883 vec![Arc::new(a), Arc::new(b)],
1884 Operator::Plus,
1885 Int32Array::from(vec![2, 4, 7, 12, 21]),
1886 )?;
1887
1888 Ok(())
1889 }
1890
1891 #[test]
minus_op() -> Result<()>1892 fn minus_op() -> Result<()> {
1893 let schema = Arc::new(Schema::new(vec![
1894 Field::new("a", DataType::Int32, false),
1895 Field::new("b", DataType::Int32, false),
1896 ]));
1897 let a = Arc::new(Int32Array::from(vec![1, 2, 4, 8, 16]));
1898 let b = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
1899
1900 apply_arithmetic::<Int32Type>(
1901 schema.clone(),
1902 vec![a.clone(), b.clone()],
1903 Operator::Minus,
1904 Int32Array::from(vec![0, 0, 1, 4, 11]),
1905 )?;
1906
1907 // should handle have negative values in result (for signed)
1908 apply_arithmetic::<Int32Type>(
1909 schema.clone(),
1910 vec![b.clone(), a.clone()],
1911 Operator::Minus,
1912 Int32Array::from(vec![0, 0, -1, -4, -11]),
1913 )?;
1914
1915 Ok(())
1916 }
1917
1918 #[test]
multiply_op() -> Result<()>1919 fn multiply_op() -> Result<()> {
1920 let schema = Arc::new(Schema::new(vec![
1921 Field::new("a", DataType::Int32, false),
1922 Field::new("b", DataType::Int32, false),
1923 ]));
1924 let a = Arc::new(Int32Array::from(vec![4, 8, 16, 32, 64]));
1925 let b = Arc::new(Int32Array::from(vec![2, 4, 8, 16, 32]));
1926
1927 apply_arithmetic::<Int32Type>(
1928 schema,
1929 vec![a, b],
1930 Operator::Multiply,
1931 Int32Array::from(vec![8, 32, 128, 512, 2048]),
1932 )?;
1933
1934 Ok(())
1935 }
1936
1937 #[test]
divide_op() -> Result<()>1938 fn divide_op() -> Result<()> {
1939 let schema = Arc::new(Schema::new(vec![
1940 Field::new("a", DataType::Int32, false),
1941 Field::new("b", DataType::Int32, false),
1942 ]));
1943 let a = Arc::new(Int32Array::from(vec![8, 32, 128, 512, 2048]));
1944 let b = Arc::new(Int32Array::from(vec![2, 4, 8, 16, 32]));
1945
1946 apply_arithmetic::<Int32Type>(
1947 schema,
1948 vec![a, b],
1949 Operator::Divide,
1950 Int32Array::from(vec![4, 8, 16, 32, 64]),
1951 )?;
1952
1953 Ok(())
1954 }
1955
apply_arithmetic<T: ArrowNumericType>( schema: Arc<Schema>, data: Vec<ArrayRef>, op: Operator, expected: PrimitiveArray<T>, ) -> Result<()>1956 fn apply_arithmetic<T: ArrowNumericType>(
1957 schema: Arc<Schema>,
1958 data: Vec<ArrayRef>,
1959 op: Operator,
1960 expected: PrimitiveArray<T>,
1961 ) -> Result<()> {
1962 let arithmetic_op = binary(col(0, schema.as_ref()), op, col(1, schema.as_ref()));
1963 let batch = RecordBatch::try_new(schema, data)?;
1964 let result = arithmetic_op.evaluate(&batch)?;
1965
1966 assert_array_eq::<T>(expected, result);
1967
1968 Ok(())
1969 }
1970
assert_array_eq<T: ArrowNumericType>( expected: PrimitiveArray<T>, actual: ArrayRef, )1971 fn assert_array_eq<T: ArrowNumericType>(
1972 expected: PrimitiveArray<T>,
1973 actual: ArrayRef,
1974 ) {
1975 let actual = actual
1976 .as_any()
1977 .downcast_ref::<PrimitiveArray<T>>()
1978 .expect("Actual array should unwrap to type of expected array");
1979
1980 for i in 0..expected.len() {
1981 assert_eq!(expected.value(i), actual.value(i));
1982 }
1983 }
1984
1985 #[test]
neg_op() -> Result<()>1986 fn neg_op() -> Result<()> {
1987 let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]);
1988 let a = BooleanArray::from(vec![true, false]);
1989 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
1990
1991 // expression: "!a"
1992 let lt = not(col(0, &schema));
1993 let result = lt.evaluate(&batch)?;
1994 assert_eq!(result.len(), 2);
1995
1996 let expected = vec![false, true];
1997 let result = result
1998 .as_any()
1999 .downcast_ref::<BooleanArray>()
2000 .expect("failed to downcast to BooleanArray");
2001 for i in 0..2 {
2002 assert_eq!(result.value(i), expected[i]);
2003 }
2004
2005 Ok(())
2006 }
2007 }
2008