1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18skip_if_not_available("dataset")
19
20library(dplyr, warn.conflicts = FALSE)
21
22left <- example_data
23left$some_grouping <- rep(c(1, 2), 5)
24
25left_tab <- Table$create(left)
26
27to_join <- tibble::tibble(
28  some_grouping = c(1, 2),
29  capital_letters = c("A", "B"),
30  another_column = TRUE
31)
32to_join_tab <- Table$create(to_join)
33
34
35test_that("left_join", {
36  expect_message(
37    compare_dplyr_binding(
38      .input %>%
39        left_join(to_join) %>%
40        collect(),
41      left
42    ),
43    'Joining, by = "some_grouping"'
44  )
45})
46
47test_that("left_join `by` args", {
48  compare_dplyr_binding(
49    .input %>%
50      left_join(to_join, by = "some_grouping") %>%
51      collect(),
52    left
53  )
54  compare_dplyr_binding(
55    .input %>%
56      left_join(
57        to_join %>%
58          rename(the_grouping = some_grouping),
59        by = c(some_grouping = "the_grouping")
60      ) %>%
61      collect(),
62    left
63  )
64
65  compare_dplyr_binding(
66    .input %>%
67      rename(the_grouping = some_grouping) %>%
68      left_join(
69        to_join,
70        by = c(the_grouping = "some_grouping")
71      ) %>%
72      collect(),
73    left
74  )
75})
76
77test_that("join two tables", {
78  expect_identical(
79    left_tab %>%
80      left_join(to_join_tab, by = "some_grouping") %>%
81      collect(),
82    left %>%
83      left_join(to_join, by = "some_grouping") %>%
84      collect()
85  )
86})
87
88test_that("Error handling", {
89  expect_error(
90    left_tab %>%
91      left_join(to_join, by = "not_a_col") %>%
92      collect(),
93    "all(names(by) %in% names(x)) is not TRUE",
94    fixed = TRUE
95  )
96})
97
98# TODO: test duplicate col names
99# TODO: casting: int and float columns?
100
101test_that("right_join", {
102  compare_dplyr_binding(
103    .input %>%
104      right_join(to_join, by = "some_grouping") %>%
105      collect(),
106    left
107  )
108})
109
110test_that("inner_join", {
111  compare_dplyr_binding(
112    .input %>%
113      inner_join(to_join, by = "some_grouping") %>%
114      collect(),
115    left
116  )
117})
118
119test_that("full_join", {
120  compare_dplyr_binding(
121    .input %>%
122      full_join(to_join, by = "some_grouping") %>%
123      collect(),
124    left
125  )
126})
127
128test_that("semi_join", {
129  compare_dplyr_binding(
130    .input %>%
131      semi_join(to_join, by = "some_grouping") %>%
132      collect(),
133    left
134  )
135})
136
137test_that("anti_join", {
138  compare_dplyr_binding(
139    .input %>%
140      # Factor levels when there are no rows in the data don't match
141      # TODO: use better anti_join test data
142      select(-fct) %>%
143      anti_join(to_join, by = "some_grouping") %>%
144      collect(),
145    left
146  )
147})
148
149test_that("mutate then join", {
150  left <- Table$create(
151    one = c("a", "b"),
152    two = 1:2
153  )
154  right <- Table$create(
155    three = TRUE,
156    dos = 2L
157  )
158
159  expect_equal(
160    left %>%
161      rename(dos = two) %>%
162      mutate(one = toupper(one)) %>%
163      left_join(
164        right %>%
165          mutate(three = !three)
166      ) %>%
167      arrange(dos) %>%
168      collect(),
169    tibble(
170      one = c("A", "B"),
171      dos = 1:2,
172      three = c(NA, FALSE)
173    )
174  )
175})
176