1 /* ========================================================================== */
2 /* == ssmult_dot ============================================================ */
3 /* ========================================================================== */
4 
5 /*
6    C = A'*B using the sparse-dot-product method.  Computes C as a full matrix
7    first and then converts the result to sparse format.  It is thus useful only
8    if C is small compared to A and/or B.  It is very fast if A and B are long
9    column vectors, because in that case, computing A' takes a long time.
10 
11    A is m-by-n, B is m-by-k, and thus C is n-by-k.
12 
13    The time take by this function is at least proportional to n*k + flops(A*B)),
14    but it can be higher because a sparse dot product of x'*y where x and y are
15    column vectors can take up to O (nnz(x) + nnz(y)).  The sparse dot product
16    can terminate early.  In particular, the time is precisely
17    O (nnz (x (1:t)) + nnz (y (1:t))) where t = min (max (find (x), find (y))).
18    This sparse dot is used for each pair of columns of A and B.  The workspace
19    required by this method is n*k*sizeof(double) or twice that if A or B are
20    complex.
21 
22    By comparison, the saxpy method to compute C=A*B takes O (m+n+k+flops(A*B))
23    time and uses only O(m) workspace.  However, C=A'*B using that method must
24    transpose A first, taking another O(m+n+nnz(A)) time and adding O(m+nnz(A))
25    workspace.
26 
27    Note that m does NOT appear in the time or memory complexity of ssmult_dot
28    when C=A'*B is computed.  Thus, if m is huge compared to n, k, nnz(A), and
29    so on, then it can be far faster and use far less memory.  For exampe, if
30    A and B are very long and very sparse column vectors, the dot product method
31    is much faster than the saxpy method.
32 
33    Comparing flop counts of the two methods is not trivial.  Thus, when
34    computing C=A'*B, ssmult uses whichever method requires the least workspace.
35 
36    Sparse dot product based matrix multiplication algorithm in MATLAB notation:
37 
38         function C = ssmult_dot (A,B)
39         % C = A'*B                  A is m-by-n, B is m-by-k, C is n-by-k
40         C = zeros (n,k) ;
41         for i = 1:n
42             for j = 1:k
43                 C(i,j) = A (:,i)'*B(:,j) ;
44             end
45         end
46         C = sparse (C) ;
47  */
48 
49 #include "ssmult.h"
50 
51 /* -------------------------------------------------------------------------- */
52 /* ssmult_dot */
53 /* -------------------------------------------------------------------------- */
54 
ssmult_dot(const mxArray * A,const mxArray * B,int ac,int bc,int cc)55 mxArray *ssmult_dot     /* returns C = A'*B */
56 (
57     const mxArray *A,
58     const mxArray *B,
59     int ac,             /* if true: conj(A)   if false: A. ignored if A real */
60     int bc,             /* if true: conj(B)   if false: B. ignored if B real */
61     int cc              /* if true: conj(C)   if false: C. ignored if C real */
62 )
63 {
64     double cx, cz, ax, az, bx, bz ;
65     mxArray *C ;
66     double *Ax, *Az, *Bx, *Bz, *Cx, *Cz ;
67     Int *Ap, *Ai, *Bp, *Bi, *Cp, *Ci ;
68     Int m, n, k, cnzmax, i, j, p, paend, pbend, ai, bi, cnz, pa, pb, zallzero,
69         A_is_complex, B_is_complex, C_is_complex ;
70 
71     /* ---------------------------------------------------------------------- */
72     /* get inputs */
73     /* ---------------------------------------------------------------------- */
74 
75     m = mxGetM (A) ;
76     n = mxGetN (A) ;
77     k = mxGetN (B) ;
78 
79     if (m != mxGetM (B)) ssmult_invalid (ERROR_DIMENSIONS) ;
80 
81     Ap = (Int *) mxGetJc (A) ;
82     Ai = (Int *) mxGetIr (A) ;
83     Ax = mxGetPr (A) ;
84     Az = mxGetPi (A) ;
85     A_is_complex = mxIsComplex (A) ;
86 
87     Bp = (Int *) mxGetJc (B) ;
88     Bi = (Int *) mxGetIr (B) ;
89     Bx = mxGetPr (B) ;
90     Bz = mxGetPi (B) ;
91     B_is_complex = mxIsComplex (B) ;
92 
93     /* ---------------------------------------------------------------------- */
94     /* allocate C as an n-by-k full matrix but do not initialize it */
95     /* ---------------------------------------------------------------------- */
96 
97     /* NOTE: integer overflow cannot occur here, because this function is not
98        called unless O(n*k) is less than O(m+nnz(A)).  The test is done
99        in the caller, not here.
100      */
101 
102     cnzmax = n*k ;
103     cnzmax = MAX (cnzmax, 1) ;
104     Cx = mxMalloc (cnzmax * sizeof (double)) ;
105     C_is_complex = A_is_complex || B_is_complex ;
106     Cz = C_is_complex ?  mxMalloc (cnzmax * sizeof (double)) : NULL ;
107 
108     /* ---------------------------------------------------------------------- */
109     /* C = A'*B using sparse dot products */
110     /* ---------------------------------------------------------------------- */
111 
112     /*
113        NOTE:  this method REQUIRES the columns of A and B to be sorted on input.
114        That is, the row indices in each column must appear in ascending order.
115        This is the standard in all versions of MATLAB to date, and likely will
116        be for some time.  However, if MATLAB were to use unsorted sparse
117        matrices in the future (a lazy sort) then a test should be included in
118        ssmult to not use ssmult_dot if A or B are unsorted, or they should be
119        sorted on input.
120      */
121 
122     cnz = 0 ;
123     for (j = 0 ; j < k ; j++)
124     {
125         for (i = 0 ; i < n ; i++)
126         {
127             /* compute C (i,j) = A (:,i)' * B (:,j) */
128             pa = Ap [i] ;
129             paend = Ap [i+1] ;
130             pb = Bp [j] ;
131             pbend = Bp [j+1] ;
132 
133             if (pa == paend            /* nnz (A (:,i)) == 0 */
134             || pb == pbend             /* nnz (B (:,j)) == 0 */
135             || Ai [paend-1] < Bi [pb]  /* max(find(A(:,i)))<min(find(B(:,j))) */
136             || Ai [pa] > Bi [pbend-1]) /* min(find(A(:,i)))>max(find(B(:,j))) */
137             {
138                 Cx [i+j*n] = 0 ;        /* no work to do */
139                 if (C_is_complex)
140                 {
141                     Cz [i+j*n] = 0 ;
142                 }
143                 continue ;
144             }
145             cx = 0 ;
146             cz = 0 ;
147             while (pa < paend && pb < pbend)
148             {
149                 /* The dot product looks like the merge in ssmergesort, except*/
150                 /* no "clean-up" phase is need when one list is exhausted. */
151                 ai = Ai [pa] ;
152                 bi = Bi [pb] ;
153                 if (ai == bi)
154                 {
155                     /* c += A (ai,i) * B (ai,j), and "consume" both entries */
156                     if (!C_is_complex)
157                     {
158                         cx += Ax [pa] * Bx [pb] ;
159                     }
160                     else
161                     {
162                         /* complex case */
163                         ax = Ax [pa] ;
164                         bx = Bx [pb] ;
165                         az = Az ? (ac ? (-Az [pa]) : Az [pa]) : 0.0 ;
166                         bz = Bz ? (bc ? (-Bz [pb]) : Bz [pb]) : 0.0 ;
167                         cx += ax * bx - az * bz ;
168                         cz += az * bx + ax * bz ;
169                     }
170                     pa++ ;
171                     pb++ ;
172                 }
173                 else if (ai < bi)
174                 {
175                     /* consume A(ai,i) and discard it, since B(ai,j) is zero */
176                     pa++ ;
177                 }
178                 else
179                 {
180                     /* consume B(bi,j) and discard it, since A(ai,i) is zero */
181                     pb++ ;
182                 }
183             }
184             Cx [i+j*n] = cx ;
185             if (C_is_complex)
186             {
187                 Cz [i+j*n] = cz ;
188             }
189         }
190 
191         /* count the number of nonzeros in C(:,j) */
192         for (i = 0 ; i < n ; i++)
193         {
194             /* This could be done above, except for the gcc compiler bug when
195                cx is an 80-bit nonzero in register above, but becomes zero here
196                when stored into memory.  We need the latter, to correctly handle
197                the case when cx underflows to zero in 64-bit floating-point.
198                Do not attempt to "optimize" this code by doing this test above,
199                unless the gcc compiler bug is fixed (as of gcc version 4.1.0).
200              */
201             if (Cx [i+j*n] != 0 || (C_is_complex && Cz [i+j*n] != 0))
202             {
203                 cnz++ ;
204             }
205         }
206     }
207 
208     /* ---------------------------------------------------------------------- */
209     /* convert C to real if the imaginary part is all zero */
210     /* ---------------------------------------------------------------------- */
211 
212     if (C_is_complex)
213     {
214         zallzero = 1 ;
215         for (p = 0 ; zallzero && p < cnzmax ; p++)
216         {
217             if (Cz [p] != 0)
218             {
219                 zallzero = 0 ;
220             }
221         }
222         if (zallzero)
223         {
224             /* the imaginary part of C is all zero */
225             C_is_complex = 0 ;
226             mxFree (Cz) ;
227             Cz = NULL ;
228         }
229     }
230 
231     /* ---------------------------------------------------------------------- */
232     /* allocate integer part of C but do not initialize it */
233     /* ---------------------------------------------------------------------- */
234 
235     cnz = MAX (cnz, 1) ;
236     C = mxCreateSparse (0, 0, 0, C_is_complex ? mxCOMPLEX : mxREAL) ;
237     MXFREE (mxGetJc (C)) ;
238     MXFREE (mxGetIr (C)) ;
239     MXFREE (mxGetPr (C)) ;
240     MXFREE (mxGetPi (C)) ;
241     Cp = mxMalloc ((k + 1) * sizeof (Int)) ;
242     Ci = mxMalloc (MAX (cnz,1) * sizeof (Int)) ;
243     mxSetJc (C, (mwIndex *) Cp) ;
244     mxSetIr (C, (mwIndex *) Ci) ;
245     mxSetM (C, n) ;
246     mxSetN (C, k) ;
247 
248     /* ---------------------------------------------------------------------- */
249     /* C = sparse (C).  Note that this is done in-place. */
250     /* ---------------------------------------------------------------------- */
251 
252     p = 0 ;
253     for (j = 0 ; j < k ; j++)
254     {
255         Cp [j] = p ;
256         for (i = 0 ; i < n ; i++)
257         {
258             cx = Cx [i+j*n] ;
259             cz = (C_is_complex ? Cz [i+j*n] : 0) ;
260             if (cx != 0 || cz != 0)
261             {
262                 Ci [p] = i ;
263                 Cx [p] = cx ;
264                 if (C_is_complex) Cz [p] = (cc ? (-cz) : cz) ;
265                 p++ ;
266             }
267         }
268     }
269     Cp [k] = p ;
270 
271     /* ---------------------------------------------------------------------- */
272     /* reduce the size of Cx and Cz and return result */
273     /* ---------------------------------------------------------------------- */
274 
275     if (cnz < cnzmax)
276     {
277         Cx = mxRealloc (Cx, cnz * sizeof (double)) ;
278         if (C_is_complex)
279         {
280             Cz = mxRealloc (Cz, cnz * sizeof (double)) ;
281         }
282     }
283 
284     mxSetNzmax (C, cnz) ;
285     mxSetPr (C, Cx) ;
286     if (C_is_complex)
287     {
288         mxSetPi (C, Cz) ;
289     }
290     return (C) ;
291 }
292