1#' High Correlation Filter
2#'
3#' `step_corr` creates a *specification* of a recipe
4#'  step that will potentially remove variables that have large
5#'  absolute correlations with other variables.
6#'
7#' @inheritParams step_center
8#' @param threshold A value for the threshold of absolute
9#'  correlation values. The step will try to remove the minimum
10#'  number of columns so that all the resulting absolute
11#'  correlations are less than this value.
12#' @param use A character string for the `use` argument to
13#'  the [stats::cor()] function.
14#' @param method A character string for the `method` argument
15#'  to the [stats::cor()] function.
16#' @param removals A character string that contains the names of
17#'  columns that should be removed. These values are not determined
18#'  until [prep.recipe()] is called.
19#' @template step-return
20#' @author Original R code for filtering algorithm by Dong Li,
21#'  modified by Max Kuhn. Contributions by Reynald Lescarbeau (for
22#'  original in `caret` package). Max Kuhn for the `step`
23#'  function.
24#' @family variable filter steps
25#' @export
26#'
27#' @details This step attempts to remove variables to keep the
28#'  largest absolute correlation between the variables less than
29#'  `threshold`.
30#'
31#' When a column has a single unique value, that column will be
32#'  excluded from the correlation analysis. Also, if the data set
33#'  has sporadic missing values (and an inappropriate value of `use`
34#'  is chosen), some columns will also be excluded from the filter.
35#'
36#' When you [`tidy()`] this step, a tibble with column `terms` (the columns
37#'  that will be removed) is returned.
38#'
39#' @examples
40#' library(modeldata)
41#' data(biomass)
42#'
43#' set.seed(3535)
44#' biomass$duplicate <- biomass$carbon + rnorm(nrow(biomass))
45#'
46#' biomass_tr <- biomass[biomass$dataset == "Training",]
47#' biomass_te <- biomass[biomass$dataset == "Testing",]
48#'
49#' rec <- recipe(HHV ~ carbon + hydrogen + oxygen + nitrogen +
50#'                     sulfur + duplicate,
51#'               data = biomass_tr)
52#'
53#' corr_filter <- rec %>%
54#'   step_corr(all_numeric_predictors(), threshold = .5)
55#'
56#' filter_obj <- prep(corr_filter, training = biomass_tr)
57#'
58#' filtered_te <- bake(filter_obj, biomass_te)
59#' round(abs(cor(biomass_tr[, c(3:7, 9)])), 2)
60#' round(abs(cor(filtered_te)), 2)
61#'
62#' tidy(corr_filter, number = 1)
63#' tidy(filter_obj, number = 1)
64step_corr <- function(recipe,
65                      ...,
66                      role = NA,
67                      trained = FALSE,
68                      threshold = 0.9,
69                      use = "pairwise.complete.obs",
70                      method = "pearson",
71                      removals = NULL,
72                      skip = FALSE,
73                      id = rand_id("corr")
74                      ) {
75  add_step(
76    recipe,
77    step_corr_new(
78      terms = ellipse_check(...),
79      role = role,
80      trained = trained,
81      threshold = threshold,
82      use = use,
83      method = method,
84      removals = removals,
85      skip = skip,
86      id = id
87    )
88  )
89}
90
91step_corr_new <-
92  function(terms, role, trained, threshold, use, method, removals, skip, id) {
93    step(
94      subclass = "corr",
95      terms = terms,
96      role = role,
97      trained = trained,
98      threshold = threshold,
99      use = use,
100      method = method,
101      removals = removals,
102      skip = skip,
103      id = id
104    )
105  }
106
107#' @export
108prep.step_corr <- function(x, training, info = NULL, ...) {
109  col_names <- recipes_eval_select(x$terms, training, info)
110  check_type(training[, col_names])
111
112  if (length(col_names) > 1) {
113    filter <- corr_filter(
114      x = training[, col_names],
115      cutoff = x$threshold,
116      use = x$use,
117      method = x$method
118    )
119  } else {
120    filter <- numeric(0)
121  }
122
123  step_corr_new(
124    terms = x$terms,
125    role = x$role,
126    trained = TRUE,
127    threshold = x$threshold,
128    use = x$use,
129    method = x$method,
130    removals = filter,
131    skip = x$skip,
132    id = x$id
133  )
134}
135
136#' @export
137bake.step_corr <- function(object, new_data, ...) {
138  if (length(object$removals) > 0)
139    new_data <- new_data[,!(colnames(new_data) %in% object$removals)]
140  as_tibble(new_data)
141}
142
143print.step_corr <-
144  function(x,  width = max(20, options()$width - 36), ...) {
145    if (x$trained) {
146      if (length(x$removals) > 0) {
147        cat("Correlation filter removed ")
148        cat(format_ch_vec(x$removals, width = width))
149      } else
150        cat("Correlation filter removed no terms")
151    } else {
152      cat("Correlation filter on ", sep = "")
153      cat(format_selectors(x$terms, width = width))
154    }
155    if (x$trained)
156      cat(" [trained]\n")
157    else
158      cat("\n")
159    invisible(x)
160  }
161
162
163corr_filter <-
164  function(x,
165           cutoff = .90,
166           use = "pairwise.complete.obs",
167           method = "pearson") {
168    x <- cor(x, use = use, method = method)
169
170    if (any(!complete.cases(x))) {
171      all_na <- apply(x, 2, function(x) all(is.na(x)))
172      if (sum(all_na) >= nrow(x) - 1) {
173        rlang::warn("Too many correlations are `NA`; skipping correlation filter.")
174        return(numeric(0))
175      } else {
176        na_cols <- which(all_na)
177        if (length(na_cols) >  0) {
178          x[na_cols, ] <- 0
179          x[, na_cols] <- 0
180          rlang::warn(
181            paste0(
182              "The correlation matrix has missing values. ",
183              length(na_cols),
184              " columns were excluded from the filter."
185            )
186          )
187        }
188      }
189      if (any(is.na(x))) {
190        rlang::warn(
191          paste0(
192            "The correlation matrix has sporadic missing values. ",
193            "Some columns were excluded from the filter."
194          )
195        )
196        x[is.na(x)] <- 0
197      }
198      diag(x) <- 1
199    }
200    averageCorr <- colMeans(abs(x))
201    averageCorr <- as.numeric(as.factor(averageCorr))
202    x[lower.tri(x, diag = TRUE)] <- NA
203    combsAboveCutoff <- which(abs(x) > cutoff)
204
205    colsToCheck <- ceiling(combsAboveCutoff / nrow(x))
206    rowsToCheck <- combsAboveCutoff %% nrow(x)
207
208    colsToDiscard <- averageCorr[colsToCheck] > averageCorr[rowsToCheck]
209    rowsToDiscard <- !colsToDiscard
210
211    deletecol <- c(colsToCheck[colsToDiscard], rowsToCheck[rowsToDiscard])
212    deletecol <- unique(deletecol)
213    if (length(deletecol) > 0) {
214      deletecol <- colnames(x)[deletecol]
215    }
216    deletecol
217  }
218
219tidy_filter <- function(x, ...) {
220  if (is_trained(x)) {
221    res <- tibble(terms = x$removals)
222  } else {
223    term_names <- sel2char(x$terms)
224    res <- tibble(terms = na_chr)
225  }
226  res$id <- x$id
227  res
228}
229
230#' @rdname tidy.recipe
231#' @export
232tidy.step_corr <- tidy_filter
233
234
235#' @rdname tunable.recipe
236#' @export
237tunable.step_corr <- function(x, ...) {
238  tibble::tibble(
239    name = "threshold",
240    call_info = list(
241      list(pkg = "dials", fun = "threshold")
242    ),
243    source = "recipe",
244    component = "step_corr",
245    component_id = x$id
246  )
247}
248
249