1## Calculate Kullback-Leibler projection from ssllrm objects
2project.ssllrm <- function(object,include,...)
3{
4    mf <- object$mf
5    term <- object$term
6    id.basis <- object$id.basis
7    qd.pt <- object$qd.pt
8    xx.wt <- object$xx.wt
9    qd.wt <- object$qd.wt
10    ## evaluate full model
11    x <- object$mf[!object$x.dup.ind,object$xnames,drop=FALSE]
12    fit0 <- object$fit
13    ## extract terms in subspace
14    include <- union(object$ynames,include)
15    nmesh <- dim(qd.pt)[1]
16    nbasis <- length(id.basis)
17    nx <- length(xx.wt)
18    qd.s <- NULL
19    qd.r <- as.list(NULL)
20    theta <- d <- q <- NULL
21    nu.wk <- nu <- nq.wk <- nq <- 0
22    for (label in term$labels) {
23        vlist <- term[[label]]$vlist
24        x.list <- object$xnames[object$xnames%in%vlist]
25        y.list <- object$ynames[object$ynames%in%vlist]
26        xy.basis <- mf[id.basis,vlist]
27        qd.xy <- data.frame(matrix(0,nmesh,length(vlist)))
28        names(qd.xy) <- vlist
29        qd.xy[,y.list] <- qd.pt[,y.list]
30        if (length(x.list)) xx <- x[,x.list,drop=FALSE]
31        else xx <- NULL
32        nphi <- term[[label]]$nphi
33        nrk <- term[[label]]$nrk
34        if (nphi) {
35            phi <- term[[label]]$phi
36            for (i in 1:nphi) {
37                nu.wk <- nu.wk+1
38                if (is.null(xx)) {
39                    if (!any(label==include)) next
40                    nu <- nu+1
41                    d <- c(d,object$d[nu.wk])
42                    s.wk <- phi$fun(qd.xy[,,drop=TRUE],nu=i,env=phi$env)
43                    wk <- matrix(s.wk,nmesh,nx)
44                    qd.s <- array(c(qd.s,wk),c(nmesh,nx,nu))
45                }
46                else {
47                    if (!any(label==include)) next
48                    nu <- nu+1
49                    d <- c(d,object$d[nu.wk])
50                    wk <- NULL
51                    for (j in 1:nx) {
52                        qd.xy[,x.list] <- xx[rep(j,nmesh),]
53                        wk <- cbind(wk,phi$fun(qd.xy,i,phi$env))
54                    }
55                    qd.s <- array(c(qd.s,wk),c(nmesh,nx,nu))
56                }
57            }
58        }
59        if (nrk) {
60            rk <- term[[label]]$rk
61            for (i in 1:nrk) {
62                nq.wk <- nq.wk+1
63                if (is.null(xx)) {
64                    if (!any(label==include)) next
65                    nq <- nq+1
66                    theta <- c(theta,object$theta[nq.wk])
67                    qd.r.wk <- rk$fun(qd.xy[,,drop=TRUE],xy.basis,nu=i,env=rk$env,out=TRUE)
68                    qd.r[[nq]] <- qd.r.wk
69                    q <- cbind(q,rk$fun(xy.basis,xy.basis,i,rk$env,out=FALSE))
70                }
71                else {
72                    if (!any(label==include)) next
73                    nq <- nq+1
74                    theta <- c(theta,object$theta[nq.wk])
75                    qd.wk <- NULL
76                    for (j in 1:nx) {
77                        qd.xy[,x.list] <- xx[rep(j,nmesh),]
78                        qd.wk <- array(c(qd.wk,rk$fun(qd.xy,xy.basis,i,rk$env,TRUE)),
79                                       c(nmesh,nbasis,j))
80                    }
81                    qd.r[[nq]] <- qd.wk
82                    q <- cbind(q,rk$fun(xy.basis,xy.basis,i,rk$env,out=FALSE))
83                }
84            }
85        }
86    }
87    nnull <- length(d)
88    nxis <- nbasis+nnull
89    ## random effect offset
90    if (!is.null(object$b)) {
91        offset <- apply(object$Random$qd.z,c(1,2),function(x,y)sum(x*y),object$b)
92    }
93    else offset <- matrix(0,nmesh,nx)
94    ## calculate projection
95    rkl <- function(theta1=NULL) {
96        theta.wk <- 1:nq
97        theta.wk[fix] <- theta[fix]
98        if (nq-1) theta.wk[-fix] <- theta1
99        qd.rs <- array(0,c(nmesh,nbasis,nx))
100        for (i in 1:nq) {
101            if (length(dim(qd.r[[i]]))==3) qd.rs <- qd.rs + 10^theta[i]*qd.r[[i]]
102            else qd.rs <- qd.rs + as.vector(10^theta[i]*qd.r[[i]])
103        }
104        qd.rs <- aperm(qd.rs,c(1,3,2))
105        qd.rs <- array(c(qd.rs,qd.s),c(nmesh,nx,nxis))
106        qd.rs <- aperm(qd.rs,c(1,3,2))
107        z <- .Fortran("llrmrkl",
108                      cd=as.double(cd), as.integer(nxis),
109                      as.double(qd.rs), as.integer(nmesh), as.integer(nx),
110                      as.double(xx.wt), as.double(qd.wt), as.double(t(fit0)),
111                      as.double(offset), as.double(.Machine$double.eps),
112                      wt=double(nmesh*nx), double(nmesh*nx), double(nxis),
113                      double(nxis), double(nxis*nxis), double(nxis*nxis),
114                      integer(nxis), double(nxis), as.double(1e-6), as.integer(30),
115                      info=integer(1), PACKAGE="gss")
116        if (z$info==1)
117            stop("gss error in project.ssllrm: Newton iteration diverges")
118        if (z$info==2)
119            warning("gss warning in project.ssllrm: Newton iteration fails to converge")
120        assign("cd",z$cd,inherits=TRUE)
121        z$wt[1]
122    }
123    cv.wk <- function(theta) cv.scale*rkl(theta)+cv.shift
124    if (nq) {
125        ## initialization
126        if (!nnull) theta.wk <- 0
127        else {
128            qd.r.wk <- array(0,c(nmesh,nbasis,nx))
129            for (i in 1:nq) {
130                if (length(dim(qd.r[[i]]))==3) qd.r.wk <- qd.r.wk + 10^theta[i]*qd.r[[i]]
131                else qd.r.wk <- qd.r.wk + as.vector(10^theta[i]*qd.r[[i]])
132            }
133            v.s <- v.r <- 0
134            for (i in 1:nx) {
135                mu.s <- apply(fit0[i,]*qd.s[,i,,drop=FALSE],2,sum)
136                v.s.wk <- apply(fit0[i,]*qd.s[,i,,drop=FALSE]^2,2,sum)-mu.s^2
137                mu.r <- apply(fit0[i,]*qd.r.wk[,,i,drop=FALSE],2,sum)
138                v.r.wk <- apply(fit0[i,]*qd.r.wk[,,i,drop=FALSE]^2,2,sum)-mu.r^2
139                v.s <- v.s + xx.wt[i]*v.s.wk
140                v.r <- v.r + xx.wt[i]*v.r.wk
141            }
142            theta.wk <- log10(sum(v.s)/nnull/sum(v.r)*nbasis) / 2
143        }
144        theta <- theta + theta.wk
145        tmp <- NULL
146        for (i in 1:nq) tmp <- c(tmp,10^theta[i]*sum(q[,i]))
147        fix <- rev(order(tmp))[1]
148        ## projection
149        cd <- c(10^(-theta.wk)*object$c,d)
150        mesh1 <- NULL
151        if (nq-1) {
152            if (object$skip.iter) kl <- rkl(theta[-fix])
153            else {
154                if (nq-2) {
155                    ## scale and shift cv
156                    tmp <- abs(rkl(theta[-fix]))
157                    cv.scale <- 1
158                    cv.shift <- 0
159                    if (tmp<1&tmp>10^(-4)) {
160                        cv.scale <- 10/tmp
161                        cv.shift <- 0
162                    }
163                    if (tmp<10^(-4)) {
164                        cv.scale <- 10^2
165                        cv.shift <- 10
166                    }
167                    zz <- nlm(cv.wk,theta[-fix],stepmax=.5,ndigit=7)
168                }
169                else {
170                    the.wk <- theta[-fix]
171                    repeat {
172                        mn <- the.wk-1
173                        mx <- the.wk+1
174                        zz <- nlm0(rkl,c(mn,mx))
175                        if (min(zz$est-mn,mx-zz$est)>=1e-3) break
176                        else the.wk <- zz$est
177                    }
178                }
179                kl <- rkl(zz$est)
180            }
181        }
182        else kl <- rkl()
183    }
184    else {
185        z <- .Fortran("llrmrkl",
186                      cd=as.double(d), as.integer(nnull),
187                      as.double(aperm(qd.s,c(1,3,2))), as.integer(nmesh), as.integer(nx),
188                      as.double(xx.wt), as.double(qd.wt), as.double(t(fit0)),
189                      as.double(offset), as.double(.Machine$double.eps),
190                      wt=double(nmesh*nx), double(nmesh*nx), double(nnull),
191                      double(nnull), double(nnull*nnull), double(nnull*nnull),
192                      integer(nnull), double(nnull), as.double(1e-6), as.integer(30),
193                      info=integer(1), PACKAGE="gss")
194        if (z$info==1)
195            stop("gss error in project.ssllrm: Newton iteration diverges")
196        if (z$info==2)
197            warning("gss warning in project.ssllrm: Newton iteration fails to converge")
198        kl <- z$wt[1]
199    }
200    ## cfit
201    cfit <- matrix(1,nx,nmesh)
202    if (!is.null(object$b)) {
203        qd.z <- object$Random$qd.z
204        nz <- object$Random$sigma$env$nz
205        id.wk <- 0
206    }
207    for (ylab in object$ynames) {
208        lvl <- levels(object$mf[,ylab])
209        if (is.null(object$cnt)) wk <- table(object$mf[,ylab])
210        else wk <- table(rep(object$mf[,ylab],object$cnt))
211
212        if (is.null(object$cnt)) wk <- table(object$mf[,ylab])
213        else {
214            wk <- NULL
215            for (lvl in levels(object$mf[,ylab]))
216                wk <- c(wk,sum(object$cnt[object$mf[,ylab]==lvl]))
217        }
218        wk <- wk/sum(wk)
219        nlvl <- length(wk)
220        if (is.null(object$b)) {
221            for (j in 1:nlvl) {
222                id <- (1:nmesh)[qd.pt[,ylab]==lvl[j]]
223                cfit[,id] <- cfit[,id]*wk[j]
224            }
225        }
226        else {
227            id <- NULL
228            for (j in 1:nlvl) {
229                id <- c(id,(1:nmesh)[qd.pt[,ylab]==lvl[j]][1])
230            }
231            offset <- apply(qd.z[id,,id.wk+(1:nz*(nlvl-1)),drop=FALSE],c(1,2),
232                            function(x,y)sum(x*y),object$b[id.wk+(1:nz*(nlvl-1))])
233            id.wk <- id.wk + nz*(nlvl-1)
234            eta <- log(wk[-nlvl]/wk[nlvl])
235            repeat {
236                p <- exp(c(eta,0)+offset)
237                p <- t(p)/apply(p,2,sum)
238                u <- (apply(p*xx.wt,2,sum)-wk)[-nlvl]
239                w <- 0
240                for (i in 1:nx) {
241                    w <- w + xx.wt[i]*(diag(p[i,])-outer(p[i,],p[i,]))[-nlvl,-nlvl]
242                }
243                eta.new <- eta-solve(w,u)
244                if (max(abs(eta-eta.new)/(1+abs(eta)))<1e-7) break
245                eta <- eta.new
246            }
247            p <- exp(c(eta,0)+offset)
248            p <- t(p)/apply(p,2,sum)
249            for (j in 1:nlvl) {
250                id <- (1:nmesh)[qd.pt[,ylab]==lvl[j]]
251                cfit[,id] <- cfit[,id]*p[,j]
252            }
253        }
254    }
255    ## return
256    kl0 <- 0
257    for (i in 1:nx) {
258        wk <- sum(log(fit0[i,]/cfit[i,])*fit0[i,])
259        kl0 <- kl0 + xx.wt[i]*wk
260    }
261    list(ratio=kl/kl0,kl=kl)
262}
263