1diagDA <- function(ls, cll, ts, pool= TRUE)
2{
3    ## Purpose: Diagonal (Linear or Quadratic) Discriminant Analysis
4    ## ----------------------------------------------------------------------
5    ## Arguments: --> ?diagDA  (i.e. ../man/diagDA.Rd )
6    ## ----------------------------------------------------------------------
7    ## Authors: Sandrine Dudoit, sandrine@stat.berkeley.edu
8    ##	        Jane Fridlyand, janef@stat.berkeley.edu
9    ## as function  stat.diag.da() in package "sma"
10    ##
11    ## Modification (API and speed): Martin Maechler, Date: 19 Nov 2003, 15:34
12
13### ---------------------- Fit Model ------------------------------
14    ls <- data.matrix(ls)
15    n <- nrow(ls)
16    p <- ncol(ls)
17
18    cl0 <- as.integer(min(cll, na.rm=TRUE) - 1)
19    cll <- as.integer(cll) - cl0 ## cll now in 1:K
20    inaC <- is.na(cll)
21    clL <- cll[!inaC]
22    K <- max(clL)
23    if(K != length(unique(clL)))
24        stop(sQuote("cll")," did not contain *consecutive* integers")
25
26    nk <- integer(K)
27    m <- v <- matrix(0,p,K)
28
29    colVars <- function(x, means = colMeans(x, na.rm = na.rm), na.rm=FALSE) {
30        x <- sweep(x, 2, means)
31        colSums(x*x, na.rm = na.rm) / (nrow(x) - 1)
32    }
33    sum.na <- function(x) sum(x, na.rm=TRUE)
34
35    ## Class means and variances
36    for(k in 1:K) {
37        which <- (cll == k)
38        nk[k] <- sum.na(which)
39        lsk <- ls[which, , drop = FALSE]
40        m[,k] <- colMeans(lsk, na.rm = TRUE)
41        if(nk[k] > 1)
42            v[,k] <- colVars (lsk, na.rm = TRUE, means = m[,k]) ## else 0
43    }
44
45### ---------------------- Predict from Model -----------------------------
46
47    ts <- data.matrix(ts)
48    if(p != ncol(ts))
49        stop("test set matrix must have same columns as learning one")
50    ## any NA's in test set currently must give NA predictions
51    ts <- na.exclude(ts)
52    nt <- nrow(ts)
53    disc <- matrix(0, nt,K)
54
55    if(pool) { ## LDA
56        ## Pooled estimates of variances
57        vp <- rowSums(rep(nk - 1, each=p) * v) / (n - K)
58        ## == apply(v, 1, function(z) sum.na((nk-1)*z))/(n-K)
59        if(any(i0 <- vp == 0)) vp[i0] <- 1e-7 * min(vp[!i0])
60
61        ivp <- rep(1/vp, each = nt) # to use in loop
62
63        for(k in 1:K) {
64            y <- ts - rep(m[,k], each=nt)
65            disc[,k] <- rowSums(y*y * ivp)
66            ## == apply(ts, 1, function(z) sum.na((z-m[,k])^2/vp))
67        }
68    }
69    else { ## QDA
70if(FALSE) { ## not yet quite : fails ../tests/dDA.R  -- FIXME
71        for(k in 1:K) {
72            ts <- ts - rep(m[,k], each=nt)
73            disc[,k] <- rowSums((ts*ts) / rep(v[,k], each=nt)) + sum(log(v[,k]))
74        }
75} else {
76        for(k in 1:K) {
77            disc[,k] <-
78                apply(ts,1, function(z) sum((z-m[,k])^2/v[,k])) +
79                    sum.na(log(v[,k]))
80        }
81}
82    }
83
84    ## predictions
85
86    pred <- cl0 + apply(disc, 1, which.min)
87    if(inherits(attr(ts,"na.action"), "exclude")) # had missings in `ts'
88        pred <- napredict(omit = attr(ts,"na.action"), pred)
89    pred
90}
91
92## Cleaner: One function to estimate; one to predict :
93## ------- (my tests give a time-penalty 5% for doing things two steps)
94
95dDA <- function(x, cll, pool= TRUE)
96{
97    ## Purpose: Diagonal (Linear or Quadratic) Discriminant Analysis
98
99    x <- data.matrix(x)
100    n <- nrow(x)
101    p <- ncol(x)
102
103    cl0 <- as.integer(min(cll, na.rm=TRUE) - 1)
104    cll <- as.integer(cll) - cl0 ## cll now in 1:K
105    inaC <- is.na(cll)
106    clL <- cll[!inaC]
107    K <- max(clL)
108    if(K != length(unique(clL)))
109        stop(sQuote("cll")," did not contain *consecutive* integers")
110
111    nk <- integer(K)
112    m <- v <- matrix(0,p,K)
113
114    colVars <- function(x, means = colMeans(x, na.rm = na.rm), na.rm=FALSE) {
115        x <- sweep(x, 2, means)
116        colSums(x*x, na.rm = na.rm) / (nrow(x) - 1)
117    }
118    sum.na <- function(x) sum(x, na.rm=TRUE)
119
120    ## Class means and variances
121    for(k in 1:K) {
122        which <- (cll == k)
123        nk[k] <- sum.na(which)
124        lsk <- x[which, , drop = FALSE]
125        m[,k] <- colMeans(lsk, na.rm = TRUE)
126        if(nk[k] > 1)
127            v[,k] <- colVars (lsk, na.rm = TRUE, means = m[,k]) ## else 0
128    }
129    structure(list(call = match.call(), cl0 = cl0, n=n, p=p, K=K,
130                   means=m, vars=v, nk=nk, pool=pool),
131              class = "dDA")
132}
133
134print.dDA <- function(x, ...)
135{
136    cat(if(x$pool)"Linear (pooled var)" else "Quadratic (no pooling)",
137        "Diagonal Discriminant Analysis,\n ", deparse(x$call),"\n")
138    with(x,
139         cat("  (n= ",n,") x (p= ",p,") data in K=",K," classes of [",
140             paste(nk, collapse=", "),"] observations each\n", sep=""))
141    cat("\n")
142    invisible(x)
143}
144
145predict.dDA <- function(object, newdata, pool = object$pool, ...)
146{
147    newdata <- data.matrix(newdata)
148    n <- object$n
149    p <- object$p
150    K <- object$K
151    ## means and vars  are  (p x K) matrices:
152    mu <- object$means
153    Vr <- object$vars
154    if(p != ncol(newdata))
155        stop("test set matrix must have same columns as learning one")
156    ## any NA's in test set currently must give NA predictions
157    newdata <- na.exclude(newdata)
158    nt <- nrow(newdata)
159    disc <- matrix(0, nt,K)
160
161    if(pool) { ## LDA
162        ## Pooled estimates of variances
163        vp <- rowSums(Vr * rep(object$nk - 1, each=p)) / (n - K)
164        ## == apply(Vr, 1, function(z) sum.na((nk-1)*z))/(n-K)
165        if(any(i0 <- vp == 0)) vp[i0] <- 1e-7 * min(vp[!i0])
166
167        ivp <- rep(1/vp, each = nt) # to use in loop
168
169        for(k in 1:K) {
170            y <- newdata - rep(mu[,k], each=nt)
171            disc[,k] <- rowSums(y*y * ivp)
172            ## == apply(newdata, 1, function(z) sum.na((z-mu[,k])^2/vp))
173        }
174    }
175    else { ## QDA
176        sum.na <- function(x) sum(x, na.rm=TRUE)
177        ## zero - variances are not acceptable later
178	if(any(i0 <- Vr == 0)) {
179	    if(all(i0))
180		stop("all variances are 0 -- cannot predict")
181	    Vr[i0] <- 1e-7 * min(Vr[!i0])
182	}
183
184if(FALSE) { ## not yet quite : fails ../tests/dDA.R  -- FIXME
185        for(k in 1:K) {
186            y <- newdata - rep(mu[,k], each=nt)
187            disc[,k] <- rowSums((y*y) / rep(Vr[,k], each=nt)) + sum(log(Vr[,k]))
188        }
189} else {
190        for(k in 1:K) {
191            disc[,k] <-
192                apply(newdata,1, function(z) sum((z-mu[,k])^2/Vr[,k])) +
193                    sum.na(log(Vr[,k]))
194        }
195}
196    }
197
198    ## predictions
199
200    pred <- object$cl0 + apply(disc, 1, which.min)
201    if(inherits(attr(newdata,"na.action"), "exclude")) {
202        ## had missings in `newdata'
203        pred <- napredict(omit = attr(newdata,"na.action"), pred)
204    } ##        ^^^^^^^^^ typically stats:::napredict.exclude()
205    pred
206}
207
208