1library(testthat) 2library(recipes) 3library(dplyr) 4 5iris2 <- iris[-(1:45),] 6iris2$Species[seq(6, 96, by = 5)] <- NA 7iris2$Species2 <- sample(iris2$Species) 8iris2$Species3 <- as.character(sample(iris2$Species)) 9 10rec <- recipe( ~ ., data = iris2) 11 12test_that('basic usage', { 13 skip_if(utils::packageVersion("recipes") > "0.1.13") 14 rec1 <- rec %>% 15 step_upsample(matches("Species$"), id = "") 16 17 untrained <- tibble( 18 terms = "matches(\"Species$\")", id = "" 19 ) 20 21 expect_equivalent(untrained, tidy(rec1, number = 1)) 22 23 rec1_p <- prep(rec1, training = iris2) 24 25 trained <- tibble( 26 terms = "Species", 27 id = "" 28 ) 29 30 expect_equal(trained, tidy(rec1_p, number = 1)) 31 32 tr_xtab <- table(juice(rec1_p)$Species, useNA = "always") 33 te_xtab <- table(bake(rec1_p, new_data = iris2)$Species, useNA = "always") 34 og_xtab <- table(iris2$Species, useNA = "always") 35 36 expect_equal(max(tr_xtab), max(og_xtab)) 37 expect_equal(sum(is.na(juice(rec1_p)$Species)), max(og_xtab)) 38 expect_equal(te_xtab, og_xtab) 39 40 expect_warning(prep(rec1, training = iris2), NA) 41}) 42 43test_that('ratio value', { 44 skip_if(utils::packageVersion("recipes") > "0.1.13") 45 rec2 <- rec %>% 46 step_upsample(matches("Species$"), ratio = .25) 47 48 rec2_p <- prep(rec2, training = iris2) 49 50 tr_xtab <- table(juice(rec2_p)$Species, useNA = "always") 51 te_xtab <- table(bake(rec2_p, new_data = iris2)$Species, useNA = "always") 52 og_xtab <- table(iris2$Species, useNA = "always") 53 54 expect_equal(min(tr_xtab), 10) 55 expect_equal(sum(is.na(juice(rec2_p)$Species)), 56 sum(is.na(iris2$Species))) 57 expect_equal(te_xtab, og_xtab) 58}) 59 60 61test_that('no skipping', { 62 skip_if(utils::packageVersion("recipes") > "0.1.13") 63 rec3 <- rec %>% 64 step_upsample(matches("Species$"), skip = FALSE) 65 66 rec3_p <- prep(rec3, training = iris2) 67 68 tr_xtab <- table(juice(rec3_p)$Species, useNA = "always") 69 te_xtab <- table(bake(rec3_p, new_data = iris2)$Species, useNA = "always") 70 og_xtab <- table(iris2$Species, useNA = "always") 71 72 expect_equal(max(tr_xtab), max(og_xtab)) 73 expect_equal(te_xtab, tr_xtab) 74}) 75 76 77 78test_that('bad data', { 79 skip_if(utils::packageVersion("recipes") > "0.1.13") 80 expect_error( 81 rec %>% 82 step_upsample(Sepal.Width) %>% 83 prep(retain = TRUE) 84 ) 85 expect_error( 86 rec %>% 87 step_upsample(Species3) %>% 88 prep(strings_as_factors = FALSE) 89 ) 90 expect_error( 91 rec %>% 92 step_upsample(Species, Species2) %>% 93 prep(strings_as_factors = FALSE) 94 ) 95}) 96 97test_that('printing', { 98 skip_if(utils::packageVersion("recipes") > "0.1.13") 99 rec4 <- rec %>% 100 step_upsample(Species) 101 102 expect_output(print(rec)) 103 expect_output(prep(rec4, training = iris2, verbose = TRUE)) 104}) 105 106test_that('`seed` produces identical sampling', { 107 skip_if(utils::packageVersion("recipes") > "0.1.13") 108 109 upsample_with_seed <- function(rec, seed = sample.int(10^5, 1)) { 110 rec %>% 111 step_upsample(Species, seed = seed) %>% 112 prep(training = iris2) %>% 113 juice() %>% 114 pull(Petal.Width) 115 } 116 117 petal_width_1 <- upsample_with_seed(rec, seed = 1234) 118 petal_width_2 <- upsample_with_seed(rec, seed = 1234) 119 petal_width_3 <- upsample_with_seed(rec, seed = 12345) 120 121 expect_equal(petal_width_1, petal_width_2) 122 expect_false(identical(petal_width_1, petal_width_3)) 123}) 124 125 126test_that('ratio deprecation', { 127 skip_if(utils::packageVersion("recipes") > "0.1.13") 128 expect_message( 129 new_rec <- 130 rec %>% 131 step_upsample(tidyselect::matches("Species$"), ratio = 2), 132 "argument is now deprecated" 133 ) 134 expect_equal(new_rec$steps[[1]]$over_ratio, 2) 135}) 136 137 138 139test_that('tunable', { 140 skip_if(utils::packageVersion("recipes") > "0.1.13") 141 rec <- 142 recipe(~ ., data = iris) %>% 143 step_upsample(all_predictors()) 144 rec_param <- tunable.step_upsample(rec$steps[[1]]) 145 expect_equal(rec_param$name, c("over_ratio")) 146 expect_true(all(rec_param$source == "recipe")) 147 expect_true(is.list(rec_param$call_info)) 148 expect_equal(nrow(rec_param), 1) 149 expect_equal( 150 names(rec_param), 151 c('name', 'call_info', 'source', 'component', 'component_id') 152 ) 153}) 154 155 156