1#' Functions for the Dirichlet Distribution 2#' 3#' Functions to compute the density of or generate random deviates from the 4#' Dirichlet distribution 5#' 6#' The Dirichlet distribution is the multidimensional generalization of the 7#' beta distribution. It is the canonical Bayesian distribution for the 8#' parameter estimates of a multinomial distribution. 9#' 10#' @aliases rdirichlet ddirichlet 11#' @param x A vector containing a single random deviate or matrix containing one 12#' random deviate per row. 13#' @param alpha Vector or (for \code{ddirichlet}) matrix containing shape 14#' parameters. 15#' @return \code{ddirichlet} returns a vector containing the Dirichlet density 16#' for the corresponding rows of \code{x}. 17#' 18#' \code{rdirichlet} returns a matrix with \code{n} rows, each containing a 19#' single Dirichlet random deviate. 20#' @author Code original posted by Ben Bolker to R-News on Fri Dec 15 2000. See 21#' \url{https://stat.ethz.ch/pipermail/r-help/2000-December/009561.html}. Ben 22#' attributed the code to Ian Wilson \email{i.wilson@@maths.abdn.ac.uk}. 23#' Subsequent modifications by Gregory R. Warnes \email{greg@@warnes.net}. 24#' @seealso \code{\link{dbeta}}, \code{\link{rbeta}} 25#' @keywords distribution 26#' @examples 27#' 28#' 29#' x <- rdirichlet(20, c(1, 1, 1)) 30#' 31#' ddirichlet(x, c(1, 1, 1)) 32#' @name dirichlet 33NULL 34 35#' @describeIn dirichlet Dirichlet distribution function. 36#' @export 37ddirichlet <- function(x, alpha) { 38 dirichlet1 <- function(x, alpha) { 39 logD <- sum(lgamma(alpha)) - lgamma(sum(alpha)) 40 s <- (alpha - 1) * log(x) 41 s <- ifelse(alpha == 1 & x == 0, -Inf, s) 42 exp(sum(s) - logD) 43 } 44 45 # make sure x is a matrix 46 if (!is.matrix(x)) { 47 if (is.data.frame(x)) { 48 x <- as.matrix(x) 49 } else { 50 x <- t(x) 51 } 52 } 53 54 if (!is.matrix(alpha)) { 55 alpha <- matrix(alpha, ncol = length(alpha), nrow = nrow(x), byrow = TRUE) 56 } 57 58 if (any(dim(x) != dim(alpha))) { 59 stop("Mismatch between dimensions of 'x' and 'alpha'.") 60 } 61 62 pd <- vector(length = nrow(x)) 63 for (i in 1:nrow(x)) { 64 pd[i] <- dirichlet1(x[i, ], alpha[i, ]) 65 } 66 67 # Enforce 0 <= x[i,j] <= 1, sum(x[i,]) = 1 68 pd[apply(x, 1, function(z) any(z < 0 | z > 1))] <- 0 69 pd[apply(x, 1, function(z) all.equal(sum(z), 1) != TRUE)] <- 0 70 pd 71} 72 73#' @describeIn dirichlet Generate dirichlet random deviates. 74#' @param n Number of random vectors to generate. 75#' @importFrom stats rgamma 76#' @export 77rdirichlet <- function(n, alpha) { 78 l <- length(alpha) 79 x <- matrix(rgamma(l * n, alpha), ncol = l, byrow = TRUE) 80 sm <- x %*% rep(1, l) 81 x / as.vector(sm) 82} 83