1 /******************************************************************************
2  * Copyright (c) 2011, Duane Merrill.  All rights reserved.
3  * Copyright (c) 2011-2018, NVIDIA CORPORATION.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  *     * Redistributions of source code must retain the above copyright
8  *       notice, this list of conditions and the following disclaimer.
9  *     * Redistributions in binary form must reproduce the above copyright
10  *       notice, this list of conditions and the following disclaimer in the
11  *       documentation and/or other materials provided with the distribution.
12  *     * Neither the name of the NVIDIA CORPORATION nor the
13  *       names of its contributors may be used to endorse or promote products
14  *       derived from this software without specific prior written permission.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19  * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
20  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26  *
27  ******************************************************************************/
28 
29 /**
30  * \file
31  * cub::AgentSpmv implements a stateful abstraction of CUDA thread blocks for participating in device-wide SpMV.
32  */
33 
34 #pragma once
35 
36 #include <iterator>
37 
38 #include "../util_type.cuh"
39 #include "../block/block_reduce.cuh"
40 #include "../block/block_scan.cuh"
41 #include "../block/block_exchange.cuh"
42 #include "../thread/thread_search.cuh"
43 #include "../thread/thread_operators.cuh"
44 #include "../iterator/cache_modified_input_iterator.cuh"
45 #include "../iterator/counting_input_iterator.cuh"
46 #include "../iterator/tex_ref_input_iterator.cuh"
47 #include "../util_namespace.cuh"
48 
49 /// Optional outer namespace(s)
50 CUB_NS_PREFIX
51 
52 /// CUB namespace
53 namespace cub {
54 
55 
56 /******************************************************************************
57  * Tuning policy
58  ******************************************************************************/
59 
60 /**
61  * Parameterizable tuning policy type for AgentSpmv
62  */
63 template <
64     int                             _BLOCK_THREADS,                         ///< Threads per thread block
65     int                             _ITEMS_PER_THREAD,                      ///< Items per thread (per tile of input)
66     CacheLoadModifier               _ROW_OFFSETS_SEARCH_LOAD_MODIFIER,      ///< Cache load modifier for reading CSR row-offsets during search
67     CacheLoadModifier               _ROW_OFFSETS_LOAD_MODIFIER,             ///< Cache load modifier for reading CSR row-offsets
68     CacheLoadModifier               _COLUMN_INDICES_LOAD_MODIFIER,          ///< Cache load modifier for reading CSR column-indices
69     CacheLoadModifier               _VALUES_LOAD_MODIFIER,                  ///< Cache load modifier for reading CSR values
70     CacheLoadModifier               _VECTOR_VALUES_LOAD_MODIFIER,           ///< Cache load modifier for reading vector values
71     bool                            _DIRECT_LOAD_NONZEROS,                  ///< Whether to load nonzeros directly from global during sequential merging (vs. pre-staged through shared memory)
72     BlockScanAlgorithm              _SCAN_ALGORITHM>                        ///< The BlockScan algorithm to use
73 struct AgentSpmvPolicy
74 {
75     enum
76     {
77         BLOCK_THREADS                                                   = _BLOCK_THREADS,                       ///< Threads per thread block
78         ITEMS_PER_THREAD                                                = _ITEMS_PER_THREAD,                    ///< Items per thread (per tile of input)
79         DIRECT_LOAD_NONZEROS                                            = _DIRECT_LOAD_NONZEROS,                ///< Whether to load nonzeros directly from global during sequential merging (pre-staged through shared memory)
80     };
81 
82     static const CacheLoadModifier  ROW_OFFSETS_SEARCH_LOAD_MODIFIER    = _ROW_OFFSETS_SEARCH_LOAD_MODIFIER;    ///< Cache load modifier for reading CSR row-offsets
83     static const CacheLoadModifier  ROW_OFFSETS_LOAD_MODIFIER           = _ROW_OFFSETS_LOAD_MODIFIER;           ///< Cache load modifier for reading CSR row-offsets
84     static const CacheLoadModifier  COLUMN_INDICES_LOAD_MODIFIER        = _COLUMN_INDICES_LOAD_MODIFIER;        ///< Cache load modifier for reading CSR column-indices
85     static const CacheLoadModifier  VALUES_LOAD_MODIFIER                = _VALUES_LOAD_MODIFIER;                ///< Cache load modifier for reading CSR values
86     static const CacheLoadModifier  VECTOR_VALUES_LOAD_MODIFIER         = _VECTOR_VALUES_LOAD_MODIFIER;         ///< Cache load modifier for reading vector values
87     static const BlockScanAlgorithm SCAN_ALGORITHM                      = _SCAN_ALGORITHM;                      ///< The BlockScan algorithm to use
88 
89 };
90 
91 
92 /******************************************************************************
93  * Thread block abstractions
94  ******************************************************************************/
95 
96 template <
97     typename        ValueT,              ///< Matrix and vector value type
98     typename        OffsetT>             ///< Signed integer type for sequence offsets
99 struct SpmvParams
100 {
101     ValueT*         d_values;            ///< Pointer to the array of \p num_nonzeros values of the corresponding nonzero elements of matrix <b>A</b>.
102     OffsetT*        d_row_end_offsets;   ///< Pointer to the array of \p m offsets demarcating the end of every row in \p d_column_indices and \p d_values
103     OffsetT*        d_column_indices;    ///< Pointer to the array of \p num_nonzeros column-indices of the corresponding nonzero elements of matrix <b>A</b>.  (Indices are zero-valued.)
104     ValueT*         d_vector_x;          ///< Pointer to the array of \p num_cols values corresponding to the dense input vector <em>x</em>
105     ValueT*         d_vector_y;          ///< Pointer to the array of \p num_rows values corresponding to the dense output vector <em>y</em>
106     int             num_rows;            ///< Number of rows of matrix <b>A</b>.
107     int             num_cols;            ///< Number of columns of matrix <b>A</b>.
108     int             num_nonzeros;        ///< Number of nonzero elements of matrix <b>A</b>.
109     ValueT          alpha;               ///< Alpha multiplicand
110     ValueT          beta;                ///< Beta addend-multiplicand
111 
112     TexRefInputIterator<ValueT, 66778899, OffsetT>  t_vector_x;
113 };
114 
115 
116 /**
117  * \brief AgentSpmv implements a stateful abstraction of CUDA thread blocks for participating in device-wide SpMV.
118  */
119 template <
120     typename    AgentSpmvPolicyT,           ///< Parameterized AgentSpmvPolicy tuning policy type
121     typename    ValueT,                     ///< Matrix and vector value type
122     typename    OffsetT,                    ///< Signed integer type for sequence offsets
123     bool        HAS_ALPHA,                  ///< Whether the input parameter \p alpha is 1
124     bool        HAS_BETA,                   ///< Whether the input parameter \p beta is 0
125     int         PTX_ARCH = CUB_PTX_ARCH>    ///< PTX compute capability
126 struct AgentSpmv
127 {
128     //---------------------------------------------------------------------
129     // Types and constants
130     //---------------------------------------------------------------------
131 
132     /// Constants
133     enum
134     {
135         BLOCK_THREADS           = AgentSpmvPolicyT::BLOCK_THREADS,
136         ITEMS_PER_THREAD        = AgentSpmvPolicyT::ITEMS_PER_THREAD,
137         TILE_ITEMS              = BLOCK_THREADS * ITEMS_PER_THREAD,
138     };
139 
140     /// 2D merge path coordinate type
141     typedef typename CubVector<OffsetT, 2>::Type CoordinateT;
142 
143     /// Input iterator wrapper types (for applying cache modifiers)
144 
145     typedef CacheModifiedInputIterator<
146             AgentSpmvPolicyT::ROW_OFFSETS_SEARCH_LOAD_MODIFIER,
147             OffsetT,
148             OffsetT>
149         RowOffsetsSearchIteratorT;
150 
151     typedef CacheModifiedInputIterator<
152             AgentSpmvPolicyT::ROW_OFFSETS_LOAD_MODIFIER,
153             OffsetT,
154             OffsetT>
155         RowOffsetsIteratorT;
156 
157     typedef CacheModifiedInputIterator<
158             AgentSpmvPolicyT::COLUMN_INDICES_LOAD_MODIFIER,
159             OffsetT,
160             OffsetT>
161         ColumnIndicesIteratorT;
162 
163     typedef CacheModifiedInputIterator<
164             AgentSpmvPolicyT::VALUES_LOAD_MODIFIER,
165             ValueT,
166             OffsetT>
167         ValueIteratorT;
168 
169     typedef CacheModifiedInputIterator<
170             AgentSpmvPolicyT::VECTOR_VALUES_LOAD_MODIFIER,
171             ValueT,
172             OffsetT>
173         VectorValueIteratorT;
174 
175     // Tuple type for scanning (pairs accumulated segment-value with segment-index)
176     typedef KeyValuePair<OffsetT, ValueT> KeyValuePairT;
177 
178     // Reduce-value-by-segment scan operator
179     typedef ReduceByKeyOp<cub::Sum> ReduceBySegmentOpT;
180 
181     // BlockReduce specialization
182     typedef BlockReduce<
183             ValueT,
184             BLOCK_THREADS,
185             BLOCK_REDUCE_WARP_REDUCTIONS>
186         BlockReduceT;
187 
188     // BlockScan specialization
189     typedef BlockScan<
190             KeyValuePairT,
191             BLOCK_THREADS,
192             AgentSpmvPolicyT::SCAN_ALGORITHM>
193         BlockScanT;
194 
195     // BlockScan specialization
196     typedef BlockScan<
197             ValueT,
198             BLOCK_THREADS,
199             AgentSpmvPolicyT::SCAN_ALGORITHM>
200         BlockPrefixSumT;
201 
202     // BlockExchange specialization
203     typedef BlockExchange<
204             ValueT,
205             BLOCK_THREADS,
206             ITEMS_PER_THREAD>
207         BlockExchangeT;
208 
209     /// Merge item type (either a non-zero value or a row-end offset)
210     union MergeItem
211     {
212         // Value type to pair with index type OffsetT (NullType if loading values directly during merge)
213         typedef typename If<AgentSpmvPolicyT::DIRECT_LOAD_NONZEROS, NullType, ValueT>::Type MergeValueT;
214 
215         OffsetT     row_end_offset;
216         MergeValueT nonzero;
217     };
218 
219     /// Shared memory type required by this thread block
220     struct _TempStorage
221     {
222         CoordinateT tile_coords[2];
223 
224         union Aliasable
225         {
226             // Smem needed for tile of merge items
227             MergeItem merge_items[ITEMS_PER_THREAD + TILE_ITEMS + 1];
228 
229             // Smem needed for block exchange
230             typename BlockExchangeT::TempStorage exchange;
231 
232             // Smem needed for block-wide reduction
233             typename BlockReduceT::TempStorage reduce;
234 
235             // Smem needed for tile scanning
236             typename BlockScanT::TempStorage scan;
237 
238             // Smem needed for tile prefix sum
239             typename BlockPrefixSumT::TempStorage prefix_sum;
240 
241         } aliasable;
242     };
243 
244     /// Temporary storage type (unionable)
245     struct TempStorage : Uninitialized<_TempStorage> {};
246 
247 
248     //---------------------------------------------------------------------
249     // Per-thread fields
250     //---------------------------------------------------------------------
251 
252 
253     _TempStorage&                   temp_storage;         /// Reference to temp_storage
254 
255     SpmvParams<ValueT, OffsetT>&    spmv_params;
256 
257     ValueIteratorT                  wd_values;            ///< Wrapped pointer to the array of \p num_nonzeros values of the corresponding nonzero elements of matrix <b>A</b>.
258     RowOffsetsIteratorT             wd_row_end_offsets;   ///< Wrapped Pointer to the array of \p m offsets demarcating the end of every row in \p d_column_indices and \p d_values
259     ColumnIndicesIteratorT          wd_column_indices;    ///< Wrapped Pointer to the array of \p num_nonzeros column-indices of the corresponding nonzero elements of matrix <b>A</b>.  (Indices are zero-valued.)
260     VectorValueIteratorT            wd_vector_x;          ///< Wrapped Pointer to the array of \p num_cols values corresponding to the dense input vector <em>x</em>
261     VectorValueIteratorT            wd_vector_y;          ///< Wrapped Pointer to the array of \p num_cols values corresponding to the dense input vector <em>x</em>
262 
263 
264     //---------------------------------------------------------------------
265     // Interface
266     //---------------------------------------------------------------------
267 
268     /**
269      * Constructor
270      */
AgentSpmvcub::AgentSpmv271     __device__ __forceinline__ AgentSpmv(
272         TempStorage&                    temp_storage,           ///< Reference to temp_storage
273         SpmvParams<ValueT, OffsetT>&    spmv_params)            ///< SpMV input parameter bundle
274     :
275         temp_storage(temp_storage.Alias()),
276         spmv_params(spmv_params),
277         wd_values(spmv_params.d_values),
278         wd_row_end_offsets(spmv_params.d_row_end_offsets),
279         wd_column_indices(spmv_params.d_column_indices),
280         wd_vector_x(spmv_params.d_vector_x),
281         wd_vector_y(spmv_params.d_vector_y)
282     {}
283 
284 
285 
286 
287     /**
288      * Consume a merge tile, specialized for direct-load of nonzeros
289      */
ConsumeTilecub::AgentSpmv290     __device__ __forceinline__ KeyValuePairT ConsumeTile(
291         int             tile_idx,
292         CoordinateT     tile_start_coord,
293         CoordinateT     tile_end_coord,
294         Int2Type<true>  is_direct_load)     ///< Marker type indicating whether to load nonzeros directly during path-discovery or beforehand in batch
295     {
296         int         tile_num_rows           = tile_end_coord.x - tile_start_coord.x;
297         int         tile_num_nonzeros       = tile_end_coord.y - tile_start_coord.y;
298         OffsetT*    s_tile_row_end_offsets  = &temp_storage.aliasable.merge_items[0].row_end_offset;
299 
300         // Gather the row end-offsets for the merge tile into shared memory
301         for (int item = threadIdx.x; item <= tile_num_rows; item += BLOCK_THREADS)
302         {
303             s_tile_row_end_offsets[item] = wd_row_end_offsets[tile_start_coord.x + item];
304         }
305 
306         CTA_SYNC();
307 
308         // Search for the thread's starting coordinate within the merge tile
309         CountingInputIterator<OffsetT>  tile_nonzero_indices(tile_start_coord.y);
310         CoordinateT                     thread_start_coord;
311 
312         MergePathSearch(
313             OffsetT(threadIdx.x * ITEMS_PER_THREAD),    // Diagonal
314             s_tile_row_end_offsets,                     // List A
315             tile_nonzero_indices,                       // List B
316             tile_num_rows,
317             tile_num_nonzeros,
318             thread_start_coord);
319 
320         CTA_SYNC();            // Perf-sync
321 
322         // Compute the thread's merge path segment
323         CoordinateT     thread_current_coord = thread_start_coord;
324         KeyValuePairT   scan_segment[ITEMS_PER_THREAD];
325 
326         ValueT          running_total = 0.0;
327 
328         #pragma unroll
329         for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
330         {
331             OffsetT nonzero_idx         = CUB_MIN(tile_nonzero_indices[thread_current_coord.y], spmv_params.num_nonzeros - 1);
332             OffsetT column_idx          = wd_column_indices[nonzero_idx];
333             ValueT  value               = wd_values[nonzero_idx];
334 
335             ValueT  vector_value        = spmv_params.t_vector_x[column_idx];
336 #if (CUB_PTX_ARCH >= 350)
337             vector_value                = wd_vector_x[column_idx];
338 #endif
339             ValueT  nonzero             = value * vector_value;
340 
341             OffsetT row_end_offset      = s_tile_row_end_offsets[thread_current_coord.x];
342 
343             if (tile_nonzero_indices[thread_current_coord.y] < row_end_offset)
344             {
345                 // Move down (accumulate)
346                 running_total += nonzero;
347                 scan_segment[ITEM].value    = running_total;
348                 scan_segment[ITEM].key      = tile_num_rows;
349                 ++thread_current_coord.y;
350             }
351             else
352             {
353                 // Move right (reset)
354                 scan_segment[ITEM].value    = running_total;
355                 scan_segment[ITEM].key      = thread_current_coord.x;
356                 running_total               = 0.0;
357                 ++thread_current_coord.x;
358             }
359         }
360 
361         CTA_SYNC();
362 
363         // Block-wide reduce-value-by-segment
364         KeyValuePairT       tile_carry;
365         ReduceBySegmentOpT  scan_op;
366         KeyValuePairT       scan_item;
367 
368         scan_item.value = running_total;
369         scan_item.key   = thread_current_coord.x;
370 
371         BlockScanT(temp_storage.aliasable.scan).ExclusiveScan(scan_item, scan_item, scan_op, tile_carry);
372 
373         if (tile_num_rows > 0)
374         {
375             if (threadIdx.x == 0)
376                 scan_item.key = -1;
377 
378             // Direct scatter
379             #pragma unroll
380             for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
381             {
382                 if (scan_segment[ITEM].key < tile_num_rows)
383                 {
384                     if (scan_item.key == scan_segment[ITEM].key)
385                         scan_segment[ITEM].value = scan_item.value + scan_segment[ITEM].value;
386 
387                     if (HAS_ALPHA)
388                     {
389                         scan_segment[ITEM].value *= spmv_params.alpha;
390                     }
391 
392                     if (HAS_BETA)
393                     {
394                         // Update the output vector element
395                         ValueT addend = spmv_params.beta * wd_vector_y[tile_start_coord.x + scan_segment[ITEM].key];
396                         scan_segment[ITEM].value += addend;
397                     }
398 
399                     // Set the output vector element
400                     spmv_params.d_vector_y[tile_start_coord.x + scan_segment[ITEM].key] = scan_segment[ITEM].value;
401                 }
402             }
403         }
404 
405         // Return the tile's running carry-out
406         return tile_carry;
407     }
408 
409 
410 
411     /**
412      * Consume a merge tile, specialized for indirect load of nonzeros
413      */
ConsumeTilecub::AgentSpmv414     __device__ __forceinline__ KeyValuePairT ConsumeTile(
415         int             tile_idx,
416         CoordinateT     tile_start_coord,
417         CoordinateT     tile_end_coord,
418         Int2Type<false> is_direct_load)     ///< Marker type indicating whether to load nonzeros directly during path-discovery or beforehand in batch
419     {
420         int         tile_num_rows           = tile_end_coord.x - tile_start_coord.x;
421         int         tile_num_nonzeros       = tile_end_coord.y - tile_start_coord.y;
422 
423 #if (CUB_PTX_ARCH >= 520)
424 
425         OffsetT*    s_tile_row_end_offsets  = &temp_storage.aliasable.merge_items[0].row_end_offset;
426         ValueT*     s_tile_nonzeros         = &temp_storage.aliasable.merge_items[tile_num_rows + ITEMS_PER_THREAD].nonzero;
427 
428         // Gather the nonzeros for the merge tile into shared memory
429         #pragma unroll
430         for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
431         {
432             int nonzero_idx = threadIdx.x + (ITEM * BLOCK_THREADS);
433 
434             ValueIteratorT a                = wd_values + tile_start_coord.y + nonzero_idx;
435             ColumnIndicesIteratorT ci       = wd_column_indices + tile_start_coord.y + nonzero_idx;
436             ValueT* s                       = s_tile_nonzeros + nonzero_idx;
437 
438             if (nonzero_idx < tile_num_nonzeros)
439             {
440 
441                 OffsetT column_idx              = *ci;
442                 ValueT  value                   = *a;
443 
444                 ValueT  vector_value            = spmv_params.t_vector_x[column_idx];
445                 vector_value                    = wd_vector_x[column_idx];
446 
447                 ValueT  nonzero                 = value * vector_value;
448 
449                 *s    = nonzero;
450             }
451         }
452 
453 
454 #else
455 
456         OffsetT*    s_tile_row_end_offsets  = &temp_storage.aliasable.merge_items[0].row_end_offset;
457         ValueT*     s_tile_nonzeros         = &temp_storage.aliasable.merge_items[tile_num_rows + ITEMS_PER_THREAD].nonzero;
458 
459         // Gather the nonzeros for the merge tile into shared memory
460         if (tile_num_nonzeros > 0)
461         {
462             #pragma unroll
463             for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
464             {
465                 int     nonzero_idx             = threadIdx.x + (ITEM * BLOCK_THREADS);
466                 nonzero_idx                     = CUB_MIN(nonzero_idx, tile_num_nonzeros - 1);
467 
468                 OffsetT column_idx              = wd_column_indices[tile_start_coord.y + nonzero_idx];
469                 ValueT  value                   = wd_values[tile_start_coord.y + nonzero_idx];
470 
471                 ValueT  vector_value            = spmv_params.t_vector_x[column_idx];
472 #if (CUB_PTX_ARCH >= 350)
473                 vector_value                    = wd_vector_x[column_idx];
474 #endif
475                 ValueT  nonzero                 = value * vector_value;
476 
477                 s_tile_nonzeros[nonzero_idx]    = nonzero;
478             }
479         }
480 
481 #endif
482 
483         // Gather the row end-offsets for the merge tile into shared memory
484         #pragma unroll 1
485         for (int item = threadIdx.x; item <= tile_num_rows; item += BLOCK_THREADS)
486         {
487             s_tile_row_end_offsets[item] = wd_row_end_offsets[tile_start_coord.x + item];
488         }
489 
490         CTA_SYNC();
491 
492         // Search for the thread's starting coordinate within the merge tile
493         CountingInputIterator<OffsetT>  tile_nonzero_indices(tile_start_coord.y);
494         CoordinateT                     thread_start_coord;
495 
496         MergePathSearch(
497             OffsetT(threadIdx.x * ITEMS_PER_THREAD),    // Diagonal
498             s_tile_row_end_offsets,                     // List A
499             tile_nonzero_indices,                       // List B
500             tile_num_rows,
501             tile_num_nonzeros,
502             thread_start_coord);
503 
504         CTA_SYNC();            // Perf-sync
505 
506         // Compute the thread's merge path segment
507         CoordinateT     thread_current_coord = thread_start_coord;
508         KeyValuePairT   scan_segment[ITEMS_PER_THREAD];
509         ValueT          running_total = 0.0;
510 
511         OffsetT row_end_offset  = s_tile_row_end_offsets[thread_current_coord.x];
512         ValueT  nonzero         = s_tile_nonzeros[thread_current_coord.y];
513 
514         #pragma unroll
515         for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
516         {
517             if (tile_nonzero_indices[thread_current_coord.y] < row_end_offset)
518             {
519                 // Move down (accumulate)
520                 scan_segment[ITEM].value    = nonzero;
521                 running_total               += nonzero;
522                 ++thread_current_coord.y;
523                 nonzero                     = s_tile_nonzeros[thread_current_coord.y];
524             }
525             else
526             {
527                 // Move right (reset)
528                 scan_segment[ITEM].value    = 0.0;
529                 running_total               = 0.0;
530                 ++thread_current_coord.x;
531                 row_end_offset              = s_tile_row_end_offsets[thread_current_coord.x];
532             }
533 
534             scan_segment[ITEM].key = thread_current_coord.x;
535         }
536 
537         CTA_SYNC();
538 
539         // Block-wide reduce-value-by-segment
540         KeyValuePairT       tile_carry;
541         ReduceBySegmentOpT  scan_op;
542         KeyValuePairT       scan_item;
543 
544         scan_item.value = running_total;
545         scan_item.key = thread_current_coord.x;
546 
547         BlockScanT(temp_storage.aliasable.scan).ExclusiveScan(scan_item, scan_item, scan_op, tile_carry);
548 
549         if (threadIdx.x == 0)
550         {
551             scan_item.key = thread_start_coord.x;
552             scan_item.value = 0.0;
553         }
554 
555         if (tile_num_rows > 0)
556         {
557 
558             CTA_SYNC();
559 
560             // Scan downsweep and scatter
561             ValueT* s_partials = &temp_storage.aliasable.merge_items[0].nonzero;
562 
563             if (scan_item.key != scan_segment[0].key)
564             {
565                 s_partials[scan_item.key] = scan_item.value;
566             }
567             else
568             {
569                 scan_segment[0].value += scan_item.value;
570             }
571 
572             #pragma unroll
573             for (int ITEM = 1; ITEM < ITEMS_PER_THREAD; ++ITEM)
574             {
575                 if (scan_segment[ITEM - 1].key != scan_segment[ITEM].key)
576                 {
577                     s_partials[scan_segment[ITEM - 1].key] = scan_segment[ITEM - 1].value;
578                 }
579                 else
580                 {
581                     scan_segment[ITEM].value += scan_segment[ITEM - 1].value;
582                 }
583             }
584 
585             CTA_SYNC();
586 
587             #pragma unroll 1
588             for (int item = threadIdx.x; item < tile_num_rows; item += BLOCK_THREADS)
589             {
590                 spmv_params.d_vector_y[tile_start_coord.x + item] = s_partials[item];
591             }
592         }
593 
594         // Return the tile's running carry-out
595         return tile_carry;
596     }
597 
598 
599     /**
600      * Consume input tile
601      */
ConsumeTilecub::AgentSpmv602     __device__ __forceinline__ void ConsumeTile(
603         CoordinateT*    d_tile_coordinates,     ///< [in] Pointer to the temporary array of tile starting coordinates
604         KeyValuePairT*  d_tile_carry_pairs,     ///< [out] Pointer to the temporary array carry-out dot product row-ids, one per block
605         int             num_merge_tiles)        ///< [in] Number of merge tiles
606     {
607         int tile_idx = (blockIdx.x * gridDim.y) + blockIdx.y;    // Current tile index
608 
609         if (tile_idx >= num_merge_tiles)
610             return;
611 
612         // Read our starting coordinates
613         if (threadIdx.x < 2)
614         {
615             if (d_tile_coordinates == NULL)
616             {
617                 // Search our starting coordinates
618                 OffsetT                         diagonal = (tile_idx + threadIdx.x) * TILE_ITEMS;
619                 CoordinateT                     tile_coord;
620                 CountingInputIterator<OffsetT>  nonzero_indices(0);
621 
622                 // Search the merge path
623                 MergePathSearch(
624                     diagonal,
625                     RowOffsetsSearchIteratorT(spmv_params.d_row_end_offsets),
626                     nonzero_indices,
627                     spmv_params.num_rows,
628                     spmv_params.num_nonzeros,
629                     tile_coord);
630 
631                 temp_storage.tile_coords[threadIdx.x] = tile_coord;
632             }
633             else
634             {
635                 temp_storage.tile_coords[threadIdx.x] = d_tile_coordinates[tile_idx + threadIdx.x];
636             }
637         }
638 
639         CTA_SYNC();
640 
641         CoordinateT tile_start_coord     = temp_storage.tile_coords[0];
642         CoordinateT tile_end_coord       = temp_storage.tile_coords[1];
643 
644         // Consume multi-segment tile
645         KeyValuePairT tile_carry = ConsumeTile(
646             tile_idx,
647             tile_start_coord,
648             tile_end_coord,
649             Int2Type<AgentSpmvPolicyT::DIRECT_LOAD_NONZEROS>());
650 
651         // Output the tile's carry-out
652         if (threadIdx.x == 0)
653         {
654             if (HAS_ALPHA)
655                 tile_carry.value *= spmv_params.alpha;
656 
657             tile_carry.key += tile_start_coord.x;
658             d_tile_carry_pairs[tile_idx]    = tile_carry;
659         }
660     }
661 
662 
663 };
664 
665 
666 
667 
668 }               // CUB namespace
669 CUB_NS_POSTFIX  // Optional outer namespace(s)
670 
671