1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements.  See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership.  The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License.  You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied.  See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17 
18 //! Implementations for DISTINCT expressions, e.g. `COUNT(DISTINCT c)`
19 
20 use std::any::Any;
21 use std::convert::TryFrom;
22 use std::fmt::Debug;
23 use std::hash::Hash;
24 use std::sync::Arc;
25 
26 use arrow::datatypes::{DataType, Field};
27 
28 use ahash::RandomState;
29 use std::collections::HashSet;
30 
31 use crate::error::{DataFusionError, Result};
32 use crate::physical_plan::group_scalar::GroupByScalar;
33 use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr};
34 use crate::scalar::ScalarValue;
35 
36 #[derive(Debug, PartialEq, Eq, Hash, Clone)]
37 struct DistinctScalarValues(Vec<GroupByScalar>);
38 
format_state_name(name: &str, state_name: &str) -> String39 fn format_state_name(name: &str, state_name: &str) -> String {
40     format!("{}[{}]", name, state_name)
41 }
42 
43 /// Expression for a COUNT(DISTINCT) aggregation.
44 #[derive(Debug)]
45 pub struct DistinctCount {
46     /// Column name
47     name: String,
48     /// The DataType for the final count
49     data_type: DataType,
50     /// The DataType for each input argument
51     input_data_types: Vec<DataType>,
52     /// The input arguments
53     exprs: Vec<Arc<dyn PhysicalExpr>>,
54 }
55 
56 impl DistinctCount {
57     /// Create a new COUNT(DISTINCT) aggregate function.
new( input_data_types: Vec<DataType>, exprs: Vec<Arc<dyn PhysicalExpr>>, name: String, data_type: DataType, ) -> Self58     pub fn new(
59         input_data_types: Vec<DataType>,
60         exprs: Vec<Arc<dyn PhysicalExpr>>,
61         name: String,
62         data_type: DataType,
63     ) -> Self {
64         Self {
65             input_data_types,
66             exprs,
67             name,
68             data_type,
69         }
70     }
71 }
72 
73 impl AggregateExpr for DistinctCount {
74     /// Return a reference to Any that can be used for downcasting
as_any(&self) -> &dyn Any75     fn as_any(&self) -> &dyn Any {
76         self
77     }
78 
field(&self) -> Result<Field>79     fn field(&self) -> Result<Field> {
80         Ok(Field::new(&self.name, self.data_type.clone(), true))
81     }
82 
state_fields(&self) -> Result<Vec<Field>>83     fn state_fields(&self) -> Result<Vec<Field>> {
84         Ok(self
85             .input_data_types
86             .iter()
87             .map(|data_type| {
88                 Field::new(
89                     &format_state_name(&self.name, "count distinct"),
90                     DataType::List(Box::new(Field::new("item", data_type.clone(), true))),
91                     false,
92                 )
93             })
94             .collect::<Vec<_>>())
95     }
96 
expressions(&self) -> Vec<Arc<dyn PhysicalExpr>>97     fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
98         self.exprs.clone()
99     }
100 
create_accumulator(&self) -> Result<Box<dyn Accumulator>>101     fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
102         Ok(Box::new(DistinctCountAccumulator {
103             values: HashSet::default(),
104             data_types: self.input_data_types.clone(),
105             count_data_type: self.data_type.clone(),
106         }))
107     }
108 }
109 
110 #[derive(Debug)]
111 struct DistinctCountAccumulator {
112     values: HashSet<DistinctScalarValues, RandomState>,
113     data_types: Vec<DataType>,
114     count_data_type: DataType,
115 }
116 
117 impl Accumulator for DistinctCountAccumulator {
update(&mut self, values: &[ScalarValue]) -> Result<()>118     fn update(&mut self, values: &[ScalarValue]) -> Result<()> {
119         // If a row has a NULL, it is not included in the final count.
120         if !values.iter().any(|v| v.is_null()) {
121             self.values.insert(DistinctScalarValues(
122                 values
123                     .iter()
124                     .map(GroupByScalar::try_from)
125                     .collect::<Result<Vec<_>>>()?,
126             ));
127         }
128 
129         Ok(())
130     }
131 
merge(&mut self, states: &[ScalarValue]) -> Result<()>132     fn merge(&mut self, states: &[ScalarValue]) -> Result<()> {
133         if states.is_empty() {
134             return Ok(());
135         }
136 
137         let col_values = states
138             .iter()
139             .map(|state| match state {
140                 ScalarValue::List(Some(values), _) => Ok(values),
141                 _ => Err(DataFusionError::Internal(format!(
142                     "Unexpected accumulator state {:?}",
143                     state
144                 ))),
145             })
146             .collect::<Result<Vec<_>>>()?;
147 
148         (0..col_values[0].len()).try_for_each(|row_index| {
149             let row_values = col_values
150                 .iter()
151                 .map(|col| col[row_index].clone())
152                 .collect::<Vec<_>>();
153             self.update(&row_values)
154         })
155     }
156 
state(&self) -> Result<Vec<ScalarValue>>157     fn state(&self) -> Result<Vec<ScalarValue>> {
158         let mut cols_out = self
159             .data_types
160             .iter()
161             .map(|data_type| ScalarValue::List(Some(Vec::new()), data_type.clone()))
162             .collect::<Vec<_>>();
163 
164         let mut cols_vec = cols_out
165             .iter_mut()
166             .map(|c| match c {
167                 ScalarValue::List(Some(ref mut v), _) => v,
168                 _ => unreachable!(),
169             })
170             .collect::<Vec<_>>();
171 
172         self.values.iter().for_each(|distinct_values| {
173             distinct_values.0.iter().enumerate().for_each(
174                 |(col_index, distinct_value)| {
175                     cols_vec[col_index].push(ScalarValue::from(distinct_value));
176                 },
177             )
178         });
179 
180         Ok(cols_out)
181     }
182 
evaluate(&self) -> Result<ScalarValue>183     fn evaluate(&self) -> Result<ScalarValue> {
184         match &self.count_data_type {
185             DataType::UInt64 => Ok(ScalarValue::UInt64(Some(self.values.len() as u64))),
186             t => Err(DataFusionError::Internal(format!(
187                 "Invalid data type {:?} for count distinct aggregation",
188                 t
189             ))),
190         }
191     }
192 }
193 
194 #[cfg(test)]
195 mod tests {
196     use super::*;
197 
198     use arrow::array::ArrayRef;
199     use arrow::array::{
200         Int16Array, Int32Array, Int64Array, Int8Array, ListArray, UInt16Array,
201         UInt32Array, UInt64Array, UInt8Array,
202     };
203     use arrow::array::{Int32Builder, ListBuilder, UInt64Builder};
204     use arrow::datatypes::DataType;
205 
206     macro_rules! build_list {
207         ($LISTS:expr, $BUILDER_TYPE:ident) => {{
208             let mut builder = ListBuilder::new($BUILDER_TYPE::new(0));
209             for list in $LISTS.iter() {
210                 match list {
211                     Some(values) => {
212                         for value in values.iter() {
213                             match value {
214                                 Some(v) => builder.values().append_value((*v).into())?,
215                                 None => builder.values().append_null()?,
216                             }
217                         }
218 
219                         builder.append(true)?;
220                     }
221                     None => {
222                         builder.append(false)?;
223                     }
224                 }
225             }
226 
227             let array = Arc::new(builder.finish()) as ArrayRef;
228 
229             Ok(array) as Result<ArrayRef>
230         }};
231     }
232 
233     macro_rules! state_to_vec {
234         ($LIST:expr, $DATA_TYPE:ident, $PRIM_TY:ty) => {{
235             match $LIST {
236                 ScalarValue::List(_, data_type) => match data_type {
237                     DataType::$DATA_TYPE => (),
238                     _ => panic!("Unexpected DataType for list"),
239                 },
240                 _ => panic!("Expected a ScalarValue::List"),
241             }
242 
243             match $LIST {
244                 ScalarValue::List(None, _) => None,
245                 ScalarValue::List(Some(scalar_values), _) => {
246                     let vec = scalar_values
247                         .iter()
248                         .map(|scalar_value| match scalar_value {
249                             ScalarValue::$DATA_TYPE(value) => *value,
250                             _ => panic!("Unexpected ScalarValue variant"),
251                         })
252                         .collect::<Vec<Option<$PRIM_TY>>>();
253 
254                     Some(vec)
255                 }
256                 _ => unreachable!(),
257             }
258         }};
259     }
260 
collect_states<T: Ord + Clone, S: Ord + Clone>( state1: &[Option<T>], state2: &[Option<S>], ) -> Vec<(Option<T>, Option<S>)>261     fn collect_states<T: Ord + Clone, S: Ord + Clone>(
262         state1: &[Option<T>],
263         state2: &[Option<S>],
264     ) -> Vec<(Option<T>, Option<S>)> {
265         let mut states = state1
266             .iter()
267             .zip(state2.iter())
268             .map(|(l, r)| (l.clone(), r.clone()))
269             .collect::<Vec<(Option<T>, Option<S>)>>();
270         states.sort();
271         states
272     }
273 
run_update_batch(arrays: &[ArrayRef]) -> Result<(Vec<ScalarValue>, ScalarValue)>274     fn run_update_batch(arrays: &[ArrayRef]) -> Result<(Vec<ScalarValue>, ScalarValue)> {
275         let agg = DistinctCount::new(
276             arrays
277                 .iter()
278                 .map(|a| a.data_type().clone())
279                 .collect::<Vec<_>>(),
280             vec![],
281             String::from("__col_name__"),
282             DataType::UInt64,
283         );
284 
285         let mut accum = agg.create_accumulator()?;
286         accum.update_batch(arrays)?;
287 
288         Ok((accum.state()?, accum.evaluate()?))
289     }
290 
run_update( data_types: &[DataType], rows: &[Vec<ScalarValue>], ) -> Result<(Vec<ScalarValue>, ScalarValue)>291     fn run_update(
292         data_types: &[DataType],
293         rows: &[Vec<ScalarValue>],
294     ) -> Result<(Vec<ScalarValue>, ScalarValue)> {
295         let agg = DistinctCount::new(
296             data_types.to_vec(),
297             vec![],
298             String::from("__col_name__"),
299             DataType::UInt64,
300         );
301 
302         let mut accum = agg.create_accumulator()?;
303 
304         for row in rows.iter() {
305             accum.update(row)?
306         }
307 
308         Ok((accum.state()?, accum.evaluate()?))
309     }
310 
run_merge_batch(arrays: &[ArrayRef]) -> Result<(Vec<ScalarValue>, ScalarValue)>311     fn run_merge_batch(arrays: &[ArrayRef]) -> Result<(Vec<ScalarValue>, ScalarValue)> {
312         let agg = DistinctCount::new(
313             arrays
314                 .iter()
315                 .map(|a| a.as_any().downcast_ref::<ListArray>().unwrap())
316                 .map(|a| a.values().data_type().clone())
317                 .collect::<Vec<_>>(),
318             vec![],
319             String::from("__col_name__"),
320             DataType::UInt64,
321         );
322 
323         let mut accum = agg.create_accumulator()?;
324         accum.merge_batch(arrays)?;
325 
326         Ok((accum.state()?, accum.evaluate()?))
327     }
328 
329     macro_rules! test_count_distinct_update_batch_numeric {
330         ($ARRAY_TYPE:ident, $DATA_TYPE:ident, $PRIM_TYPE:ty) => {{
331             let values: Vec<Option<$PRIM_TYPE>> = vec![
332                 Some(1),
333                 Some(1),
334                 None,
335                 Some(3),
336                 Some(2),
337                 None,
338                 Some(2),
339                 Some(3),
340                 Some(1),
341             ];
342 
343             let arrays = vec![Arc::new($ARRAY_TYPE::from(values)) as ArrayRef];
344 
345             let (states, result) = run_update_batch(&arrays)?;
346 
347             let mut state_vec =
348                 state_to_vec!(&states[0], $DATA_TYPE, $PRIM_TYPE).unwrap();
349             state_vec.sort();
350 
351             assert_eq!(states.len(), 1);
352             assert_eq!(state_vec, vec![Some(1), Some(2), Some(3)]);
353             assert_eq!(result, ScalarValue::UInt64(Some(3)));
354 
355             Ok(())
356         }};
357     }
358 
359     #[test]
count_distinct_update_batch_i8() -> Result<()>360     fn count_distinct_update_batch_i8() -> Result<()> {
361         test_count_distinct_update_batch_numeric!(Int8Array, Int8, i8)
362     }
363 
364     #[test]
count_distinct_update_batch_i16() -> Result<()>365     fn count_distinct_update_batch_i16() -> Result<()> {
366         test_count_distinct_update_batch_numeric!(Int16Array, Int16, i16)
367     }
368 
369     #[test]
count_distinct_update_batch_i32() -> Result<()>370     fn count_distinct_update_batch_i32() -> Result<()> {
371         test_count_distinct_update_batch_numeric!(Int32Array, Int32, i32)
372     }
373 
374     #[test]
count_distinct_update_batch_i64() -> Result<()>375     fn count_distinct_update_batch_i64() -> Result<()> {
376         test_count_distinct_update_batch_numeric!(Int64Array, Int64, i64)
377     }
378 
379     #[test]
count_distinct_update_batch_u8() -> Result<()>380     fn count_distinct_update_batch_u8() -> Result<()> {
381         test_count_distinct_update_batch_numeric!(UInt8Array, UInt8, u8)
382     }
383 
384     #[test]
count_distinct_update_batch_u16() -> Result<()>385     fn count_distinct_update_batch_u16() -> Result<()> {
386         test_count_distinct_update_batch_numeric!(UInt16Array, UInt16, u16)
387     }
388 
389     #[test]
count_distinct_update_batch_u32() -> Result<()>390     fn count_distinct_update_batch_u32() -> Result<()> {
391         test_count_distinct_update_batch_numeric!(UInt32Array, UInt32, u32)
392     }
393 
394     #[test]
count_distinct_update_batch_u64() -> Result<()>395     fn count_distinct_update_batch_u64() -> Result<()> {
396         test_count_distinct_update_batch_numeric!(UInt64Array, UInt64, u64)
397     }
398 
399     #[test]
count_distinct_update_batch_all_nulls() -> Result<()>400     fn count_distinct_update_batch_all_nulls() -> Result<()> {
401         let arrays = vec![Arc::new(Int32Array::from(
402             vec![None, None, None, None] as Vec<Option<i32>>
403         )) as ArrayRef];
404 
405         let (states, result) = run_update_batch(&arrays)?;
406 
407         assert_eq!(states.len(), 1);
408         assert_eq!(state_to_vec!(&states[0], Int32, i32), Some(vec![]));
409         assert_eq!(result, ScalarValue::UInt64(Some(0)));
410 
411         Ok(())
412     }
413 
414     #[test]
count_distinct_update_batch_empty() -> Result<()>415     fn count_distinct_update_batch_empty() -> Result<()> {
416         let arrays =
417             vec![Arc::new(Int32Array::from(vec![] as Vec<Option<i32>>)) as ArrayRef];
418 
419         let (states, result) = run_update_batch(&arrays)?;
420 
421         assert_eq!(states.len(), 1);
422         assert_eq!(state_to_vec!(&states[0], Int32, i32), Some(vec![]));
423         assert_eq!(result, ScalarValue::UInt64(Some(0)));
424 
425         Ok(())
426     }
427 
428     #[test]
count_distinct_update_batch_multiple_columns() -> Result<()>429     fn count_distinct_update_batch_multiple_columns() -> Result<()> {
430         let array_int8: ArrayRef = Arc::new(Int8Array::from(vec![1, 1, 2]));
431         let array_int16: ArrayRef = Arc::new(Int16Array::from(vec![3, 3, 4]));
432         let arrays = vec![array_int8, array_int16];
433 
434         let (states, result) = run_update_batch(&arrays)?;
435 
436         let state_vec1 = state_to_vec!(&states[0], Int8, i8).unwrap();
437         let state_vec2 = state_to_vec!(&states[1], Int16, i16).unwrap();
438         let state_pairs = collect_states::<i8, i16>(&state_vec1, &state_vec2);
439 
440         assert_eq!(states.len(), 2);
441         assert_eq!(
442             state_pairs,
443             vec![(Some(1_i8), Some(3_i16)), (Some(2_i8), Some(4_i16))]
444         );
445 
446         assert_eq!(result, ScalarValue::UInt64(Some(2)));
447 
448         Ok(())
449     }
450 
451     #[test]
count_distinct_update() -> Result<()>452     fn count_distinct_update() -> Result<()> {
453         let (states, result) = run_update(
454             &[DataType::Int32, DataType::UInt64],
455             &[
456                 vec![ScalarValue::Int32(Some(-1)), ScalarValue::UInt64(Some(5))],
457                 vec![ScalarValue::Int32(Some(5)), ScalarValue::UInt64(Some(1))],
458                 vec![ScalarValue::Int32(Some(-1)), ScalarValue::UInt64(Some(5))],
459                 vec![ScalarValue::Int32(Some(5)), ScalarValue::UInt64(Some(1))],
460                 vec![ScalarValue::Int32(Some(-1)), ScalarValue::UInt64(Some(6))],
461                 vec![ScalarValue::Int32(Some(-1)), ScalarValue::UInt64(Some(7))],
462                 vec![ScalarValue::Int32(Some(2)), ScalarValue::UInt64(Some(7))],
463             ],
464         )?;
465 
466         let state_vec1 = state_to_vec!(&states[0], Int32, i32).unwrap();
467         let state_vec2 = state_to_vec!(&states[1], UInt64, u64).unwrap();
468         let state_pairs = collect_states::<i32, u64>(&state_vec1, &state_vec2);
469 
470         assert_eq!(states.len(), 2);
471         assert_eq!(
472             state_pairs,
473             vec![
474                 (Some(-1_i32), Some(5_u64)),
475                 (Some(-1_i32), Some(6_u64)),
476                 (Some(-1_i32), Some(7_u64)),
477                 (Some(2_i32), Some(7_u64)),
478                 (Some(5_i32), Some(1_u64)),
479             ]
480         );
481         assert_eq!(result, ScalarValue::UInt64(Some(5)));
482 
483         Ok(())
484     }
485 
486     #[test]
count_distinct_update_with_nulls() -> Result<()>487     fn count_distinct_update_with_nulls() -> Result<()> {
488         let (states, result) = run_update(
489             &[DataType::Int32, DataType::UInt64],
490             &[
491                 // None of these updates contains a None, so these are accumulated.
492                 vec![ScalarValue::Int32(Some(-1)), ScalarValue::UInt64(Some(5))],
493                 vec![ScalarValue::Int32(Some(-1)), ScalarValue::UInt64(Some(5))],
494                 vec![ScalarValue::Int32(Some(-2)), ScalarValue::UInt64(Some(5))],
495                 // Each of these updates contains at least one None, so these
496                 // won't be accumulated.
497                 vec![ScalarValue::Int32(Some(-1)), ScalarValue::UInt64(None)],
498                 vec![ScalarValue::Int32(None), ScalarValue::UInt64(Some(5))],
499                 vec![ScalarValue::Int32(None), ScalarValue::UInt64(None)],
500             ],
501         )?;
502 
503         let state_vec1 = state_to_vec!(&states[0], Int32, i32).unwrap();
504         let state_vec2 = state_to_vec!(&states[1], UInt64, u64).unwrap();
505         let state_pairs = collect_states::<i32, u64>(&state_vec1, &state_vec2);
506 
507         assert_eq!(states.len(), 2);
508         assert_eq!(
509             state_pairs,
510             vec![(Some(-2_i32), Some(5_u64)), (Some(-1_i32), Some(5_u64))]
511         );
512 
513         assert_eq!(result, ScalarValue::UInt64(Some(2)));
514 
515         Ok(())
516     }
517 
518     #[test]
count_distinct_merge_batch() -> Result<()>519     fn count_distinct_merge_batch() -> Result<()> {
520         let state_in1 = build_list!(
521             vec![
522                 Some(vec![Some(-1_i32), Some(-1_i32), Some(-2_i32), Some(-2_i32)]),
523                 Some(vec![Some(-2_i32), Some(-3_i32)]),
524             ],
525             Int32Builder
526         )?;
527 
528         let state_in2 = build_list!(
529             vec![
530                 Some(vec![Some(5_u64), Some(6_u64), Some(5_u64), Some(7_u64)]),
531                 Some(vec![Some(5_u64), Some(7_u64)]),
532             ],
533             UInt64Builder
534         )?;
535 
536         let (states, result) = run_merge_batch(&[state_in1, state_in2])?;
537 
538         let state_out_vec1 = state_to_vec!(&states[0], Int32, i32).unwrap();
539         let state_out_vec2 = state_to_vec!(&states[1], UInt64, u64).unwrap();
540         let state_pairs = collect_states::<i32, u64>(&state_out_vec1, &state_out_vec2);
541 
542         assert_eq!(
543             state_pairs,
544             vec![
545                 (Some(-3_i32), Some(7_u64)),
546                 (Some(-2_i32), Some(5_u64)),
547                 (Some(-2_i32), Some(7_u64)),
548                 (Some(-1_i32), Some(5_u64)),
549                 (Some(-1_i32), Some(6_u64)),
550             ]
551         );
552 
553         assert_eq!(result, ScalarValue::UInt64(Some(5)));
554 
555         Ok(())
556     }
557 }
558