1## Calculate Kullback-Leibler projection from sscox objects
2project.sscox <- function(object,include,...)
3{
4    qd.pt <- object$mf
5    if (is.null(object$cnt)) qd.wt <- rep(1,dim(qd.pt)[1])
6    else qd.wt <- object$cnt
7    if (!is.null(object$random))
8        qd.wt <- qd.wt*exp(object$random$qd.z%*%object$b)
9    bias <- object$bias
10    ## evaluate full model
11    mesh0 <- predict(object,qd.pt)
12    qd.wt <- qd.wt*bias$qd.wt
13    qd.wt <- t(t(qd.wt)/apply(qd.wt*mesh0,2,sum))
14    ## extract terms in subspace
15    nqd <- dim(qd.wt)[1]
16    nxi <- length(object$id.basis)
17    qd.s <- qd.r <- q <- NULL
18    theta <- d <- NULL
19    n0.wk <- nq.wk <- nq <- 0
20    for (label in object$terms$labels) {
21        x.basis <- object$mf[object$id.basis,object$term[[label]]$vlist]
22        qd.x <- qd.pt[,object$term[[label]]$vlist]
23        nphi <- object$term[[label]]$nphi
24        nrk <- object$term[[label]]$nrk
25        if (nphi) {
26            phi <- object$term[[label]]$phi
27            for (i in 1:nphi) {
28                n0.wk <- n0.wk + 1
29                if (!any(label==include)) next
30                d <- c(d,object$d[n0.wk])
31                qd.s <- cbind(qd.s,phi$fun(qd.x,nu=i,env=phi$env))
32            }
33        }
34        if (nrk) {
35            rk <- object$term[[label]]$rk
36            for (i in 1:nrk) {
37                nq.wk <- nq.wk + 1
38                if (!any(label==include)) next
39                nq <- nq + 1
40                theta <- c(theta,object$theta[nq.wk])
41                qd.r <- array(c(qd.r,rk$fun(x.basis,qd.x,nu=i,env=rk$env,out=TRUE)),
42                              c(nxi,nqd,nq))
43                q <- cbind(q,rk$fun(x.basis,x.basis,nu=i,env=rk$env,out=FALSE))
44            }
45        }
46    }
47    if (!is.null(object$partial)) {
48        matx.p <- model.matrix(object$partial$mt,object$mf)[,-1,drop=FALSE]
49        matx.p <- scale(matx.p)
50        for (label in object$lab.p) {
51            n0.wk <- n0.wk + 1
52            if (!any(label==include)) next
53            d <- c(d,object$d[n0.wk])
54            qd.s <- cbind(qd.s,matx.p[,label])
55        }
56    }
57    if (!is.null(qd.s)) {
58        nn <- nxi + ncol(qd.s)
59        qd.s <- t(qd.s)
60    }
61    else nn <- nxi
62    ## calculate projection
63    rkl <- function(theta1=NULL) {
64        theta.wk <- 1:nq
65        theta.wk[fix] <- theta[fix]
66        if (nq-1) theta.wk[-fix] <- theta1
67        qd.rs <- 0
68        for (i in 1:nq) qd.rs <- qd.rs + 10^theta.wk[i]*qd.r[,,i]
69        qd.rs <- rbind(qd.rs,qd.s)
70        z <- .Fortran("drkl",
71                      cd=as.double(cd), as.integer(nn),
72                      as.double(t(qd.rs)), as.integer(nqd), as.integer(bias$nt),
73                      as.double(bias$wt), as.double(t(qd.wt)),
74                      mesh=as.double(mesh0), as.double(.Machine$double.eps),
75                      as.double(1e-6), as.integer(30), double(nn),
76                      double(2*bias$nt*(nqd+1)+nn*(2*nn+4)), info=integer(1),
77                      PACKAGE="gss")
78        if (z$info==1)
79            stop("gss error in project.sscox: Newton iteration diverges")
80        if (z$info==2)
81            warning("gss warning in project.sscox: Newton iteration fails to converge")
82        assign("cd",z$cd,inherits=TRUE)
83        assign("mesh1",z$mesh,inherits=TRUE)
84        sum(bias$wt*(apply(qd.wt*log(mesh0/mesh1)*mesh0,2,sum)+
85                     log(apply(qd.wt*mesh1,2,sum))))
86    }
87    cv.wk <- function(theta) cv.scale*rkl(theta)+cv.shift
88    if (nq) {
89        ## initialization
90        if (is.null(qd.s)) theta.wk <- 0
91        else {
92            qd.r.wk <- 0
93            for (i in 1:nq) qd.r.wk <- qd.r.wk + 10^theta[i]*qd.r[,,i]
94            vv.s <- vv.r <- 0
95            for (i in 1:bias$nt) {
96                mu.s <- apply(qd.wt[,i]*qd.s,2,sum)/sum(qd.wt[,i])
97                v.s <- apply(qd.wt[,i]*qd.s^2,2,sum)/sum(qd.wt[,i])
98                v.s <- v.s - mu.s^2
99                mu.r <- apply(qd.wt[,i]*qd.r.wk,2,sum)/sum(qd.wt[,i])
100                v.r <- apply(qd.wt[,i]*qd.r.wk^2,2,sum)/sum(qd.wt[,i])
101                v.r <- v.r - mu.r^2
102                vv.s <- vv.s + bias$wt[i]*v.s
103                vv.r <- vv.r + bias$wt[i]*v.r
104            }
105            theta.wk <- log10(sum(vv.s)/(nn-nxi)/sum(vv.r)*nxi) / 2
106        }
107        theta <- theta + theta.wk
108        tmp <- NULL
109        for (i in 1:nq) tmp <- c(tmp,10^theta[i]*sum(q[,i]))
110        fix <- rev(order(tmp))[1]
111        ## projection
112        cd <- c(10^(-theta.wk)*object$c,d)
113        mesh1 <- NULL
114        if (nq-1) {
115            if (object$skip.iter) kl <- rkl(theta[-fix])
116            else {
117                if (nq-2) {
118                    ## scale and shift cv
119                    tmp <- abs(rkl(theta[-fix]))
120                    cv.scale <- 1
121                    cv.shift <- 0
122                    if (tmp<1&tmp>10^(-4)) {
123                        cv.scale <- 10/tmp
124                        cv.shift <- 0
125                    }
126                    if (tmp<10^(-4)) {
127                        cv.scale <- 10^2
128                        cv.shift <- 10
129                    }
130                    zz <- nlm(cv.wk,theta[-fix],stepmax=.5,ndigit=7)
131                }
132                else {
133                    the.wk <- theta[-fix]
134                    repeat {
135                        mn <- the.wk-1
136                        mx <- the.wk+1
137                        zz <- nlm0(rkl,c(mn,mx))
138                        if (min(zz$est-mn,mx-zz$est)>=1e-3) break
139                        else the.wk <- zz$est
140                    }
141                }
142                kl <- rkl(zz$est)
143            }
144        }
145        else kl <- rkl()
146    }
147    else {
148        nn <- nrow(qd.s)
149        z <- .Fortran("drkl",
150                      cd=as.double(d), as.integer(nn),
151                      as.double(qd.s), as.integer(nqd), as.integer(bias$nt),
152                      as.double(bias$wt), as.double(t(qd.wt)),
153                      mesh=as.double(mesh0), as.double(.Machine$double.eps),
154                      as.double(1e-6), as.integer(30), double(nn),
155                      double(2*bias$nt*(nqd+1)+nn*(2*nn+4)), info=integer(1),
156                      PACKAGE="gss")
157        if (z$info==1)
158            stop("gss error in project.sscox: Newton iteration diverges")
159        if (z$info==2)
160            warning("gss warning in project.sscox: Newton iteration fails to converge")
161        mesh1 <- z$mesh
162        kl <- sum(bias$wt*(apply(qd.wt*log(mesh0/mesh1)*mesh0,2,sum)+
163                           log(apply(qd.wt*mesh1,2,sum))))
164    }
165    kl0 <- sum(bias$wt*(apply(qd.wt*log(mesh0)*mesh0,2,sum)+
166                        log(apply(qd.wt,2,sum))))
167    wt.wk <- t(t(qd.wt)/apply(qd.wt*mesh1,2,sum))
168    kl1 <- sum(bias$wt*(apply(wt.wk*log(mesh1)*mesh1,2,sum)+
169                        log(apply(wt.wk,2,sum))))
170    obj <- list(ratio=kl/kl0,kl=kl,check=(kl+kl1)/kl0)
171    obj
172}
173