1 /* Copyright (c) 2015 Gerald Knizia
2 *
3 * This file is part of the IboView program (see: http://www.iboview.org)
4 *
5 * IboView is free software: you can redistribute it and/or modify
6 * it under the terms of the GNU General Public License as published by
7 * the Free Software Foundation, version 3.
8 *
9 * IboView is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with bfint (LICENSE). If not, see http://www.gnu.org/licenses/
16 *
17 * Please see IboView documentation in README.txt for:
18 * -- A list of included external software and their licenses. The included
19 * external software's copyright is not touched by this agreement.
20 * -- Notes on re-distribution and contributions to/further development of
21 * the IboView software
22 */
23
24 #include <algorithm> // for std::min
25 #include <stdexcept>
26 #include <stdlib.h>
27
28 #include "CxAlgebra.h"
29 #include "CxDefs.h" // for assert.
30
31 namespace ct {
32
33 // Out = f * A * B
Mxm(double * pOut,ptrdiff_t iRowStO,ptrdiff_t iColStO,double const * pA,ptrdiff_t iRowStA,ptrdiff_t iColStA,double const * pB,ptrdiff_t iRowStB,ptrdiff_t iColStB,size_t nRows,size_t nLink,size_t nCols,bool AddToDest,double fFactor)34 void Mxm(double *pOut, ptrdiff_t iRowStO, ptrdiff_t iColStO,
35 double const *pA, ptrdiff_t iRowStA, ptrdiff_t iColStA,
36 double const *pB, ptrdiff_t iRowStB, ptrdiff_t iColStB,
37 size_t nRows, size_t nLink, size_t nCols, bool AddToDest, double fFactor )
38 {
39 assert( iRowStO == 1 || iColStO == 1 );
40 assert( iRowStA == 1 || iColStA == 1 );
41 assert( iRowStB == 1 || iColStB == 1 );
42 // ^- otherwise dgemm directly not applicable. Would need local copy
43 // of matrix/matrices with compressed strides.
44
45 // if ( nRows == 1 || nLink == 1 || nCols == 1 ) {
46 // if ( !AddToDest )
47 // for ( uint ic = 0; ic < nCols; ++ ic )
48 // for ( uint ir = 0; ir < nRows; ++ ir )
49 // pOut[ir*iRowStO + ic*iColStO] = 0;
50 //
51 // for ( uint ic = 0; ic < nCols; ++ ic )
52 // for ( uint ir = 0; ir < nRows; ++ ir )
53 // for ( uint il = 0; il < nLink; ++ il )
54 // pOut[ir*iRowStO + ic*iColStO] += fFactor * pA[ir*iRowStA + il*iColStA] * pB[il*iRowStB + ic*iColStB];
55 // return;
56 // }
57
58 double
59 Beta = AddToDest? 1.0 : 0.0;
60 char
61 TransA, TransB;
62 FORTINT
63 lda, ldb,
64 ldc = (iRowStO == 1)? iColStO : iRowStO;
65
66 if ( iRowStA == 1 ) {
67 TransA = 'N'; lda = iColStA;
68 } else {
69 TransA = 'T'; lda = iRowStA;
70 }
71 if ( iRowStB == 1 ) {
72 TransB = 'N'; ldb = iColStB;
73 } else {
74 TransB = 'T'; ldb = iRowStB;
75 }
76
77 DGEMM( TransA, TransB, nRows, nCols, nLink,
78 fFactor, pA, lda, pB, ldb, Beta, pOut, ldc );
79 }
80
81 //// this one is used if neither the column nor the row stride in pMat is unity.
82 //void MxvLame(double *RESTRICT pOut, ptrdiff_t iStO, double const *RESTRICT pMat, ptrdiff_t iRowStM, ptrdiff_t iColStM,
83 // double const *RESTRICT pIn, ptrdiff_t iStI, size_t nRows, size_t nLink, bool AddToDest, double fFactor)
84 //{
85 // for (size_t iRow = 0; iRow < nRows; ++ iRow) {
86 // double
87 // d = 0;
88 // double const
89 // *RESTRICT pM = &pMat[iRowStM * iRow];
90 // for (size_t iLink = 0; iLink < nLink; ++ iLink) {
91 // d += pIn[iStI * iLink] * pM[iColStM * iLink];
92 // }
93 // d *= fFactor;
94 //
95 // double *RESTRICT r = &pOut[iStO * iRow];
96 // if (AddToDest)
97 // *r += d;
98 // else
99 // *r = d;
100 // }
101 //}
102
103 // this one is used if neither the column nor the row stride in pMat is unity.
MxvLameG(double * RESTRICT pOut,ptrdiff_t iStO,double const * RESTRICT pMat,ptrdiff_t iRowStM,ptrdiff_t iColStM,double const * RESTRICT pIn,ptrdiff_t iStI,size_t nRows,size_t nLink,bool AddToDest,double fFactor)104 void MxvLameG(double *RESTRICT pOut, ptrdiff_t iStO, double const *RESTRICT pMat, ptrdiff_t iRowStM, ptrdiff_t iColStM,
105 double const *RESTRICT pIn, ptrdiff_t iStI, size_t nRows, size_t nLink, bool AddToDest, double fFactor)
106 {
107 for (size_t iRow = 0; iRow < nRows; ++iRow) {
108 double
109 d = 0;
110 double const
111 *RESTRICT pM = &pMat[iRowStM * iRow];
112 for (size_t iLink = 0; iLink < nLink; ++iLink) {
113 d += pIn[iStI * iLink] * pM[iColStM * iLink];
114 }
115 d *= fFactor;
116
117 double *RESTRICT r = &pOut[iStO * iRow];
118 if (AddToDest)
119 *r += d;
120 else
121 *r = d;
122 }
123 }
124
125 // this one is used if neither the column nor the row stride in pMat is unity.
MxvLame(double * RESTRICT pOut,ptrdiff_t iStO,double const * RESTRICT pMat,ptrdiff_t iRowStM,ptrdiff_t iColStM,double const * RESTRICT pIn,ptrdiff_t iStI,size_t nRows,size_t nLink,bool AddToDest,double fFactor)126 void MxvLame(double *RESTRICT pOut, ptrdiff_t iStO, double const *RESTRICT pMat, ptrdiff_t iRowStM, ptrdiff_t iColStM,
127 double const *RESTRICT pIn, ptrdiff_t iStI, size_t nRows, size_t nLink, bool AddToDest, double fFactor)
128 {
129 if (iStO != 1 || iStI != 1)
130 return MxvLameG(pOut, iStO, pMat, iRowStM, iColStM, pIn, iStI, nRows, nLink, AddToDest, fFactor);
131 for (size_t iRow = 0; iRow < nRows; ++iRow) {
132 double
133 d = 0;
134 double const
135 *RESTRICT pM = &pMat[iRowStM * iRow];
136 if (iColStM == 1) {
137 for (size_t iLink = 0; iLink < nLink; ++iLink) {
138 d += pIn[iLink] * pM[iLink];
139 }
140 } else {
141 for (size_t iLink = 0; iLink < nLink; ++iLink) {
142 d += pIn[iLink] * pM[iColStM * iLink];
143 }
144 }
145 d *= fFactor;
146
147 double *RESTRICT r = &pOut[iRow];
148 if (AddToDest)
149 *r += d;
150 else
151 *r = d;
152 }
153 }
154
155
156
157 // note: both H and S are overwritten. Eigenvectors go into H.
DiagonalizeGen(double * pEw,double * pH,uint ldH,double * pS,uint ldS,uint N)158 void DiagonalizeGen(double *pEw, double *pH, uint ldH, double *pS, uint ldS, uint N)
159 {
160 FORTINT info = 0, nWork = 128*N;
161 double *pWork = (double*)::malloc(sizeof(double)*nWork);
162 DSYGV(1, 'V', 'L', N, pH, ldH, pS, ldS, pEw, pWork, nWork, info );
163 ::free(pWork);
164 if ( info != 0 ) throw std::runtime_error("dsygv failed.");
165 }
166
167 // void Diagonalize(double *pEw, double *pH, uint ldH, uint N)
168 // {
169 // FORTINT info = 0, nWork = 128*N;
170 // double *pWork = (double*)::malloc(sizeof(double)*nWork);
171 // DSYEV('V', 'L', N, pH, ldH, pEw, pWork, nWork, info );
172 // ::free(pWork);
173 // if ( info != 0 ) throw std::runtime_error("dsyev failed.");
174 // }
175
Diagonalize(double * pEw,double * pH,uint ldH,uint N)176 void Diagonalize(double *pEw, double *pH, uint ldH, uint N)
177 {
178 if (N == 0)
179 return;
180 FORTINT info = 0;
181 // workspace query.
182 double fWork = 0;
183 FORTINT nWork, niWork[2] = {0};
184 DSYEVD('V', 'L', N, pH, ldH, pEw, &fWork, -1, &niWork[0], -1, info);
185 if ( info != 0 ) throw std::runtime_error("dsyevd workspace query failed.");
186 nWork = FORTINT(fWork);
187
188 double *pWork = (double*)::malloc(sizeof(double)*nWork);
189 FORTINT *piWork = (FORTINT*)::malloc(sizeof(FORTINT)*niWork[0]);
190 DSYEVD('V', 'L', N, pH, ldH, pEw, pWork, nWork, piWork, niWork[0], info);
191 ::free(piWork);
192 ::free(pWork);
193 if ( info != 0 ) throw std::runtime_error("dsyevd failed.");
194
195 // if ( info != 0 ) {
196 // std::stringstream str;
197 // str << "Something went wrong when trying to diagonalize a " << InOut.nRows << "x" << InOut.nCols << " matrix. "
198 // << "DSYEVD returned error code " << info << ".";
199 // throw std::runtime_error(str.str());
200 // }
201 }
202
203
204 // U: nRows x nSig, Vt: nCols x nSig,
205 // where nSig = min(nRows, nCols)
ComputeSvd(double * pU,size_t ldU,double * pSigma,double * pVt,size_t ldVt,double * pInAndTmp,size_t ldIn,size_t nRows,size_t nCols)206 void ComputeSvd(double *pU, size_t ldU, double *pSigma, double *pVt, size_t ldVt, double *pInAndTmp, size_t ldIn, size_t nRows, size_t nCols)
207 {
208 size_t
209 nSig = std::min(nRows, nCols);
210 assert(ldU >= nRows && ldVt >= nSig && ldIn >= nRows);
211 FORTINT
212 lWork = 0,
213 info = 0;
214 FORTINT
215 *piWork = (FORTINT*)::malloc(sizeof(FORTINT)*8*nSig);
216 // workspace query.
217 double
218 flWork = 0;
219 DGESDD('S', nRows, nCols, pInAndTmp, ldIn, pSigma, pU, ldU, pVt, ldVt, &flWork, -1, piWork, &info);
220 if ( info != 0 ) {
221 ::free(piWork); // my understanding of the docs is that piWork needs to be valid for the workspace query...
222 throw std::runtime_error("dgesdd workspace query failed.");
223 }
224 lWork = FORTINT(flWork);
225 double
226 *pWork = (double*)::malloc(sizeof(double)*lWork);
227 DGESDD('S', nRows, nCols, pInAndTmp, ldIn, pSigma, pU, ldU, pVt, ldVt, pWork, lWork, piWork, &info);
228 ::free(pWork);
229 ::free(piWork);
230 if ( info != 0 )
231 throw std::runtime_error("dgesdd failed.");
232 }
233
234
235
236 }
237
238 // kate: indent-width 4
239