1 /*
2  * Copyright © 2007-2019 Dynare Team
3  *
4  * This file is part of Dynare.
5  *
6  * Dynare is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * Dynare is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  * GNU General Public License for more details.
15  *
16  * You should have received a copy of the GNU General Public License
17  * along with Dynare.  If not, see <http://www.gnu.org/licenses/>.
18  */
19 
20 /*
21  * This mex file computes A·(B⊗C) or A·(B⊗B) without explicitly building B⊗C or B⊗B, so that
22  * one can consider large matrices B and/or C.
23  */
24 
25 #include <dynmex.h>
26 #include <dynblas.h>
27 
28 void
full_A_times_kronecker_B_C(const double * A,const double * B,const double * C,double * D,blas_int mA,blas_int nA,blas_int mB,blas_int nB,blas_int mC,blas_int nC)29 full_A_times_kronecker_B_C(const double *A, const double *B, const double *C, double *D,
30                            blas_int mA, blas_int nA, blas_int mB, blas_int nB, blas_int mC, blas_int nC)
31 {
32   const blas_int shiftA = mA*mC;
33   const blas_int shiftD = mA*nC;
34   blas_int kd = 0, ka = 0;
35   double one = 1.0;
36   for (blas_int col = 0; col < nB; col++)
37     {
38       ka = 0;
39       for (blas_int row = 0; row < mB; row++)
40         {
41           dgemm("N", "N", &mA, &nC, &mC, &B[mB*col+row], &A[ka], &mA, C, &mC, &one, &D[kd], &mA);
42           ka += shiftA;
43         }
44       kd += shiftD;
45     }
46 }
47 
48 void
full_A_times_kronecker_B_B(const double * A,const double * B,double * D,blas_int mA,blas_int nA,blas_int mB,blas_int nB)49 full_A_times_kronecker_B_B(const double *A, const double *B, double *D, blas_int mA, blas_int nA, blas_int mB, blas_int nB)
50 {
51   const blas_int shiftA = mA*mB;
52   const blas_int shiftD = mA*nB;
53   blas_int kd = 0, ka = 0;
54   double one = 1.0;
55   for (blas_int col = 0; col < nB; col++)
56     {
57       ka = 0;
58       for (blas_int row = 0; row < mB; row++)
59         {
60           dgemm("N", "N", &mA, &nB, &mB, &B[mB*col+row], &A[ka], &mA, B, &mB, &one, &D[kd], &mA);
61           ka += shiftA;
62         }
63       kd += shiftD;
64     }
65 }
66 
67 void
mexFunction(int nlhs,mxArray * plhs[],int nrhs,const mxArray * prhs[])68 mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
69 {
70   // Check input and output:
71   if (nrhs > 3 || nrhs < 2)
72     DYN_MEX_FUNC_ERR_MSG_TXT("A_times_B_kronecker_C takes 2 or 3 input arguments and provides 2 output arguments.");
73 
74   // Get & Check dimensions (columns and rows):
75   size_t mA = mxGetM(prhs[0]);
76   size_t nA = mxGetN(prhs[0]);
77   size_t mB = mxGetM(prhs[1]);
78   size_t nB = mxGetN(prhs[1]);
79   size_t mC, nC;
80   if (nrhs == 3) // A·(B⊗C) is to be computed.
81     {
82       mC = mxGetM(prhs[2]);
83       nC = mxGetN(prhs[2]);
84       if (mB*mC != nA)
85         DYN_MEX_FUNC_ERR_MSG_TXT("Input dimension error!");
86     }
87   else // A·(B⊗B) is to be computed.
88     {
89       if (mB*mB != nA)
90         DYN_MEX_FUNC_ERR_MSG_TXT("Input dimension error!");
91     }
92   // Get input matrices:
93   const double *A = mxGetPr(prhs[0]);
94   const double *B = mxGetPr(prhs[1]);
95   const double *C{nullptr};
96   if (nrhs == 3)
97     C = mxGetPr(prhs[2]);
98 
99   // Initialization of the ouput:
100   if (nrhs == 3)
101     plhs[0] = mxCreateDoubleMatrix(mA, nB*nC, mxREAL);
102   else
103     plhs[0] = mxCreateDoubleMatrix(mA, nB*nB, mxREAL);
104   double *D = mxGetPr(plhs[0]);
105 
106   // Computational part:
107   if (nrhs == 2)
108     full_A_times_kronecker_B_B(A, B, D, mA, nA, mB, nB);
109   else
110     full_A_times_kronecker_B_C(A, B, C, D, mA, nA, mB, nB, mC, nC);
111 
112   plhs[1] = mxCreateDoubleScalar(0);
113 }
114