1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18
19# The following S3 methods are registered on load if dplyr is present
20
21summarise.arrow_dplyr_query <- function(.data, ...) {
22  call <- match.call()
23  .data <- as_adq(.data)
24  exprs <- quos(...)
25  # Only retain the columns we need to do our aggregations
26  vars_to_keep <- unique(c(
27    unlist(lapply(exprs, all.vars)), # vars referenced in summarise
28    dplyr::group_vars(.data) # vars needed for grouping
29  ))
30  # If exprs rely on the results of previous exprs
31  # (total = sum(x), mean = total / n())
32  # then not all vars will correspond to columns in the data,
33  # so don't try to select() them (use intersect() to exclude them)
34  # Note that this select() isn't useful for the Arrow summarize implementation
35  # because it will effectively project to keep what it needs anyway,
36  # but the data.frame fallback version does benefit from select here
37  .data <- dplyr::select(.data, intersect(vars_to_keep, names(.data)))
38
39  # Try stuff, if successful return()
40  out <- try(do_arrow_summarize(.data, ...), silent = TRUE)
41  if (inherits(out, "try-error")) {
42    return(abandon_ship(call, .data, format(out)))
43  } else {
44    return(out)
45  }
46}
47summarise.Dataset <- summarise.ArrowTabular <- summarise.arrow_dplyr_query
48
49# This is the Arrow summarize implementation
50do_arrow_summarize <- function(.data, ..., .groups = NULL) {
51  exprs <- ensure_named_exprs(quos(...))
52
53  # Create a stateful environment for recording our evaluated expressions
54  # It's more complex than other places because a single summarize() expr
55  # may result in multiple query nodes (Aggregate, Project),
56  # and we have to walk through the expressions to disentangle them.
57  ctx <- env(
58    mask = arrow_mask(.data, aggregation = TRUE),
59    aggregations = empty_named_list(),
60    post_mutate = empty_named_list()
61  )
62  for (i in seq_along(exprs)) {
63    # Iterate over the indices and not the names because names may be repeated
64    # (which overwrites the previous name)
65    summarize_eval(
66      names(exprs)[i],
67      exprs[[i]],
68      ctx,
69      length(.data$group_by_vars) > 0
70    )
71  }
72
73  # Apply the results to the .data object.
74  # First, the aggregations
75  .data$aggregations <- ctx$aggregations
76  # Then collapse the query so that the resulting query object can have
77  # additional operations applied to it
78  out <- collapse.arrow_dplyr_query(.data)
79  # The expressions may have been translated into
80  # "first, aggregate, then transform the result further"
81  # nolint start
82  # For example,
83  #   summarize(mean = sum(x) / n())
84  # is effectively implemented as
85  #   summarize(..temp0 = sum(x), ..temp1 = n()) %>%
86  #   mutate(mean = ..temp0 / ..temp1) %>%
87  #   select(-starts_with("..temp"))
88  # If this is the case, there will be expressions in post_mutate
89  # nolint end
90  if (length(ctx$post_mutate)) {
91    # Append post_mutate, and make sure order is correct
92    # according to input exprs (also dropping ..temp columns)
93    out$selected_columns <- c(
94      out$selected_columns,
95      ctx$post_mutate
96    )[c(.data$group_by_vars, names(exprs))]
97  }
98
99  # If the object has .drop = FALSE and any group vars are dictionaries,
100  # we can't (currently) preserve the empty rows that dplyr does,
101  # so give a warning about that.
102  if (!dplyr::group_by_drop_default(.data)) {
103    group_by_exprs <- .data$selected_columns[.data$group_by_vars]
104    if (any(map_lgl(group_by_exprs, ~ inherits(.$type(), "DictionaryType")))) {
105      warning(
106        ".drop = FALSE currently not supported in Arrow aggregation",
107        call. = FALSE
108      )
109    }
110  }
111
112  # Handle .groups argument
113  if (length(.data$group_by_vars)) {
114    if (is.null(.groups)) {
115      # dplyr docs say:
116      # When ‘.groups’ is not specified, it is chosen based on the
117      # number of rows of the results:
118      # • If all the results have 1 row, you get "drop_last".
119      # • If the number of rows varies, you get "keep".
120      #
121      # But we don't support anything that returns multiple rows now
122      .groups <- "drop_last"
123    } else {
124      assert_that(is.string(.groups))
125    }
126    if (.groups == "drop_last") {
127      out$group_by_vars <- head(.data$group_by_vars, -1)
128    } else if (.groups == "keep") {
129      out$group_by_vars <- .data$group_by_vars
130    } else if (.groups == "rowwise") {
131      stop(arrow_not_supported('.groups = "rowwise"'))
132    } else if (.groups == "drop") {
133      # collapse() preserves groups so remove them
134      out <- dplyr::ungroup(out)
135    } else {
136      stop(paste("Invalid .groups argument:", .groups))
137    }
138    # TODO: shouldn't we be doing something with `drop_empty_groups` in summarize? (ARROW-14044)
139    out$drop_empty_groups <- .data$drop_empty_groups
140  }
141  out
142}
143
144arrow_eval_or_stop <- function(expr, mask) {
145  # TODO: change arrow_eval error handling behavior?
146  out <- arrow_eval(expr, mask)
147  if (inherits(out, "try-error")) {
148    msg <- handle_arrow_not_supported(out, format_expr(expr))
149    stop(msg, call. = FALSE)
150  }
151  out
152}
153
154summarize_projection <- function(.data) {
155  c(
156    map(.data$aggregations, ~ .$data),
157    .data$selected_columns[.data$group_by_vars]
158  )
159}
160
161format_aggregation <- function(x) {
162  paste0(x$fun, "(", x$data$ToString(), ")")
163}
164
165# This function handles each summarize expression and turns it into the
166# appropriate combination of (1) aggregations (possibly temporary) and
167# (2) post-aggregation transformations (mutate)
168# The function returns nothing: it assigns into the `ctx` environment
169summarize_eval <- function(name, quosure, ctx, hash, recurse = FALSE) {
170  expr <- quo_get_expr(quosure)
171  ctx$quo_env <- quo_get_env(quosure)
172
173  funs_in_expr <- all_funs(expr)
174  if (length(funs_in_expr) == 0) {
175    # If it is a scalar or field ref, no special handling required
176    ctx$aggregations[[name]] <- arrow_eval_or_stop(quosure, ctx$mask)
177    return()
178  }
179
180  # For the quantile() binding in the hash aggregation case, we need to mutate
181  # the list output from the Arrow hash_tdigest kernel to flatten it into a
182  # column of type float64. We do that by modifying the unevaluated expression
183  # to replace quantile(...) with arrow_list_element(quantile(...), 0L)
184  if (hash && "quantile" %in% funs_in_expr) {
185    expr <- wrap_hash_quantile(expr)
186    funs_in_expr <- all_funs(expr)
187  }
188
189  # Start inspecting the expr to see what aggregations it involves
190  agg_funs <- names(agg_funcs)
191  outer_agg <- funs_in_expr[1] %in% agg_funs
192  inner_agg <- funs_in_expr[-1] %in% agg_funs
193
194  # First, pull out any aggregations wrapped in other function calls
195  if (any(inner_agg)) {
196    expr <- extract_aggregations(expr, ctx)
197  }
198
199  # By this point, there are no more aggregation functions in expr
200  # except for possibly the outer function call:
201  # they've all been pulled out to ctx$aggregations, and in their place in expr
202  # there are variable names, which will correspond to field refs in the
203  # query object after aggregation and collapse().
204  # So if we want to know if there are any aggregations inside expr,
205  # we have to look for them by their new var names
206  inner_agg_exprs <- all_vars(expr) %in% names(ctx$aggregations)
207
208  if (outer_agg) {
209    # This is something like agg(fun(x, y)
210    # It just works by normal arrow_eval, unless there's a mix of aggs and
211    # columns in the original data like agg(fun(x, agg(x)))
212    # (but that will have been caught in extract_aggregations())
213    ctx$aggregations[[name]] <- arrow_eval_or_stop(
214      as_quosure(expr, ctx$quo_env),
215      ctx$mask
216    )
217    return()
218  } else if (all(inner_agg_exprs)) {
219    # Something like: fun(agg(x), agg(y))
220    # So based on the aggregations that have been extracted, mutate after
221    mutate_mask <- arrow_mask(
222      list(selected_columns = make_field_refs(names(ctx$aggregations)))
223    )
224    ctx$post_mutate[[name]] <- arrow_eval_or_stop(
225      as_quosure(expr, ctx$quo_env),
226      mutate_mask
227    )
228    return()
229  }
230
231  # Backstop for any other odd cases, like fun(x, y) (i.e. no aggregation),
232  # or aggregation functions that aren't supported in Arrow (not in agg_funcs)
233  stop(
234    handle_arrow_not_supported(quo_get_expr(quosure), format_expr(quosure)),
235    call. = FALSE
236  )
237}
238
239# This function recurses through expr, pulls out any aggregation expressions,
240# and inserts a variable name (field ref) in place of the aggregation
241extract_aggregations <- function(expr, ctx) {
242  # Keep the input in case we need to raise an error message with it
243  original_expr <- expr
244  funs <- all_funs(expr)
245  if (length(funs) == 0) {
246    return(expr)
247  } else if (length(funs) > 1) {
248    # Recurse more
249    expr[-1] <- lapply(expr[-1], extract_aggregations, ctx)
250  }
251  if (funs[1] %in% names(agg_funcs)) {
252    inner_agg_exprs <- all_vars(expr) %in% names(ctx$aggregations)
253    if (any(inner_agg_exprs) & !all(inner_agg_exprs)) {
254      # We can't aggregate over a combination of dataset columns and other
255      # aggregations (e.g. sum(x - mean(x)))
256      # TODO: support in ARROW-13926
257      # TODO: Add "because" arg to explain _why_ it's not supported?
258      # TODO: this message could also say "not supported in summarize()"
259      #       since some of these expressions may be legal elsewhere
260      stop(
261        handle_arrow_not_supported(original_expr, format_expr(original_expr)),
262        call. = FALSE
263      )
264    }
265
266    # We have an aggregation expression with no other aggregations inside it,
267    # so arrow_eval the expression on the data and give it a ..temp name prefix,
268    # then insert that name (symbol) back into the expression so that we can
269    # mutate() on the result of the aggregation and reference this field.
270    tmpname <- paste0("..temp", length(ctx$aggregations))
271    ctx$aggregations[[tmpname]] <- arrow_eval_or_stop(as_quosure(expr, ctx$quo_env), ctx$mask)
272    expr <- as.symbol(tmpname)
273  }
274  expr
275}
276
277# This function recurses through expr and wraps each call to quantile() with a
278# call to arrow_list_element()
279wrap_hash_quantile <- function(expr) {
280  if (length(expr) == 1) {
281    return(expr)
282  } else {
283    if (is.call(expr) && expr[[1]] == quote(quantile)) {
284      return(str2lang(paste0("arrow_list_element(", deparse1(expr), ", 0L)")))
285    } else {
286      return(as.call(lapply(expr, wrap_hash_quantile)))
287    }
288  }
289}
290