1 /*
2  *  R : A Computer Language for Statistical Data Analysis
3  *  Copyright (C) 2012-2019  The R Core Team
4  *
5  *  This program is free software; you can redistribute it and/or modify
6  *  it under the terms of the GNU General Public License as published by
7  *  the Free Software Foundation; either version 2 of the License, or
8  *  (at your option) any later version.
9  *
10  *  This program is distributed in the hope that it will be useful,
11  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
12  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13  *  GNU General Public License for more details.
14  *
15  *  You should have received a copy of the GNU General Public License
16  *  along with this program; if not, a copy is available at
17  *  https://www.R-project.org/Licenses/
18  */
19 
20 #ifdef HAVE_CONFIG_H
21 # include <config.h>
22 #endif
23 
24 #include <math.h>
25 #include <string.h>  // memset, memcpy
26 #include <R.h>
27 #include <Rinternals.h>
28 #include <Rmath.h>
29 #include <R_ext/Lapack.h>        /* for Lapack (dpotrf, etc.) and BLAS */
30 
31 #include "stats.h" // for _()
32 #include "statsR.h"
33 
34 
35 /**
36  * Simulate the Cholesky factor of a standardized Wishart variate with
37  * dimension p and nu degrees of freedom.
38  *
39  * @param nu degrees of freedom
40  * @param p dimension of the Wishart distribution
41  * @param upper if 0 the result is lower triangular, otherwise upper
42                 triangular
43  * @param ans array of size p * p to hold the result
44  *
45  * @return ans
46  */
47 static double
std_rWishart_factor(double nu,int p,int upper,double ans[])48 *std_rWishart_factor(double nu, int p, int upper, double ans[])
49 {
50     int pp1 = p + 1;
51 
52     if (nu < (double) p || p <= 0)
53 	error(_("inconsistent degrees of freedom and dimension"));
54 
55     memset(ans, 0, p * p * sizeof(double));
56     for (int j = 0; j < p; j++) {	/* jth column */
57 	ans[j * pp1] = sqrt(rchisq(nu - (double) j));
58 	for (int i = 0; i < j; i++) {
59 	    int uind = i + j * p, /* upper triangle index */
60 		lind = j + i * p; /* lower triangle index */
61 	    ans[(upper ? uind : lind)] = norm_rand();
62 	    ans[(upper ? lind : uind)] = 0;
63 	}
64     }
65     return ans;
66 }
67 
68 /**
69  * Simulate a sample of random matrices from a Wishart distribution
70  *
71  * @param ns Number of samples to generate
72  * @param nuP Degrees of freedom
73  * @param scal Positive-definite scale matrix
74  *
75  * @return
76  */
77 SEXP
rWishart(SEXP ns,SEXP nuP,SEXP scal)78 rWishart(SEXP ns, SEXP nuP, SEXP scal)
79 {
80     SEXP ans;
81     int *dims = INTEGER(getAttrib(scal, R_DimSymbol)), info,
82 	n = asInteger(ns), psqr;
83     double *scCp, *ansp, *tmp, nu = asReal(nuP), one = 1, zero = 0;
84 
85     if (!isMatrix(scal) || !isReal(scal) || dims[0] != dims[1])
86 	error(_("'scal' must be a square, real matrix"));
87     if (n <= 0) n = 1;
88     // allocate early to avoid memory leaks in Callocs below.
89     PROTECT(ans = alloc3DArray(REALSXP, dims[0], dims[0], n));
90     psqr = dims[0] * dims[0];
91     tmp = Calloc(psqr, double);
92     scCp = Calloc(psqr, double);
93 
94     Memcpy(scCp, REAL(scal), psqr);
95     memset(tmp, 0, psqr * sizeof(double));
96     F77_CALL(dpotrf)("U", &(dims[0]), scCp, &(dims[0]), &info FCONE); // LAPACK
97     if (info)
98 	error(_("'scal' matrix is not positive-definite"));
99     ansp = REAL(ans);
100     GetRNGstate();
101     for (int j = 0; j < n; j++) {
102 	double *ansj = ansp + j * psqr;
103 	std_rWishart_factor(nu, dims[0], 1, tmp);
104 	F77_CALL(dtrmm)("R", "U", "N", "N", dims, dims,
105 			&one, scCp, dims, tmp, dims
106 			FCONE FCONE FCONE FCONE); // BLAS
107 	F77_CALL(dsyrk)("U", "T", &(dims[1]), &(dims[1]),
108 			&one, tmp, &(dims[1]),
109 			&zero, ansj, &(dims[1]) FCONE FCONE); // BLAS
110 
111 	for (int i = 1; i < dims[0]; i++)
112 	    for (int k = 0; k < i; k++)
113 		ansj[i + k * dims[0]] = ansj[k + i * dims[0]];
114     }
115 
116     PutRNGstate();
117     Free(scCp); Free(tmp);
118     UNPROTECT(1);
119     return ans;
120 }
121