1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18 19# The following S3 methods are registered on load if dplyr is present 20 21summarise.arrow_dplyr_query <- function(.data, ...) { 22 call <- match.call() 23 .data <- as_adq(.data) 24 exprs <- quos(...) 25 # Only retain the columns we need to do our aggregations 26 vars_to_keep <- unique(c( 27 unlist(lapply(exprs, all.vars)), # vars referenced in summarise 28 dplyr::group_vars(.data) # vars needed for grouping 29 )) 30 # If exprs rely on the results of previous exprs 31 # (total = sum(x), mean = total / n()) 32 # then not all vars will correspond to columns in the data, 33 # so don't try to select() them (use intersect() to exclude them) 34 # Note that this select() isn't useful for the Arrow summarize implementation 35 # because it will effectively project to keep what it needs anyway, 36 # but the data.frame fallback version does benefit from select here 37 .data <- dplyr::select(.data, intersect(vars_to_keep, names(.data))) 38 39 # Try stuff, if successful return() 40 out <- try(do_arrow_summarize(.data, ...), silent = TRUE) 41 if (inherits(out, "try-error")) { 42 return(abandon_ship(call, .data, format(out))) 43 } else { 44 return(out) 45 } 46} 47summarise.Dataset <- summarise.ArrowTabular <- summarise.arrow_dplyr_query 48 49# This is the Arrow summarize implementation 50do_arrow_summarize <- function(.data, ..., .groups = NULL) { 51 exprs <- ensure_named_exprs(quos(...)) 52 53 # Create a stateful environment for recording our evaluated expressions 54 # It's more complex than other places because a single summarize() expr 55 # may result in multiple query nodes (Aggregate, Project), 56 # and we have to walk through the expressions to disentangle them. 57 ctx <- env( 58 mask = arrow_mask(.data, aggregation = TRUE), 59 aggregations = empty_named_list(), 60 post_mutate = empty_named_list() 61 ) 62 for (i in seq_along(exprs)) { 63 # Iterate over the indices and not the names because names may be repeated 64 # (which overwrites the previous name) 65 summarize_eval( 66 names(exprs)[i], 67 exprs[[i]], 68 ctx, 69 length(.data$group_by_vars) > 0 70 ) 71 } 72 73 # Apply the results to the .data object. 74 # First, the aggregations 75 .data$aggregations <- ctx$aggregations 76 # Then collapse the query so that the resulting query object can have 77 # additional operations applied to it 78 out <- collapse.arrow_dplyr_query(.data) 79 # The expressions may have been translated into 80 # "first, aggregate, then transform the result further" 81 # nolint start 82 # For example, 83 # summarize(mean = sum(x) / n()) 84 # is effectively implemented as 85 # summarize(..temp0 = sum(x), ..temp1 = n()) %>% 86 # mutate(mean = ..temp0 / ..temp1) %>% 87 # select(-starts_with("..temp")) 88 # If this is the case, there will be expressions in post_mutate 89 # nolint end 90 if (length(ctx$post_mutate)) { 91 # Append post_mutate, and make sure order is correct 92 # according to input exprs (also dropping ..temp columns) 93 out$selected_columns <- c( 94 out$selected_columns, 95 ctx$post_mutate 96 )[c(.data$group_by_vars, names(exprs))] 97 } 98 99 # If the object has .drop = FALSE and any group vars are dictionaries, 100 # we can't (currently) preserve the empty rows that dplyr does, 101 # so give a warning about that. 102 if (!dplyr::group_by_drop_default(.data)) { 103 group_by_exprs <- .data$selected_columns[.data$group_by_vars] 104 if (any(map_lgl(group_by_exprs, ~ inherits(.$type(), "DictionaryType")))) { 105 warning( 106 ".drop = FALSE currently not supported in Arrow aggregation", 107 call. = FALSE 108 ) 109 } 110 } 111 112 # Handle .groups argument 113 if (length(.data$group_by_vars)) { 114 if (is.null(.groups)) { 115 # dplyr docs say: 116 # When ‘.groups’ is not specified, it is chosen based on the 117 # number of rows of the results: 118 # • If all the results have 1 row, you get "drop_last". 119 # • If the number of rows varies, you get "keep". 120 # 121 # But we don't support anything that returns multiple rows now 122 .groups <- "drop_last" 123 } else { 124 assert_that(is.string(.groups)) 125 } 126 if (.groups == "drop_last") { 127 out$group_by_vars <- head(.data$group_by_vars, -1) 128 } else if (.groups == "keep") { 129 out$group_by_vars <- .data$group_by_vars 130 } else if (.groups == "rowwise") { 131 stop(arrow_not_supported('.groups = "rowwise"')) 132 } else if (.groups == "drop") { 133 # collapse() preserves groups so remove them 134 out <- dplyr::ungroup(out) 135 } else { 136 stop(paste("Invalid .groups argument:", .groups)) 137 } 138 # TODO: shouldn't we be doing something with `drop_empty_groups` in summarize? (ARROW-14044) 139 out$drop_empty_groups <- .data$drop_empty_groups 140 } 141 out 142} 143 144arrow_eval_or_stop <- function(expr, mask) { 145 # TODO: change arrow_eval error handling behavior? 146 out <- arrow_eval(expr, mask) 147 if (inherits(out, "try-error")) { 148 msg <- handle_arrow_not_supported(out, format_expr(expr)) 149 stop(msg, call. = FALSE) 150 } 151 out 152} 153 154summarize_projection <- function(.data) { 155 c( 156 map(.data$aggregations, ~ .$data), 157 .data$selected_columns[.data$group_by_vars] 158 ) 159} 160 161format_aggregation <- function(x) { 162 paste0(x$fun, "(", x$data$ToString(), ")") 163} 164 165# This function handles each summarize expression and turns it into the 166# appropriate combination of (1) aggregations (possibly temporary) and 167# (2) post-aggregation transformations (mutate) 168# The function returns nothing: it assigns into the `ctx` environment 169summarize_eval <- function(name, quosure, ctx, hash, recurse = FALSE) { 170 expr <- quo_get_expr(quosure) 171 ctx$quo_env <- quo_get_env(quosure) 172 173 funs_in_expr <- all_funs(expr) 174 if (length(funs_in_expr) == 0) { 175 # If it is a scalar or field ref, no special handling required 176 ctx$aggregations[[name]] <- arrow_eval_or_stop(quosure, ctx$mask) 177 return() 178 } 179 180 # For the quantile() binding in the hash aggregation case, we need to mutate 181 # the list output from the Arrow hash_tdigest kernel to flatten it into a 182 # column of type float64. We do that by modifying the unevaluated expression 183 # to replace quantile(...) with arrow_list_element(quantile(...), 0L) 184 if (hash && "quantile" %in% funs_in_expr) { 185 expr <- wrap_hash_quantile(expr) 186 funs_in_expr <- all_funs(expr) 187 } 188 189 # Start inspecting the expr to see what aggregations it involves 190 agg_funs <- names(agg_funcs) 191 outer_agg <- funs_in_expr[1] %in% agg_funs 192 inner_agg <- funs_in_expr[-1] %in% agg_funs 193 194 # First, pull out any aggregations wrapped in other function calls 195 if (any(inner_agg)) { 196 expr <- extract_aggregations(expr, ctx) 197 } 198 199 # By this point, there are no more aggregation functions in expr 200 # except for possibly the outer function call: 201 # they've all been pulled out to ctx$aggregations, and in their place in expr 202 # there are variable names, which will correspond to field refs in the 203 # query object after aggregation and collapse(). 204 # So if we want to know if there are any aggregations inside expr, 205 # we have to look for them by their new var names 206 inner_agg_exprs <- all_vars(expr) %in% names(ctx$aggregations) 207 208 if (outer_agg) { 209 # This is something like agg(fun(x, y) 210 # It just works by normal arrow_eval, unless there's a mix of aggs and 211 # columns in the original data like agg(fun(x, agg(x))) 212 # (but that will have been caught in extract_aggregations()) 213 ctx$aggregations[[name]] <- arrow_eval_or_stop( 214 as_quosure(expr, ctx$quo_env), 215 ctx$mask 216 ) 217 return() 218 } else if (all(inner_agg_exprs)) { 219 # Something like: fun(agg(x), agg(y)) 220 # So based on the aggregations that have been extracted, mutate after 221 mutate_mask <- arrow_mask( 222 list(selected_columns = make_field_refs(names(ctx$aggregations))) 223 ) 224 ctx$post_mutate[[name]] <- arrow_eval_or_stop( 225 as_quosure(expr, ctx$quo_env), 226 mutate_mask 227 ) 228 return() 229 } 230 231 # Backstop for any other odd cases, like fun(x, y) (i.e. no aggregation), 232 # or aggregation functions that aren't supported in Arrow (not in agg_funcs) 233 stop( 234 handle_arrow_not_supported(quo_get_expr(quosure), format_expr(quosure)), 235 call. = FALSE 236 ) 237} 238 239# This function recurses through expr, pulls out any aggregation expressions, 240# and inserts a variable name (field ref) in place of the aggregation 241extract_aggregations <- function(expr, ctx) { 242 # Keep the input in case we need to raise an error message with it 243 original_expr <- expr 244 funs <- all_funs(expr) 245 if (length(funs) == 0) { 246 return(expr) 247 } else if (length(funs) > 1) { 248 # Recurse more 249 expr[-1] <- lapply(expr[-1], extract_aggregations, ctx) 250 } 251 if (funs[1] %in% names(agg_funcs)) { 252 inner_agg_exprs <- all_vars(expr) %in% names(ctx$aggregations) 253 if (any(inner_agg_exprs) & !all(inner_agg_exprs)) { 254 # We can't aggregate over a combination of dataset columns and other 255 # aggregations (e.g. sum(x - mean(x))) 256 # TODO: support in ARROW-13926 257 # TODO: Add "because" arg to explain _why_ it's not supported? 258 # TODO: this message could also say "not supported in summarize()" 259 # since some of these expressions may be legal elsewhere 260 stop( 261 handle_arrow_not_supported(original_expr, format_expr(original_expr)), 262 call. = FALSE 263 ) 264 } 265 266 # We have an aggregation expression with no other aggregations inside it, 267 # so arrow_eval the expression on the data and give it a ..temp name prefix, 268 # then insert that name (symbol) back into the expression so that we can 269 # mutate() on the result of the aggregation and reference this field. 270 tmpname <- paste0("..temp", length(ctx$aggregations)) 271 ctx$aggregations[[tmpname]] <- arrow_eval_or_stop(as_quosure(expr, ctx$quo_env), ctx$mask) 272 expr <- as.symbol(tmpname) 273 } 274 expr 275} 276 277# This function recurses through expr and wraps each call to quantile() with a 278# call to arrow_list_element() 279wrap_hash_quantile <- function(expr) { 280 if (length(expr) == 1) { 281 return(expr) 282 } else { 283 if (is.call(expr) && expr[[1]] == quote(quantile)) { 284 return(str2lang(paste0("arrow_list_element(", deparse1(expr), ", 0L)"))) 285 } else { 286 return(as.call(lapply(expr, wrap_hash_quantile))) 287 } 288 } 289} 290