1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18 19#' @include expression.R 20NULL 21 22# This environment is an internal cache for things including data mask functions 23# We'll populate it at package load time. 24.cache <- NULL 25init_env <- function() { 26 .cache <<- new.env(hash = TRUE) 27} 28init_env() 29 30# nse_funcs is a list of functions that operated on (and return) Expressions 31# These will be the basis for a data_mask inside dplyr methods 32# and will be added to .cache at package load time 33 34# Start with mappings from R function name spellings 35nse_funcs <- lapply(set_names(names(.array_function_map)), function(operator) { 36 force(operator) 37 function(...) build_expr(operator, ...) 38}) 39 40# Now add functions to that list where the mapping from R to Arrow isn't 1:1 41# Each of these functions should have the same signature as the R function 42# they're replacing. 43# 44# When to use `build_expr()` vs. `Expression$create()`? 45# 46# Use `build_expr()` if you need to 47# (1) map R function names to Arrow C++ functions 48# (2) wrap R inputs (vectors) as Array/Scalar 49# 50# `Expression$create()` is lower level. Most of the functions below use it 51# because they manage the preparation of the user-provided inputs 52# and don't need to wrap scalars 53 54nse_funcs$cast <- function(x, target_type, safe = TRUE, ...) { 55 opts <- cast_options(safe, ...) 56 opts$to_type <- as_type(target_type) 57 Expression$create("cast", x, options = opts) 58} 59 60nse_funcs$coalesce <- function(...) { 61 args <- list2(...) 62 if (length(args) < 1) { 63 abort("At least one argument must be supplied to coalesce()") 64 } 65 66 # Treat NaN like NA for consistency with dplyr::coalesce(), but if *all* 67 # the values are NaN, we should return NaN, not NA, so don't replace 68 # NaN with NA in the final (or only) argument 69 # TODO: if an option is added to the coalesce kernel to treat NaN as NA, 70 # use that to simplify the code here (ARROW-13389) 71 attr(args[[length(args)]], "last") <- TRUE 72 args <- lapply(args, function(arg) { 73 last_arg <- is.null(attr(arg, "last")) 74 attr(arg, "last") <- NULL 75 76 if (!inherits(arg, "Expression")) { 77 arg <- Expression$scalar(arg) 78 } 79 80 # coalesce doesn't yet support factors/dictionaries 81 # TODO: remove this after ARROW-14167 is merged 82 if (nse_funcs$is.factor(arg)) { 83 warning("Dictionaries (in R: factors) are currently converted to strings (characters) in coalesce", call. = FALSE) 84 } 85 86 if (last_arg && arg$type_id() %in% TYPES_WITH_NAN) { 87 # store the NA_real_ in the same type as arg to avoid avoid casting 88 # smaller float types to larger float types 89 NA_expr <- Expression$scalar(Scalar$create(NA_real_, type = arg$type())) 90 Expression$create("if_else", Expression$create("is_nan", arg), NA_expr, arg) 91 } else { 92 arg 93 } 94 }) 95 Expression$create("coalesce", args = args) 96} 97 98nse_funcs$is.na <- function(x) { 99 build_expr("is_null", x, options = list(nan_is_null = TRUE)) 100} 101 102nse_funcs$is.nan <- function(x) { 103 if (is.double(x) || (inherits(x, "Expression") && 104 x$type_id() %in% TYPES_WITH_NAN)) { 105 # TODO: if an option is added to the is_nan kernel to treat NA as NaN, 106 # use that to simplify the code here (ARROW-13366) 107 build_expr("is_nan", x) & build_expr("is_valid", x) 108 } else { 109 Expression$scalar(FALSE) 110 } 111} 112 113nse_funcs$is <- function(object, class2) { 114 if (is.string(class2)) { 115 switch(class2, 116 # for R data types, pass off to is.*() functions 117 character = nse_funcs$is.character(object), 118 numeric = nse_funcs$is.numeric(object), 119 integer = nse_funcs$is.integer(object), 120 integer64 = nse_funcs$is.integer64(object), 121 logical = nse_funcs$is.logical(object), 122 factor = nse_funcs$is.factor(object), 123 list = nse_funcs$is.list(object), 124 # for Arrow data types, compare class2 with object$type()$ToString(), 125 # but first strip off any parameters to only compare the top-level data 126 # type, and canonicalize class2 127 sub("^([^([<]+).*$", "\\1", object$type()$ToString()) == 128 canonical_type_str(class2) 129 ) 130 } else if (inherits(class2, "DataType")) { 131 object$type() == as_type(class2) 132 } else { 133 stop("Second argument to is() is not a string or DataType", call. = FALSE) 134 } 135} 136 137nse_funcs$dictionary_encode <- function(x, 138 null_encoding_behavior = c("mask", "encode")) { 139 behavior <- toupper(match.arg(null_encoding_behavior)) 140 null_encoding_behavior <- NullEncodingBehavior[[behavior]] 141 Expression$create( 142 "dictionary_encode", 143 x, 144 options = list(null_encoding_behavior = null_encoding_behavior) 145 ) 146} 147 148nse_funcs$between <- function(x, left, right) { 149 x >= left & x <= right 150} 151 152nse_funcs$is.finite <- function(x) { 153 is_fin <- Expression$create("is_finite", x) 154 # for compatibility with base::is.finite(), return FALSE for NA_real_ 155 is_fin & !nse_funcs$is.na(is_fin) 156} 157 158nse_funcs$is.infinite <- function(x) { 159 is_inf <- Expression$create("is_inf", x) 160 # for compatibility with base::is.infinite(), return FALSE for NA_real_ 161 is_inf & !nse_funcs$is.na(is_inf) 162} 163 164# as.* type casting functions 165# as.factor() is mapped in expression.R 166nse_funcs$as.character <- function(x) { 167 Expression$create("cast", x, options = cast_options(to_type = string())) 168} 169nse_funcs$as.double <- function(x) { 170 Expression$create("cast", x, options = cast_options(to_type = float64())) 171} 172nse_funcs$as.integer <- function(x) { 173 Expression$create( 174 "cast", 175 x, 176 options = cast_options( 177 to_type = int32(), 178 allow_float_truncate = TRUE, 179 allow_decimal_truncate = TRUE 180 ) 181 ) 182} 183nse_funcs$as.integer64 <- function(x) { 184 Expression$create( 185 "cast", 186 x, 187 options = cast_options( 188 to_type = int64(), 189 allow_float_truncate = TRUE, 190 allow_decimal_truncate = TRUE 191 ) 192 ) 193} 194nse_funcs$as.logical <- function(x) { 195 Expression$create("cast", x, options = cast_options(to_type = boolean())) 196} 197nse_funcs$as.numeric <- function(x) { 198 Expression$create("cast", x, options = cast_options(to_type = float64())) 199} 200 201# is.* type functions 202nse_funcs$is.character <- function(x) { 203 is.character(x) || (inherits(x, "Expression") && 204 x$type_id() %in% Type[c("STRING", "LARGE_STRING")]) 205} 206nse_funcs$is.numeric <- function(x) { 207 is.numeric(x) || (inherits(x, "Expression") && x$type_id() %in% Type[c( 208 "UINT8", "INT8", "UINT16", "INT16", "UINT32", "INT32", 209 "UINT64", "INT64", "HALF_FLOAT", "FLOAT", "DOUBLE", 210 "DECIMAL", "DECIMAL256" 211 )]) 212} 213nse_funcs$is.double <- function(x) { 214 is.double(x) || (inherits(x, "Expression") && x$type_id() == Type["DOUBLE"]) 215} 216nse_funcs$is.integer <- function(x) { 217 is.integer(x) || (inherits(x, "Expression") && x$type_id() %in% Type[c( 218 "UINT8", "INT8", "UINT16", "INT16", "UINT32", "INT32", 219 "UINT64", "INT64" 220 )]) 221} 222nse_funcs$is.integer64 <- function(x) { 223 is.integer64(x) || (inherits(x, "Expression") && x$type_id() == Type["INT64"]) 224} 225nse_funcs$is.logical <- function(x) { 226 is.logical(x) || (inherits(x, "Expression") && x$type_id() == Type["BOOL"]) 227} 228nse_funcs$is.factor <- function(x) { 229 is.factor(x) || (inherits(x, "Expression") && x$type_id() == Type["DICTIONARY"]) 230} 231nse_funcs$is.list <- function(x) { 232 is.list(x) || (inherits(x, "Expression") && x$type_id() %in% Type[c( 233 "LIST", "FIXED_SIZE_LIST", "LARGE_LIST" 234 )]) 235} 236 237# rlang::is_* type functions 238nse_funcs$is_character <- function(x, n = NULL) { 239 assert_that(is.null(n)) 240 nse_funcs$is.character(x) 241} 242nse_funcs$is_double <- function(x, n = NULL, finite = NULL) { 243 assert_that(is.null(n) && is.null(finite)) 244 nse_funcs$is.double(x) 245} 246nse_funcs$is_integer <- function(x, n = NULL) { 247 assert_that(is.null(n)) 248 nse_funcs$is.integer(x) 249} 250nse_funcs$is_list <- function(x, n = NULL) { 251 assert_that(is.null(n)) 252 nse_funcs$is.list(x) 253} 254nse_funcs$is_logical <- function(x, n = NULL) { 255 assert_that(is.null(n)) 256 nse_funcs$is.logical(x) 257} 258nse_funcs$is_timestamp <- function(x, n = NULL) { 259 assert_that(is.null(n)) 260 inherits(x, "POSIXt") || (inherits(x, "Expression") && x$type_id() %in% Type[c("TIMESTAMP")]) 261} 262 263# String functions 264nse_funcs$nchar <- function(x, type = "chars", allowNA = FALSE, keepNA = NA) { 265 if (allowNA) { 266 arrow_not_supported("allowNA = TRUE") 267 } 268 if (is.na(keepNA)) { 269 keepNA <- !identical(type, "width") 270 } 271 if (!keepNA) { 272 # TODO: I think there is a fill_null kernel we could use, set null to 2 273 arrow_not_supported("keepNA = TRUE") 274 } 275 if (identical(type, "bytes")) { 276 Expression$create("binary_length", x) 277 } else { 278 Expression$create("utf8_length", x) 279 } 280} 281 282nse_funcs$paste <- function(..., sep = " ", collapse = NULL, recycle0 = FALSE) { 283 assert_that( 284 is.null(collapse), 285 msg = "paste() with the collapse argument is not yet supported in Arrow" 286 ) 287 if (!inherits(sep, "Expression")) { 288 assert_that(!is.na(sep), msg = "Invalid separator") 289 } 290 arrow_string_join_function(NullHandlingBehavior$REPLACE, "NA")(..., sep) 291} 292 293nse_funcs$paste0 <- function(..., collapse = NULL, recycle0 = FALSE) { 294 assert_that( 295 is.null(collapse), 296 msg = "paste0() with the collapse argument is not yet supported in Arrow" 297 ) 298 arrow_string_join_function(NullHandlingBehavior$REPLACE, "NA")(..., "") 299} 300 301nse_funcs$str_c <- function(..., sep = "", collapse = NULL) { 302 assert_that( 303 is.null(collapse), 304 msg = "str_c() with the collapse argument is not yet supported in Arrow" 305 ) 306 arrow_string_join_function(NullHandlingBehavior$EMIT_NULL)(..., sep) 307} 308 309arrow_string_join_function <- function(null_handling, null_replacement = NULL) { 310 # the `binary_join_element_wise` Arrow C++ compute kernel takes the separator 311 # as the last argument, so pass `sep` as the last dots arg to this function 312 function(...) { 313 args <- lapply(list(...), function(arg) { 314 # handle scalar literal args, and cast all args to string for 315 # consistency with base::paste(), base::paste0(), and stringr::str_c() 316 if (!inherits(arg, "Expression")) { 317 assert_that( 318 length(arg) == 1, 319 msg = "Literal vectors of length != 1 not supported in string concatenation" 320 ) 321 Expression$scalar(as.character(arg)) 322 } else { 323 nse_funcs$as.character(arg) 324 } 325 }) 326 Expression$create( 327 "binary_join_element_wise", 328 args = args, 329 options = list( 330 null_handling = null_handling, 331 null_replacement = null_replacement 332 ) 333 ) 334 } 335} 336 337# Currently, Arrow does not supports a locale option for string case conversion 338# functions, contrast to stringr's API, so the 'locale' argument is only valid 339# for stringr's default value ("en"). The following are string functions that 340# take a 'locale' option as its second argument: 341# str_to_lower 342# str_to_upper 343# str_to_title 344# 345# Arrow locale will be supported with ARROW-14126 346stop_if_locale_provided <- function(locale) { 347 if (!identical(locale, "en")) { 348 stop("Providing a value for 'locale' other than the default ('en') is not supported by Arrow. ", 349 "To change locale, use 'Sys.setlocale()'", 350 call. = FALSE 351 ) 352 } 353} 354 355nse_funcs$str_to_lower <- function(string, locale = "en") { 356 stop_if_locale_provided(locale) 357 Expression$create("utf8_lower", string) 358} 359 360nse_funcs$str_to_upper <- function(string, locale = "en") { 361 stop_if_locale_provided(locale) 362 Expression$create("utf8_upper", string) 363} 364 365nse_funcs$str_to_title <- function(string, locale = "en") { 366 stop_if_locale_provided(locale) 367 Expression$create("utf8_title", string) 368} 369 370nse_funcs$str_trim <- function(string, side = c("both", "left", "right")) { 371 side <- match.arg(side) 372 trim_fun <- switch(side, 373 left = "utf8_ltrim_whitespace", 374 right = "utf8_rtrim_whitespace", 375 both = "utf8_trim_whitespace" 376 ) 377 Expression$create(trim_fun, string) 378} 379 380nse_funcs$substr <- function(x, start, stop) { 381 assert_that( 382 length(start) == 1, 383 msg = "`start` must be length 1 - other lengths are not supported in Arrow" 384 ) 385 assert_that( 386 length(stop) == 1, 387 msg = "`stop` must be length 1 - other lengths are not supported in Arrow" 388 ) 389 390 # substr treats values as if they're on a continous number line, so values 391 # 0 are effectively blank characters - set `start` to 1 here so Arrow mimics 392 # this behavior 393 if (start <= 0) { 394 start <- 1 395 } 396 397 # if `stop` is lower than `start`, this is invalid, so set `stop` to 398 # 0 so that an empty string will be returned (consistent with base::substr()) 399 if (stop < start) { 400 stop <- 0 401 } 402 403 Expression$create( 404 "utf8_slice_codeunits", 405 x, 406 # we don't need to subtract 1 from `stop` as C++ counts exclusively 407 # which effectively cancels out the difference in indexing between R & C++ 408 options = list(start = start - 1L, stop = stop) 409 ) 410} 411 412nse_funcs$substring <- function(text, first, last) { 413 nse_funcs$substr(x = text, start = first, stop = last) 414} 415 416nse_funcs$str_sub <- function(string, start = 1L, end = -1L) { 417 assert_that( 418 length(start) == 1, 419 msg = "`start` must be length 1 - other lengths are not supported in Arrow" 420 ) 421 assert_that( 422 length(end) == 1, 423 msg = "`end` must be length 1 - other lengths are not supported in Arrow" 424 ) 425 426 # In stringr::str_sub, an `end` value of -1 means the end of the string, so 427 # set it to the maximum integer to match this behavior 428 if (end == -1) { 429 end <- .Machine$integer.max 430 } 431 432 # An end value lower than a start value returns an empty string in 433 # stringr::str_sub so set end to 0 here to match this behavior 434 if (end < start) { 435 end <- 0 436 } 437 438 # subtract 1 from `start` because C++ is 0-based and R is 1-based 439 # str_sub treats a `start` value of 0 or 1 as the same thing so don't subtract 1 when `start` == 0 440 # when `start` < 0, both str_sub and utf8_slice_codeunits count backwards from the end 441 if (start > 0) { 442 start <- start - 1L 443 } 444 445 Expression$create( 446 "utf8_slice_codeunits", 447 string, 448 options = list(start = start, stop = end) 449 ) 450} 451 452nse_funcs$grepl <- function(pattern, x, ignore.case = FALSE, fixed = FALSE) { 453 arrow_fun <- ifelse(fixed, "match_substring", "match_substring_regex") 454 Expression$create( 455 arrow_fun, 456 x, 457 options = list(pattern = pattern, ignore_case = ignore.case) 458 ) 459} 460 461nse_funcs$str_detect <- function(string, pattern, negate = FALSE) { 462 opts <- get_stringr_pattern_options(enexpr(pattern)) 463 out <- nse_funcs$grepl( 464 pattern = opts$pattern, 465 x = string, 466 ignore.case = opts$ignore_case, 467 fixed = opts$fixed 468 ) 469 if (negate) { 470 out <- !out 471 } 472 out 473} 474 475nse_funcs$str_like <- function(string, pattern, ignore_case = TRUE) { 476 Expression$create( 477 "match_like", 478 string, 479 options = list(pattern = pattern, ignore_case = ignore_case) 480 ) 481} 482 483# Encapsulate some common logic for sub/gsub/str_replace/str_replace_all 484arrow_r_string_replace_function <- function(max_replacements) { 485 function(pattern, replacement, x, ignore.case = FALSE, fixed = FALSE) { 486 Expression$create( 487 ifelse(fixed && !ignore.case, "replace_substring", "replace_substring_regex"), 488 x, 489 options = list( 490 pattern = format_string_pattern(pattern, ignore.case, fixed), 491 replacement = format_string_replacement(replacement, ignore.case, fixed), 492 max_replacements = max_replacements 493 ) 494 ) 495 } 496} 497 498arrow_stringr_string_replace_function <- function(max_replacements) { 499 function(string, pattern, replacement) { 500 opts <- get_stringr_pattern_options(enexpr(pattern)) 501 arrow_r_string_replace_function(max_replacements)( 502 pattern = opts$pattern, 503 replacement = replacement, 504 x = string, 505 ignore.case = opts$ignore_case, 506 fixed = opts$fixed 507 ) 508 } 509} 510 511nse_funcs$sub <- arrow_r_string_replace_function(1L) 512nse_funcs$gsub <- arrow_r_string_replace_function(-1L) 513nse_funcs$str_replace <- arrow_stringr_string_replace_function(1L) 514nse_funcs$str_replace_all <- arrow_stringr_string_replace_function(-1L) 515 516nse_funcs$strsplit <- function(x, 517 split, 518 fixed = FALSE, 519 perl = FALSE, 520 useBytes = FALSE) { 521 assert_that(is.string(split)) 522 523 arrow_fun <- ifelse(fixed, "split_pattern", "split_pattern_regex") 524 # warn when the user specifies both fixed = TRUE and perl = TRUE, for 525 # consistency with the behavior of base::strsplit() 526 if (fixed && perl) { 527 warning("Argument 'perl = TRUE' will be ignored", call. = FALSE) 528 } 529 # since split is not a regex, proceed without any warnings or errors regardless 530 # of the value of perl, for consistency with the behavior of base::strsplit() 531 Expression$create( 532 arrow_fun, 533 x, 534 options = list(pattern = split, reverse = FALSE, max_splits = -1L) 535 ) 536} 537 538nse_funcs$str_split <- function(string, pattern, n = Inf, simplify = FALSE) { 539 opts <- get_stringr_pattern_options(enexpr(pattern)) 540 arrow_fun <- ifelse(opts$fixed, "split_pattern", "split_pattern_regex") 541 if (opts$ignore_case) { 542 arrow_not_supported("Case-insensitive string splitting") 543 } 544 if (n == 0) { 545 arrow_not_supported("Splitting strings into zero parts") 546 } 547 if (identical(n, Inf)) { 548 n <- 0L 549 } 550 if (simplify) { 551 warning("Argument 'simplify = TRUE' will be ignored", call. = FALSE) 552 } 553 # The max_splits option in the Arrow C++ library controls the maximum number 554 # of places at which the string is split, whereas the argument n to 555 # str_split() controls the maximum number of pieces to return. So we must 556 # subtract 1 from n to get max_splits. 557 Expression$create( 558 arrow_fun, 559 string, 560 options = list( 561 pattern = opts$pattern, 562 reverse = FALSE, 563 max_splits = n - 1L 564 ) 565 ) 566} 567 568nse_funcs$pmin <- function(..., na.rm = FALSE) { 569 build_expr( 570 "min_element_wise", 571 ..., 572 options = list(skip_nulls = na.rm) 573 ) 574} 575 576nse_funcs$pmax <- function(..., na.rm = FALSE) { 577 build_expr( 578 "max_element_wise", 579 ..., 580 options = list(skip_nulls = na.rm) 581 ) 582} 583 584nse_funcs$str_pad <- function(string, width, side = c("left", "right", "both"), pad = " ") { 585 assert_that(is_integerish(width)) 586 side <- match.arg(side) 587 assert_that(is.string(pad)) 588 589 if (side == "left") { 590 pad_func <- "utf8_lpad" 591 } else if (side == "right") { 592 pad_func <- "utf8_rpad" 593 } else if (side == "both") { 594 pad_func <- "utf8_center" 595 } 596 597 Expression$create( 598 pad_func, 599 string, 600 options = list(width = width, padding = pad) 601 ) 602} 603 604nse_funcs$startsWith <- function(x, prefix) { 605 Expression$create( 606 "starts_with", 607 x, 608 options = list(pattern = prefix) 609 ) 610} 611 612nse_funcs$endsWith <- function(x, suffix) { 613 Expression$create( 614 "ends_with", 615 x, 616 options = list(pattern = suffix) 617 ) 618} 619 620nse_funcs$str_starts <- function(string, pattern, negate = FALSE) { 621 opts <- get_stringr_pattern_options(enexpr(pattern)) 622 if (opts$fixed) { 623 out <- nse_funcs$startsWith(x = string, prefix = opts$pattern) 624 } else { 625 out <- nse_funcs$grepl(pattern = paste0("^", opts$pattern), x = string, fixed = FALSE) 626 } 627 628 if (negate) { 629 out <- !out 630 } 631 out 632} 633 634nse_funcs$str_ends <- function(string, pattern, negate = FALSE) { 635 opts <- get_stringr_pattern_options(enexpr(pattern)) 636 if (opts$fixed) { 637 out <- nse_funcs$endsWith(x = string, suffix = opts$pattern) 638 } else { 639 out <- nse_funcs$grepl(pattern = paste0(opts$pattern, "$"), x = string, fixed = FALSE) 640 } 641 642 if (negate) { 643 out <- !out 644 } 645 out 646} 647 648nse_funcs$str_count <- function(string, pattern) { 649 opts <- get_stringr_pattern_options(enexpr(pattern)) 650 if (!is.string(pattern)) { 651 arrow_not_supported("`pattern` must be a length 1 character vector; other values") 652 } 653 arrow_fun <- ifelse(opts$fixed, "count_substring", "count_substring_regex") 654 Expression$create( 655 arrow_fun, 656 string, 657 options = list(pattern = opts$pattern, ignore_case = opts$ignore_case) 658 ) 659} 660 661# String function helpers 662 663# format `pattern` as needed for case insensitivity and literal matching by RE2 664format_string_pattern <- function(pattern, ignore.case, fixed) { 665 # Arrow lacks native support for case-insensitive literal string matching and 666 # replacement, so we use the regular expression engine (RE2) to do this. 667 # https://github.com/google/re2/wiki/Syntax 668 if (ignore.case) { 669 if (fixed) { 670 # Everything between "\Q" and "\E" is treated as literal text. 671 # If the search text contains any literal "\E" strings, make them 672 # lowercase so they won't signal the end of the literal text: 673 pattern <- gsub("\\E", "\\e", pattern, fixed = TRUE) 674 pattern <- paste0("\\Q", pattern, "\\E") 675 } 676 # Prepend "(?i)" for case-insensitive matching 677 pattern <- paste0("(?i)", pattern) 678 } 679 pattern 680} 681 682# format `replacement` as needed for literal replacement by RE2 683format_string_replacement <- function(replacement, ignore.case, fixed) { 684 # Arrow lacks native support for case-insensitive literal string 685 # replacement, so we use the regular expression engine (RE2) to do this. 686 # https://github.com/google/re2/wiki/Syntax 687 if (ignore.case && fixed) { 688 # Escape single backslashes in the regex replacement text so they are 689 # interpreted as literal backslashes: 690 replacement <- gsub("\\", "\\\\", replacement, fixed = TRUE) 691 } 692 replacement 693} 694 695#' Get `stringr` pattern options 696#' 697#' This function assigns definitions for the `stringr` pattern modifier 698#' functions (`fixed()`, `regex()`, etc.) inside itself, and uses them to 699#' evaluate the quoted expression `pattern`, returning a list that is used 700#' to control pattern matching behavior in internal `arrow` functions. 701#' 702#' @param pattern Unevaluated expression containing a call to a `stringr` 703#' pattern modifier function 704#' 705#' @return List containing elements `pattern`, `fixed`, and `ignore_case` 706#' @keywords internal 707get_stringr_pattern_options <- function(pattern) { 708 fixed <- function(pattern, ignore_case = FALSE, ...) { 709 check_dots(...) 710 list(pattern = pattern, fixed = TRUE, ignore_case = ignore_case) 711 } 712 regex <- function(pattern, ignore_case = FALSE, ...) { 713 check_dots(...) 714 list(pattern = pattern, fixed = FALSE, ignore_case = ignore_case) 715 } 716 coll <- function(...) { 717 arrow_not_supported("Pattern modifier `coll()`") 718 } 719 boundary <- function(...) { 720 arrow_not_supported("Pattern modifier `boundary()`") 721 } 722 check_dots <- function(...) { 723 dots <- list(...) 724 if (length(dots)) { 725 warning( 726 "Ignoring pattern modifier ", 727 ngettext(length(dots), "argument ", "arguments "), 728 "not supported in Arrow: ", 729 oxford_paste(names(dots)), 730 call. = FALSE 731 ) 732 } 733 } 734 ensure_opts <- function(opts) { 735 if (is.character(opts)) { 736 opts <- list(pattern = opts, fixed = FALSE, ignore_case = FALSE) 737 } 738 opts 739 } 740 ensure_opts(eval(pattern)) 741} 742 743#' Does this string contain regex metacharacters? 744#' 745#' @param string String to be tested 746#' @keywords internal 747#' @return Logical: does `string` contain regex metacharacters? 748contains_regex <- function(string) { 749 grepl("[.\\|()[{^$*+?]", string) 750} 751 752nse_funcs$strptime <- function(x, format = "%Y-%m-%d %H:%M:%S", tz = NULL, unit = "ms") { 753 # Arrow uses unit for time parsing, strptime() does not. 754 # Arrow has no default option for strptime (format, unit), 755 # we suggest following format = "%Y-%m-%d %H:%M:%S", unit = MILLI/1L/"ms", 756 # (ARROW-12809) 757 758 # ParseTimestampStrptime currently ignores the timezone information (ARROW-12820). 759 # Stop if tz is provided. 760 if (is.character(tz)) { 761 arrow_not_supported("Time zone argument") 762 } 763 764 unit <- make_valid_time_unit(unit, c(valid_time64_units, valid_time32_units)) 765 766 Expression$create("strptime", x, options = list(format = format, unit = unit)) 767} 768 769nse_funcs$strftime <- function(x, format = "", tz = "", usetz = FALSE) { 770 if (usetz) { 771 format <- paste(format, "%Z") 772 } 773 if (tz == "") { 774 tz <- Sys.timezone() 775 } 776 # Arrow's strftime prints in timezone of the timestamp. To match R's strftime behavior we first 777 # cast the timestamp to desired timezone. This is a metadata only change. 778 if (nse_funcs$is_timestamp(x)) { 779 ts <- Expression$create("cast", x, options = list(to_type = timestamp(x$type()$unit(), tz))) 780 } else { 781 ts <- x 782 } 783 Expression$create("strftime", ts, options = list(format = format, locale = Sys.getlocale("LC_TIME"))) 784} 785 786nse_funcs$format_ISO8601 <- function(x, usetz = FALSE, precision = NULL, ...) { 787 ISO8601_precision_map <- 788 list( 789 y = "%Y", 790 ym = "%Y-%m", 791 ymd = "%Y-%m-%d", 792 ymdh = "%Y-%m-%dT%H", 793 ymdhm = "%Y-%m-%dT%H:%M", 794 ymdhms = "%Y-%m-%dT%H:%M:%S" 795 ) 796 797 if (is.null(precision)) { 798 precision <- "ymdhms" 799 } 800 if (!precision %in% names(ISO8601_precision_map)) { 801 abort( 802 paste( 803 "`precision` must be one of the following values:", 804 paste(names(ISO8601_precision_map), collapse = ", "), 805 "\nValue supplied was: ", 806 precision 807 ) 808 ) 809 } 810 format <- ISO8601_precision_map[[precision]] 811 if (usetz) { 812 format <- paste0(format, "%z") 813 } 814 Expression$create("strftime", x, options = list(format = format, locale = "C")) 815} 816 817nse_funcs$second <- function(x) { 818 Expression$create("add", Expression$create("second", x), Expression$create("subsecond", x)) 819} 820 821nse_funcs$trunc <- function(x, ...) { 822 # accepts and ignores ... for consistency with base::trunc() 823 build_expr("trunc", x) 824} 825 826nse_funcs$round <- function(x, digits = 0) { 827 build_expr( 828 "round", 829 x, 830 options = list(ndigits = digits, round_mode = RoundMode$HALF_TO_EVEN) 831 ) 832} 833 834nse_funcs$wday <- function(x, 835 label = FALSE, 836 abbr = TRUE, 837 week_start = getOption("lubridate.week.start", 7), 838 locale = Sys.getlocale("LC_TIME")) { 839 if (label) { 840 if (abbr) { 841 format <- "%a" 842 } else { 843 format <- "%A" 844 } 845 return(Expression$create("strftime", x, options = list(format = format, locale = locale))) 846 } 847 848 Expression$create("day_of_week", x, options = list(count_from_zero = FALSE, week_start = week_start)) 849} 850 851nse_funcs$log <- nse_funcs$logb <- function(x, base = exp(1)) { 852 # like other binary functions, either `x` or `base` can be Expression or double(1) 853 if (is.numeric(x) && length(x) == 1) { 854 x <- Expression$scalar(x) 855 } else if (!inherits(x, "Expression")) { 856 arrow_not_supported("x must be a column or a length-1 numeric; other values") 857 } 858 859 # handle `base` differently because we use the simpler ln, log2, and log10 860 # functions for specific scalar base values 861 if (inherits(base, "Expression")) { 862 return(Expression$create("logb_checked", x, base)) 863 } 864 865 if (!is.numeric(base) || length(base) != 1) { 866 arrow_not_supported("base must be a column or a length-1 numeric; other values") 867 } 868 869 if (base == exp(1)) { 870 return(Expression$create("ln_checked", x)) 871 } 872 873 if (base == 2) { 874 return(Expression$create("log2_checked", x)) 875 } 876 877 if (base == 10) { 878 return(Expression$create("log10_checked", x)) 879 } 880 881 Expression$create("logb_checked", x, Expression$scalar(base)) 882} 883 884nse_funcs$if_else <- function(condition, true, false, missing = NULL) { 885 if (!is.null(missing)) { 886 return(nse_funcs$if_else( 887 nse_funcs$is.na(condition), 888 missing, 889 nse_funcs$if_else(condition, true, false) 890 )) 891 } 892 893 # if_else doesn't yet support factors/dictionaries 894 # TODO: remove this after ARROW-13358 is merged 895 warn_types <- nse_funcs$is.factor(true) | nse_funcs$is.factor(false) 896 if (warn_types) { 897 warning( 898 "Dictionaries (in R: factors) are currently converted to strings (characters) ", 899 "in if_else and ifelse", 900 call. = FALSE 901 ) 902 } 903 904 build_expr("if_else", condition, true, false) 905} 906 907# Although base R ifelse allows `yes` and `no` to be different classes 908nse_funcs$ifelse <- function(test, yes, no) { 909 nse_funcs$if_else(condition = test, true = yes, false = no) 910} 911 912nse_funcs$case_when <- function(...) { 913 formulas <- list2(...) 914 n <- length(formulas) 915 if (n == 0) { 916 abort("No cases provided in case_when()") 917 } 918 query <- vector("list", n) 919 value <- vector("list", n) 920 mask <- caller_env() 921 for (i in seq_len(n)) { 922 f <- formulas[[i]] 923 if (!inherits(f, "formula")) { 924 abort("Each argument to case_when() must be a two-sided formula") 925 } 926 query[[i]] <- arrow_eval(f[[2]], mask) 927 value[[i]] <- arrow_eval(f[[3]], mask) 928 if (!nse_funcs$is.logical(query[[i]])) { 929 abort("Left side of each formula in case_when() must be a logical expression") 930 } 931 if (inherits(value[[i]], "try-error")) { 932 abort(handle_arrow_not_supported(value[[i]], format_expr(f[[3]]))) 933 } 934 } 935 build_expr( 936 "case_when", 937 args = c( 938 build_expr( 939 "make_struct", 940 args = query, 941 options = list(field_names = as.character(seq_along(query))) 942 ), 943 value 944 ) 945 ) 946} 947 948# Aggregation functions 949# These all return a list of: 950# @param fun string function name 951# @param data Expression (these are all currently a single field) 952# @param options list of function options, as passed to call_function 953# For group-by aggregation, `hash_` gets prepended to the function name. 954# So to see a list of available hash aggregation functions, 955# you can use list_compute_functions("^hash_") 956agg_funcs <- list() 957agg_funcs$sum <- function(..., na.rm = FALSE) { 958 list( 959 fun = "sum", 960 data = ensure_one_arg(list2(...), "sum"), 961 options = list(skip_nulls = na.rm, min_count = 0L) 962 ) 963} 964agg_funcs$any <- function(..., na.rm = FALSE) { 965 list( 966 fun = "any", 967 data = ensure_one_arg(list2(...), "any"), 968 options = list(skip_nulls = na.rm, min_count = 0L) 969 ) 970} 971agg_funcs$all <- function(..., na.rm = FALSE) { 972 list( 973 fun = "all", 974 data = ensure_one_arg(list2(...), "all"), 975 options = list(skip_nulls = na.rm, min_count = 0L) 976 ) 977} 978agg_funcs$mean <- function(x, na.rm = FALSE) { 979 list( 980 fun = "mean", 981 data = x, 982 options = list(skip_nulls = na.rm, min_count = 0L) 983 ) 984} 985agg_funcs$sd <- function(x, na.rm = FALSE, ddof = 1) { 986 list( 987 fun = "stddev", 988 data = x, 989 options = list(skip_nulls = na.rm, min_count = 0L, ddof = ddof) 990 ) 991} 992agg_funcs$var <- function(x, na.rm = FALSE, ddof = 1) { 993 list( 994 fun = "variance", 995 data = x, 996 options = list(skip_nulls = na.rm, min_count = 0L, ddof = ddof) 997 ) 998} 999agg_funcs$quantile <- function(x, probs, na.rm = FALSE) { 1000 if (length(probs) != 1) { 1001 arrow_not_supported("quantile() with length(probs) != 1") 1002 } 1003 # TODO: Bind to the Arrow function that returns an exact quantile and remove 1004 # this warning (ARROW-14021) 1005 warn( 1006 "quantile() currently returns an approximate quantile in Arrow", 1007 .frequency = ifelse(is_interactive(), "once", "always"), 1008 .frequency_id = "arrow.quantile.approximate" 1009 ) 1010 list( 1011 fun = "tdigest", 1012 data = x, 1013 options = list(skip_nulls = na.rm, q = probs) 1014 ) 1015} 1016agg_funcs$median <- function(x, na.rm = FALSE) { 1017 # TODO: Bind to the Arrow function that returns an exact median and remove 1018 # this warning (ARROW-14021) 1019 warn( 1020 "median() currently returns an approximate median in Arrow", 1021 .frequency = ifelse(is_interactive(), "once", "always"), 1022 .frequency_id = "arrow.median.approximate" 1023 ) 1024 list( 1025 fun = "approximate_median", 1026 data = x, 1027 options = list(skip_nulls = na.rm) 1028 ) 1029} 1030agg_funcs$n_distinct <- function(..., na.rm = FALSE) { 1031 list( 1032 fun = "count_distinct", 1033 data = ensure_one_arg(list2(...), "n_distinct"), 1034 options = list(na.rm = na.rm) 1035 ) 1036} 1037agg_funcs$n <- function() { 1038 list( 1039 fun = "sum", 1040 data = Expression$scalar(1L), 1041 options = list() 1042 ) 1043} 1044agg_funcs$min <- function(..., na.rm = FALSE) { 1045 list( 1046 fun = "min", 1047 data = ensure_one_arg(list2(...), "min"), 1048 options = list(skip_nulls = na.rm, min_count = 0L) 1049 ) 1050} 1051agg_funcs$max <- function(..., na.rm = FALSE) { 1052 list( 1053 fun = "max", 1054 data = ensure_one_arg(list2(...), "max"), 1055 options = list(skip_nulls = na.rm, min_count = 0L) 1056 ) 1057} 1058 1059ensure_one_arg <- function(args, fun) { 1060 if (length(args) == 0) { 1061 arrow_not_supported(paste0(fun, "() with 0 arguments")) 1062 } else if (length(args) > 1) { 1063 arrow_not_supported(paste0("Multiple arguments to ", fun, "()")) 1064 } 1065 args[[1]] 1066} 1067 1068output_type <- function(fun, input_type, hash) { 1069 # These are quick and dirty heuristics. 1070 if (fun %in% c("any", "all")) { 1071 bool() 1072 } else if (fun %in% "sum") { 1073 # It may upcast to a bigger type but this is close enough 1074 input_type 1075 } else if (fun %in% c("mean", "stddev", "variance", "approximate_median")) { 1076 float64() 1077 } else if (fun %in% "tdigest") { 1078 if (hash) { 1079 fixed_size_list_of(float64(), 1L) 1080 } else { 1081 float64() 1082 } 1083 } else { 1084 # Just so things don't error, assume the resulting type is the same 1085 input_type 1086 } 1087} 1088