1#' High Correlation Filter 2#' 3#' `step_corr` creates a *specification* of a recipe 4#' step that will potentially remove variables that have large 5#' absolute correlations with other variables. 6#' 7#' @inheritParams step_center 8#' @param threshold A value for the threshold of absolute 9#' correlation values. The step will try to remove the minimum 10#' number of columns so that all the resulting absolute 11#' correlations are less than this value. 12#' @param use A character string for the `use` argument to 13#' the [stats::cor()] function. 14#' @param method A character string for the `method` argument 15#' to the [stats::cor()] function. 16#' @param removals A character string that contains the names of 17#' columns that should be removed. These values are not determined 18#' until [prep.recipe()] is called. 19#' @template step-return 20#' @author Original R code for filtering algorithm by Dong Li, 21#' modified by Max Kuhn. Contributions by Reynald Lescarbeau (for 22#' original in `caret` package). Max Kuhn for the `step` 23#' function. 24#' @family variable filter steps 25#' @export 26#' 27#' @details This step attempts to remove variables to keep the 28#' largest absolute correlation between the variables less than 29#' `threshold`. 30#' 31#' When a column has a single unique value, that column will be 32#' excluded from the correlation analysis. Also, if the data set 33#' has sporadic missing values (and an inappropriate value of `use` 34#' is chosen), some columns will also be excluded from the filter. 35#' 36#' When you [`tidy()`] this step, a tibble with column `terms` (the columns 37#' that will be removed) is returned. 38#' 39#' @examples 40#' library(modeldata) 41#' data(biomass) 42#' 43#' set.seed(3535) 44#' biomass$duplicate <- biomass$carbon + rnorm(nrow(biomass)) 45#' 46#' biomass_tr <- biomass[biomass$dataset == "Training",] 47#' biomass_te <- biomass[biomass$dataset == "Testing",] 48#' 49#' rec <- recipe(HHV ~ carbon + hydrogen + oxygen + nitrogen + 50#' sulfur + duplicate, 51#' data = biomass_tr) 52#' 53#' corr_filter <- rec %>% 54#' step_corr(all_numeric_predictors(), threshold = .5) 55#' 56#' filter_obj <- prep(corr_filter, training = biomass_tr) 57#' 58#' filtered_te <- bake(filter_obj, biomass_te) 59#' round(abs(cor(biomass_tr[, c(3:7, 9)])), 2) 60#' round(abs(cor(filtered_te)), 2) 61#' 62#' tidy(corr_filter, number = 1) 63#' tidy(filter_obj, number = 1) 64step_corr <- function(recipe, 65 ..., 66 role = NA, 67 trained = FALSE, 68 threshold = 0.9, 69 use = "pairwise.complete.obs", 70 method = "pearson", 71 removals = NULL, 72 skip = FALSE, 73 id = rand_id("corr") 74 ) { 75 add_step( 76 recipe, 77 step_corr_new( 78 terms = ellipse_check(...), 79 role = role, 80 trained = trained, 81 threshold = threshold, 82 use = use, 83 method = method, 84 removals = removals, 85 skip = skip, 86 id = id 87 ) 88 ) 89} 90 91step_corr_new <- 92 function(terms, role, trained, threshold, use, method, removals, skip, id) { 93 step( 94 subclass = "corr", 95 terms = terms, 96 role = role, 97 trained = trained, 98 threshold = threshold, 99 use = use, 100 method = method, 101 removals = removals, 102 skip = skip, 103 id = id 104 ) 105 } 106 107#' @export 108prep.step_corr <- function(x, training, info = NULL, ...) { 109 col_names <- recipes_eval_select(x$terms, training, info) 110 check_type(training[, col_names]) 111 112 if (length(col_names) > 1) { 113 filter <- corr_filter( 114 x = training[, col_names], 115 cutoff = x$threshold, 116 use = x$use, 117 method = x$method 118 ) 119 } else { 120 filter <- numeric(0) 121 } 122 123 step_corr_new( 124 terms = x$terms, 125 role = x$role, 126 trained = TRUE, 127 threshold = x$threshold, 128 use = x$use, 129 method = x$method, 130 removals = filter, 131 skip = x$skip, 132 id = x$id 133 ) 134} 135 136#' @export 137bake.step_corr <- function(object, new_data, ...) { 138 if (length(object$removals) > 0) 139 new_data <- new_data[,!(colnames(new_data) %in% object$removals)] 140 as_tibble(new_data) 141} 142 143print.step_corr <- 144 function(x, width = max(20, options()$width - 36), ...) { 145 if (x$trained) { 146 if (length(x$removals) > 0) { 147 cat("Correlation filter removed ") 148 cat(format_ch_vec(x$removals, width = width)) 149 } else 150 cat("Correlation filter removed no terms") 151 } else { 152 cat("Correlation filter on ", sep = "") 153 cat(format_selectors(x$terms, width = width)) 154 } 155 if (x$trained) 156 cat(" [trained]\n") 157 else 158 cat("\n") 159 invisible(x) 160 } 161 162 163corr_filter <- 164 function(x, 165 cutoff = .90, 166 use = "pairwise.complete.obs", 167 method = "pearson") { 168 x <- cor(x, use = use, method = method) 169 170 if (any(!complete.cases(x))) { 171 all_na <- apply(x, 2, function(x) all(is.na(x))) 172 if (sum(all_na) >= nrow(x) - 1) { 173 rlang::warn("Too many correlations are `NA`; skipping correlation filter.") 174 return(numeric(0)) 175 } else { 176 na_cols <- which(all_na) 177 if (length(na_cols) > 0) { 178 x[na_cols, ] <- 0 179 x[, na_cols] <- 0 180 rlang::warn( 181 paste0( 182 "The correlation matrix has missing values. ", 183 length(na_cols), 184 " columns were excluded from the filter." 185 ) 186 ) 187 } 188 } 189 if (any(is.na(x))) { 190 rlang::warn( 191 paste0( 192 "The correlation matrix has sporadic missing values. ", 193 "Some columns were excluded from the filter." 194 ) 195 ) 196 x[is.na(x)] <- 0 197 } 198 diag(x) <- 1 199 } 200 averageCorr <- colMeans(abs(x)) 201 averageCorr <- as.numeric(as.factor(averageCorr)) 202 x[lower.tri(x, diag = TRUE)] <- NA 203 combsAboveCutoff <- which(abs(x) > cutoff) 204 205 colsToCheck <- ceiling(combsAboveCutoff / nrow(x)) 206 rowsToCheck <- combsAboveCutoff %% nrow(x) 207 208 colsToDiscard <- averageCorr[colsToCheck] > averageCorr[rowsToCheck] 209 rowsToDiscard <- !colsToDiscard 210 211 deletecol <- c(colsToCheck[colsToDiscard], rowsToCheck[rowsToDiscard]) 212 deletecol <- unique(deletecol) 213 if (length(deletecol) > 0) { 214 deletecol <- colnames(x)[deletecol] 215 } 216 deletecol 217 } 218 219tidy_filter <- function(x, ...) { 220 if (is_trained(x)) { 221 res <- tibble(terms = x$removals) 222 } else { 223 term_names <- sel2char(x$terms) 224 res <- tibble(terms = na_chr) 225 } 226 res$id <- x$id 227 res 228} 229 230#' @rdname tidy.recipe 231#' @export 232tidy.step_corr <- tidy_filter 233 234 235#' @rdname tunable.recipe 236#' @export 237tunable.step_corr <- function(x, ...) { 238 tibble::tibble( 239 name = "threshold", 240 call_info = list( 241 list(pkg = "dials", fun = "threshold") 242 ), 243 source = "recipe", 244 component = "step_corr", 245 component_id = x$id 246 ) 247} 248 249