1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18skip_if_not_available("dataset")
19
20library(dplyr, warn.conflicts = FALSE)
21library(stringr)
22
23tbl <- example_data
24# Add some better string data
25tbl$verses <- verses[[1]]
26# c(" a ", "  b  ", "   c   ", ...) increasing padding
27# nchar =   3  5  7  9 11 13 15 17 19 21
28tbl$padded_strings <- stringr::str_pad(letters[1:10], width = 2 * (1:10) + 1, side = "both")
29tbl$some_negative <- tbl$int * (-1)^(1:nrow(tbl)) # nolint
30
31test_that("filter() on is.na()", {
32  compare_dplyr_binding(
33    .input %>%
34      filter(is.na(lgl)) %>%
35      select(chr, int, lgl) %>%
36      collect(),
37    tbl
38  )
39})
40
41test_that("filter() with NAs in selection", {
42  compare_dplyr_binding(
43    .input %>%
44      filter(lgl) %>%
45      select(chr, int, lgl) %>%
46      collect(),
47    tbl
48  )
49})
50
51test_that("Filter returning an empty Table should not segfault (ARROW-8354)", {
52  compare_dplyr_binding(
53    .input %>%
54      filter(false) %>%
55      select(chr, int, lgl) %>%
56      collect(),
57    tbl
58  )
59})
60
61test_that("filtering with expression", {
62  char_sym <- "b"
63  compare_dplyr_binding(
64    .input %>%
65      filter(chr == char_sym) %>%
66      select(string = chr, int) %>%
67      collect(),
68    tbl
69  )
70})
71
72test_that("filtering with arithmetic", {
73  compare_dplyr_binding(
74    .input %>%
75      filter(dbl + 1 > 3) %>%
76      select(string = chr, int, dbl) %>%
77      collect(),
78    tbl
79  )
80
81  compare_dplyr_binding(
82    .input %>%
83      filter(dbl / 2 > 3) %>%
84      select(string = chr, int, dbl) %>%
85      collect(),
86    tbl
87  )
88
89  compare_dplyr_binding(
90    .input %>%
91      filter(dbl / 2L > 3) %>%
92      select(string = chr, int, dbl) %>%
93      collect(),
94    tbl
95  )
96
97  compare_dplyr_binding(
98    .input %>%
99      filter(int / 2 > 3) %>%
100      select(string = chr, int, dbl) %>%
101      collect(),
102    tbl
103  )
104
105  compare_dplyr_binding(
106    .input %>%
107      filter(int / 2L > 3) %>%
108      select(string = chr, int, dbl) %>%
109      collect(),
110    tbl
111  )
112
113  compare_dplyr_binding(
114    .input %>%
115      filter(dbl %/% 2 > 3) %>%
116      select(string = chr, int, dbl) %>%
117      collect(),
118    tbl
119  )
120
121  compare_dplyr_binding(
122    .input %>%
123      filter(dbl^2 > 3) %>%
124      select(string = chr, int, dbl) %>%
125      collect(),
126    tbl
127  )
128})
129
130test_that("filtering with expression + autocasting", {
131  compare_dplyr_binding(
132    .input %>%
133      filter(dbl + 1 > 3L) %>% # test autocasting with comparison to 3L
134      select(string = chr, int, dbl) %>%
135      collect(),
136    tbl
137  )
138
139  compare_dplyr_binding(
140    .input %>%
141      filter(int + 1 > 3) %>%
142      select(string = chr, int, dbl) %>%
143      collect(),
144    tbl
145  )
146
147  compare_dplyr_binding(
148    .input %>%
149      filter(int^2 > 3) %>%
150      select(string = chr, int, dbl) %>%
151      collect(),
152    tbl
153  )
154})
155
156test_that("More complex select/filter", {
157  compare_dplyr_binding(
158    .input %>%
159      filter(dbl > 2, chr == "d" | chr == "f") %>%
160      select(chr, int, lgl) %>%
161      filter(int < 5) %>%
162      select(int, chr) %>%
163      collect(),
164    tbl
165  )
166})
167
168test_that("filter() with %in%", {
169  compare_dplyr_binding(
170    .input %>%
171      filter(dbl > 2, chr %in% c("d", "f")) %>%
172      collect(),
173    tbl
174  )
175})
176
177test_that("Negative scalar values", {
178  compare_dplyr_binding(
179    .input %>%
180      filter(some_negative > -2) %>%
181      collect(),
182    tbl
183  )
184  compare_dplyr_binding(
185    .input %>%
186      filter(some_negative %in% -1) %>%
187      collect(),
188    tbl
189  )
190  compare_dplyr_binding(
191    .input %>%
192      filter(int == -some_negative) %>%
193      collect(),
194    tbl
195  )
196})
197
198test_that("filter() with between()", {
199  compare_dplyr_binding(
200    .input %>%
201      filter(between(dbl, 1, 2)) %>%
202      collect(),
203    tbl
204  )
205
206  compare_dplyr_binding(
207    .input %>%
208      filter(between(dbl, 0.5, 2)) %>%
209      collect(),
210    tbl
211  )
212
213  expect_identical(
214    tbl %>%
215      record_batch() %>%
216      filter(between(dbl, int, dbl2)) %>%
217      collect(),
218    tbl %>%
219      filter(dbl >= int, dbl <= dbl2)
220  )
221
222  expect_error(
223    tbl %>%
224      record_batch() %>%
225      filter(between(dbl, 1, "2")) %>%
226      collect()
227  )
228
229  expect_error(
230    tbl %>%
231      record_batch() %>%
232      filter(between(dbl, 1, NA)) %>%
233      collect()
234  )
235
236  expect_error(
237    tbl %>%
238      record_batch() %>%
239      filter(between(chr, 1, 2)) %>%
240      collect()
241  )
242})
243
244test_that("filter() with string ops", {
245  skip_if_not_available("utf8proc")
246  compare_dplyr_binding(
247    .input %>%
248      filter(dbl > 2, str_length(verses) > 25) %>%
249      collect(),
250    tbl
251  )
252
253  compare_dplyr_binding(
254    .input %>%
255      filter(dbl > 2, str_length(str_trim(padded_strings, "left")) > 5) %>%
256      collect(),
257    tbl
258  )
259})
260
261test_that("filter environment scope", {
262  # "object 'b_var' not found"
263  compare_dplyr_error(.input %>% filter(chr == b_var), tbl)
264
265  b_var <- "b"
266  compare_dplyr_binding(
267    .input %>%
268      filter(chr == b_var) %>%
269      collect(),
270    tbl
271  )
272  # Also for functions
273  # 'could not find function "isEqualTo"' because we haven't defined it yet
274  compare_dplyr_error(.input %>% filter(isEqualTo(int, 4)), tbl)
275
276  # This works but only because there are S3 methods for those operations
277  isEqualTo <- function(x, y) x == y & !is.na(x)
278  compare_dplyr_binding(
279    .input %>%
280      select(-fct) %>% # factor levels aren't identical
281      filter(isEqualTo(int, 4)) %>%
282      collect(),
283    tbl
284  )
285  # Try something that needs to call another nse_func
286  compare_dplyr_binding(
287    .input %>%
288      select(-fct) %>%
289      filter(nchar(padded_strings) < 10) %>%
290      collect(),
291    tbl
292  )
293  isShortString <- function(x) nchar(x) < 10
294  skip("TODO: 14071")
295  compare_dplyr_binding(
296    .input %>%
297      select(-fct) %>%
298      filter(isShortString(padded_strings)) %>%
299      collect(),
300    tbl
301  )
302})
303
304test_that("Filtering on a column that doesn't exist errors correctly", {
305  with_language("fr", {
306    # expect_warning(., NA) because the usual behavior when it hits a filter
307    # that it can't evaluate is to raise a warning, collect() to R, and retry
308    # the filter. But we want this to error the first time because it's
309    # a user error, not solvable by retrying in R
310    expect_warning(
311      expect_error(
312        tbl %>% record_batch() %>% filter(not_a_col == 42) %>% collect(),
313        "objet 'not_a_col' introuvable"
314      ),
315      NA
316    )
317  })
318  with_language("en", {
319    expect_warning(
320      expect_error(
321        tbl %>% record_batch() %>% filter(not_a_col == 42) %>% collect(),
322        "object 'not_a_col' not found"
323      ),
324      NA
325    )
326  })
327})
328
329test_that("Filtering with unsupported functions", {
330  compare_dplyr_binding(
331    .input %>%
332      filter(int > 2, pnorm(dbl) > .99) %>%
333      collect(),
334    tbl,
335    warning = "Expression pnorm\\(dbl\\) > 0.99 not supported in Arrow; pulling data into R"
336  )
337  compare_dplyr_binding(
338    .input %>%
339      filter(
340        nchar(chr, type = "bytes", allowNA = TRUE) == 1, # bad, Arrow msg
341        int > 2, # good
342        pnorm(dbl) > .99 # bad, opaque
343      ) %>%
344      collect(),
345    tbl,
346    warning = '\\* In nchar\\(chr, type = "bytes", allowNA = TRUE\\) == 1, allowNA = TRUE not supported by Arrow
347\\* Expression pnorm\\(dbl\\) > 0.99 not supported in Arrow
348pulling data into R'
349  )
350})
351
352test_that("Calling Arrow compute functions 'directly'", {
353  expect_equal(
354    tbl %>%
355      record_batch() %>%
356      filter(arrow_add(dbl, 1) > 3L) %>%
357      select(string = chr, int, dbl) %>%
358      collect(),
359    tbl %>%
360      filter(dbl + 1 > 3L) %>%
361      select(string = chr, int, dbl)
362  )
363
364  compare_dplyr_binding(
365    tbl %>%
366      record_batch() %>%
367      filter(arrow_greater(arrow_add(dbl, 1), 3L)) %>%
368      select(string = chr, int, dbl) %>%
369      collect(),
370    tbl %>%
371      filter(dbl + 1 > 3L) %>%
372      select(string = chr, int, dbl)
373  )
374})
375
376test_that("filter() with .data pronoun", {
377  compare_dplyr_binding(
378    .input %>%
379      filter(.data$dbl > 4) %>%
380      select(.data$chr, .data$int, .data$lgl) %>%
381      collect(),
382    tbl
383  )
384
385  compare_dplyr_binding(
386    .input %>%
387      filter(is.na(.data$lgl)) %>%
388      select(.data$chr, .data$int, .data$lgl) %>%
389      collect(),
390    tbl
391  )
392
393  # and the .env pronoun too!
394  chr <- 4
395  compare_dplyr_binding(
396    .input %>%
397      filter(.data$dbl > .env$chr) %>%
398      select(.data$chr, .data$int, .data$lgl) %>%
399      collect(),
400    tbl
401  )
402
403  skip("test now faulty - code no longer gives error & outputs a empty tibble")
404  # but there is an error if we don't override the masking with `.env`
405  compare_dplyr_error(
406    .input %>%
407      filter(.data$dbl > chr) %>%
408      select(.data$chr, .data$int, .data$lgl) %>%
409      collect(),
410    tbl
411  )
412})
413