1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18
19#' @include expression.R
20NULL
21
22# This environment is an internal cache for things including data mask functions
23# We'll populate it at package load time.
24.cache <- NULL
25init_env <- function() {
26  .cache <<- new.env(hash = TRUE)
27}
28init_env()
29
30# nse_funcs is a list of functions that operated on (and return) Expressions
31# These will be the basis for a data_mask inside dplyr methods
32# and will be added to .cache at package load time
33
34# Start with mappings from R function name spellings
35nse_funcs <- lapply(set_names(names(.array_function_map)), function(operator) {
36  force(operator)
37  function(...) build_expr(operator, ...)
38})
39
40# Now add functions to that list where the mapping from R to Arrow isn't 1:1
41# Each of these functions should have the same signature as the R function
42# they're replacing.
43#
44# When to use `build_expr()` vs. `Expression$create()`?
45#
46# Use `build_expr()` if you need to
47# (1) map R function names to Arrow C++ functions
48# (2) wrap R inputs (vectors) as Array/Scalar
49#
50# `Expression$create()` is lower level. Most of the functions below use it
51# because they manage the preparation of the user-provided inputs
52# and don't need to wrap scalars
53
54nse_funcs$cast <- function(x, target_type, safe = TRUE, ...) {
55  opts <- cast_options(safe, ...)
56  opts$to_type <- as_type(target_type)
57  Expression$create("cast", x, options = opts)
58}
59
60nse_funcs$coalesce <- function(...) {
61  args <- list2(...)
62  if (length(args) < 1) {
63    abort("At least one argument must be supplied to coalesce()")
64  }
65
66  # Treat NaN like NA for consistency with dplyr::coalesce(), but if *all*
67  # the values are NaN, we should return NaN, not NA, so don't replace
68  # NaN with NA in the final (or only) argument
69  # TODO: if an option is added to the coalesce kernel to treat NaN as NA,
70  # use that to simplify the code here (ARROW-13389)
71  attr(args[[length(args)]], "last") <- TRUE
72  args <- lapply(args, function(arg) {
73    last_arg <- is.null(attr(arg, "last"))
74    attr(arg, "last") <- NULL
75
76    if (!inherits(arg, "Expression")) {
77      arg <- Expression$scalar(arg)
78    }
79
80    # coalesce doesn't yet support factors/dictionaries
81    # TODO: remove this after ARROW-14167 is merged
82    if (nse_funcs$is.factor(arg)) {
83      warning("Dictionaries (in R: factors) are currently converted to strings (characters) in coalesce", call. = FALSE)
84    }
85
86    if (last_arg && arg$type_id() %in% TYPES_WITH_NAN) {
87      # store the NA_real_ in the same type as arg to avoid avoid casting
88      # smaller float types to larger float types
89      NA_expr <- Expression$scalar(Scalar$create(NA_real_, type = arg$type()))
90      Expression$create("if_else", Expression$create("is_nan", arg), NA_expr, arg)
91    } else {
92      arg
93    }
94  })
95  Expression$create("coalesce", args = args)
96}
97
98nse_funcs$is.na <- function(x) {
99  build_expr("is_null", x, options = list(nan_is_null = TRUE))
100}
101
102nse_funcs$is.nan <- function(x) {
103  if (is.double(x) || (inherits(x, "Expression") &&
104    x$type_id() %in% TYPES_WITH_NAN)) {
105    # TODO: if an option is added to the is_nan kernel to treat NA as NaN,
106    # use that to simplify the code here (ARROW-13366)
107    build_expr("is_nan", x) & build_expr("is_valid", x)
108  } else {
109    Expression$scalar(FALSE)
110  }
111}
112
113nse_funcs$is <- function(object, class2) {
114  if (is.string(class2)) {
115    switch(class2,
116      # for R data types, pass off to is.*() functions
117      character = nse_funcs$is.character(object),
118      numeric = nse_funcs$is.numeric(object),
119      integer = nse_funcs$is.integer(object),
120      integer64 = nse_funcs$is.integer64(object),
121      logical = nse_funcs$is.logical(object),
122      factor = nse_funcs$is.factor(object),
123      list = nse_funcs$is.list(object),
124      # for Arrow data types, compare class2 with object$type()$ToString(),
125      # but first strip off any parameters to only compare the top-level data
126      # type,  and canonicalize class2
127      sub("^([^([<]+).*$", "\\1", object$type()$ToString()) ==
128        canonical_type_str(class2)
129    )
130  } else if (inherits(class2, "DataType")) {
131    object$type() == as_type(class2)
132  } else {
133    stop("Second argument to is() is not a string or DataType", call. = FALSE)
134  }
135}
136
137nse_funcs$dictionary_encode <- function(x,
138                                        null_encoding_behavior = c("mask", "encode")) {
139  behavior <- toupper(match.arg(null_encoding_behavior))
140  null_encoding_behavior <- NullEncodingBehavior[[behavior]]
141  Expression$create(
142    "dictionary_encode",
143    x,
144    options = list(null_encoding_behavior = null_encoding_behavior)
145  )
146}
147
148nse_funcs$between <- function(x, left, right) {
149  x >= left & x <= right
150}
151
152nse_funcs$is.finite <- function(x) {
153  is_fin <- Expression$create("is_finite", x)
154  # for compatibility with base::is.finite(), return FALSE for NA_real_
155  is_fin & !nse_funcs$is.na(is_fin)
156}
157
158nse_funcs$is.infinite <- function(x) {
159  is_inf <- Expression$create("is_inf", x)
160  # for compatibility with base::is.infinite(), return FALSE for NA_real_
161  is_inf & !nse_funcs$is.na(is_inf)
162}
163
164# as.* type casting functions
165# as.factor() is mapped in expression.R
166nse_funcs$as.character <- function(x) {
167  Expression$create("cast", x, options = cast_options(to_type = string()))
168}
169nse_funcs$as.double <- function(x) {
170  Expression$create("cast", x, options = cast_options(to_type = float64()))
171}
172nse_funcs$as.integer <- function(x) {
173  Expression$create(
174    "cast",
175    x,
176    options = cast_options(
177      to_type = int32(),
178      allow_float_truncate = TRUE,
179      allow_decimal_truncate = TRUE
180    )
181  )
182}
183nse_funcs$as.integer64 <- function(x) {
184  Expression$create(
185    "cast",
186    x,
187    options = cast_options(
188      to_type = int64(),
189      allow_float_truncate = TRUE,
190      allow_decimal_truncate = TRUE
191    )
192  )
193}
194nse_funcs$as.logical <- function(x) {
195  Expression$create("cast", x, options = cast_options(to_type = boolean()))
196}
197nse_funcs$as.numeric <- function(x) {
198  Expression$create("cast", x, options = cast_options(to_type = float64()))
199}
200
201# is.* type functions
202nse_funcs$is.character <- function(x) {
203  is.character(x) || (inherits(x, "Expression") &&
204    x$type_id() %in% Type[c("STRING", "LARGE_STRING")])
205}
206nse_funcs$is.numeric <- function(x) {
207  is.numeric(x) || (inherits(x, "Expression") && x$type_id() %in% Type[c(
208    "UINT8", "INT8", "UINT16", "INT16", "UINT32", "INT32",
209    "UINT64", "INT64", "HALF_FLOAT", "FLOAT", "DOUBLE",
210    "DECIMAL", "DECIMAL256"
211  )])
212}
213nse_funcs$is.double <- function(x) {
214  is.double(x) || (inherits(x, "Expression") && x$type_id() == Type["DOUBLE"])
215}
216nse_funcs$is.integer <- function(x) {
217  is.integer(x) || (inherits(x, "Expression") && x$type_id() %in% Type[c(
218    "UINT8", "INT8", "UINT16", "INT16", "UINT32", "INT32",
219    "UINT64", "INT64"
220  )])
221}
222nse_funcs$is.integer64 <- function(x) {
223  is.integer64(x) || (inherits(x, "Expression") && x$type_id() == Type["INT64"])
224}
225nse_funcs$is.logical <- function(x) {
226  is.logical(x) || (inherits(x, "Expression") && x$type_id() == Type["BOOL"])
227}
228nse_funcs$is.factor <- function(x) {
229  is.factor(x) || (inherits(x, "Expression") && x$type_id() == Type["DICTIONARY"])
230}
231nse_funcs$is.list <- function(x) {
232  is.list(x) || (inherits(x, "Expression") && x$type_id() %in% Type[c(
233    "LIST", "FIXED_SIZE_LIST", "LARGE_LIST"
234  )])
235}
236
237# rlang::is_* type functions
238nse_funcs$is_character <- function(x, n = NULL) {
239  assert_that(is.null(n))
240  nse_funcs$is.character(x)
241}
242nse_funcs$is_double <- function(x, n = NULL, finite = NULL) {
243  assert_that(is.null(n) && is.null(finite))
244  nse_funcs$is.double(x)
245}
246nse_funcs$is_integer <- function(x, n = NULL) {
247  assert_that(is.null(n))
248  nse_funcs$is.integer(x)
249}
250nse_funcs$is_list <- function(x, n = NULL) {
251  assert_that(is.null(n))
252  nse_funcs$is.list(x)
253}
254nse_funcs$is_logical <- function(x, n = NULL) {
255  assert_that(is.null(n))
256  nse_funcs$is.logical(x)
257}
258nse_funcs$is_timestamp <- function(x, n = NULL) {
259  assert_that(is.null(n))
260  inherits(x, "POSIXt") || (inherits(x, "Expression") && x$type_id() %in% Type[c("TIMESTAMP")])
261}
262
263# String functions
264nse_funcs$nchar <- function(x, type = "chars", allowNA = FALSE, keepNA = NA) {
265  if (allowNA) {
266    arrow_not_supported("allowNA = TRUE")
267  }
268  if (is.na(keepNA)) {
269    keepNA <- !identical(type, "width")
270  }
271  if (!keepNA) {
272    # TODO: I think there is a fill_null kernel we could use, set null to 2
273    arrow_not_supported("keepNA = TRUE")
274  }
275  if (identical(type, "bytes")) {
276    Expression$create("binary_length", x)
277  } else {
278    Expression$create("utf8_length", x)
279  }
280}
281
282nse_funcs$paste <- function(..., sep = " ", collapse = NULL, recycle0 = FALSE) {
283  assert_that(
284    is.null(collapse),
285    msg = "paste() with the collapse argument is not yet supported in Arrow"
286  )
287  if (!inherits(sep, "Expression")) {
288    assert_that(!is.na(sep), msg = "Invalid separator")
289  }
290  arrow_string_join_function(NullHandlingBehavior$REPLACE, "NA")(..., sep)
291}
292
293nse_funcs$paste0 <- function(..., collapse = NULL, recycle0 = FALSE) {
294  assert_that(
295    is.null(collapse),
296    msg = "paste0() with the collapse argument is not yet supported in Arrow"
297  )
298  arrow_string_join_function(NullHandlingBehavior$REPLACE, "NA")(..., "")
299}
300
301nse_funcs$str_c <- function(..., sep = "", collapse = NULL) {
302  assert_that(
303    is.null(collapse),
304    msg = "str_c() with the collapse argument is not yet supported in Arrow"
305  )
306  arrow_string_join_function(NullHandlingBehavior$EMIT_NULL)(..., sep)
307}
308
309arrow_string_join_function <- function(null_handling, null_replacement = NULL) {
310  # the `binary_join_element_wise` Arrow C++ compute kernel takes the separator
311  # as the last argument, so pass `sep` as the last dots arg to this function
312  function(...) {
313    args <- lapply(list(...), function(arg) {
314      # handle scalar literal args, and cast all args to string for
315      # consistency with base::paste(), base::paste0(), and stringr::str_c()
316      if (!inherits(arg, "Expression")) {
317        assert_that(
318          length(arg) == 1,
319          msg = "Literal vectors of length != 1 not supported in string concatenation"
320        )
321        Expression$scalar(as.character(arg))
322      } else {
323        nse_funcs$as.character(arg)
324      }
325    })
326    Expression$create(
327      "binary_join_element_wise",
328      args = args,
329      options = list(
330        null_handling = null_handling,
331        null_replacement = null_replacement
332      )
333    )
334  }
335}
336
337# Currently, Arrow does not supports a locale option for string case conversion
338# functions, contrast to stringr's API, so the 'locale' argument is only valid
339# for stringr's default value ("en"). The following are string functions that
340# take a 'locale' option as its second argument:
341#   str_to_lower
342#   str_to_upper
343#   str_to_title
344#
345# Arrow locale will be supported with ARROW-14126
346stop_if_locale_provided <- function(locale) {
347  if (!identical(locale, "en")) {
348    stop("Providing a value for 'locale' other than the default ('en') is not supported by Arrow. ",
349      "To change locale, use 'Sys.setlocale()'",
350      call. = FALSE
351    )
352  }
353}
354
355nse_funcs$str_to_lower <- function(string, locale = "en") {
356  stop_if_locale_provided(locale)
357  Expression$create("utf8_lower", string)
358}
359
360nse_funcs$str_to_upper <- function(string, locale = "en") {
361  stop_if_locale_provided(locale)
362  Expression$create("utf8_upper", string)
363}
364
365nse_funcs$str_to_title <- function(string, locale = "en") {
366  stop_if_locale_provided(locale)
367  Expression$create("utf8_title", string)
368}
369
370nse_funcs$str_trim <- function(string, side = c("both", "left", "right")) {
371  side <- match.arg(side)
372  trim_fun <- switch(side,
373    left = "utf8_ltrim_whitespace",
374    right = "utf8_rtrim_whitespace",
375    both = "utf8_trim_whitespace"
376  )
377  Expression$create(trim_fun, string)
378}
379
380nse_funcs$substr <- function(x, start, stop) {
381  assert_that(
382    length(start) == 1,
383    msg = "`start` must be length 1 - other lengths are not supported in Arrow"
384  )
385  assert_that(
386    length(stop) == 1,
387    msg = "`stop` must be length 1 - other lengths are not supported in Arrow"
388  )
389
390  # substr treats values as if they're on a continous number line, so values
391  # 0 are effectively blank characters - set `start` to 1 here so Arrow mimics
392  # this behavior
393  if (start <= 0) {
394    start <- 1
395  }
396
397  # if `stop` is lower than `start`, this is invalid, so set `stop` to
398  # 0 so that an empty string will be returned (consistent with base::substr())
399  if (stop < start) {
400    stop <- 0
401  }
402
403  Expression$create(
404    "utf8_slice_codeunits",
405    x,
406    # we don't need to subtract 1 from `stop` as C++ counts exclusively
407    # which effectively cancels out the difference in indexing between R & C++
408    options = list(start = start - 1L, stop = stop)
409  )
410}
411
412nse_funcs$substring <- function(text, first, last) {
413  nse_funcs$substr(x = text, start = first, stop = last)
414}
415
416nse_funcs$str_sub <- function(string, start = 1L, end = -1L) {
417  assert_that(
418    length(start) == 1,
419    msg = "`start` must be length 1 - other lengths are not supported in Arrow"
420  )
421  assert_that(
422    length(end) == 1,
423    msg = "`end` must be length 1 - other lengths are not supported in Arrow"
424  )
425
426  # In stringr::str_sub, an `end` value of -1 means the end of the string, so
427  # set it to the maximum integer to match this behavior
428  if (end == -1) {
429    end <- .Machine$integer.max
430  }
431
432  # An end value lower than a start value returns an empty string in
433  # stringr::str_sub so set end to 0 here to match this behavior
434  if (end < start) {
435    end <- 0
436  }
437
438  # subtract 1 from `start` because C++ is 0-based and R is 1-based
439  # str_sub treats a `start` value of 0 or 1 as the same thing so don't subtract 1 when `start` == 0
440  # when `start` < 0, both str_sub and utf8_slice_codeunits count backwards from the end
441  if (start > 0) {
442    start <- start - 1L
443  }
444
445  Expression$create(
446    "utf8_slice_codeunits",
447    string,
448    options = list(start = start, stop = end)
449  )
450}
451
452nse_funcs$grepl <- function(pattern, x, ignore.case = FALSE, fixed = FALSE) {
453  arrow_fun <- ifelse(fixed, "match_substring", "match_substring_regex")
454  Expression$create(
455    arrow_fun,
456    x,
457    options = list(pattern = pattern, ignore_case = ignore.case)
458  )
459}
460
461nse_funcs$str_detect <- function(string, pattern, negate = FALSE) {
462  opts <- get_stringr_pattern_options(enexpr(pattern))
463  out <- nse_funcs$grepl(
464    pattern = opts$pattern,
465    x = string,
466    ignore.case = opts$ignore_case,
467    fixed = opts$fixed
468  )
469  if (negate) {
470    out <- !out
471  }
472  out
473}
474
475nse_funcs$str_like <- function(string, pattern, ignore_case = TRUE) {
476  Expression$create(
477    "match_like",
478    string,
479    options = list(pattern = pattern, ignore_case = ignore_case)
480  )
481}
482
483# Encapsulate some common logic for sub/gsub/str_replace/str_replace_all
484arrow_r_string_replace_function <- function(max_replacements) {
485  function(pattern, replacement, x, ignore.case = FALSE, fixed = FALSE) {
486    Expression$create(
487      ifelse(fixed && !ignore.case, "replace_substring", "replace_substring_regex"),
488      x,
489      options = list(
490        pattern = format_string_pattern(pattern, ignore.case, fixed),
491        replacement = format_string_replacement(replacement, ignore.case, fixed),
492        max_replacements = max_replacements
493      )
494    )
495  }
496}
497
498arrow_stringr_string_replace_function <- function(max_replacements) {
499  function(string, pattern, replacement) {
500    opts <- get_stringr_pattern_options(enexpr(pattern))
501    arrow_r_string_replace_function(max_replacements)(
502      pattern = opts$pattern,
503      replacement = replacement,
504      x = string,
505      ignore.case = opts$ignore_case,
506      fixed = opts$fixed
507    )
508  }
509}
510
511nse_funcs$sub <- arrow_r_string_replace_function(1L)
512nse_funcs$gsub <- arrow_r_string_replace_function(-1L)
513nse_funcs$str_replace <- arrow_stringr_string_replace_function(1L)
514nse_funcs$str_replace_all <- arrow_stringr_string_replace_function(-1L)
515
516nse_funcs$strsplit <- function(x,
517                               split,
518                               fixed = FALSE,
519                               perl = FALSE,
520                               useBytes = FALSE) {
521  assert_that(is.string(split))
522
523  arrow_fun <- ifelse(fixed, "split_pattern", "split_pattern_regex")
524  # warn when the user specifies both fixed = TRUE and perl = TRUE, for
525  # consistency with the behavior of base::strsplit()
526  if (fixed && perl) {
527    warning("Argument 'perl = TRUE' will be ignored", call. = FALSE)
528  }
529  # since split is not a regex, proceed without any warnings or errors regardless
530  # of the value of perl, for consistency with the behavior of base::strsplit()
531  Expression$create(
532    arrow_fun,
533    x,
534    options = list(pattern = split, reverse = FALSE, max_splits = -1L)
535  )
536}
537
538nse_funcs$str_split <- function(string, pattern, n = Inf, simplify = FALSE) {
539  opts <- get_stringr_pattern_options(enexpr(pattern))
540  arrow_fun <- ifelse(opts$fixed, "split_pattern", "split_pattern_regex")
541  if (opts$ignore_case) {
542    arrow_not_supported("Case-insensitive string splitting")
543  }
544  if (n == 0) {
545    arrow_not_supported("Splitting strings into zero parts")
546  }
547  if (identical(n, Inf)) {
548    n <- 0L
549  }
550  if (simplify) {
551    warning("Argument 'simplify = TRUE' will be ignored", call. = FALSE)
552  }
553  # The max_splits option in the Arrow C++ library controls the maximum number
554  # of places at which the string is split, whereas the argument n to
555  # str_split() controls the maximum number of pieces to return. So we must
556  # subtract 1 from n to get max_splits.
557  Expression$create(
558    arrow_fun,
559    string,
560    options = list(
561      pattern = opts$pattern,
562      reverse = FALSE,
563      max_splits = n - 1L
564    )
565  )
566}
567
568nse_funcs$pmin <- function(..., na.rm = FALSE) {
569  build_expr(
570    "min_element_wise",
571    ...,
572    options = list(skip_nulls = na.rm)
573  )
574}
575
576nse_funcs$pmax <- function(..., na.rm = FALSE) {
577  build_expr(
578    "max_element_wise",
579    ...,
580    options = list(skip_nulls = na.rm)
581  )
582}
583
584nse_funcs$str_pad <- function(string, width, side = c("left", "right", "both"), pad = " ") {
585  assert_that(is_integerish(width))
586  side <- match.arg(side)
587  assert_that(is.string(pad))
588
589  if (side == "left") {
590    pad_func <- "utf8_lpad"
591  } else if (side == "right") {
592    pad_func <- "utf8_rpad"
593  } else if (side == "both") {
594    pad_func <- "utf8_center"
595  }
596
597  Expression$create(
598    pad_func,
599    string,
600    options = list(width = width, padding = pad)
601  )
602}
603
604nse_funcs$startsWith <- function(x, prefix) {
605  Expression$create(
606    "starts_with",
607    x,
608    options = list(pattern = prefix)
609  )
610}
611
612nse_funcs$endsWith <- function(x, suffix) {
613  Expression$create(
614    "ends_with",
615    x,
616    options = list(pattern = suffix)
617  )
618}
619
620nse_funcs$str_starts <- function(string, pattern, negate = FALSE) {
621  opts <- get_stringr_pattern_options(enexpr(pattern))
622  if (opts$fixed) {
623    out <- nse_funcs$startsWith(x = string, prefix = opts$pattern)
624  } else {
625    out <- nse_funcs$grepl(pattern = paste0("^", opts$pattern), x = string, fixed = FALSE)
626  }
627
628  if (negate) {
629    out <- !out
630  }
631  out
632}
633
634nse_funcs$str_ends <- function(string, pattern, negate = FALSE) {
635  opts <- get_stringr_pattern_options(enexpr(pattern))
636  if (opts$fixed) {
637    out <- nse_funcs$endsWith(x = string, suffix = opts$pattern)
638  } else {
639    out <- nse_funcs$grepl(pattern = paste0(opts$pattern, "$"), x = string, fixed = FALSE)
640  }
641
642  if (negate) {
643    out <- !out
644  }
645  out
646}
647
648nse_funcs$str_count <- function(string, pattern) {
649  opts <- get_stringr_pattern_options(enexpr(pattern))
650  if (!is.string(pattern)) {
651    arrow_not_supported("`pattern` must be a length 1 character vector; other values")
652  }
653  arrow_fun <- ifelse(opts$fixed, "count_substring", "count_substring_regex")
654  Expression$create(
655    arrow_fun,
656    string,
657    options = list(pattern = opts$pattern, ignore_case = opts$ignore_case)
658  )
659}
660
661# String function helpers
662
663# format `pattern` as needed for case insensitivity and literal matching by RE2
664format_string_pattern <- function(pattern, ignore.case, fixed) {
665  # Arrow lacks native support for case-insensitive literal string matching and
666  # replacement, so we use the regular expression engine (RE2) to do this.
667  # https://github.com/google/re2/wiki/Syntax
668  if (ignore.case) {
669    if (fixed) {
670      # Everything between "\Q" and "\E" is treated as literal text.
671      # If the search text contains any literal "\E" strings, make them
672      # lowercase so they won't signal the end of the literal text:
673      pattern <- gsub("\\E", "\\e", pattern, fixed = TRUE)
674      pattern <- paste0("\\Q", pattern, "\\E")
675    }
676    # Prepend "(?i)" for case-insensitive matching
677    pattern <- paste0("(?i)", pattern)
678  }
679  pattern
680}
681
682# format `replacement` as needed for literal replacement by RE2
683format_string_replacement <- function(replacement, ignore.case, fixed) {
684  # Arrow lacks native support for case-insensitive literal string
685  # replacement, so we use the regular expression engine (RE2) to do this.
686  # https://github.com/google/re2/wiki/Syntax
687  if (ignore.case && fixed) {
688    # Escape single backslashes in the regex replacement text so they are
689    # interpreted as literal backslashes:
690    replacement <- gsub("\\", "\\\\", replacement, fixed = TRUE)
691  }
692  replacement
693}
694
695#' Get `stringr` pattern options
696#'
697#' This function assigns definitions for the `stringr` pattern modifier
698#' functions (`fixed()`, `regex()`, etc.) inside itself, and uses them to
699#' evaluate the quoted expression `pattern`, returning a list that is used
700#' to control pattern matching behavior in internal `arrow` functions.
701#'
702#' @param pattern Unevaluated expression containing a call to a `stringr`
703#' pattern modifier function
704#'
705#' @return List containing elements `pattern`, `fixed`, and `ignore_case`
706#' @keywords internal
707get_stringr_pattern_options <- function(pattern) {
708  fixed <- function(pattern, ignore_case = FALSE, ...) {
709    check_dots(...)
710    list(pattern = pattern, fixed = TRUE, ignore_case = ignore_case)
711  }
712  regex <- function(pattern, ignore_case = FALSE, ...) {
713    check_dots(...)
714    list(pattern = pattern, fixed = FALSE, ignore_case = ignore_case)
715  }
716  coll <- function(...) {
717    arrow_not_supported("Pattern modifier `coll()`")
718  }
719  boundary <- function(...) {
720    arrow_not_supported("Pattern modifier `boundary()`")
721  }
722  check_dots <- function(...) {
723    dots <- list(...)
724    if (length(dots)) {
725      warning(
726        "Ignoring pattern modifier ",
727        ngettext(length(dots), "argument ", "arguments "),
728        "not supported in Arrow: ",
729        oxford_paste(names(dots)),
730        call. = FALSE
731      )
732    }
733  }
734  ensure_opts <- function(opts) {
735    if (is.character(opts)) {
736      opts <- list(pattern = opts, fixed = FALSE, ignore_case = FALSE)
737    }
738    opts
739  }
740  ensure_opts(eval(pattern))
741}
742
743#' Does this string contain regex metacharacters?
744#'
745#' @param string String to be tested
746#' @keywords internal
747#' @return Logical: does `string` contain regex metacharacters?
748contains_regex <- function(string) {
749  grepl("[.\\|()[{^$*+?]", string)
750}
751
752nse_funcs$strptime <- function(x, format = "%Y-%m-%d %H:%M:%S", tz = NULL, unit = "ms") {
753  # Arrow uses unit for time parsing, strptime() does not.
754  # Arrow has no default option for strptime (format, unit),
755  # we suggest following format = "%Y-%m-%d %H:%M:%S", unit = MILLI/1L/"ms",
756  # (ARROW-12809)
757
758  # ParseTimestampStrptime currently ignores the timezone information (ARROW-12820).
759  # Stop if tz is provided.
760  if (is.character(tz)) {
761    arrow_not_supported("Time zone argument")
762  }
763
764  unit <- make_valid_time_unit(unit, c(valid_time64_units, valid_time32_units))
765
766  Expression$create("strptime", x, options = list(format = format, unit = unit))
767}
768
769nse_funcs$strftime <- function(x, format = "", tz = "", usetz = FALSE) {
770  if (usetz) {
771    format <- paste(format, "%Z")
772  }
773  if (tz == "") {
774    tz <- Sys.timezone()
775  }
776  # Arrow's strftime prints in timezone of the timestamp. To match R's strftime behavior we first
777  # cast the timestamp to desired timezone. This is a metadata only change.
778  if (nse_funcs$is_timestamp(x)) {
779    ts <- Expression$create("cast", x, options = list(to_type = timestamp(x$type()$unit(), tz)))
780  } else {
781    ts <- x
782  }
783  Expression$create("strftime", ts, options = list(format = format, locale = Sys.getlocale("LC_TIME")))
784}
785
786nse_funcs$format_ISO8601 <- function(x, usetz = FALSE, precision = NULL, ...) {
787  ISO8601_precision_map <-
788    list(
789      y = "%Y",
790      ym = "%Y-%m",
791      ymd = "%Y-%m-%d",
792      ymdh = "%Y-%m-%dT%H",
793      ymdhm = "%Y-%m-%dT%H:%M",
794      ymdhms = "%Y-%m-%dT%H:%M:%S"
795    )
796
797  if (is.null(precision)) {
798    precision <- "ymdhms"
799  }
800  if (!precision %in% names(ISO8601_precision_map)) {
801    abort(
802      paste(
803        "`precision` must be one of the following values:",
804        paste(names(ISO8601_precision_map), collapse = ", "),
805        "\nValue supplied was: ",
806        precision
807      )
808    )
809  }
810  format <- ISO8601_precision_map[[precision]]
811  if (usetz) {
812    format <- paste0(format, "%z")
813  }
814  Expression$create("strftime", x, options = list(format = format, locale = "C"))
815}
816
817nse_funcs$second <- function(x) {
818  Expression$create("add", Expression$create("second", x), Expression$create("subsecond", x))
819}
820
821nse_funcs$trunc <- function(x, ...) {
822  # accepts and ignores ... for consistency with base::trunc()
823  build_expr("trunc", x)
824}
825
826nse_funcs$round <- function(x, digits = 0) {
827  build_expr(
828    "round",
829    x,
830    options = list(ndigits = digits, round_mode = RoundMode$HALF_TO_EVEN)
831  )
832}
833
834nse_funcs$wday <- function(x,
835                           label = FALSE,
836                           abbr = TRUE,
837                           week_start = getOption("lubridate.week.start", 7),
838                           locale = Sys.getlocale("LC_TIME")) {
839  if (label) {
840    if (abbr) {
841      format <- "%a"
842    } else {
843      format <- "%A"
844    }
845    return(Expression$create("strftime", x, options = list(format = format, locale = locale)))
846  }
847
848  Expression$create("day_of_week", x, options = list(count_from_zero = FALSE, week_start = week_start))
849}
850
851nse_funcs$log <- nse_funcs$logb <- function(x, base = exp(1)) {
852  # like other binary functions, either `x` or `base` can be Expression or double(1)
853  if (is.numeric(x) && length(x) == 1) {
854    x <- Expression$scalar(x)
855  } else if (!inherits(x, "Expression")) {
856    arrow_not_supported("x must be a column or a length-1 numeric; other values")
857  }
858
859  # handle `base` differently because we use the simpler ln, log2, and log10
860  # functions for specific scalar base values
861  if (inherits(base, "Expression")) {
862    return(Expression$create("logb_checked", x, base))
863  }
864
865  if (!is.numeric(base) || length(base) != 1) {
866    arrow_not_supported("base must be a column or a length-1 numeric; other values")
867  }
868
869  if (base == exp(1)) {
870    return(Expression$create("ln_checked", x))
871  }
872
873  if (base == 2) {
874    return(Expression$create("log2_checked", x))
875  }
876
877  if (base == 10) {
878    return(Expression$create("log10_checked", x))
879  }
880
881  Expression$create("logb_checked", x, Expression$scalar(base))
882}
883
884nse_funcs$if_else <- function(condition, true, false, missing = NULL) {
885  if (!is.null(missing)) {
886    return(nse_funcs$if_else(
887      nse_funcs$is.na(condition),
888      missing,
889      nse_funcs$if_else(condition, true, false)
890    ))
891  }
892
893  # if_else doesn't yet support factors/dictionaries
894  # TODO: remove this after ARROW-13358 is merged
895  warn_types <- nse_funcs$is.factor(true) | nse_funcs$is.factor(false)
896  if (warn_types) {
897    warning(
898      "Dictionaries (in R: factors) are currently converted to strings (characters) ",
899      "in if_else and ifelse",
900      call. = FALSE
901    )
902  }
903
904  build_expr("if_else", condition, true, false)
905}
906
907# Although base R ifelse allows `yes` and `no` to be different classes
908nse_funcs$ifelse <- function(test, yes, no) {
909  nse_funcs$if_else(condition = test, true = yes, false = no)
910}
911
912nse_funcs$case_when <- function(...) {
913  formulas <- list2(...)
914  n <- length(formulas)
915  if (n == 0) {
916    abort("No cases provided in case_when()")
917  }
918  query <- vector("list", n)
919  value <- vector("list", n)
920  mask <- caller_env()
921  for (i in seq_len(n)) {
922    f <- formulas[[i]]
923    if (!inherits(f, "formula")) {
924      abort("Each argument to case_when() must be a two-sided formula")
925    }
926    query[[i]] <- arrow_eval(f[[2]], mask)
927    value[[i]] <- arrow_eval(f[[3]], mask)
928    if (!nse_funcs$is.logical(query[[i]])) {
929      abort("Left side of each formula in case_when() must be a logical expression")
930    }
931    if (inherits(value[[i]], "try-error")) {
932      abort(handle_arrow_not_supported(value[[i]], format_expr(f[[3]])))
933    }
934  }
935  build_expr(
936    "case_when",
937    args = c(
938      build_expr(
939        "make_struct",
940        args = query,
941        options = list(field_names = as.character(seq_along(query)))
942      ),
943      value
944    )
945  )
946}
947
948# Aggregation functions
949# These all return a list of:
950# @param fun string function name
951# @param data Expression (these are all currently a single field)
952# @param options list of function options, as passed to call_function
953# For group-by aggregation, `hash_` gets prepended to the function name.
954# So to see a list of available hash aggregation functions,
955# you can use list_compute_functions("^hash_")
956agg_funcs <- list()
957agg_funcs$sum <- function(..., na.rm = FALSE) {
958  list(
959    fun = "sum",
960    data = ensure_one_arg(list2(...), "sum"),
961    options = list(skip_nulls = na.rm, min_count = 0L)
962  )
963}
964agg_funcs$any <- function(..., na.rm = FALSE) {
965  list(
966    fun = "any",
967    data = ensure_one_arg(list2(...), "any"),
968    options = list(skip_nulls = na.rm, min_count = 0L)
969  )
970}
971agg_funcs$all <- function(..., na.rm = FALSE) {
972  list(
973    fun = "all",
974    data = ensure_one_arg(list2(...), "all"),
975    options = list(skip_nulls = na.rm, min_count = 0L)
976  )
977}
978agg_funcs$mean <- function(x, na.rm = FALSE) {
979  list(
980    fun = "mean",
981    data = x,
982    options = list(skip_nulls = na.rm, min_count = 0L)
983  )
984}
985agg_funcs$sd <- function(x, na.rm = FALSE, ddof = 1) {
986  list(
987    fun = "stddev",
988    data = x,
989    options = list(skip_nulls = na.rm, min_count = 0L, ddof = ddof)
990  )
991}
992agg_funcs$var <- function(x, na.rm = FALSE, ddof = 1) {
993  list(
994    fun = "variance",
995    data = x,
996    options = list(skip_nulls = na.rm, min_count = 0L, ddof = ddof)
997  )
998}
999agg_funcs$quantile <- function(x, probs, na.rm = FALSE) {
1000  if (length(probs) != 1) {
1001    arrow_not_supported("quantile() with length(probs) != 1")
1002  }
1003  # TODO: Bind to the Arrow function that returns an exact quantile and remove
1004  # this warning (ARROW-14021)
1005  warn(
1006    "quantile() currently returns an approximate quantile in Arrow",
1007    .frequency = ifelse(is_interactive(), "once", "always"),
1008    .frequency_id = "arrow.quantile.approximate"
1009  )
1010  list(
1011    fun = "tdigest",
1012    data = x,
1013    options = list(skip_nulls = na.rm, q = probs)
1014  )
1015}
1016agg_funcs$median <- function(x, na.rm = FALSE) {
1017  # TODO: Bind to the Arrow function that returns an exact median and remove
1018  # this warning (ARROW-14021)
1019  warn(
1020    "median() currently returns an approximate median in Arrow",
1021    .frequency = ifelse(is_interactive(), "once", "always"),
1022    .frequency_id = "arrow.median.approximate"
1023  )
1024  list(
1025    fun = "approximate_median",
1026    data = x,
1027    options = list(skip_nulls = na.rm)
1028  )
1029}
1030agg_funcs$n_distinct <- function(..., na.rm = FALSE) {
1031  list(
1032    fun = "count_distinct",
1033    data = ensure_one_arg(list2(...), "n_distinct"),
1034    options = list(na.rm = na.rm)
1035  )
1036}
1037agg_funcs$n <- function() {
1038  list(
1039    fun = "sum",
1040    data = Expression$scalar(1L),
1041    options = list()
1042  )
1043}
1044agg_funcs$min <- function(..., na.rm = FALSE) {
1045  list(
1046    fun = "min",
1047    data = ensure_one_arg(list2(...), "min"),
1048    options = list(skip_nulls = na.rm, min_count = 0L)
1049  )
1050}
1051agg_funcs$max <- function(..., na.rm = FALSE) {
1052  list(
1053    fun = "max",
1054    data = ensure_one_arg(list2(...), "max"),
1055    options = list(skip_nulls = na.rm, min_count = 0L)
1056  )
1057}
1058
1059ensure_one_arg <- function(args, fun) {
1060  if (length(args) == 0) {
1061    arrow_not_supported(paste0(fun, "() with 0 arguments"))
1062  } else if (length(args) > 1) {
1063    arrow_not_supported(paste0("Multiple arguments to ", fun, "()"))
1064  }
1065  args[[1]]
1066}
1067
1068output_type <- function(fun, input_type, hash) {
1069  # These are quick and dirty heuristics.
1070  if (fun %in% c("any", "all")) {
1071    bool()
1072  } else if (fun %in% "sum") {
1073    # It may upcast to a bigger type but this is close enough
1074    input_type
1075  } else if (fun %in% c("mean", "stddev", "variance", "approximate_median")) {
1076    float64()
1077  } else if (fun %in% "tdigest") {
1078    if (hash) {
1079      fixed_size_list_of(float64(), 1L)
1080    } else {
1081      float64()
1082    }
1083  } else {
1084    # Just so things don't error, assume the resulting type is the same
1085    input_type
1086  }
1087}
1088