1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18#' @include expression.R
19#' @include record-batch.R
20#' @include table.R
21
22arrow_dplyr_query <- function(.data) {
23  # An arrow_dplyr_query is a container for an Arrow data object (Table,
24  # RecordBatch, or Dataset) and the state of the user's dplyr query--things
25  # like selected columns, filters, and group vars.
26  # An arrow_dplyr_query can contain another arrow_dplyr_query in .data
27  gv <- dplyr::group_vars(.data) %||% character()
28
29  if (!inherits(.data, c("Dataset", "arrow_dplyr_query", "RecordBatchReader"))) {
30    .data <- InMemoryDataset$create(.data)
31  }
32  # Evaluating expressions on a dataset with duplicated fieldnames will error
33  dupes <- duplicated(names(.data))
34  if (any(dupes)) {
35    abort(c(
36      "Duplicated field names",
37      x = paste0(
38        "The following field names were found more than once in the data: ",
39        oxford_paste(names(.data)[dupes])
40      )
41    ))
42  }
43  structure(
44    list(
45      .data = .data,
46      # selected_columns is a named list:
47      # * contents are references/expressions pointing to the data
48      # * names are the names they should be in the end (i.e. this
49      #   records any renaming)
50      selected_columns = make_field_refs(names(.data$schema)),
51      # filtered_rows will be an Expression
52      filtered_rows = TRUE,
53      # group_by_vars is a character vector of columns (as renamed)
54      # in the data. They will be kept when data is pulled into R.
55      group_by_vars = gv,
56      # drop_empty_groups is a logical value indicating whether to drop
57      # groups formed by factor levels that don't appear in the data. It
58      # should be non-null only when the data is grouped.
59      drop_empty_groups = NULL,
60      # arrange_vars will be a list of expressions named by their associated
61      # column names
62      arrange_vars = list(),
63      # arrange_desc will be a logical vector indicating the sort order for each
64      # expression in arrange_vars (FALSE for ascending, TRUE for descending)
65      arrange_desc = logical()
66    ),
67    class = "arrow_dplyr_query"
68  )
69}
70
71# The only difference between `arrow_dplyr_query()` and `as_adq()` is that if
72# `.data` is already an `arrow_dplyr_query`, `as_adq()`, will return it as is, but
73# `arrow_dplyr_query()` will nest it inside a new `arrow_dplyr_query`. The only
74# place where `arrow_dplyr_query()` should be called directly is inside
75# `collapse()` methods; everywhere else, call `as_adq()`.
76as_adq <- function(.data) {
77  # For most dplyr methods,
78  # method.Table == method.RecordBatch == method.Dataset == method.arrow_dplyr_query
79  # This works because the functions all pass .data through as_adq()
80  if (inherits(.data, "arrow_dplyr_query")) {
81    return(.data)
82  }
83  arrow_dplyr_query(.data)
84}
85
86make_field_refs <- function(field_names) {
87  set_names(lapply(field_names, Expression$field_ref), field_names)
88}
89
90#' @export
91print.arrow_dplyr_query <- function(x, ...) {
92  schm <- x$.data$schema
93  types <- map_chr(x$selected_columns, function(expr) {
94    name <- expr$field_name
95    if (nzchar(name)) {
96      # Just a field_ref, so look up in the schema
97      schm$GetFieldByName(name)$type$ToString()
98    } else {
99      # Expression, so get its type and append the expression
100      paste0(
101        expr$type(schm)$ToString(),
102        " (", expr$ToString(), ")"
103      )
104    }
105  })
106  fields <- paste(names(types), types, sep = ": ", collapse = "\n")
107  cat(class(source_data(x))[1], " (query)\n", sep = "")
108  cat(fields, "\n", sep = "")
109  cat("\n")
110  if (length(x$aggregations)) {
111    cat("* Aggregations:\n")
112    aggs <- paste0(names(x$aggregations), ": ", map_chr(x$aggregations, format_aggregation), collapse = "\n")
113    cat(aggs, "\n", sep = "")
114  }
115  if (!isTRUE(x$filtered_rows)) {
116    filter_string <- x$filtered_rows$ToString()
117    cat("* Filter: ", filter_string, "\n", sep = "")
118  }
119  if (length(x$group_by_vars)) {
120    cat("* Grouped by ", paste(x$group_by_vars, collapse = ", "), "\n", sep = "")
121  }
122  if (length(x$arrange_vars)) {
123    arrange_strings <- map_chr(x$arrange_vars, function(x) x$ToString())
124    cat(
125      "* Sorted by ",
126      paste(
127        paste0(
128          arrange_strings,
129          " [", ifelse(x$arrange_desc, "desc", "asc"), "]"
130        ),
131        collapse = ", "
132      ),
133      "\n",
134      sep = ""
135    )
136  }
137  cat("See $.data for the source Arrow object\n")
138  invisible(x)
139}
140
141# These are the names reflecting all select/rename, not what is in Arrow
142#' @export
143names.arrow_dplyr_query <- function(x) names(x$selected_columns)
144
145#' @export
146dim.arrow_dplyr_query <- function(x) {
147  cols <- length(names(x))
148
149  if (is_collapsed(x)) {
150    # Don't evaluate just for nrow
151    rows <- NA_integer_
152  } else if (isTRUE(x$filtered_rows)) {
153    rows <- x$.data$num_rows
154  } else {
155    rows <- Scanner$create(x)$CountRows()
156  }
157  c(rows, cols)
158}
159
160#' @export
161as.data.frame.arrow_dplyr_query <- function(x, row.names = NULL, optional = FALSE, ...) {
162  collect.arrow_dplyr_query(x, as_data_frame = TRUE, ...)
163}
164
165#' @export
166head.arrow_dplyr_query <- function(x, n = 6L, ...) {
167  x$head <- n
168  collapse.arrow_dplyr_query(x)
169}
170
171#' @export
172tail.arrow_dplyr_query <- function(x, n = 6L, ...) {
173  x$tail <- n
174  collapse.arrow_dplyr_query(x)
175}
176
177#' @export
178`[.arrow_dplyr_query` <- function(x, i, j, ..., drop = FALSE) {
179  x <- ensure_group_vars(x)
180  if (nargs() == 2L) {
181    # List-like column extraction (x[i])
182    return(x[, i])
183  }
184  if (!missing(j)) {
185    x <- select.arrow_dplyr_query(x, all_of(j))
186  }
187
188  if (!missing(i)) {
189    out <- take_dataset_rows(x, i)
190    x <- restore_dplyr_features(out, x)
191  }
192  x
193}
194
195ensure_group_vars <- function(x) {
196  if (inherits(x, "arrow_dplyr_query")) {
197    # Before pulling data from Arrow, make sure all group vars are in the projection
198    gv <- set_names(setdiff(dplyr::group_vars(x), names(x)))
199    if (length(gv)) {
200      # Add them back
201      x$selected_columns <- c(
202        x$selected_columns,
203        make_field_refs(gv)
204      )
205    }
206  }
207  x
208}
209
210ensure_arrange_vars <- function(x) {
211  # The arrange() operation is not performed until later, because:
212  # - It must be performed after mutate(), to enable sorting by new columns.
213  # - It should be performed after filter() and select(), for efficiency.
214  # However, we need users to be able to arrange() by columns and expressions
215  # that are *not* returned in the query result. To enable this, we must
216  # *temporarily* include these columns and expressions in the projection. We
217  # use x$temp_columns to store these. Later, after the arrange() operation has
218  # been performed, these are omitted from the result. This differs from the
219  # columns in x$group_by_vars which *are* returned in the result.
220  x$temp_columns <- x$arrange_vars[!names(x$arrange_vars) %in% names(x$selected_columns)]
221  x
222}
223
224# Helper to handle unsupported dplyr features
225# * For Table/RecordBatch, we collect() and then call the dplyr method in R
226# * For Dataset, we just error
227abandon_ship <- function(call, .data, msg) {
228  msg <- trimws(msg)
229  dplyr_fun_name <- sub("^(.*?)\\..*", "\\1", as.character(call[[1]]))
230  if (query_on_dataset(.data)) {
231    stop(msg, "\nCall collect() first to pull data into R.", call. = FALSE)
232  }
233  # else, collect and call dplyr method
234  warning(msg, "; pulling data into R", immediate. = TRUE, call. = FALSE)
235  call$.data <- dplyr::collect(.data)
236  call[[1]] <- get(dplyr_fun_name, envir = asNamespace("dplyr"))
237  eval.parent(call, 2)
238}
239
240query_on_dataset <- function(x) !inherits(source_data(x), "InMemoryDataset")
241
242source_data <- function(x) {
243  if (is_collapsed(x)) {
244    source_data(x$.data)
245  } else {
246    x$.data
247  }
248}
249
250is_collapsed <- function(x) inherits(x$.data, "arrow_dplyr_query")
251
252has_aggregation <- function(x) {
253  # TODO: update with joins (check right side data too)
254  !is.null(x$aggregations) || (is_collapsed(x) && has_aggregation(x$.data))
255}
256
257has_head_tail <- function(x) {
258  !is.null(x$head) || !is.null(x$tail) || (is_collapsed(x) && has_head_tail(x$.data))
259}
260