1library(testthat)
2library(recipes)
3library(dplyr)
4
5iris2 <- iris[-(1:45),]
6iris2$Species[seq(6, 96, by = 5)] <- NA
7iris2$Species2 <- sample(iris2$Species)
8iris2$Species3 <- as.character(sample(iris2$Species))
9
10rec <- recipe( ~ ., data = iris2)
11
12test_that('basic usage', {
13  skip_if(utils::packageVersion("recipes") > "0.1.13")
14  rec1 <- rec %>%
15    step_upsample(matches("Species$"), id = "")
16
17  untrained <- tibble(
18    terms = "matches(\"Species$\")", id = ""
19  )
20
21  expect_equivalent(untrained, tidy(rec1, number = 1))
22
23  rec1_p <- prep(rec1, training = iris2)
24
25  trained <- tibble(
26    terms = "Species",
27    id = ""
28  )
29
30  expect_equal(trained, tidy(rec1_p, number = 1))
31
32  tr_xtab <- table(juice(rec1_p)$Species, useNA = "always")
33  te_xtab <- table(bake(rec1_p, new_data = iris2)$Species, useNA = "always")
34  og_xtab <- table(iris2$Species, useNA = "always")
35
36  expect_equal(max(tr_xtab), max(og_xtab))
37  expect_equal(sum(is.na(juice(rec1_p)$Species)), max(og_xtab))
38  expect_equal(te_xtab, og_xtab)
39
40  expect_warning(prep(rec1, training = iris2), NA)
41})
42
43test_that('ratio value', {
44  skip_if(utils::packageVersion("recipes") > "0.1.13")
45  rec2 <- rec %>%
46    step_upsample(matches("Species$"), ratio = .25)
47
48  rec2_p <- prep(rec2, training = iris2)
49
50  tr_xtab <- table(juice(rec2_p)$Species, useNA = "always")
51  te_xtab <- table(bake(rec2_p, new_data = iris2)$Species, useNA = "always")
52  og_xtab <- table(iris2$Species, useNA = "always")
53
54  expect_equal(min(tr_xtab), 10)
55  expect_equal(sum(is.na(juice(rec2_p)$Species)),
56               sum(is.na(iris2$Species)))
57  expect_equal(te_xtab, og_xtab)
58})
59
60
61test_that('no skipping', {
62  skip_if(utils::packageVersion("recipes") > "0.1.13")
63  rec3 <- rec %>%
64    step_upsample(matches("Species$"), skip = FALSE)
65
66  rec3_p <- prep(rec3, training = iris2)
67
68  tr_xtab <- table(juice(rec3_p)$Species, useNA = "always")
69  te_xtab <- table(bake(rec3_p, new_data = iris2)$Species, useNA = "always")
70  og_xtab <- table(iris2$Species, useNA = "always")
71
72  expect_equal(max(tr_xtab), max(og_xtab))
73  expect_equal(te_xtab, tr_xtab)
74})
75
76
77
78test_that('bad data', {
79  skip_if(utils::packageVersion("recipes") > "0.1.13")
80  expect_error(
81    rec %>%
82      step_upsample(Sepal.Width) %>%
83      prep(retain = TRUE)
84  )
85  expect_error(
86    rec %>%
87      step_upsample(Species3) %>%
88      prep(strings_as_factors = FALSE)
89  )
90  expect_error(
91    rec %>%
92      step_upsample(Species, Species2) %>%
93      prep(strings_as_factors = FALSE)
94  )
95})
96
97test_that('printing', {
98  skip_if(utils::packageVersion("recipes") > "0.1.13")
99  rec4 <- rec %>%
100    step_upsample(Species)
101
102  expect_output(print(rec))
103  expect_output(prep(rec4, training = iris2, verbose = TRUE))
104})
105
106test_that('`seed` produces identical sampling', {
107  skip_if(utils::packageVersion("recipes") > "0.1.13")
108
109  upsample_with_seed <- function(rec, seed = sample.int(10^5, 1)) {
110    rec %>%
111      step_upsample(Species, seed = seed) %>%
112      prep(training = iris2) %>%
113      juice() %>%
114      pull(Petal.Width)
115  }
116
117  petal_width_1 <- upsample_with_seed(rec, seed = 1234)
118  petal_width_2 <- upsample_with_seed(rec, seed = 1234)
119  petal_width_3 <- upsample_with_seed(rec, seed = 12345)
120
121  expect_equal(petal_width_1, petal_width_2)
122  expect_false(identical(petal_width_1, petal_width_3))
123})
124
125
126test_that('ratio deprecation', {
127  skip_if(utils::packageVersion("recipes") > "0.1.13")
128  expect_message(
129    new_rec <-
130      rec %>%
131      step_upsample(tidyselect::matches("Species$"), ratio = 2),
132    "argument is now deprecated"
133  )
134  expect_equal(new_rec$steps[[1]]$over_ratio, 2)
135})
136
137
138
139test_that('tunable', {
140  skip_if(utils::packageVersion("recipes") > "0.1.13")
141  rec <-
142    recipe(~ ., data = iris) %>%
143    step_upsample(all_predictors())
144  rec_param <- tunable.step_upsample(rec$steps[[1]])
145  expect_equal(rec_param$name, c("over_ratio"))
146  expect_true(all(rec_param$source == "recipe"))
147  expect_true(is.list(rec_param$call_info))
148  expect_equal(nrow(rec_param), 1)
149  expect_equal(
150    names(rec_param),
151    c('name', 'call_info', 'source', 'component', 'component_id')
152  )
153})
154
155
156