1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 //! Implementations for DISTINCT expressions, e.g. `COUNT(DISTINCT c)`
19
20 use std::any::Any;
21 use std::convert::TryFrom;
22 use std::fmt::Debug;
23 use std::hash::Hash;
24 use std::sync::Arc;
25
26 use arrow::datatypes::{DataType, Field};
27
28 use ahash::RandomState;
29 use std::collections::HashSet;
30
31 use crate::error::{DataFusionError, Result};
32 use crate::physical_plan::group_scalar::GroupByScalar;
33 use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr};
34 use crate::scalar::ScalarValue;
35
36 #[derive(Debug, PartialEq, Eq, Hash, Clone)]
37 struct DistinctScalarValues(Vec<GroupByScalar>);
38
format_state_name(name: &str, state_name: &str) -> String39 fn format_state_name(name: &str, state_name: &str) -> String {
40 format!("{}[{}]", name, state_name)
41 }
42
43 /// Expression for a COUNT(DISTINCT) aggregation.
44 #[derive(Debug)]
45 pub struct DistinctCount {
46 /// Column name
47 name: String,
48 /// The DataType for the final count
49 data_type: DataType,
50 /// The DataType for each input argument
51 input_data_types: Vec<DataType>,
52 /// The input arguments
53 exprs: Vec<Arc<dyn PhysicalExpr>>,
54 }
55
56 impl DistinctCount {
57 /// Create a new COUNT(DISTINCT) aggregate function.
new( input_data_types: Vec<DataType>, exprs: Vec<Arc<dyn PhysicalExpr>>, name: String, data_type: DataType, ) -> Self58 pub fn new(
59 input_data_types: Vec<DataType>,
60 exprs: Vec<Arc<dyn PhysicalExpr>>,
61 name: String,
62 data_type: DataType,
63 ) -> Self {
64 Self {
65 input_data_types,
66 exprs,
67 name,
68 data_type,
69 }
70 }
71 }
72
73 impl AggregateExpr for DistinctCount {
74 /// Return a reference to Any that can be used for downcasting
as_any(&self) -> &dyn Any75 fn as_any(&self) -> &dyn Any {
76 self
77 }
78
field(&self) -> Result<Field>79 fn field(&self) -> Result<Field> {
80 Ok(Field::new(&self.name, self.data_type.clone(), true))
81 }
82
state_fields(&self) -> Result<Vec<Field>>83 fn state_fields(&self) -> Result<Vec<Field>> {
84 Ok(self
85 .input_data_types
86 .iter()
87 .map(|data_type| {
88 Field::new(
89 &format_state_name(&self.name, "count distinct"),
90 DataType::List(Box::new(Field::new("item", data_type.clone(), true))),
91 false,
92 )
93 })
94 .collect::<Vec<_>>())
95 }
96
expressions(&self) -> Vec<Arc<dyn PhysicalExpr>>97 fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
98 self.exprs.clone()
99 }
100
create_accumulator(&self) -> Result<Box<dyn Accumulator>>101 fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
102 Ok(Box::new(DistinctCountAccumulator {
103 values: HashSet::default(),
104 data_types: self.input_data_types.clone(),
105 count_data_type: self.data_type.clone(),
106 }))
107 }
108 }
109
110 #[derive(Debug)]
111 struct DistinctCountAccumulator {
112 values: HashSet<DistinctScalarValues, RandomState>,
113 data_types: Vec<DataType>,
114 count_data_type: DataType,
115 }
116
117 impl Accumulator for DistinctCountAccumulator {
update(&mut self, values: &[ScalarValue]) -> Result<()>118 fn update(&mut self, values: &[ScalarValue]) -> Result<()> {
119 // If a row has a NULL, it is not included in the final count.
120 if !values.iter().any(|v| v.is_null()) {
121 self.values.insert(DistinctScalarValues(
122 values
123 .iter()
124 .map(GroupByScalar::try_from)
125 .collect::<Result<Vec<_>>>()?,
126 ));
127 }
128
129 Ok(())
130 }
131
merge(&mut self, states: &[ScalarValue]) -> Result<()>132 fn merge(&mut self, states: &[ScalarValue]) -> Result<()> {
133 if states.is_empty() {
134 return Ok(());
135 }
136
137 let col_values = states
138 .iter()
139 .map(|state| match state {
140 ScalarValue::List(Some(values), _) => Ok(values),
141 _ => Err(DataFusionError::Internal(format!(
142 "Unexpected accumulator state {:?}",
143 state
144 ))),
145 })
146 .collect::<Result<Vec<_>>>()?;
147
148 (0..col_values[0].len()).try_for_each(|row_index| {
149 let row_values = col_values
150 .iter()
151 .map(|col| col[row_index].clone())
152 .collect::<Vec<_>>();
153 self.update(&row_values)
154 })
155 }
156
state(&self) -> Result<Vec<ScalarValue>>157 fn state(&self) -> Result<Vec<ScalarValue>> {
158 let mut cols_out = self
159 .data_types
160 .iter()
161 .map(|data_type| ScalarValue::List(Some(Vec::new()), data_type.clone()))
162 .collect::<Vec<_>>();
163
164 let mut cols_vec = cols_out
165 .iter_mut()
166 .map(|c| match c {
167 ScalarValue::List(Some(ref mut v), _) => v,
168 _ => unreachable!(),
169 })
170 .collect::<Vec<_>>();
171
172 self.values.iter().for_each(|distinct_values| {
173 distinct_values.0.iter().enumerate().for_each(
174 |(col_index, distinct_value)| {
175 cols_vec[col_index].push(ScalarValue::from(distinct_value));
176 },
177 )
178 });
179
180 Ok(cols_out)
181 }
182
evaluate(&self) -> Result<ScalarValue>183 fn evaluate(&self) -> Result<ScalarValue> {
184 match &self.count_data_type {
185 DataType::UInt64 => Ok(ScalarValue::UInt64(Some(self.values.len() as u64))),
186 t => Err(DataFusionError::Internal(format!(
187 "Invalid data type {:?} for count distinct aggregation",
188 t
189 ))),
190 }
191 }
192 }
193
194 #[cfg(test)]
195 mod tests {
196 use super::*;
197
198 use arrow::array::ArrayRef;
199 use arrow::array::{
200 Int16Array, Int32Array, Int64Array, Int8Array, ListArray, UInt16Array,
201 UInt32Array, UInt64Array, UInt8Array,
202 };
203 use arrow::array::{Int32Builder, ListBuilder, UInt64Builder};
204 use arrow::datatypes::DataType;
205
206 macro_rules! build_list {
207 ($LISTS:expr, $BUILDER_TYPE:ident) => {{
208 let mut builder = ListBuilder::new($BUILDER_TYPE::new(0));
209 for list in $LISTS.iter() {
210 match list {
211 Some(values) => {
212 for value in values.iter() {
213 match value {
214 Some(v) => builder.values().append_value((*v).into())?,
215 None => builder.values().append_null()?,
216 }
217 }
218
219 builder.append(true)?;
220 }
221 None => {
222 builder.append(false)?;
223 }
224 }
225 }
226
227 let array = Arc::new(builder.finish()) as ArrayRef;
228
229 Ok(array) as Result<ArrayRef>
230 }};
231 }
232
233 macro_rules! state_to_vec {
234 ($LIST:expr, $DATA_TYPE:ident, $PRIM_TY:ty) => {{
235 match $LIST {
236 ScalarValue::List(_, data_type) => match data_type {
237 DataType::$DATA_TYPE => (),
238 _ => panic!("Unexpected DataType for list"),
239 },
240 _ => panic!("Expected a ScalarValue::List"),
241 }
242
243 match $LIST {
244 ScalarValue::List(None, _) => None,
245 ScalarValue::List(Some(scalar_values), _) => {
246 let vec = scalar_values
247 .iter()
248 .map(|scalar_value| match scalar_value {
249 ScalarValue::$DATA_TYPE(value) => *value,
250 _ => panic!("Unexpected ScalarValue variant"),
251 })
252 .collect::<Vec<Option<$PRIM_TY>>>();
253
254 Some(vec)
255 }
256 _ => unreachable!(),
257 }
258 }};
259 }
260
collect_states<T: Ord + Clone, S: Ord + Clone>( state1: &[Option<T>], state2: &[Option<S>], ) -> Vec<(Option<T>, Option<S>)>261 fn collect_states<T: Ord + Clone, S: Ord + Clone>(
262 state1: &[Option<T>],
263 state2: &[Option<S>],
264 ) -> Vec<(Option<T>, Option<S>)> {
265 let mut states = state1
266 .iter()
267 .zip(state2.iter())
268 .map(|(l, r)| (l.clone(), r.clone()))
269 .collect::<Vec<(Option<T>, Option<S>)>>();
270 states.sort();
271 states
272 }
273
run_update_batch(arrays: &[ArrayRef]) -> Result<(Vec<ScalarValue>, ScalarValue)>274 fn run_update_batch(arrays: &[ArrayRef]) -> Result<(Vec<ScalarValue>, ScalarValue)> {
275 let agg = DistinctCount::new(
276 arrays
277 .iter()
278 .map(|a| a.data_type().clone())
279 .collect::<Vec<_>>(),
280 vec![],
281 String::from("__col_name__"),
282 DataType::UInt64,
283 );
284
285 let mut accum = agg.create_accumulator()?;
286 accum.update_batch(arrays)?;
287
288 Ok((accum.state()?, accum.evaluate()?))
289 }
290
run_update( data_types: &[DataType], rows: &[Vec<ScalarValue>], ) -> Result<(Vec<ScalarValue>, ScalarValue)>291 fn run_update(
292 data_types: &[DataType],
293 rows: &[Vec<ScalarValue>],
294 ) -> Result<(Vec<ScalarValue>, ScalarValue)> {
295 let agg = DistinctCount::new(
296 data_types.to_vec(),
297 vec![],
298 String::from("__col_name__"),
299 DataType::UInt64,
300 );
301
302 let mut accum = agg.create_accumulator()?;
303
304 for row in rows.iter() {
305 accum.update(row)?
306 }
307
308 Ok((accum.state()?, accum.evaluate()?))
309 }
310
run_merge_batch(arrays: &[ArrayRef]) -> Result<(Vec<ScalarValue>, ScalarValue)>311 fn run_merge_batch(arrays: &[ArrayRef]) -> Result<(Vec<ScalarValue>, ScalarValue)> {
312 let agg = DistinctCount::new(
313 arrays
314 .iter()
315 .map(|a| a.as_any().downcast_ref::<ListArray>().unwrap())
316 .map(|a| a.values().data_type().clone())
317 .collect::<Vec<_>>(),
318 vec![],
319 String::from("__col_name__"),
320 DataType::UInt64,
321 );
322
323 let mut accum = agg.create_accumulator()?;
324 accum.merge_batch(arrays)?;
325
326 Ok((accum.state()?, accum.evaluate()?))
327 }
328
329 macro_rules! test_count_distinct_update_batch_numeric {
330 ($ARRAY_TYPE:ident, $DATA_TYPE:ident, $PRIM_TYPE:ty) => {{
331 let values: Vec<Option<$PRIM_TYPE>> = vec![
332 Some(1),
333 Some(1),
334 None,
335 Some(3),
336 Some(2),
337 None,
338 Some(2),
339 Some(3),
340 Some(1),
341 ];
342
343 let arrays = vec![Arc::new($ARRAY_TYPE::from(values)) as ArrayRef];
344
345 let (states, result) = run_update_batch(&arrays)?;
346
347 let mut state_vec =
348 state_to_vec!(&states[0], $DATA_TYPE, $PRIM_TYPE).unwrap();
349 state_vec.sort();
350
351 assert_eq!(states.len(), 1);
352 assert_eq!(state_vec, vec![Some(1), Some(2), Some(3)]);
353 assert_eq!(result, ScalarValue::UInt64(Some(3)));
354
355 Ok(())
356 }};
357 }
358
359 #[test]
count_distinct_update_batch_i8() -> Result<()>360 fn count_distinct_update_batch_i8() -> Result<()> {
361 test_count_distinct_update_batch_numeric!(Int8Array, Int8, i8)
362 }
363
364 #[test]
count_distinct_update_batch_i16() -> Result<()>365 fn count_distinct_update_batch_i16() -> Result<()> {
366 test_count_distinct_update_batch_numeric!(Int16Array, Int16, i16)
367 }
368
369 #[test]
count_distinct_update_batch_i32() -> Result<()>370 fn count_distinct_update_batch_i32() -> Result<()> {
371 test_count_distinct_update_batch_numeric!(Int32Array, Int32, i32)
372 }
373
374 #[test]
count_distinct_update_batch_i64() -> Result<()>375 fn count_distinct_update_batch_i64() -> Result<()> {
376 test_count_distinct_update_batch_numeric!(Int64Array, Int64, i64)
377 }
378
379 #[test]
count_distinct_update_batch_u8() -> Result<()>380 fn count_distinct_update_batch_u8() -> Result<()> {
381 test_count_distinct_update_batch_numeric!(UInt8Array, UInt8, u8)
382 }
383
384 #[test]
count_distinct_update_batch_u16() -> Result<()>385 fn count_distinct_update_batch_u16() -> Result<()> {
386 test_count_distinct_update_batch_numeric!(UInt16Array, UInt16, u16)
387 }
388
389 #[test]
count_distinct_update_batch_u32() -> Result<()>390 fn count_distinct_update_batch_u32() -> Result<()> {
391 test_count_distinct_update_batch_numeric!(UInt32Array, UInt32, u32)
392 }
393
394 #[test]
count_distinct_update_batch_u64() -> Result<()>395 fn count_distinct_update_batch_u64() -> Result<()> {
396 test_count_distinct_update_batch_numeric!(UInt64Array, UInt64, u64)
397 }
398
399 #[test]
count_distinct_update_batch_all_nulls() -> Result<()>400 fn count_distinct_update_batch_all_nulls() -> Result<()> {
401 let arrays = vec![Arc::new(Int32Array::from(
402 vec![None, None, None, None] as Vec<Option<i32>>
403 )) as ArrayRef];
404
405 let (states, result) = run_update_batch(&arrays)?;
406
407 assert_eq!(states.len(), 1);
408 assert_eq!(state_to_vec!(&states[0], Int32, i32), Some(vec![]));
409 assert_eq!(result, ScalarValue::UInt64(Some(0)));
410
411 Ok(())
412 }
413
414 #[test]
count_distinct_update_batch_empty() -> Result<()>415 fn count_distinct_update_batch_empty() -> Result<()> {
416 let arrays =
417 vec![Arc::new(Int32Array::from(vec![] as Vec<Option<i32>>)) as ArrayRef];
418
419 let (states, result) = run_update_batch(&arrays)?;
420
421 assert_eq!(states.len(), 1);
422 assert_eq!(state_to_vec!(&states[0], Int32, i32), Some(vec![]));
423 assert_eq!(result, ScalarValue::UInt64(Some(0)));
424
425 Ok(())
426 }
427
428 #[test]
count_distinct_update_batch_multiple_columns() -> Result<()>429 fn count_distinct_update_batch_multiple_columns() -> Result<()> {
430 let array_int8: ArrayRef = Arc::new(Int8Array::from(vec![1, 1, 2]));
431 let array_int16: ArrayRef = Arc::new(Int16Array::from(vec![3, 3, 4]));
432 let arrays = vec![array_int8, array_int16];
433
434 let (states, result) = run_update_batch(&arrays)?;
435
436 let state_vec1 = state_to_vec!(&states[0], Int8, i8).unwrap();
437 let state_vec2 = state_to_vec!(&states[1], Int16, i16).unwrap();
438 let state_pairs = collect_states::<i8, i16>(&state_vec1, &state_vec2);
439
440 assert_eq!(states.len(), 2);
441 assert_eq!(
442 state_pairs,
443 vec![(Some(1_i8), Some(3_i16)), (Some(2_i8), Some(4_i16))]
444 );
445
446 assert_eq!(result, ScalarValue::UInt64(Some(2)));
447
448 Ok(())
449 }
450
451 #[test]
count_distinct_update() -> Result<()>452 fn count_distinct_update() -> Result<()> {
453 let (states, result) = run_update(
454 &[DataType::Int32, DataType::UInt64],
455 &[
456 vec![ScalarValue::Int32(Some(-1)), ScalarValue::UInt64(Some(5))],
457 vec![ScalarValue::Int32(Some(5)), ScalarValue::UInt64(Some(1))],
458 vec![ScalarValue::Int32(Some(-1)), ScalarValue::UInt64(Some(5))],
459 vec![ScalarValue::Int32(Some(5)), ScalarValue::UInt64(Some(1))],
460 vec![ScalarValue::Int32(Some(-1)), ScalarValue::UInt64(Some(6))],
461 vec![ScalarValue::Int32(Some(-1)), ScalarValue::UInt64(Some(7))],
462 vec![ScalarValue::Int32(Some(2)), ScalarValue::UInt64(Some(7))],
463 ],
464 )?;
465
466 let state_vec1 = state_to_vec!(&states[0], Int32, i32).unwrap();
467 let state_vec2 = state_to_vec!(&states[1], UInt64, u64).unwrap();
468 let state_pairs = collect_states::<i32, u64>(&state_vec1, &state_vec2);
469
470 assert_eq!(states.len(), 2);
471 assert_eq!(
472 state_pairs,
473 vec![
474 (Some(-1_i32), Some(5_u64)),
475 (Some(-1_i32), Some(6_u64)),
476 (Some(-1_i32), Some(7_u64)),
477 (Some(2_i32), Some(7_u64)),
478 (Some(5_i32), Some(1_u64)),
479 ]
480 );
481 assert_eq!(result, ScalarValue::UInt64(Some(5)));
482
483 Ok(())
484 }
485
486 #[test]
count_distinct_update_with_nulls() -> Result<()>487 fn count_distinct_update_with_nulls() -> Result<()> {
488 let (states, result) = run_update(
489 &[DataType::Int32, DataType::UInt64],
490 &[
491 // None of these updates contains a None, so these are accumulated.
492 vec![ScalarValue::Int32(Some(-1)), ScalarValue::UInt64(Some(5))],
493 vec![ScalarValue::Int32(Some(-1)), ScalarValue::UInt64(Some(5))],
494 vec![ScalarValue::Int32(Some(-2)), ScalarValue::UInt64(Some(5))],
495 // Each of these updates contains at least one None, so these
496 // won't be accumulated.
497 vec![ScalarValue::Int32(Some(-1)), ScalarValue::UInt64(None)],
498 vec![ScalarValue::Int32(None), ScalarValue::UInt64(Some(5))],
499 vec![ScalarValue::Int32(None), ScalarValue::UInt64(None)],
500 ],
501 )?;
502
503 let state_vec1 = state_to_vec!(&states[0], Int32, i32).unwrap();
504 let state_vec2 = state_to_vec!(&states[1], UInt64, u64).unwrap();
505 let state_pairs = collect_states::<i32, u64>(&state_vec1, &state_vec2);
506
507 assert_eq!(states.len(), 2);
508 assert_eq!(
509 state_pairs,
510 vec![(Some(-2_i32), Some(5_u64)), (Some(-1_i32), Some(5_u64))]
511 );
512
513 assert_eq!(result, ScalarValue::UInt64(Some(2)));
514
515 Ok(())
516 }
517
518 #[test]
count_distinct_merge_batch() -> Result<()>519 fn count_distinct_merge_batch() -> Result<()> {
520 let state_in1 = build_list!(
521 vec![
522 Some(vec![Some(-1_i32), Some(-1_i32), Some(-2_i32), Some(-2_i32)]),
523 Some(vec![Some(-2_i32), Some(-3_i32)]),
524 ],
525 Int32Builder
526 )?;
527
528 let state_in2 = build_list!(
529 vec![
530 Some(vec![Some(5_u64), Some(6_u64), Some(5_u64), Some(7_u64)]),
531 Some(vec![Some(5_u64), Some(7_u64)]),
532 ],
533 UInt64Builder
534 )?;
535
536 let (states, result) = run_merge_batch(&[state_in1, state_in2])?;
537
538 let state_out_vec1 = state_to_vec!(&states[0], Int32, i32).unwrap();
539 let state_out_vec2 = state_to_vec!(&states[1], UInt64, u64).unwrap();
540 let state_pairs = collect_states::<i32, u64>(&state_out_vec1, &state_out_vec2);
541
542 assert_eq!(
543 state_pairs,
544 vec![
545 (Some(-3_i32), Some(7_u64)),
546 (Some(-2_i32), Some(5_u64)),
547 (Some(-2_i32), Some(7_u64)),
548 (Some(-1_i32), Some(5_u64)),
549 (Some(-1_i32), Some(6_u64)),
550 ]
551 );
552
553 assert_eq!(result, ScalarValue::UInt64(Some(5)));
554
555 Ok(())
556 }
557 }
558