1# These functions are
2# Copyright (C) 1998-2021 T.W. Yee, University of Auckland.
3# All rights reserved.
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25smartpredenv <- new.env()
26
27
28smart.mode.is <- function(mode.arg = NULL) {
29  if (!length(mode.arg)) {
30    if (exists(".smart.prediction", envir = smartpredenv)) {
31      get(".smart.prediction.mode", envir = smartpredenv)
32    } else {
33      "neutral"
34    }
35  } else {
36    if (mode.arg != "neutral" &&
37        mode.arg != "read" &&
38        mode.arg != "write")
39      stop("argument \"mode.arg\" must be one of",
40           " \"neutral\", \"read\" or \"write\"")
41    if (exists(".smart.prediction", envir = smartpredenv)) {
42      get(".smart.prediction.mode", envir = smartpredenv) ==
43        mode.arg
44    } else {
45      mode.arg == "neutral"
46    }
47  }
48}
49
50
51setup.smart <- function(mode.arg, smart.prediction = NULL,
52                        max.smart = 30) {
53  actual <- if (mode.arg == "write") vector("list", max.smart) else
54            if (mode.arg == "read") smart.prediction else
55            stop("value of 'mode.arg' unrecognized")
56
57  wrapup.smart()  # make sure
58
59  if (length(actual)) {
60
61
62    assign(".smart.prediction", actual, envir = smartpredenv)
63    assign(".smart.prediction.counter", 0, envir = smartpredenv)
64    assign(".smart.prediction.mode", mode.arg, envir = smartpredenv)
65    assign(".max.smart", max.smart, envir = smartpredenv)
66    assign(".smart.prediction", actual, envir = smartpredenv)
67  }
68}
69
70
71wrapup.smart <- function() {
72  if (exists(".smart.prediction", envir = smartpredenv))
73    rm(".smart.prediction", envir = smartpredenv)
74  if (exists(".smart.prediction.counter", envir = smartpredenv))
75    rm(".smart.prediction.counter", envir = smartpredenv)
76  if (exists(".smart.prediction.mode", envir = smartpredenv))
77    rm(".smart.prediction.mode", envir = smartpredenv)
78  if (exists(".max.smart", envir = smartpredenv))
79    rm(".max.smart", envir = smartpredenv)
80}
81
82
83get.smart.prediction <- function() {
84
85  smart.prediction.counter <- get(".smart.prediction.counter",
86                                  envir = smartpredenv)
87  max.smart <- get(".max.smart", envir = smartpredenv)
88
89  if (smart.prediction.counter > 0) {
90    smart.prediction <- get(".smart.prediction", envir = smartpredenv)
91    if (max.smart >= (smart.prediction.counter + 1))
92      for(i in max.smart:(smart.prediction.counter + 1))
93        smart.prediction[[i]] <- NULL
94    smart.prediction
95  } else
96    NULL
97}
98
99
100put.smart <- function(smart) {
101
102
103
104  max.smart <- get(".max.smart", envir = smartpredenv)
105  smart.prediction.counter <- get(".smart.prediction.counter",
106                                  envir = smartpredenv)
107  smart.prediction <- get(".smart.prediction", envir = smartpredenv)
108  smart.prediction.counter <- smart.prediction.counter + 1
109
110  if (smart.prediction.counter > max.smart) {
111    max.smart <- max.smart + (inc.smart <- 10)  # can change inc.smart
112    smart.prediction <- c(smart.prediction, vector("list", inc.smart))
113    assign(".max.smart", max.smart, envir = smartpredenv)
114  }
115
116  smart.prediction[[smart.prediction.counter]] <- smart
117  assign(".smart.prediction", smart.prediction, envir = smartpredenv)
118  assign(".smart.prediction.counter", smart.prediction.counter,
119         envir = smartpredenv)
120}
121
122
123get.smart <- function() {
124  smart.prediction <- get(".smart.prediction", envir = smartpredenv)
125  smart.prediction.counter <- get(".smart.prediction.counter",
126                                  envir = smartpredenv)
127  smart.prediction.counter <- smart.prediction.counter + 1
128  assign(".smart.prediction.counter", smart.prediction.counter,
129         envir = smartpredenv)
130  smart <- smart.prediction[[smart.prediction.counter]]
131  smart
132}
133
134
135smart.expression <- expression({
136
137
138  smart  <- get.smart()
139  assign(".smart.prediction.mode", "neutral", envir = smartpredenv)
140
141  .smart.match.call <- as.character(smart$match.call)
142  smart$match.call <- NULL  # Kill it off for the do.call
143
144  ans.smart <- do.call(.smart.match.call[1], c(list(x=x), smart))
145  assign(".smart.prediction.mode", "read", envir = smartpredenv)
146
147  ans.smart
148})
149
150
151
152
153is.smart <- function(object) {
154  if (is.function(object)) {
155    if (is.logical(a <- attr(object, "smart"))) a else FALSE
156  } else {
157    if (length(slotNames(object))) {
158        if (length(object@smart.prediction) == 1 &&
159            is.logical(object@smart.prediction$smart.arg))
160        object@smart.prediction$smart.arg else
161            any(slotNames(object) == "smart.prediction")
162    } else {
163      if (length(object$smart.prediction) == 1 &&
164          is.logical(object$smart.prediction$smart.arg))
165        object$smart.prediction$smart.arg else
166        any(names(object) == "smart.prediction")
167    }
168  }
169}
170
171
172
173
174
175
176
177
178
179
180
181
182 sm.bs <-
183  function (x, df = NULL, knots = NULL, degree = 3, intercept = FALSE,
184            Boundary.knots = range(x)) {
185  x <- x  # Evaluate x; needed for nested calls, e.g., sm.bs(sm.scale(x)).
186  if (smart.mode.is("read")) {
187    return(eval(smart.expression))
188  }
189
190  nx <- names(x)
191  x <- as.vector(x)
192  nax <- is.na(x)
193  if (nas <- any(nax))
194    x <- x[!nax]
195  if (!missing(Boundary.knots)) {
196    Boundary.knots <- sort(Boundary.knots)
197    outside <- (ol <- x < Boundary.knots[1]) | (or <- x >
198        Boundary.knots[2L])
199  } else outside <- FALSE
200  ord <- 1 + (degree <- as.integer(degree))
201  if (ord <= 1)
202    stop("'degree' must be integer >= 1")
203  if (!missing(df) && missing(knots)) {
204    nIknots <- df - ord + (1 - intercept)
205    if (nIknots < 0) {
206      nIknots <- 0
207      warning("'df' was too small; have used  ", ord - (1 - intercept))
208    }
209    knots <- if (nIknots > 0) {
210      knots <- seq(from = 0, to = 1, length = nIknots +
211          2)[-c(1, nIknots + 2)]
212      stats::quantile(x[!outside], knots)
213    }
214  }
215  Aknots <- sort(c(rep(Boundary.knots, ord), knots))
216  if (any(outside)) {
217    warning("some 'x' values beyond boundary knots may ",
218            "cause ill-conditioned bases")
219    derivs <- 0:degree
220    scalef <- gamma(1L:ord)
221    basis <- array(0, c(length(x), length(Aknots) - degree - 1L))
222      if (any(ol)) {
223        k.pivot <- Boundary.knots[1L]
224        xl <- cbind(1, outer(x[ol] - k.pivot, 1L:degree, "^"))
225        tt <- splines::splineDesign(Aknots, rep(k.pivot, ord), ord, derivs)
226        basis[ol, ] <- xl %*% (tt/scalef)
227      }
228      if (any(or)) {
229        k.pivot <- Boundary.knots[2L]
230        xr <- cbind(1, outer(x[or] - k.pivot, 1L:degree, "^"))
231        tt <- splines::splineDesign(Aknots, rep(k.pivot, ord), ord, derivs)
232        basis[or, ] <- xr %*% (tt/scalef)
233      }
234      if (any(inside <- !outside))
235        basis[inside, ] <- splines::splineDesign(Aknots, x[inside], ord)
236  } else basis <- splines::splineDesign(Aknots, x, ord)
237  if (!intercept)
238    basis <- basis[, -1L, drop = FALSE]
239  n.col <- ncol(basis)
240  if (nas) {
241    nmat <- matrix(NA_real_, length(nax), n.col)
242    nmat[!nax, ] <- basis
243    basis <- nmat
244  }
245  dimnames(basis) <- list(nx, 1L:n.col)
246  a <- list(degree = degree,
247            knots = if (is.null(knots)) numeric(0L) else knots,
248            Boundary.knots = Boundary.knots,
249            intercept = intercept,
250            Aknots = Aknots)
251  attributes(basis) <- c(attributes(basis), a)
252  class(basis) <- c("bs", "basis", "matrix")
253
254  if (smart.mode.is("write"))
255    put.smart(list(df = df,
256                   knots = knots,
257                   degree = degree,
258                   intercept = intercept,
259                   Boundary.knots = Boundary.knots,
260                   match.call = match.call()))
261
262  basis
263}
264attr( sm.bs, "smart") <- TRUE
265
266
267
268
269
270
271 sm.ns <-
272  function (x, df = NULL, knots = NULL, intercept = FALSE,
273            Boundary.knots = range(x)) {
274  x <- x  # Evaluate x; needed for nested calls, e.g., sm.bs(sm.scale(x)).
275  if (smart.mode.is("read")) {
276    return(eval(smart.expression))
277  }
278
279  nx <- names(x)
280  x <- as.vector(x)
281  nax <- is.na(x)
282  if (nas <- any(nax))
283    x <- x[!nax]
284  if (!missing(Boundary.knots)) {
285    Boundary.knots <- sort(Boundary.knots)
286    outside <- (ol <- x < Boundary.knots[1L]) | (or <- x >
287        Boundary.knots[2L])
288  } else outside <- FALSE
289  if (!missing(df) && missing(knots)) {
290    nIknots <- df - 1 - intercept
291    if (nIknots < 0) {
292      nIknots <- 0
293      warning("'df' was too small; have used ", 1 + intercept)
294    }
295    knots <- if (nIknots > 0) {
296      knots <- seq.int(0, 1, length.out = nIknots + 2L)[-c(1L, nIknots + 2L)]
297      stats::quantile(x[!outside], knots)
298    }
299  } else nIknots <- length(knots)
300  Aknots <- sort(c(rep(Boundary.knots, 4), knots))
301  if (any(outside)) {
302    basis <- array(0, c(length(x), nIknots + 4L))
303    if (any(ol)) {
304      k.pivot <- Boundary.knots[1L]
305      xl <- cbind(1, x[ol] - k.pivot)
306      tt <- splines::splineDesign(Aknots, rep(k.pivot, 2L), 4, c(0, 1))
307      basis[ol, ] <- xl %*% tt
308    }
309    if (any(or)) {
310        k.pivot <- Boundary.knots[2L]
311        xr <- cbind(1, x[or] - k.pivot)
312        tt <- splines::splineDesign(Aknots, rep(k.pivot, 2L), 4, c(0, 1))
313        basis[or, ] <- xr %*% tt
314      }
315      if (any(inside <- !outside))
316        basis[inside, ] <- splines::splineDesign(Aknots, x[inside], 4)
317    } else basis <- splines::splineDesign(Aknots, x, 4)
318  const <- splines::splineDesign(Aknots, Boundary.knots, 4, c(2, 2))
319  if (!intercept) {
320    const <- const[, -1, drop = FALSE]
321    basis <- basis[, -1, drop = FALSE]
322  }
323  qr.const <- qr(t(const))
324  basis <- as.matrix((t(qr.qty(qr.const, t(basis))))[, -(1L:2L),
325      drop = FALSE])
326  n.col <- ncol(basis)
327  if (nas) {
328    nmat <- matrix(NA_real_, length(nax), n.col)
329    nmat[!nax, ] <- basis
330    basis <- nmat
331  }
332  dimnames(basis) <- list(nx, 1L:n.col)
333  a <- list(degree = 3,
334            knots = if (is.null(knots)) numeric(0) else knots,
335            Boundary.knots = Boundary.knots,
336            intercept = intercept,
337            Aknots = Aknots)
338  attributes(basis) <- c(attributes(basis), a)
339  class(basis) <- c("ns", "basis", "matrix")
340
341  if (smart.mode.is("write"))
342    put.smart(list(df = df,
343                   knots = knots,
344                   intercept = intercept,
345                   Boundary.knots = Boundary.knots,
346                   match.call = match.call()))
347
348  basis
349}
350attr( sm.ns, "smart") <- TRUE
351
352
353
354
355
356
357
358
359 sm.poly <-
360  function (x, ..., degree = 1, coefs = NULL, raw = FALSE) {
361    x <- x  # Evaluate x; needed for nested calls, e.g., sm.bs(sm.scale(x)).
362    if (!raw && smart.mode.is("read")) {
363      smart <- get.smart()
364      degree <- smart$degree
365      coefs  <- smart$coefs
366      raw  <- smart$raw
367    }
368
369    dots <- list(...)
370    if (nd <- length(dots)) {
371      if (nd == 1 && length(dots[[1]]) == 1L)
372        degree <- dots[[1L]] else
373      return(polym(x, ..., degree = degree, raw = raw))
374    }
375    if (is.matrix(x)) {
376      m <- unclass(as.data.frame(cbind(x, ...)))
377      return(do.call("polym", c(m, degree = degree, raw = raw)))
378    }
379    if (degree < 1)
380      stop("'degree' must be at least 1")
381
382
383
384    if (smart.mode.is("write") || smart.mode.is("neutral"))
385    if (degree >= length(x))
386        stop("degree must be less than number of points")
387
388
389
390
391    if (anyNA(x))
392      stop("missing values are not allowed in 'poly'")
393    n <- degree + 1
394    if (raw) {
395      if (degree >= length(unique(x)))
396        stop("'degree' must be less than number of unique points")
397      Z <- outer(x, 1L:degree, "^")
398      colnames(Z) <- 1L:degree
399      attr(Z, "degree") <- 1L:degree
400      class(Z) <- c("poly", "matrix")
401      return(Z)
402    }
403    if (is.null(coefs)) {
404      if (degree >= length(unique(x)))
405        stop("'degree' must be less than number of unique points")
406      xbar <- mean(x)
407      x <- x - xbar
408      X <- outer(x, seq_len(n) - 1, "^")
409      QR <- qr(X)
410
411      if (QR$rank < degree)
412        stop("'degree' must be less than number of unique points")
413
414      z <- QR$qr
415      z <- z * (row(z) == col(z))
416      raw <- qr.qy(QR, z)
417      norm2 <- colSums(raw^2)
418      alpha <- (colSums(x * raw^2)/norm2 + xbar)[1L:degree]
419      Z <- raw/rep(sqrt(norm2), each = length(x))
420      colnames(Z) <- 1L:n - 1L
421      Z <- Z[, -1, drop = FALSE]
422      attr(Z, "degree") <- 1:degree
423      attr(Z, "coefs") <- list(alpha = alpha, norm2 = c(1, norm2))
424      class(Z) <- c("poly", "matrix")
425    } else {
426      alpha <- coefs$alpha
427      norm2 <- coefs$norm2
428      Z <- matrix(, length(x), n)
429      Z[, 1] <- 1
430      Z[, 2] <- x - alpha[1L]
431      if (degree > 1)
432        for (i in 2:degree) Z[, i + 1] <- (x - alpha[i]) *
433            Z[, i] - (norm2[i + 1]/norm2[i]) * Z[, i - 1]
434      Z <- Z/rep(sqrt(norm2[-1L]), each = length(x))
435      colnames(Z) <- 0:degree
436      Z <- Z[, -1, drop = FALSE]
437      attr(Z, "degree") <- 1L:degree
438      attr(Z, "coefs") <- list(alpha = alpha, norm2 = norm2)
439      class(Z) <- c("poly", "matrix")
440    }
441
442  if (smart.mode.is("write"))
443    put.smart(list(degree = degree,
444                   coefs = attr(Z, "coefs"),
445                   raw = FALSE,  # raw is changed above
446                   match.call = match.call()))
447
448  Z
449}
450attr(sm.poly, "smart") <- TRUE
451
452
453
454
455
456
457 sm.scale.default <- function (x, center = TRUE, scale = TRUE) {
458  x <- as.matrix(x)
459
460  if (smart.mode.is("read")) {
461    return(eval(smart.expression))
462  }
463
464  nc <- ncol(x)
465  if (is.logical(center)) {
466    if (center) {
467      center <- colMeans(x, na.rm = TRUE)
468      x <- sweep(x, 2L, center, check.margin = FALSE)
469    }
470  } else if (is.numeric(center) && (length(center) == nc))
471    x <- sweep(x, 2L, center, check.margin = FALSE) else
472    stop("length of 'center' must equal the number of columns of 'x'")
473  if (is.logical(scale)) {
474    if (scale) {
475      f <- function(v) {
476        v <- v[!is.na(v)]
477        sqrt(sum(v^2) / max(1, length(v) - 1L))
478      }
479      scale <- apply(x, 2L, f)
480      x <- sweep(x, 2L, scale, "/", check.margin = FALSE)
481    }
482  } else if (is.numeric(scale) && length(scale) == nc)
483    x <- sweep(x, 2L, scale, "/", check.margin = FALSE) else
484    stop("length of 'scale' must equal the number of columns of 'x'")
485  if (is.numeric(center))
486    attr(x, "scaled:center") <- center
487  if (is.numeric(scale))
488    attr(x, "scaled:scale") <- scale
489
490  if (smart.mode.is("write")) {
491    put.smart(list(center = center, scale = scale,
492                   match.call = match.call()))
493  }
494
495  x
496}
497attr(sm.scale.default, "smart") <- TRUE
498
499
500
501
502
503
504 sm.scale <- function (x, center = TRUE, scale = TRUE)
505  UseMethod("sm.scale")
506
507
508
509attr(sm.scale, "smart") <- TRUE
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525sm.min1 <- function(x) {
526  x <- x  # Evaluate x; needed for nested calls, e.g., sm.bs(sm.scale(x)).
527  minx <- min(x)
528  if (smart.mode.is("read")) {
529    smart  <- get.smart()
530    minx <- smart$minx  # Overwrite its value
531  } else if (smart.mode.is("write"))
532    put.smart(list(minx = minx))
533  minx
534}
535attr(sm.min1, "smart") <- TRUE
536
537
538
539
540
541sm.min2 <- function(x, .minx = min(x)) {
542  x <- x  # Evaluate x; needed for nested calls, e.g., sm.bs(sm.scale(x)).
543  if (smart.mode.is("read")) {  # Use recursion
544    return(eval(smart.expression))
545  } else
546  if (smart.mode.is("write"))
547    put.smart(list( .minx = .minx , match.call = match.call()))
548  .minx
549}
550attr(sm.min2, "smart") <- TRUE
551
552
553
554
555
556
557
558
559sm.scale1 <- function(x, center = TRUE, scale = TRUE) {
560  x <- x  # Evaluate x; needed for nested calls, e.g., sm.bs(sm.scale(x)).
561  if (!is.vector(x))
562    stop("argument 'x' must be a vector")
563  if (smart.mode.is("read")) {
564    smart  <- get.smart()
565    return((x - smart$Center) / smart$Scale)
566  }
567  if (is.logical(center))
568    center <- if (center) mean(x) else 0
569  if (is.logical(scale))
570    scale <- if (scale) sqrt(var(x)) else 1
571  if (smart.mode.is("write"))
572    put.smart(list(Center = center,
573                   Scale  = scale))
574  (x - center) / scale
575}
576attr(sm.scale1, "smart") <- TRUE
577
578
579
580sm.scale2 <- function(x, center = TRUE, scale = TRUE) {
581  x <- x  # Evaluate x; needed for nested calls, e.g., sm.bs(sm.scale(x)).
582  if (!is.vector(x))
583    stop("argument 'x' must be a vector")
584  if (smart.mode.is("read")) {
585    return(eval(smart.expression))  # Recursion used
586  }
587  if (is.logical(center))
588    center <- if (center) mean(x) else 0
589  if (is.logical(scale))
590    scale <- if (scale) sqrt(var(x)) else 1
591  if (smart.mode.is("write"))
592    put.smart(list(center = center,
593                   scale  = scale,
594                   match.call = match.call()))
595    (x - center) / scale
596}
597attr(sm.scale2, "smart") <- TRUE
598
599
600
601
602