1 /****************************************************************/
2 /* Parallel Combinatorial BLAS Library (for Graph Computations) */
3 /* version 1.6 -------------------------------------------------*/
4 /* date: 6/15/2017 ---------------------------------------------*/
5 /* authors: Ariful Azad, Aydin Buluc --------------------------*/
6 /****************************************************************/
7 /*
8 Copyright (c) 2010-2017, The Regents of the University of California
9
10 Permission is hereby granted, free of charge, to any person obtaining a copy
11 of this software and associated documentation files (the "Software"), to deal
12 in the Software without restriction, including without limitation the rights
13 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
14 copies of the Software, and to permit persons to whom the Software is
15 furnished to do so, subject to the following conditions:
16
17 The above copyright notice and this permission notice shall be included in
18 all copies or substantial portions of the Software.
19
20 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
21 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
22 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
23 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
24 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
25 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
26 THE SOFTWARE.
27 */
28
29
30 #ifndef _FRIENDS_H_
31 #define _FRIENDS_H_
32
33 #include <iostream>
34 #include "SpMat.h" // Best to include the base class first
35 #include "SpHelper.h"
36 #include "StackEntry.h"
37 #include "Isect.h"
38 #include "Deleter.h"
39 #include "SpImpl.h"
40 #include "SpParHelper.h"
41 #include "Compare.h"
42 #include "CombBLAS.h"
43 #include "PreAllocatedSPA.h"
44
45 namespace combblas {
46
47 template <class IU, class NU>
48 class SpTuples;
49
50 template <class IU, class NU>
51 class SpDCCols;
52
53 template <class IU, class NU>
54 class Dcsc;
55
56 /*************************************************************************************************/
57 /**************************** SHARED ADDRESS SPACE FRIEND FUNCTIONS ******************************/
58 /****************************** MULTITHREADED LOGIC ALSO GOES HERE *******************************/
59 /*************************************************************************************************/
60
61
62 //! SpMV with dense vector
63 template <typename SR, typename IU, typename NU, typename RHS, typename LHS>
dcsc_gespmv(const SpDCCols<IU,NU> & A,const RHS * x,LHS * y)64 void dcsc_gespmv (const SpDCCols<IU, NU> & A, const RHS * x, LHS * y)
65 {
66 if(A.nnz > 0)
67 {
68 for(IU j =0; j<A.dcsc->nzc; ++j) // for all nonzero columns
69 {
70 IU colid = A.dcsc->jc[j];
71 for(IU i = A.dcsc->cp[j]; i< A.dcsc->cp[j+1]; ++i)
72 {
73 IU rowid = A.dcsc->ir[i];
74 SR::axpy(A.dcsc->numx[i], x[colid], y[rowid]);
75 }
76 }
77 }
78 }
79
80 //! SpMV with dense vector (multithreaded version)
81 template <typename SR, typename IU, typename NU, typename RHS, typename LHS>
dcsc_gespmv_threaded(const SpDCCols<IU,NU> & A,const RHS * x,LHS * y)82 void dcsc_gespmv_threaded (const SpDCCols<IU, NU> & A, const RHS * x, LHS * y)
83 {
84 if(A.nnz > 0)
85 {
86 int nthreads=1;
87 #ifdef _OPENMP
88 #pragma omp parallel
89 {
90 nthreads = omp_get_num_threads();
91 }
92 #endif
93
94 IU nlocrows = A.getnrow();
95 LHS ** tomerge = SpHelper::allocate2D<LHS>(nthreads, nlocrows);
96 auto id = SR::id();
97
98 for(int i=0; i<nthreads; ++i)
99 {
100 std::fill_n(tomerge[i], nlocrows, id);
101 }
102
103 #pragma omp parallel for
104 for(IU j =0; j<A.dcsc->nzc; ++j) // for all nonzero columns
105 {
106 int curthread = 1;
107 #ifdef _OPENMP
108 curthread = omp_get_thread_num();
109 #endif
110
111 LHS * loc2merge = tomerge[curthread];
112
113 IU colid = A.dcsc->jc[j];
114 for(IU i = A.dcsc->cp[j]; i< A.dcsc->cp[j+1]; ++i)
115 {
116 IU rowid = A.dcsc->ir[i];
117 SR::axpy(A.dcsc->numx[i], x[colid], loc2merge[rowid]);
118 }
119 }
120
121 #pragma omp parallel for
122 for(IU j=0; j < nlocrows; ++j)
123 {
124 for(int i=0; i< nthreads; ++i)
125 {
126 y[j] = SR::add(y[j], tomerge[i][j]);
127 }
128 }
129 SpHelper::deallocate2D(tomerge, nthreads);
130 }
131 }
132
133
134 /**
135 * Multithreaded SpMV with sparse vector
136 * the assembly of outgoing buffers sendindbuf/sendnumbuf are done here
137 */
138 template <typename SR, typename IU, typename NUM, typename DER, typename IVT, typename OVT>
generic_gespmv_threaded(const SpMat<IU,NUM,DER> & A,const int32_t * indx,const IVT * numx,int32_t nnzx,int32_t * & sendindbuf,OVT * & sendnumbuf,int * & sdispls,int p_c,PreAllocatedSPA<OVT> & SPA)139 int generic_gespmv_threaded (const SpMat<IU,NUM,DER> & A, const int32_t * indx, const IVT * numx, int32_t nnzx,
140 int32_t * & sendindbuf, OVT * & sendnumbuf, int * & sdispls, int p_c, PreAllocatedSPA<OVT> & SPA)
141 {
142 // FACTS: Split boundaries (for multithreaded execution) are independent of recipient boundaries
143 // Two splits might create output to the same recipient (needs to be merged)
144 // However, each split's output is distinct (no duplicate elimination is needed after merge)
145
146 sdispls = new int[p_c](); // initialize to zero (as all indy might be empty)
147 if(A.getnnz() > 0 && nnzx > 0)
148 {
149 int splits = A.getnsplit();
150 if(splits > 0)
151 {
152 int32_t nlocrows = static_cast<int32_t>(A.getnrow());
153 int32_t perpiece = nlocrows / splits;
154 std::vector< std::vector< int32_t > > indy(splits);
155 std::vector< std::vector< OVT > > numy(splits);
156
157 // Parallelize with OpenMP
158 #ifdef _OPENMP
159 #pragma omp parallel for // num_threads(6)
160 #endif
161 for(int i=0; i<splits; ++i)
162 {
163 if(SPA.initialized)
164 {
165 if(i != splits-1)
166 SpMXSpV_ForThreading<SR>(*(A.GetInternal(i)), perpiece, indx, numx, nnzx, indy[i], numy[i], i*perpiece, SPA.V_localy[i], SPA.V_isthere[i], SPA.V_inds[i]);
167 else
168 SpMXSpV_ForThreading<SR>(*(A.GetInternal(i)), nlocrows - perpiece*i, indx, numx, nnzx, indy[i], numy[i], i*perpiece, SPA.V_localy[i], SPA.V_isthere[i], SPA.V_inds[i]);
169 }
170 else
171 {
172 if(i != splits-1)
173 SpMXSpV_ForThreading<SR>(*(A.GetInternal(i)), perpiece, indx, numx, nnzx, indy[i], numy[i], i*perpiece);
174 else
175 SpMXSpV_ForThreading<SR>(*(A.GetInternal(i)), nlocrows - perpiece*i, indx, numx, nnzx, indy[i], numy[i], i*perpiece);
176 }
177 }
178
179 std::vector<int> accum(splits+1, 0);
180 for(int i=0; i<splits; ++i)
181 accum[i+1] = accum[i] + indy[i].size();
182
183 sendindbuf = new int32_t[accum[splits]];
184 sendnumbuf = new OVT[accum[splits]];
185 int32_t perproc = nlocrows / p_c;
186 int32_t last_rec = p_c-1;
187
188 // keep recipients of last entries in each split (-1 for an empty split)
189 // so that we can delete indy[] and numy[] contents as soon as they are processed
190 std::vector<int32_t> end_recs(splits);
191 for(int i=0; i<splits; ++i)
192 {
193 if(indy[i].empty())
194 end_recs[i] = -1;
195 else
196 end_recs[i] = std::min(indy[i].back() / perproc, last_rec);
197 }
198 #ifdef _OPENMP
199 #pragma omp parallel for // num_threads(6)
200 #endif
201 for(int i=0; i<splits; ++i)
202 {
203 if(!indy[i].empty()) // guarantee that .begin() and .end() are not null
204 {
205 // FACT: Data is sorted, so if the recipient of begin is the same as the owner of end,
206 // then the whole data is sent to the same processor
207 int32_t beg_rec = std::min( indy[i].front() / perproc, last_rec);
208
209 // We have to test the previous "split", to see if we are marking a "recipient head"
210 // set displacement markers for the completed (previous) buffers only
211 if(i != 0)
212 {
213 int k = i-1;
214 while (k >= 0 && end_recs[k] == -1) k--; // loop backwards until seeing an non-empty split
215 if(k >= 0) // we found a non-empty split
216 {
217 std::fill(sdispls+end_recs[k]+1, sdispls+beg_rec+1, accum[i]); // last entry to be set is sdispls[beg_rec]
218 }
219 // else fill sdispls[1...beg_rec] with zero (already done)
220 }
221 // else set sdispls[0] to zero (already done)
222 if(beg_rec == end_recs[i]) // fast case
223 {
224 std::transform(indy[i].begin(), indy[i].end(), indy[i].begin(), std::bind2nd(std::minus<int32_t>(), perproc*beg_rec));
225 std::copy(indy[i].begin(), indy[i].end(), sendindbuf+accum[i]);
226 std::copy(numy[i].begin(), numy[i].end(), sendnumbuf+accum[i]);
227 }
228 else // slow case
229 {
230 // FACT: No matter how many splits or threads, there will be only one "recipient head"
231 // Therefore there are no race conditions for marking send displacements (sdispls)
232 int end = indy[i].size();
233 for(int cur=0; cur< end; ++cur)
234 {
235 int32_t cur_rec = std::min( indy[i][cur] / perproc, last_rec);
236 while(beg_rec != cur_rec)
237 {
238 sdispls[++beg_rec] = accum[i] + cur; // first entry to be set is sdispls[beg_rec+1]
239 }
240 sendindbuf[ accum[i] + cur ] = indy[i][cur] - perproc*beg_rec; // convert to receiver's local index
241 sendnumbuf[ accum[i] + cur ] = numy[i][cur];
242 }
243 }
244 std::vector<int32_t>().swap(indy[i]);
245 std::vector<OVT>().swap(numy[i]);
246 bool lastnonzero = true; // am I the last nonzero split?
247 for(int k=i+1; k < splits; ++k)
248 {
249 if(end_recs[k] != -1)
250 lastnonzero = false;
251 }
252 if(lastnonzero)
253 std::fill(sdispls+end_recs[i]+1, sdispls+p_c, accum[i+1]);
254 } // end_if(!indy[i].empty)
255 } // end parallel for
256 return accum[splits];
257 }
258 else
259 {
260 std::cout << "Something is wrong, splits should be nonzero for multithreaded execution" << std::endl;
261 return 0;
262 }
263 }
264 else
265 {
266 sendindbuf = NULL;
267 sendnumbuf = NULL;
268 return 0;
269 }
270 }
271
272
273 /**
274 * Multithreaded SpMV with sparse vector and preset buffers
275 * the assembly of outgoing buffers sendindbuf/sendnumbuf are done here
276 * IVT: input vector numerical type
277 * OVT: output vector numerical type
278 */
279 template <typename SR, typename IU, typename NUM, typename DER, typename IVT, typename OVT>
generic_gespmv_threaded_setbuffers(const SpMat<IU,NUM,DER> & A,const int32_t * indx,const IVT * numx,int32_t nnzx,int32_t * sendindbuf,OVT * sendnumbuf,int * cnts,int * dspls,int p_c)280 void generic_gespmv_threaded_setbuffers (const SpMat<IU,NUM,DER> & A, const int32_t * indx, const IVT * numx, int32_t nnzx,
281 int32_t * sendindbuf, OVT * sendnumbuf, int * cnts, int * dspls, int p_c)
282 {
283 if(A.getnnz() > 0 && nnzx > 0)
284 {
285 int splits = A.getnsplit();
286 if(splits > 0)
287 {
288 std::vector< std::vector<int32_t> > indy(splits);
289 std::vector< std::vector< OVT > > numy(splits);
290 int32_t nlocrows = static_cast<int32_t>(A.getnrow());
291 int32_t perpiece = nlocrows / splits;
292
293 #ifdef _OPENMP
294 #pragma omp parallel for
295 #endif
296 for(int i=0; i<splits; ++i)
297 {
298 if(i != splits-1)
299 SpMXSpV_ForThreading<SR>(*(A.GetInternal(i)), perpiece, indx, numx, nnzx, indy[i], numy[i], i*perpiece);
300 else
301 SpMXSpV_ForThreading<SR>(*(A.GetInternal(i)), nlocrows - perpiece*i, indx, numx, nnzx, indy[i], numy[i], i*perpiece);
302 }
303
304 int32_t perproc = nlocrows / p_c;
305 int32_t last_rec = p_c-1;
306
307 // keep recipients of last entries in each split (-1 for an empty split)
308 // so that we can delete indy[] and numy[] contents as soon as they are processed
309 std::vector<int32_t> end_recs(splits);
310 for(int i=0; i<splits; ++i)
311 {
312 if(indy[i].empty())
313 end_recs[i] = -1;
314 else
315 end_recs[i] = std::min(indy[i].back() / perproc, last_rec);
316 }
317
318 int ** loc_rec_cnts = new int *[splits];
319 #ifdef _OPENMP
320 #pragma omp parallel for
321 #endif
322 for(int i=0; i<splits; ++i)
323 {
324 loc_rec_cnts[i] = new int[p_c](); // thread-local recipient data
325 if(!indy[i].empty()) // guarantee that .begin() and .end() are not null
326 {
327 int32_t cur_rec = std::min( indy[i].front() / perproc, last_rec);
328 int32_t lastdata = (cur_rec+1) * perproc; // one past last entry that goes to this current recipient
329 for(typename std::vector<int32_t>::iterator it = indy[i].begin(); it != indy[i].end(); ++it)
330 {
331
332 if( ( (*it) >= lastdata ) && cur_rec != last_rec )
333 {
334 cur_rec = std::min( (*it) / perproc, last_rec);
335 lastdata = (cur_rec+1) * perproc;
336 }
337 ++loc_rec_cnts[i][cur_rec];
338 }
339 }
340 }
341 #ifdef _OPENMP
342 #pragma omp parallel for
343 #endif
344 for(int i=0; i<splits; ++i)
345 {
346 if(!indy[i].empty()) // guarantee that .begin() and .end() are not null
347 {
348 // FACT: Data is sorted, so if the recipient of begin is the same as the owner of end,
349 // then the whole data is sent to the same processor
350 int32_t beg_rec = std::min( indy[i].front() / perproc, last_rec);
351 int32_t alreadysent = 0; // already sent per recipient
352 for(int before = i-1; before >= 0; before--)
353 alreadysent += loc_rec_cnts[before][beg_rec];
354
355 if(beg_rec == end_recs[i]) // fast case
356 {
357 std::transform(indy[i].begin(), indy[i].end(), indy[i].begin(), std::bind2nd(std::minus<int32_t>(), perproc*beg_rec));
358 std::copy(indy[i].begin(), indy[i].end(), sendindbuf + dspls[beg_rec] + alreadysent);
359 std::copy(numy[i].begin(), numy[i].end(), sendnumbuf + dspls[beg_rec] + alreadysent);
360 }
361 else // slow case
362 {
363 int32_t cur_rec = beg_rec;
364 int32_t lastdata = (cur_rec+1) * perproc; // one past last entry that goes to this current recipient
365 for(typename std::vector<int32_t>::iterator it = indy[i].begin(); it != indy[i].end(); ++it)
366 {
367 if( ( (*it) >= lastdata ) && cur_rec != last_rec )
368 {
369 cur_rec = std::min( (*it) / perproc, last_rec);
370 lastdata = (cur_rec+1) * perproc;
371
372 // if this split switches to a new recipient after sending some data
373 // then it's sure that no data has been sent to that recipient yet
374 alreadysent = 0;
375 }
376 sendindbuf[ dspls[cur_rec] + alreadysent ] = (*it) - perproc*cur_rec; // convert to receiver's local index
377 sendnumbuf[ dspls[cur_rec] + (alreadysent++) ] = *(numy[i].begin() + (it-indy[i].begin()));
378 }
379 }
380 }
381 }
382 // Deallocated rec counts serially once all threads complete
383 for(int i=0; i< splits; ++i)
384 {
385 for(int j=0; j< p_c; ++j)
386 cnts[j] += loc_rec_cnts[i][j];
387 delete [] loc_rec_cnts[i];
388 }
389 delete [] loc_rec_cnts;
390 }
391 else
392 {
393 std::cout << "Something is wrong, splits should be nonzero for multithreaded execution" << std::endl;
394 }
395 }
396 }
397
398 //! SpMV with sparse vector
399 //! MIND: Matrix index type
400 //! VIND: Vector index type (optimized: int32_t, general: int64_t)
401 template <typename SR, typename MIND, typename VIND, typename DER, typename NUM, typename IVT, typename OVT>
generic_gespmv(const SpMat<MIND,NUM,DER> & A,const VIND * indx,const IVT * numx,VIND nnzx,std::vector<VIND> & indy,std::vector<OVT> & numy,PreAllocatedSPA<OVT> & SPA)402 void generic_gespmv (const SpMat<MIND,NUM,DER> & A, const VIND * indx, const IVT * numx, VIND nnzx, std::vector<VIND> & indy, std::vector<OVT> & numy, PreAllocatedSPA<OVT> & SPA)
403 {
404 if(A.getnnz() > 0 && nnzx > 0)
405 {
406 if(A.getnsplit() > 0)
407 {
408 std::cout << "Call dcsc_gespmv_threaded instead" << std::endl;
409 }
410 else
411 {
412 SpMXSpV<SR>(*(A.GetInternal()), (VIND) A.getnrow(), indx, numx, nnzx, indy, numy, SPA);
413 }
414 }
415 }
416
417 /** SpMV with sparse vector
418 * @param[in] indexisvalue is only used for BFS-like computations, if true then we can call the optimized version that skips SPA
419 */
420 template <typename SR, typename IU, typename DER, typename NUM, typename IVT, typename OVT>
generic_gespmv(const SpMat<IU,NUM,DER> & A,const int32_t * indx,const IVT * numx,int32_t nnzx,int32_t * indy,OVT * numy,int * cnts,int * dspls,int p_c,bool indexisvalue)421 void generic_gespmv (const SpMat<IU,NUM,DER> & A, const int32_t * indx, const IVT * numx, int32_t nnzx,
422 int32_t * indy, OVT * numy, int * cnts, int * dspls, int p_c, bool indexisvalue)
423 {
424 if(A.getnnz() > 0 && nnzx > 0)
425 {
426 if(A.getnsplit() > 0)
427 {
428 SpParHelper::Print("Call dcsc_gespmv_threaded instead\n");
429 }
430 else
431 {
432 SpMXSpV<SR>(*(A.GetInternal()), (int32_t) A.getnrow(), indx, numx, nnzx, indy, numy, cnts, dspls, p_c);
433 }
434 }
435 }
436
437
438 template<typename IU>
BooleanRowSplit(SpDCCols<IU,bool> & A,int numsplits)439 void BooleanRowSplit(SpDCCols<IU, bool> & A, int numsplits)
440 {
441 A.splits = numsplits;
442 IU perpiece = A.m / A.splits;
443 std::vector<IU> prevcolids(A.splits, -1); // previous column id's are set to -1
444 std::vector<IU> nzcs(A.splits, 0);
445 std::vector<IU> nnzs(A.splits, 0);
446 std::vector < std::vector < std::pair<IU,IU> > > colrowpairs(A.splits);
447 if(A.nnz > 0 && A.dcsc != NULL)
448 {
449 for(IU i=0; i< A.dcsc->nzc; ++i)
450 {
451 for(IU j = A.dcsc->cp[i]; j< A.dcsc->cp[i+1]; ++j)
452 {
453 IU colid = A.dcsc->jc[i];
454 IU rowid = A.dcsc->ir[j];
455 IU owner = std::min(rowid / perpiece, static_cast<IU>(A.splits-1));
456 colrowpairs[owner].push_back(std::make_pair(colid, rowid - owner*perpiece));
457
458 if(prevcolids[owner] != colid)
459 {
460 prevcolids[owner] = colid;
461 ++nzcs[owner];
462 }
463 ++nnzs[owner];
464 }
465 }
466 }
467 delete A.dcsc; // claim memory
468 //copy(nzcs.begin(), nzcs.end(), ostream_iterator<IU>(cout," " )); cout << endl;
469 //copy(nnzs.begin(), nnzs.end(), ostream_iterator<IU>(cout," " )); cout << endl;
470 A.dcscarr = new Dcsc<IU,bool>*[A.splits];
471
472 // To be parallelized with OpenMP
473 for(int i=0; i< A.splits; ++i)
474 {
475 sort(colrowpairs[i].begin(), colrowpairs[i].end()); // sort w.r.t. columns
476 A.dcscarr[i] = new Dcsc<IU,bool>(nnzs[i],nzcs[i]);
477 std::fill(A.dcscarr[i]->numx, A.dcscarr[i]->numx+nnzs[i], static_cast<bool>(1));
478 IU curnzc = 0; // number of nonzero columns constructed so far
479 IU cindex = colrowpairs[i][0].first;
480 IU rindex = colrowpairs[i][0].second;
481
482 A.dcscarr[i]->ir[0] = rindex;
483 A.dcscarr[i]->jc[curnzc] = cindex;
484 A.dcscarr[i]->cp[curnzc++] = 0;
485
486 for(IU j=1; j<nnzs[i]; ++j)
487 {
488 cindex = colrowpairs[i][j].first;
489 rindex = colrowpairs[i][j].second;
490
491 A.dcscarr[i]->ir[j] = rindex;
492 if(cindex != A.dcscarr[i]->jc[curnzc-1])
493 {
494 A.dcscarr[i]->jc[curnzc] = cindex;
495 A.dcscarr[i]->cp[curnzc++] = j;
496 }
497 }
498 A.dcscarr[i]->cp[curnzc] = nnzs[i];
499 }
500 }
501
502
503 /**
504 * SpTuples(A*B') (Using OuterProduct Algorithm)
505 * Returns the tuples for efficient merging later
506 * Support mixed precision multiplication
507 * The multiplication is on the specified semiring (passed as parameter)
508 */
509 template<class SR, class NUO, class IU, class NU1, class NU2>
510 SpTuples<IU, NUO> * Tuples_AnXBt
511 (const SpDCCols<IU, NU1> & A,
512 const SpDCCols<IU, NU2> & B,
513 bool clearA = false, bool clearB = false)
514 {
515 IU mdim = A.m;
516 IU ndim = B.m; // B is already transposed
517
518 if(A.isZero() || B.isZero())
519 {
520 if(clearA) delete const_cast<SpDCCols<IU, NU1> *>(&A);
521 if(clearB) delete const_cast<SpDCCols<IU, NU2> *>(&B);
522 return new SpTuples< IU, NUO >(0, mdim, ndim); // just return an empty matrix
523 }
524 Isect<IU> *isect1, *isect2, *itr1, *itr2, *cols, *rows;
525 SpHelper::SpIntersect(*(A.dcsc), *(B.dcsc), cols, rows, isect1, isect2, itr1, itr2);
526
527 IU kisect = static_cast<IU>(itr1-isect1); // size of the intersection ((itr1-isect1) == (itr2-isect2))
528 if(kisect == 0)
529 {
530 if(clearA) delete const_cast<SpDCCols<IU, NU1> *>(&A);
531 if(clearB) delete const_cast<SpDCCols<IU, NU2> *>(&B);
532 DeleteAll(isect1, isect2, cols, rows);
533 return new SpTuples< IU, NUO >(0, mdim, ndim);
534 }
535
536 StackEntry< NUO, std::pair<IU,IU> > * multstack;
537
538 IU cnz = SpHelper::SpCartesian< SR > (*(A.dcsc), *(B.dcsc), kisect, isect1, isect2, multstack);
539 DeleteAll(isect1, isect2, cols, rows);
540
541 if(clearA) delete const_cast<SpDCCols<IU, NU1> *>(&A);
542 if(clearB) delete const_cast<SpDCCols<IU, NU2> *>(&B);
543 return new SpTuples<IU, NUO> (cnz, mdim, ndim, multstack);
544 }
545
546 /**
547 * SpTuples(A*B) (Using ColByCol Algorithm)
548 * Returns the tuples for efficient merging later
549 * Support mixed precision multiplication
550 * The multiplication is on the specified semiring (passed as parameter)
551 */
552 template<class SR, class NUO, class IU, class NU1, class NU2>
553 SpTuples<IU, NUO> * Tuples_AnXBn
554 (const SpDCCols<IU, NU1> & A,
555 const SpDCCols<IU, NU2> & B,
556 bool clearA = false, bool clearB = false)
557 {
558 IU mdim = A.m;
559 IU ndim = B.n;
560 if(A.isZero() || B.isZero())
561 {
562 return new SpTuples<IU, NUO>(0, mdim, ndim);
563 }
564 StackEntry< NUO, std::pair<IU,IU> > * multstack;
565 IU cnz = SpHelper::SpColByCol< SR > (*(A.dcsc), *(B.dcsc), A.n, multstack);
566
567 if(clearA)
568 delete const_cast<SpDCCols<IU, NU1> *>(&A);
569 if(clearB)
570 delete const_cast<SpDCCols<IU, NU2> *>(&B);
571
572 return new SpTuples<IU, NUO> (cnz, mdim, ndim, multstack);
573 }
574
575
576 template<class SR, class NUO, class IU, class NU1, class NU2>
577 SpTuples<IU, NUO> * Tuples_AtXBt
578 (const SpDCCols<IU, NU1> & A,
579 const SpDCCols<IU, NU2> & B,
580 bool clearA = false, bool clearB = false)
581 {
582 IU mdim = A.n;
583 IU ndim = B.m;
584 std::cout << "Tuples_AtXBt function has not been implemented yet !" << std::endl;
585
586 return new SpTuples<IU, NUO> (0, mdim, ndim);
587 }
588
589 template<class SR, class NUO, class IU, class NU1, class NU2>
590 SpTuples<IU, NUO> * Tuples_AtXBn
591 (const SpDCCols<IU, NU1> & A,
592 const SpDCCols<IU, NU2> & B,
593 bool clearA = false, bool clearB = false)
594 {
595 IU mdim = A.n;
596 IU ndim = B.n;
597 std::cout << "Tuples_AtXBn function has not been implemented yet !" << std::endl;
598
599 return new SpTuples<IU, NUO> (0, mdim, ndim);
600 }
601
602 // Performs a balanced merge of the array of SpTuples
603 // Assumes the input parameters are already column sorted
604 template<class SR, class IU, class NU>
605 SpTuples<IU,NU> MergeAll( const std::vector<SpTuples<IU,NU> *> & ArrSpTups, IU mstar = 0, IU nstar = 0, bool delarrs = false )
606 {
607 int hsize = ArrSpTups.size();
608 if(hsize == 0)
609 {
610 return SpTuples<IU,NU>(0, mstar,nstar);
611 }
612 else
613 {
614 mstar = ArrSpTups[0]->m;
615 nstar = ArrSpTups[0]->n;
616 }
617 for(int i=1; i< hsize; ++i)
618 {
619 if((mstar != ArrSpTups[i]->m) || nstar != ArrSpTups[i]->n)
620 {
621 std::cerr << "Dimensions do not match on MergeAll()" << std::endl;
622 return SpTuples<IU,NU>(0,0,0);
623 }
624 }
625 if(hsize > 1)
626 {
627 ColLexiCompare<IU,int> heapcomp;
628 std::tuple<IU, IU, int> * heap = new std::tuple<IU, IU, int> [hsize]; // (rowindex, colindex, source-id)
629 IU * curptr = new IU[hsize];
630 std::fill_n(curptr, hsize, static_cast<IU>(0));
631 IU estnnz = 0;
632
633 for(int i=0; i< hsize; ++i)
634 {
635 estnnz += ArrSpTups[i]->getnnz();
636 heap[i] = std::make_tuple(std::get<0>(ArrSpTups[i]->tuples[0]), std::get<1>(ArrSpTups[i]->tuples[0]), i);
637 }
638 std::make_heap(heap, heap+hsize, std::not2(heapcomp));
639
640 std::tuple<IU, IU, NU> * ntuples = new std::tuple<IU,IU,NU>[estnnz];
641 IU cnz = 0;
642
643 while(hsize > 0)
644 {
645 std::pop_heap(heap, heap + hsize, std::not2(heapcomp)); // result is stored in heap[hsize-1]
646 int source = std::get<2>(heap[hsize-1]);
647
648 if( (cnz != 0) &&
649 ((std::get<0>(ntuples[cnz-1]) == std::get<0>(heap[hsize-1])) && (std::get<1>(ntuples[cnz-1]) == std::get<1>(heap[hsize-1]))) )
650 {
651 std::get<2>(ntuples[cnz-1]) = SR::add(std::get<2>(ntuples[cnz-1]), ArrSpTups[source]->numvalue(curptr[source]++));
652 }
653 else
654 {
655 ntuples[cnz++] = ArrSpTups[source]->tuples[curptr[source]++];
656 }
657
658 if(curptr[source] != ArrSpTups[source]->getnnz()) // That array has not been depleted
659 {
660 heap[hsize-1] = std::make_tuple(std::get<0>(ArrSpTups[source]->tuples[curptr[source]]),
661 std::get<1>(ArrSpTups[source]->tuples[curptr[source]]), source);
662 std::push_heap(heap, heap+hsize, std::not2(heapcomp));
663 }
664 else
665 {
666 --hsize;
667 }
668 }
669 SpHelper::ShrinkArray(ntuples, cnz);
670 DeleteAll(heap, curptr);
671
672 if(delarrs)
673 {
674 for(size_t i=0; i<ArrSpTups.size(); ++i)
675 delete ArrSpTups[i];
676 }
677 return SpTuples<IU,NU> (cnz, mstar, nstar, ntuples);
678 }
679 else
680 {
681 SpTuples<IU,NU> ret = *ArrSpTups[0];
682 if(delarrs)
683 delete ArrSpTups[0];
684 return ret;
685 }
686 }
687
688 /**
689 * @param[in] exclude if false,
690 * \n then operation is A = A .* B
691 * \n else operation is A = A .* not(B)
692 **/
693 template <typename IU, typename NU1, typename NU2>
EWiseMult(const Dcsc<IU,NU1> & A,const Dcsc<IU,NU2> * B,bool exclude)694 Dcsc<IU, typename promote_trait<NU1,NU2>::T_promote> EWiseMult(const Dcsc<IU,NU1> & A, const Dcsc<IU,NU2> * B, bool exclude)
695 {
696 typedef typename promote_trait<NU1,NU2>::T_promote N_promote;
697 IU estnzc, estnz;
698 if(exclude)
699 {
700 estnzc = A.nzc;
701 estnz = A.nz;
702 }
703 else
704 {
705 estnzc = std::min(A.nzc, B->nzc);
706 estnz = std::min(A.nz, B->nz);
707 }
708
709 Dcsc<IU,N_promote> temp(estnz, estnzc);
710
711 IU curnzc = 0;
712 IU curnz = 0;
713 IU i = 0;
714 IU j = 0;
715 temp.cp[0] = 0;
716
717 if(!exclude) // A = A .* B
718 {
719 while(i< A.nzc && B != NULL && j<B->nzc)
720 {
721 if(A.jc[i] > B->jc[j]) ++j;
722 else if(A.jc[i] < B->jc[j]) ++i;
723 else
724 {
725 IU ii = A.cp[i];
726 IU jj = B->cp[j];
727 IU prevnz = curnz;
728 while (ii < A.cp[i+1] && jj < B->cp[j+1])
729 {
730 if (A.ir[ii] < B->ir[jj]) ++ii;
731 else if (A.ir[ii] > B->ir[jj]) ++jj;
732 else
733 {
734 temp.ir[curnz] = A.ir[ii];
735 temp.numx[curnz++] = A.numx[ii++] * B->numx[jj++];
736 }
737 }
738 if(prevnz < curnz) // at least one nonzero exists in this column
739 {
740 temp.jc[curnzc++] = A.jc[i];
741 temp.cp[curnzc] = temp.cp[curnzc-1] + curnz-prevnz;
742 }
743 ++i;
744 ++j;
745 }
746 }
747 }
748 else // A = A .* not(B)
749 {
750 while(i< A.nzc && B != NULL && j< B->nzc)
751 {
752 if(A.jc[i] > B->jc[j]) ++j;
753 else if(A.jc[i] < B->jc[j])
754 {
755 temp.jc[curnzc++] = A.jc[i++];
756 for(IU k = A.cp[i-1]; k< A.cp[i]; k++)
757 {
758 temp.ir[curnz] = A.ir[k];
759 temp.numx[curnz++] = A.numx[k];
760 }
761 temp.cp[curnzc] = temp.cp[curnzc-1] + (A.cp[i] - A.cp[i-1]);
762 }
763 else
764 {
765 IU ii = A.cp[i];
766 IU jj = B->cp[j];
767 IU prevnz = curnz;
768 while (ii < A.cp[i+1] && jj < B->cp[j+1])
769 {
770 if (A.ir[ii] > B->ir[jj]) ++jj;
771 else if (A.ir[ii] < B->ir[jj])
772 {
773 temp.ir[curnz] = A.ir[ii];
774 temp.numx[curnz++] = A.numx[ii++];
775 }
776 else // eliminate those existing nonzeros
777 {
778 ++ii;
779 ++jj;
780 }
781 }
782 while (ii < A.cp[i+1])
783 {
784 temp.ir[curnz] = A.ir[ii];
785 temp.numx[curnz++] = A.numx[ii++];
786 }
787
788 if(prevnz < curnz) // at least one nonzero exists in this column
789 {
790 temp.jc[curnzc++] = A.jc[i];
791 temp.cp[curnzc] = temp.cp[curnzc-1] + curnz-prevnz;
792 }
793 ++i;
794 ++j;
795 }
796 }
797 while(i< A.nzc)
798 {
799 temp.jc[curnzc++] = A.jc[i++];
800 for(IU k = A.cp[i-1]; k< A.cp[i]; ++k)
801 {
802 temp.ir[curnz] = A.ir[k];
803 temp.numx[curnz++] = A.numx[k];
804 }
805 temp.cp[curnzc] = temp.cp[curnzc-1] + (A.cp[i] - A.cp[i-1]);
806 }
807 }
808
809 temp.Resize(curnzc, curnz);
810 return temp;
811 }
812
813 template <typename N_promote, typename IU, typename NU1, typename NU2, typename _BinaryOperation>
EWiseApply(const Dcsc<IU,NU1> & A,const Dcsc<IU,NU2> * B,_BinaryOperation __binary_op,bool notB,const NU2 & defaultBVal)814 Dcsc<IU, N_promote> EWiseApply(const Dcsc<IU,NU1> & A, const Dcsc<IU,NU2> * B, _BinaryOperation __binary_op, bool notB, const NU2& defaultBVal)
815 {
816 //typedef typename promote_trait<NU1,NU2>::T_promote N_promote;
817 IU estnzc, estnz;
818 if(notB)
819 {
820 estnzc = A.nzc;
821 estnz = A.nz;
822 }
823 else
824 {
825 estnzc = std::min(A.nzc, B->nzc);
826 estnz = std::min(A.nz, B->nz);
827 }
828
829 Dcsc<IU,N_promote> temp(estnz, estnzc);
830
831 IU curnzc = 0;
832 IU curnz = 0;
833 IU i = 0;
834 IU j = 0;
835 temp.cp[0] = 0;
836
837 if(!notB) // A = A .* B
838 {
839 while(i< A.nzc && B != NULL && j<B->nzc)
840 {
841 if(A.jc[i] > B->jc[j]) ++j;
842 else if(A.jc[i] < B->jc[j]) ++i;
843 else
844 {
845 IU ii = A.cp[i];
846 IU jj = B->cp[j];
847 IU prevnz = curnz;
848 while (ii < A.cp[i+1] && jj < B->cp[j+1])
849 {
850 if (A.ir[ii] < B->ir[jj]) ++ii;
851 else if (A.ir[ii] > B->ir[jj]) ++jj;
852 else
853 {
854 temp.ir[curnz] = A.ir[ii];
855 temp.numx[curnz++] = __binary_op(A.numx[ii++], B->numx[jj++]);
856 }
857 }
858 if(prevnz < curnz) // at least one nonzero exists in this column
859 {
860 temp.jc[curnzc++] = A.jc[i];
861 temp.cp[curnzc] = temp.cp[curnzc-1] + curnz-prevnz;
862 }
863 ++i;
864 ++j;
865 }
866 }
867 }
868 else // A = A .* not(B)
869 {
870 while(i< A.nzc && B != NULL && j< B->nzc)
871 {
872 if(A.jc[i] > B->jc[j]) ++j;
873 else if(A.jc[i] < B->jc[j])
874 {
875 temp.jc[curnzc++] = A.jc[i++];
876 for(IU k = A.cp[i-1]; k< A.cp[i]; k++)
877 {
878 temp.ir[curnz] = A.ir[k];
879 temp.numx[curnz++] = __binary_op(A.numx[k], defaultBVal);
880 }
881 temp.cp[curnzc] = temp.cp[curnzc-1] + (A.cp[i] - A.cp[i-1]);
882 }
883 else
884 {
885 IU ii = A.cp[i];
886 IU jj = B->cp[j];
887 IU prevnz = curnz;
888 while (ii < A.cp[i+1] && jj < B->cp[j+1])
889 {
890 if (A.ir[ii] > B->ir[jj]) ++jj;
891 else if (A.ir[ii] < B->ir[jj])
892 {
893 temp.ir[curnz] = A.ir[ii];
894 temp.numx[curnz++] = __binary_op(A.numx[ii++], defaultBVal);
895 }
896 else // eliminate those existing nonzeros
897 {
898 ++ii;
899 ++jj;
900 }
901 }
902 while (ii < A.cp[i+1])
903 {
904 temp.ir[curnz] = A.ir[ii];
905 temp.numx[curnz++] = __binary_op(A.numx[ii++], defaultBVal);
906 }
907
908 if(prevnz < curnz) // at least one nonzero exists in this column
909 {
910 temp.jc[curnzc++] = A.jc[i];
911 temp.cp[curnzc] = temp.cp[curnzc-1] + curnz-prevnz;
912 }
913 ++i;
914 ++j;
915 }
916 }
917 while(i< A.nzc)
918 {
919 temp.jc[curnzc++] = A.jc[i++];
920 for(IU k = A.cp[i-1]; k< A.cp[i]; ++k)
921 {
922 temp.ir[curnz] = A.ir[k];
923 temp.numx[curnz++] = __binary_op(A.numx[k], defaultBVal);
924 }
925 temp.cp[curnzc] = temp.cp[curnzc-1] + (A.cp[i] - A.cp[i-1]);
926 }
927 }
928
929 temp.Resize(curnzc, curnz);
930 return temp;
931 }
932
933
934 template<typename IU, typename NU1, typename NU2>
EWiseMult(const SpDCCols<IU,NU1> & A,const SpDCCols<IU,NU2> & B,bool exclude)935 SpDCCols<IU, typename promote_trait<NU1,NU2>::T_promote > EWiseMult (const SpDCCols<IU,NU1> & A, const SpDCCols<IU,NU2> & B, bool exclude)
936 {
937 typedef typename promote_trait<NU1,NU2>::T_promote N_promote;
938 assert(A.m == B.m);
939 assert(A.n == B.n);
940
941 Dcsc<IU, N_promote> * tdcsc = NULL;
942 if(A.nnz > 0 && B.nnz > 0)
943 {
944 tdcsc = new Dcsc<IU, N_promote>(EWiseMult(*(A.dcsc), B.dcsc, exclude));
945 return SpDCCols<IU, N_promote> (A.m , A.n, tdcsc);
946 }
947 else if (A.nnz > 0 && exclude) // && B.nnz == 0
948 {
949 tdcsc = new Dcsc<IU, N_promote>(EWiseMult(*(A.dcsc), (const Dcsc<IU,NU2>*)NULL, exclude));
950 return SpDCCols<IU, N_promote> (A.m , A.n, tdcsc);
951 }
952 else
953 {
954 return SpDCCols<IU, N_promote> (A.m , A.n, tdcsc);
955 }
956 }
957
958
959 template<typename N_promote, typename IU, typename NU1, typename NU2, typename _BinaryOperation>
EWiseApply(const SpDCCols<IU,NU1> & A,const SpDCCols<IU,NU2> & B,_BinaryOperation __binary_op,bool notB,const NU2 & defaultBVal)960 SpDCCols<IU, N_promote> EWiseApply (const SpDCCols<IU,NU1> & A, const SpDCCols<IU,NU2> & B, _BinaryOperation __binary_op, bool notB, const NU2& defaultBVal)
961 {
962 //typedef typename promote_trait<NU1,NU2>::T_promote N_promote;
963 assert(A.m == B.m);
964 assert(A.n == B.n);
965
966 Dcsc<IU, N_promote> * tdcsc = NULL;
967 if(A.nnz > 0 && B.nnz > 0)
968 {
969 tdcsc = new Dcsc<IU, N_promote>(EWiseApply<N_promote>(*(A.dcsc), B.dcsc, __binary_op, notB, defaultBVal));
970 return SpDCCols<IU, N_promote> (A.m , A.n, tdcsc);
971 }
972 else if (A.nnz > 0 && notB) // && B.nnz == 0
973 {
974 tdcsc = new Dcsc<IU, N_promote>(EWiseApply<N_promote>(*(A.dcsc), (const Dcsc<IU,NU2>*)NULL, __binary_op, notB, defaultBVal));
975 return SpDCCols<IU, N_promote> (A.m , A.n, tdcsc);
976 }
977 else
978 {
979 return SpDCCols<IU, N_promote> (A.m , A.n, tdcsc);
980 }
981 }
982
983 /**
984 * Implementation based on operator +=
985 * Element wise apply with the following constraints
986 * The operation to be performed is __binary_op
987 * The operation `c = __binary_op(a, b)` is only performed if `do_op(a, b)` returns true
988 * If allowANulls is true, then if A is missing an element that B has, then ANullVal is used
989 * In that case the operation becomes c[i,j] = __binary_op(ANullVal, b[i,j])
990 * If both allowANulls and allowBNulls is false then the function degenerates into intersection
991 */
992 template <typename RETT, typename IU, typename NU1, typename NU2, typename _BinaryOperation, typename _BinaryPredicate>
EWiseApply(const Dcsc<IU,NU1> * Ap,const Dcsc<IU,NU2> * Bp,_BinaryOperation __binary_op,_BinaryPredicate do_op,bool allowANulls,bool allowBNulls,const NU1 & ANullVal,const NU2 & BNullVal,const bool allowIntersect)993 Dcsc<IU, RETT> EWiseApply(const Dcsc<IU,NU1> * Ap, const Dcsc<IU,NU2> * Bp, _BinaryOperation __binary_op, _BinaryPredicate do_op, bool allowANulls, bool allowBNulls, const NU1& ANullVal, const NU2& BNullVal, const bool allowIntersect)
994 {
995 if (Ap == NULL && Bp == NULL)
996 return Dcsc<IU,RETT>(0, 0);
997
998 if (Ap == NULL && Bp != NULL)
999 {
1000 if (!allowANulls)
1001 return Dcsc<IU,RETT>(0, 0);
1002
1003 const Dcsc<IU,NU2> & B = *Bp;
1004 IU estnzc = B.nzc;
1005 IU estnz = B.nz;
1006 Dcsc<IU,RETT> temp(estnz, estnzc);
1007
1008 IU curnzc = 0;
1009 IU curnz = 0;
1010 //IU i = 0;
1011 IU j = 0;
1012 temp.cp[0] = 0;
1013 while(j<B.nzc)
1014 {
1015 // Based on the if statement below which handles A null values.
1016 j++;
1017 IU prevnz = curnz;
1018 temp.jc[curnzc++] = B.jc[j-1];
1019 for(IU k = B.cp[j-1]; k< B.cp[j]; ++k)
1020 {
1021 if (do_op(ANullVal, B.numx[k], true, false))
1022 {
1023 temp.ir[curnz] = B.ir[k];
1024 temp.numx[curnz++] = __binary_op(ANullVal, B.numx[k], true, false);
1025 }
1026 }
1027 //temp.cp[curnzc] = temp.cp[curnzc-1] + (B.cp[j] - B.cp[j-1]);
1028 temp.cp[curnzc] = temp.cp[curnzc-1] + curnz-prevnz;
1029 }
1030 temp.Resize(curnzc, curnz);
1031 return temp;
1032 }
1033
1034 if (Ap != NULL && Bp == NULL)
1035 {
1036 if (!allowBNulls)
1037 return Dcsc<IU,RETT>(0, 0);
1038
1039 const Dcsc<IU,NU1> & A = *Ap;
1040 IU estnzc = A.nzc;
1041 IU estnz = A.nz;
1042 Dcsc<IU,RETT> temp(estnz, estnzc);
1043
1044 IU curnzc = 0;
1045 IU curnz = 0;
1046 IU i = 0;
1047 //IU j = 0;
1048 temp.cp[0] = 0;
1049 while(i< A.nzc)
1050 {
1051 i++;
1052 IU prevnz = curnz;
1053 temp.jc[curnzc++] = A.jc[i-1];
1054 for(IU k = A.cp[i-1]; k< A.cp[i]; k++)
1055 {
1056 if (do_op(A.numx[k], BNullVal, false, true))
1057 {
1058 temp.ir[curnz] = A.ir[k];
1059 temp.numx[curnz++] = __binary_op(A.numx[k], BNullVal, false, true);
1060 }
1061 }
1062 //temp.cp[curnzc] = temp.cp[curnzc-1] + (A.cp[i] - A.cp[i-1]);
1063 temp.cp[curnzc] = temp.cp[curnzc-1] + curnz-prevnz;
1064 }
1065 temp.Resize(curnzc, curnz);
1066 return temp;
1067 }
1068
1069 // both A and B are non-NULL at this point
1070 const Dcsc<IU,NU1> & A = *Ap;
1071 const Dcsc<IU,NU2> & B = *Bp;
1072
1073 IU estnzc = A.nzc + B.nzc;
1074 IU estnz = A.nz + B.nz;
1075 Dcsc<IU,RETT> temp(estnz, estnzc);
1076
1077 IU curnzc = 0;
1078 IU curnz = 0;
1079 IU i = 0;
1080 IU j = 0;
1081 temp.cp[0] = 0;
1082 while(i< A.nzc && j<B.nzc)
1083 {
1084 if(A.jc[i] > B.jc[j])
1085 {
1086 j++;
1087 if (allowANulls)
1088 {
1089 IU prevnz = curnz;
1090 temp.jc[curnzc++] = B.jc[j-1];
1091 for(IU k = B.cp[j-1]; k< B.cp[j]; ++k)
1092 {
1093 if (do_op(ANullVal, B.numx[k], true, false))
1094 {
1095 temp.ir[curnz] = B.ir[k];
1096 temp.numx[curnz++] = __binary_op(ANullVal, B.numx[k], true, false);
1097 }
1098 }
1099 //temp.cp[curnzc] = temp.cp[curnzc-1] + (B.cp[j] - B.cp[j-1]);
1100 temp.cp[curnzc] = temp.cp[curnzc-1] + curnz-prevnz;
1101 }
1102 }
1103 else if(A.jc[i] < B.jc[j])
1104 {
1105 i++;
1106 if (allowBNulls)
1107 {
1108 IU prevnz = curnz;
1109 temp.jc[curnzc++] = A.jc[i-1];
1110 for(IU k = A.cp[i-1]; k< A.cp[i]; k++)
1111 {
1112 if (do_op(A.numx[k], BNullVal, false, true))
1113 {
1114 temp.ir[curnz] = A.ir[k];
1115 temp.numx[curnz++] = __binary_op(A.numx[k], BNullVal, false, true);
1116 }
1117 }
1118 //temp.cp[curnzc] = temp.cp[curnzc-1] + (A.cp[i] - A.cp[i-1]);
1119 temp.cp[curnzc] = temp.cp[curnzc-1] + curnz-prevnz;
1120 }
1121 }
1122 else
1123 {
1124 temp.jc[curnzc++] = A.jc[i];
1125 IU ii = A.cp[i];
1126 IU jj = B.cp[j];
1127 IU prevnz = curnz;
1128 while (ii < A.cp[i+1] && jj < B.cp[j+1])
1129 {
1130 if (A.ir[ii] < B.ir[jj])
1131 {
1132 if (allowBNulls && do_op(A.numx[ii], BNullVal, false, true))
1133 {
1134 temp.ir[curnz] = A.ir[ii];
1135 temp.numx[curnz++] = __binary_op(A.numx[ii++], BNullVal, false, true);
1136 }
1137 else
1138 ii++;
1139 }
1140 else if (A.ir[ii] > B.ir[jj])
1141 {
1142 if (allowANulls && do_op(ANullVal, B.numx[jj], true, false))
1143 {
1144 temp.ir[curnz] = B.ir[jj];
1145 temp.numx[curnz++] = __binary_op(ANullVal, B.numx[jj++], true, false);
1146 }
1147 else
1148 jj++;
1149 }
1150 else
1151 {
1152 if (allowIntersect && do_op(A.numx[ii], B.numx[jj], false, false))
1153 {
1154 temp.ir[curnz] = A.ir[ii];
1155 temp.numx[curnz++] = __binary_op(A.numx[ii++], B.numx[jj++], false, false); // might include zeros
1156 }
1157 else
1158 {
1159 ii++;
1160 jj++;
1161 }
1162 }
1163 }
1164 while (ii < A.cp[i+1])
1165 {
1166 if (allowBNulls && do_op(A.numx[ii], BNullVal, false, true))
1167 {
1168 temp.ir[curnz] = A.ir[ii];
1169 temp.numx[curnz++] = __binary_op(A.numx[ii++], BNullVal, false, true);
1170 }
1171 else
1172 ii++;
1173 }
1174 while (jj < B.cp[j+1])
1175 {
1176 if (allowANulls && do_op(ANullVal, B.numx[jj], true, false))
1177 {
1178 temp.ir[curnz] = B.ir[jj];
1179 temp.numx[curnz++] = __binary_op(ANullVal, B.numx[jj++], true, false);
1180 }
1181 else
1182 jj++;
1183 }
1184 temp.cp[curnzc] = temp.cp[curnzc-1] + curnz-prevnz;
1185 ++i;
1186 ++j;
1187 }
1188 }
1189 while(allowBNulls && i< A.nzc) // remaining A elements after B ran out
1190 {
1191 IU prevnz = curnz;
1192 temp.jc[curnzc++] = A.jc[i++];
1193 for(IU k = A.cp[i-1]; k< A.cp[i]; ++k)
1194 {
1195 if (do_op(A.numx[k], BNullVal, false, true))
1196 {
1197 temp.ir[curnz] = A.ir[k];
1198 temp.numx[curnz++] = __binary_op(A.numx[k], BNullVal, false, true);
1199 }
1200 }
1201 //temp.cp[curnzc] = temp.cp[curnzc-1] + (A.cp[i] - A.cp[i-1]);
1202 temp.cp[curnzc] = temp.cp[curnzc-1] + curnz-prevnz;
1203 }
1204 while(allowANulls && j < B.nzc) // remaining B elements after A ran out
1205 {
1206 IU prevnz = curnz;
1207 temp.jc[curnzc++] = B.jc[j++];
1208 for(IU k = B.cp[j-1]; k< B.cp[j]; ++k)
1209 {
1210 if (do_op(ANullVal, B.numx[k], true, false))
1211 {
1212 temp.ir[curnz] = B.ir[k];
1213 temp.numx[curnz++] = __binary_op(ANullVal, B.numx[k], true, false);
1214 }
1215 }
1216 //temp.cp[curnzc] = temp.cp[curnzc-1] + (B.cp[j] - B.cp[j-1]);
1217 temp.cp[curnzc] = temp.cp[curnzc-1] + curnz-prevnz;
1218 }
1219 temp.Resize(curnzc, curnz);
1220 return temp;
1221 }
1222
1223 template <typename RETT, typename IU, typename NU1, typename NU2, typename _BinaryOperation, typename _BinaryPredicate>
EWiseApply(const SpDCCols<IU,NU1> & A,const SpDCCols<IU,NU2> & B,_BinaryOperation __binary_op,_BinaryPredicate do_op,bool allowANulls,bool allowBNulls,const NU1 & ANullVal,const NU2 & BNullVal,const bool allowIntersect)1224 SpDCCols<IU,RETT> EWiseApply (const SpDCCols<IU,NU1> & A, const SpDCCols<IU,NU2> & B, _BinaryOperation __binary_op, _BinaryPredicate do_op, bool allowANulls, bool allowBNulls, const NU1& ANullVal, const NU2& BNullVal, const bool allowIntersect)
1225 {
1226 assert(A.m == B.m);
1227 assert(A.n == B.n);
1228
1229 Dcsc<IU, RETT> * tdcsc = new Dcsc<IU, RETT>(EWiseApply<RETT>(A.dcsc, B.dcsc, __binary_op, do_op, allowANulls, allowBNulls, ANullVal, BNullVal, allowIntersect));
1230 return SpDCCols<IU, RETT> (A.m , A.n, tdcsc);
1231 }
1232
1233
1234 }
1235
1236 #endif
1237