1#' Functions for the Dirichlet Distribution
2#'
3#' Functions to compute the density of or generate random deviates from the
4#' Dirichlet distribution
5#'
6#' The Dirichlet distribution is the multidimensional generalization of the
7#' beta distribution.  It is the canonical Bayesian distribution for the
8#' parameter estimates of a multinomial distribution.
9#'
10#' @aliases rdirichlet ddirichlet
11#' @param x A vector containing a single random deviate or matrix containing one
12#' random deviate per row.
13#' @param alpha Vector or (for \code{ddirichlet}) matrix containing shape
14#' parameters.
15#' @return \code{ddirichlet} returns a vector containing the Dirichlet density
16#' for the corresponding rows of \code{x}.
17#'
18#' \code{rdirichlet} returns a matrix with \code{n} rows, each containing a
19#' single Dirichlet random deviate.
20#' @author Code original posted by Ben Bolker to R-News on Fri Dec 15 2000. See
21#' \url{https://stat.ethz.ch/pipermail/r-help/2000-December/009561.html}.  Ben
22#' attributed the code to Ian Wilson \email{i.wilson@@maths.abdn.ac.uk}.
23#' Subsequent modifications by Gregory R. Warnes \email{greg@@warnes.net}.
24#' @seealso \code{\link{dbeta}}, \code{\link{rbeta}}
25#' @keywords distribution
26#' @examples
27#'
28#'
29#' x <- rdirichlet(20, c(1, 1, 1))
30#'
31#' ddirichlet(x, c(1, 1, 1))
32#' @name dirichlet
33NULL
34
35#' @describeIn dirichlet  Dirichlet distribution function.
36#' @export
37ddirichlet <- function(x, alpha) {
38  dirichlet1 <- function(x, alpha) {
39    logD <- sum(lgamma(alpha)) - lgamma(sum(alpha))
40    s <- (alpha - 1) * log(x)
41    s <- ifelse(alpha == 1 & x == 0, -Inf, s)
42    exp(sum(s) - logD)
43  }
44
45  # make sure x is a matrix
46  if (!is.matrix(x)) {
47    if (is.data.frame(x)) {
48      x <- as.matrix(x)
49    } else {
50      x <- t(x)
51    }
52  }
53
54  if (!is.matrix(alpha)) {
55    alpha <- matrix(alpha, ncol = length(alpha), nrow = nrow(x), byrow = TRUE)
56  }
57
58  if (any(dim(x) != dim(alpha))) {
59    stop("Mismatch between dimensions of 'x' and 'alpha'.")
60  }
61
62  pd <- vector(length = nrow(x))
63  for (i in 1:nrow(x)) {
64    pd[i] <- dirichlet1(x[i, ], alpha[i, ])
65  }
66
67  # Enforce 0 <= x[i,j] <= 1, sum(x[i,]) = 1
68  pd[apply(x, 1, function(z) any(z < 0 | z > 1))] <- 0
69  pd[apply(x, 1, function(z) all.equal(sum(z), 1) != TRUE)] <- 0
70  pd
71}
72
73#' @describeIn dirichlet Generate dirichlet random deviates.
74#' @param n Number of random vectors to generate.
75#' @importFrom stats rgamma
76#' @export
77rdirichlet <- function(n, alpha) {
78  l <- length(alpha)
79  x <- matrix(rgamma(l * n, alpha), ncol = l, byrow = TRUE)
80  sm <- x %*% rep(1, l)
81  x / as.vector(sm)
82}
83