1diagDA <- function(ls, cll, ts, pool= TRUE) 2{ 3 ## Purpose: Diagonal (Linear or Quadratic) Discriminant Analysis 4 ## ---------------------------------------------------------------------- 5 ## Arguments: --> ?diagDA (i.e. ../man/diagDA.Rd ) 6 ## ---------------------------------------------------------------------- 7 ## Authors: Sandrine Dudoit, sandrine@stat.berkeley.edu 8 ## Jane Fridlyand, janef@stat.berkeley.edu 9 ## as function stat.diag.da() in package "sma" 10 ## 11 ## Modification (API and speed): Martin Maechler, Date: 19 Nov 2003, 15:34 12 13### ---------------------- Fit Model ------------------------------ 14 ls <- data.matrix(ls) 15 n <- nrow(ls) 16 p <- ncol(ls) 17 18 cl0 <- as.integer(min(cll, na.rm=TRUE) - 1) 19 cll <- as.integer(cll) - cl0 ## cll now in 1:K 20 inaC <- is.na(cll) 21 clL <- cll[!inaC] 22 K <- max(clL) 23 if(K != length(unique(clL))) 24 stop(sQuote("cll")," did not contain *consecutive* integers") 25 26 nk <- integer(K) 27 m <- v <- matrix(0,p,K) 28 29 colVars <- function(x, means = colMeans(x, na.rm = na.rm), na.rm=FALSE) { 30 x <- sweep(x, 2, means) 31 colSums(x*x, na.rm = na.rm) / (nrow(x) - 1) 32 } 33 sum.na <- function(x) sum(x, na.rm=TRUE) 34 35 ## Class means and variances 36 for(k in 1:K) { 37 which <- (cll == k) 38 nk[k] <- sum.na(which) 39 lsk <- ls[which, , drop = FALSE] 40 m[,k] <- colMeans(lsk, na.rm = TRUE) 41 if(nk[k] > 1) 42 v[,k] <- colVars (lsk, na.rm = TRUE, means = m[,k]) ## else 0 43 } 44 45### ---------------------- Predict from Model ----------------------------- 46 47 ts <- data.matrix(ts) 48 if(p != ncol(ts)) 49 stop("test set matrix must have same columns as learning one") 50 ## any NA's in test set currently must give NA predictions 51 ts <- na.exclude(ts) 52 nt <- nrow(ts) 53 disc <- matrix(0, nt,K) 54 55 if(pool) { ## LDA 56 ## Pooled estimates of variances 57 vp <- rowSums(rep(nk - 1, each=p) * v) / (n - K) 58 ## == apply(v, 1, function(z) sum.na((nk-1)*z))/(n-K) 59 if(any(i0 <- vp == 0)) vp[i0] <- 1e-7 * min(vp[!i0]) 60 61 ivp <- rep(1/vp, each = nt) # to use in loop 62 63 for(k in 1:K) { 64 y <- ts - rep(m[,k], each=nt) 65 disc[,k] <- rowSums(y*y * ivp) 66 ## == apply(ts, 1, function(z) sum.na((z-m[,k])^2/vp)) 67 } 68 } 69 else { ## QDA 70if(FALSE) { ## not yet quite : fails ../tests/dDA.R -- FIXME 71 for(k in 1:K) { 72 ts <- ts - rep(m[,k], each=nt) 73 disc[,k] <- rowSums((ts*ts) / rep(v[,k], each=nt)) + sum(log(v[,k])) 74 } 75} else { 76 for(k in 1:K) { 77 disc[,k] <- 78 apply(ts,1, function(z) sum((z-m[,k])^2/v[,k])) + 79 sum.na(log(v[,k])) 80 } 81} 82 } 83 84 ## predictions 85 86 pred <- cl0 + apply(disc, 1, which.min) 87 if(inherits(attr(ts,"na.action"), "exclude")) # had missings in `ts' 88 pred <- napredict(omit = attr(ts,"na.action"), pred) 89 pred 90} 91 92## Cleaner: One function to estimate; one to predict : 93## ------- (my tests give a time-penalty 5% for doing things two steps) 94 95dDA <- function(x, cll, pool= TRUE) 96{ 97 ## Purpose: Diagonal (Linear or Quadratic) Discriminant Analysis 98 99 x <- data.matrix(x) 100 n <- nrow(x) 101 p <- ncol(x) 102 103 cl0 <- as.integer(min(cll, na.rm=TRUE) - 1) 104 cll <- as.integer(cll) - cl0 ## cll now in 1:K 105 inaC <- is.na(cll) 106 clL <- cll[!inaC] 107 K <- max(clL) 108 if(K != length(unique(clL))) 109 stop(sQuote("cll")," did not contain *consecutive* integers") 110 111 nk <- integer(K) 112 m <- v <- matrix(0,p,K) 113 114 colVars <- function(x, means = colMeans(x, na.rm = na.rm), na.rm=FALSE) { 115 x <- sweep(x, 2, means) 116 colSums(x*x, na.rm = na.rm) / (nrow(x) - 1) 117 } 118 sum.na <- function(x) sum(x, na.rm=TRUE) 119 120 ## Class means and variances 121 for(k in 1:K) { 122 which <- (cll == k) 123 nk[k] <- sum.na(which) 124 lsk <- x[which, , drop = FALSE] 125 m[,k] <- colMeans(lsk, na.rm = TRUE) 126 if(nk[k] > 1) 127 v[,k] <- colVars (lsk, na.rm = TRUE, means = m[,k]) ## else 0 128 } 129 structure(list(call = match.call(), cl0 = cl0, n=n, p=p, K=K, 130 means=m, vars=v, nk=nk, pool=pool), 131 class = "dDA") 132} 133 134print.dDA <- function(x, ...) 135{ 136 cat(if(x$pool)"Linear (pooled var)" else "Quadratic (no pooling)", 137 "Diagonal Discriminant Analysis,\n ", deparse(x$call),"\n") 138 with(x, 139 cat(" (n= ",n,") x (p= ",p,") data in K=",K," classes of [", 140 paste(nk, collapse=", "),"] observations each\n", sep="")) 141 cat("\n") 142 invisible(x) 143} 144 145predict.dDA <- function(object, newdata, pool = object$pool, ...) 146{ 147 newdata <- data.matrix(newdata) 148 n <- object$n 149 p <- object$p 150 K <- object$K 151 ## means and vars are (p x K) matrices: 152 mu <- object$means 153 Vr <- object$vars 154 if(p != ncol(newdata)) 155 stop("test set matrix must have same columns as learning one") 156 ## any NA's in test set currently must give NA predictions 157 newdata <- na.exclude(newdata) 158 nt <- nrow(newdata) 159 disc <- matrix(0, nt,K) 160 161 if(pool) { ## LDA 162 ## Pooled estimates of variances 163 vp <- rowSums(Vr * rep(object$nk - 1, each=p)) / (n - K) 164 ## == apply(Vr, 1, function(z) sum.na((nk-1)*z))/(n-K) 165 if(any(i0 <- vp == 0)) vp[i0] <- 1e-7 * min(vp[!i0]) 166 167 ivp <- rep(1/vp, each = nt) # to use in loop 168 169 for(k in 1:K) { 170 y <- newdata - rep(mu[,k], each=nt) 171 disc[,k] <- rowSums(y*y * ivp) 172 ## == apply(newdata, 1, function(z) sum.na((z-mu[,k])^2/vp)) 173 } 174 } 175 else { ## QDA 176 sum.na <- function(x) sum(x, na.rm=TRUE) 177 ## zero - variances are not acceptable later 178 if(any(i0 <- Vr == 0)) { 179 if(all(i0)) 180 stop("all variances are 0 -- cannot predict") 181 Vr[i0] <- 1e-7 * min(Vr[!i0]) 182 } 183 184if(FALSE) { ## not yet quite : fails ../tests/dDA.R -- FIXME 185 for(k in 1:K) { 186 y <- newdata - rep(mu[,k], each=nt) 187 disc[,k] <- rowSums((y*y) / rep(Vr[,k], each=nt)) + sum(log(Vr[,k])) 188 } 189} else { 190 for(k in 1:K) { 191 disc[,k] <- 192 apply(newdata,1, function(z) sum((z-mu[,k])^2/Vr[,k])) + 193 sum.na(log(Vr[,k])) 194 } 195} 196 } 197 198 ## predictions 199 200 pred <- object$cl0 + apply(disc, 1, which.min) 201 if(inherits(attr(newdata,"na.action"), "exclude")) { 202 ## had missings in `newdata' 203 pred <- napredict(omit = attr(newdata,"na.action"), pred) 204 } ## ^^^^^^^^^ typically stats:::napredict.exclude() 205 pred 206} 207 208