1 /* ========================================================================== */
2 /* == ssmult_dot ============================================================ */
3 /* ========================================================================== */
4
5 /*
6 C = A'*B using the sparse-dot-product method. Computes C as a full matrix
7 first and then converts the result to sparse format. It is thus useful only
8 if C is small compared to A and/or B. It is very fast if A and B are long
9 column vectors, because in that case, computing A' takes a long time.
10
11 A is m-by-n, B is m-by-k, and thus C is n-by-k.
12
13 The time take by this function is at least proportional to n*k + flops(A*B)),
14 but it can be higher because a sparse dot product of x'*y where x and y are
15 column vectors can take up to O (nnz(x) + nnz(y)). The sparse dot product
16 can terminate early. In particular, the time is precisely
17 O (nnz (x (1:t)) + nnz (y (1:t))) where t = min (max (find (x), find (y))).
18 This sparse dot is used for each pair of columns of A and B. The workspace
19 required by this method is n*k*sizeof(double) or twice that if A or B are
20 complex.
21
22 By comparison, the saxpy method to compute C=A*B takes O (m+n+k+flops(A*B))
23 time and uses only O(m) workspace. However, C=A'*B using that method must
24 transpose A first, taking another O(m+n+nnz(A)) time and adding O(m+nnz(A))
25 workspace.
26
27 Note that m does NOT appear in the time or memory complexity of ssmult_dot
28 when C=A'*B is computed. Thus, if m is huge compared to n, k, nnz(A), and
29 so on, then it can be far faster and use far less memory. For exampe, if
30 A and B are very long and very sparse column vectors, the dot product method
31 is much faster than the saxpy method.
32
33 Comparing flop counts of the two methods is not trivial. Thus, when
34 computing C=A'*B, ssmult uses whichever method requires the least workspace.
35
36 Sparse dot product based matrix multiplication algorithm in MATLAB notation:
37
38 function C = ssmult_dot (A,B)
39 % C = A'*B A is m-by-n, B is m-by-k, C is n-by-k
40 C = zeros (n,k) ;
41 for i = 1:n
42 for j = 1:k
43 C(i,j) = A (:,i)'*B(:,j) ;
44 end
45 end
46 C = sparse (C) ;
47 */
48
49 #include "ssmult.h"
50
51 /* -------------------------------------------------------------------------- */
52 /* ssmult_dot */
53 /* -------------------------------------------------------------------------- */
54
ssmult_dot(const mxArray * A,const mxArray * B,int ac,int bc,int cc)55 mxArray *ssmult_dot /* returns C = A'*B */
56 (
57 const mxArray *A,
58 const mxArray *B,
59 int ac, /* if true: conj(A) if false: A. ignored if A real */
60 int bc, /* if true: conj(B) if false: B. ignored if B real */
61 int cc /* if true: conj(C) if false: C. ignored if C real */
62 )
63 {
64 double cx, cz, ax, az, bx, bz ;
65 mxArray *C ;
66 double *Ax, *Az, *Bx, *Bz, *Cx, *Cz ;
67 Int *Ap, *Ai, *Bp, *Bi, *Cp, *Ci ;
68 Int m, n, k, cnzmax, i, j, p, paend, pbend, ai, bi, cnz, pa, pb, zallzero,
69 A_is_complex, B_is_complex, C_is_complex ;
70
71 /* ---------------------------------------------------------------------- */
72 /* get inputs */
73 /* ---------------------------------------------------------------------- */
74
75 m = mxGetM (A) ;
76 n = mxGetN (A) ;
77 k = mxGetN (B) ;
78
79 if (m != mxGetM (B)) ssmult_invalid (ERROR_DIMENSIONS) ;
80
81 Ap = (Int *) mxGetJc (A) ;
82 Ai = (Int *) mxGetIr (A) ;
83 Ax = mxGetPr (A) ;
84 Az = mxGetPi (A) ;
85 A_is_complex = mxIsComplex (A) ;
86
87 Bp = (Int *) mxGetJc (B) ;
88 Bi = (Int *) mxGetIr (B) ;
89 Bx = mxGetPr (B) ;
90 Bz = mxGetPi (B) ;
91 B_is_complex = mxIsComplex (B) ;
92
93 /* ---------------------------------------------------------------------- */
94 /* allocate C as an n-by-k full matrix but do not initialize it */
95 /* ---------------------------------------------------------------------- */
96
97 /* NOTE: integer overflow cannot occur here, because this function is not
98 called unless O(n*k) is less than O(m+nnz(A)). The test is done
99 in the caller, not here.
100 */
101
102 cnzmax = n*k ;
103 cnzmax = MAX (cnzmax, 1) ;
104 Cx = mxMalloc (cnzmax * sizeof (double)) ;
105 C_is_complex = A_is_complex || B_is_complex ;
106 Cz = C_is_complex ? mxMalloc (cnzmax * sizeof (double)) : NULL ;
107
108 /* ---------------------------------------------------------------------- */
109 /* C = A'*B using sparse dot products */
110 /* ---------------------------------------------------------------------- */
111
112 /*
113 NOTE: this method REQUIRES the columns of A and B to be sorted on input.
114 That is, the row indices in each column must appear in ascending order.
115 This is the standard in all versions of MATLAB to date, and likely will
116 be for some time. However, if MATLAB were to use unsorted sparse
117 matrices in the future (a lazy sort) then a test should be included in
118 ssmult to not use ssmult_dot if A or B are unsorted, or they should be
119 sorted on input.
120 */
121
122 cnz = 0 ;
123 for (j = 0 ; j < k ; j++)
124 {
125 for (i = 0 ; i < n ; i++)
126 {
127 /* compute C (i,j) = A (:,i)' * B (:,j) */
128 pa = Ap [i] ;
129 paend = Ap [i+1] ;
130 pb = Bp [j] ;
131 pbend = Bp [j+1] ;
132
133 if (pa == paend /* nnz (A (:,i)) == 0 */
134 || pb == pbend /* nnz (B (:,j)) == 0 */
135 || Ai [paend-1] < Bi [pb] /* max(find(A(:,i)))<min(find(B(:,j))) */
136 || Ai [pa] > Bi [pbend-1]) /* min(find(A(:,i)))>max(find(B(:,j))) */
137 {
138 Cx [i+j*n] = 0 ; /* no work to do */
139 if (C_is_complex)
140 {
141 Cz [i+j*n] = 0 ;
142 }
143 continue ;
144 }
145 cx = 0 ;
146 cz = 0 ;
147 while (pa < paend && pb < pbend)
148 {
149 /* The dot product looks like the merge in ssmergesort, except*/
150 /* no "clean-up" phase is need when one list is exhausted. */
151 ai = Ai [pa] ;
152 bi = Bi [pb] ;
153 if (ai == bi)
154 {
155 /* c += A (ai,i) * B (ai,j), and "consume" both entries */
156 if (!C_is_complex)
157 {
158 cx += Ax [pa] * Bx [pb] ;
159 }
160 else
161 {
162 /* complex case */
163 ax = Ax [pa] ;
164 bx = Bx [pb] ;
165 az = Az ? (ac ? (-Az [pa]) : Az [pa]) : 0.0 ;
166 bz = Bz ? (bc ? (-Bz [pb]) : Bz [pb]) : 0.0 ;
167 cx += ax * bx - az * bz ;
168 cz += az * bx + ax * bz ;
169 }
170 pa++ ;
171 pb++ ;
172 }
173 else if (ai < bi)
174 {
175 /* consume A(ai,i) and discard it, since B(ai,j) is zero */
176 pa++ ;
177 }
178 else
179 {
180 /* consume B(bi,j) and discard it, since A(ai,i) is zero */
181 pb++ ;
182 }
183 }
184 Cx [i+j*n] = cx ;
185 if (C_is_complex)
186 {
187 Cz [i+j*n] = cz ;
188 }
189 }
190
191 /* count the number of nonzeros in C(:,j) */
192 for (i = 0 ; i < n ; i++)
193 {
194 /* This could be done above, except for the gcc compiler bug when
195 cx is an 80-bit nonzero in register above, but becomes zero here
196 when stored into memory. We need the latter, to correctly handle
197 the case when cx underflows to zero in 64-bit floating-point.
198 Do not attempt to "optimize" this code by doing this test above,
199 unless the gcc compiler bug is fixed (as of gcc version 4.1.0).
200 */
201 if (Cx [i+j*n] != 0 || (C_is_complex && Cz [i+j*n] != 0))
202 {
203 cnz++ ;
204 }
205 }
206 }
207
208 /* ---------------------------------------------------------------------- */
209 /* convert C to real if the imaginary part is all zero */
210 /* ---------------------------------------------------------------------- */
211
212 if (C_is_complex)
213 {
214 zallzero = 1 ;
215 for (p = 0 ; zallzero && p < cnzmax ; p++)
216 {
217 if (Cz [p] != 0)
218 {
219 zallzero = 0 ;
220 }
221 }
222 if (zallzero)
223 {
224 /* the imaginary part of C is all zero */
225 C_is_complex = 0 ;
226 mxFree (Cz) ;
227 Cz = NULL ;
228 }
229 }
230
231 /* ---------------------------------------------------------------------- */
232 /* allocate integer part of C but do not initialize it */
233 /* ---------------------------------------------------------------------- */
234
235 cnz = MAX (cnz, 1) ;
236 C = mxCreateSparse (0, 0, 0, C_is_complex ? mxCOMPLEX : mxREAL) ;
237 MXFREE (mxGetJc (C)) ;
238 MXFREE (mxGetIr (C)) ;
239 MXFREE (mxGetPr (C)) ;
240 MXFREE (mxGetPi (C)) ;
241 Cp = mxMalloc ((k + 1) * sizeof (Int)) ;
242 Ci = mxMalloc (MAX (cnz,1) * sizeof (Int)) ;
243 mxSetJc (C, (mwIndex *) Cp) ;
244 mxSetIr (C, (mwIndex *) Ci) ;
245 mxSetM (C, n) ;
246 mxSetN (C, k) ;
247
248 /* ---------------------------------------------------------------------- */
249 /* C = sparse (C). Note that this is done in-place. */
250 /* ---------------------------------------------------------------------- */
251
252 p = 0 ;
253 for (j = 0 ; j < k ; j++)
254 {
255 Cp [j] = p ;
256 for (i = 0 ; i < n ; i++)
257 {
258 cx = Cx [i+j*n] ;
259 cz = (C_is_complex ? Cz [i+j*n] : 0) ;
260 if (cx != 0 || cz != 0)
261 {
262 Ci [p] = i ;
263 Cx [p] = cx ;
264 if (C_is_complex) Cz [p] = (cc ? (-cz) : cz) ;
265 p++ ;
266 }
267 }
268 }
269 Cp [k] = p ;
270
271 /* ---------------------------------------------------------------------- */
272 /* reduce the size of Cx and Cz and return result */
273 /* ---------------------------------------------------------------------- */
274
275 if (cnz < cnzmax)
276 {
277 Cx = mxRealloc (Cx, cnz * sizeof (double)) ;
278 if (C_is_complex)
279 {
280 Cz = mxRealloc (Cz, cnz * sizeof (double)) ;
281 }
282 }
283
284 mxSetNzmax (C, cnz) ;
285 mxSetPr (C, Cx) ;
286 if (C_is_complex)
287 {
288 mxSetPi (C, Cz) ;
289 }
290 return (C) ;
291 }
292