1 /****************************************************************************** 2 * Copyright (c) 2011, Duane Merrill. All rights reserved. 3 * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions are met: 7 * * Redistributions of source code must retain the above copyright 8 * notice, this list of conditions and the following disclaimer. 9 * * Redistributions in binary form must reproduce the above copyright 10 * notice, this list of conditions and the following disclaimer in the 11 * documentation and/or other materials provided with the distribution. 12 * * Neither the name of the NVIDIA CORPORATION nor the 13 * names of its contributors may be used to endorse or promote products 14 * derived from this software without specific prior written permission. 15 * 16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 17 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 18 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 20 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 21 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 22 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 23 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 * 27 ******************************************************************************/ 28 29 /** 30 * \file 31 * cub::AgentSpmv implements a stateful abstraction of CUDA thread blocks for participating in device-wide SpMV. 32 */ 33 34 #pragma once 35 36 #include <iterator> 37 38 #include "../util_type.cuh" 39 #include "../block/block_reduce.cuh" 40 #include "../block/block_scan.cuh" 41 #include "../block/block_exchange.cuh" 42 #include "../thread/thread_search.cuh" 43 #include "../thread/thread_operators.cuh" 44 #include "../iterator/cache_modified_input_iterator.cuh" 45 #include "../iterator/counting_input_iterator.cuh" 46 #include "../iterator/tex_ref_input_iterator.cuh" 47 #include "../util_namespace.cuh" 48 49 /// Optional outer namespace(s) 50 CUB_NS_PREFIX 51 52 /// CUB namespace 53 namespace cub { 54 55 56 /****************************************************************************** 57 * Tuning policy 58 ******************************************************************************/ 59 60 /** 61 * Parameterizable tuning policy type for AgentSpmv 62 */ 63 template < 64 int _BLOCK_THREADS, ///< Threads per thread block 65 int _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) 66 CacheLoadModifier _ROW_OFFSETS_SEARCH_LOAD_MODIFIER, ///< Cache load modifier for reading CSR row-offsets during search 67 CacheLoadModifier _ROW_OFFSETS_LOAD_MODIFIER, ///< Cache load modifier for reading CSR row-offsets 68 CacheLoadModifier _COLUMN_INDICES_LOAD_MODIFIER, ///< Cache load modifier for reading CSR column-indices 69 CacheLoadModifier _VALUES_LOAD_MODIFIER, ///< Cache load modifier for reading CSR values 70 CacheLoadModifier _VECTOR_VALUES_LOAD_MODIFIER, ///< Cache load modifier for reading vector values 71 bool _DIRECT_LOAD_NONZEROS, ///< Whether to load nonzeros directly from global during sequential merging (vs. pre-staged through shared memory) 72 BlockScanAlgorithm _SCAN_ALGORITHM> ///< The BlockScan algorithm to use 73 struct AgentSpmvPolicy 74 { 75 enum 76 { 77 BLOCK_THREADS = _BLOCK_THREADS, ///< Threads per thread block 78 ITEMS_PER_THREAD = _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) 79 DIRECT_LOAD_NONZEROS = _DIRECT_LOAD_NONZEROS, ///< Whether to load nonzeros directly from global during sequential merging (pre-staged through shared memory) 80 }; 81 82 static const CacheLoadModifier ROW_OFFSETS_SEARCH_LOAD_MODIFIER = _ROW_OFFSETS_SEARCH_LOAD_MODIFIER; ///< Cache load modifier for reading CSR row-offsets 83 static const CacheLoadModifier ROW_OFFSETS_LOAD_MODIFIER = _ROW_OFFSETS_LOAD_MODIFIER; ///< Cache load modifier for reading CSR row-offsets 84 static const CacheLoadModifier COLUMN_INDICES_LOAD_MODIFIER = _COLUMN_INDICES_LOAD_MODIFIER; ///< Cache load modifier for reading CSR column-indices 85 static const CacheLoadModifier VALUES_LOAD_MODIFIER = _VALUES_LOAD_MODIFIER; ///< Cache load modifier for reading CSR values 86 static const CacheLoadModifier VECTOR_VALUES_LOAD_MODIFIER = _VECTOR_VALUES_LOAD_MODIFIER; ///< Cache load modifier for reading vector values 87 static const BlockScanAlgorithm SCAN_ALGORITHM = _SCAN_ALGORITHM; ///< The BlockScan algorithm to use 88 89 }; 90 91 92 /****************************************************************************** 93 * Thread block abstractions 94 ******************************************************************************/ 95 96 template < 97 typename ValueT, ///< Matrix and vector value type 98 typename OffsetT> ///< Signed integer type for sequence offsets 99 struct SpmvParams 100 { 101 ValueT* d_values; ///< Pointer to the array of \p num_nonzeros values of the corresponding nonzero elements of matrix <b>A</b>. 102 OffsetT* d_row_end_offsets; ///< Pointer to the array of \p m offsets demarcating the end of every row in \p d_column_indices and \p d_values 103 OffsetT* d_column_indices; ///< Pointer to the array of \p num_nonzeros column-indices of the corresponding nonzero elements of matrix <b>A</b>. (Indices are zero-valued.) 104 ValueT* d_vector_x; ///< Pointer to the array of \p num_cols values corresponding to the dense input vector <em>x</em> 105 ValueT* d_vector_y; ///< Pointer to the array of \p num_rows values corresponding to the dense output vector <em>y</em> 106 int num_rows; ///< Number of rows of matrix <b>A</b>. 107 int num_cols; ///< Number of columns of matrix <b>A</b>. 108 int num_nonzeros; ///< Number of nonzero elements of matrix <b>A</b>. 109 ValueT alpha; ///< Alpha multiplicand 110 ValueT beta; ///< Beta addend-multiplicand 111 112 TexRefInputIterator<ValueT, 66778899, OffsetT> t_vector_x; 113 }; 114 115 116 /** 117 * \brief AgentSpmv implements a stateful abstraction of CUDA thread blocks for participating in device-wide SpMV. 118 */ 119 template < 120 typename AgentSpmvPolicyT, ///< Parameterized AgentSpmvPolicy tuning policy type 121 typename ValueT, ///< Matrix and vector value type 122 typename OffsetT, ///< Signed integer type for sequence offsets 123 bool HAS_ALPHA, ///< Whether the input parameter \p alpha is 1 124 bool HAS_BETA, ///< Whether the input parameter \p beta is 0 125 int PTX_ARCH = CUB_PTX_ARCH> ///< PTX compute capability 126 struct AgentSpmv 127 { 128 //--------------------------------------------------------------------- 129 // Types and constants 130 //--------------------------------------------------------------------- 131 132 /// Constants 133 enum 134 { 135 BLOCK_THREADS = AgentSpmvPolicyT::BLOCK_THREADS, 136 ITEMS_PER_THREAD = AgentSpmvPolicyT::ITEMS_PER_THREAD, 137 TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD, 138 }; 139 140 /// 2D merge path coordinate type 141 typedef typename CubVector<OffsetT, 2>::Type CoordinateT; 142 143 /// Input iterator wrapper types (for applying cache modifiers) 144 145 typedef CacheModifiedInputIterator< 146 AgentSpmvPolicyT::ROW_OFFSETS_SEARCH_LOAD_MODIFIER, 147 OffsetT, 148 OffsetT> 149 RowOffsetsSearchIteratorT; 150 151 typedef CacheModifiedInputIterator< 152 AgentSpmvPolicyT::ROW_OFFSETS_LOAD_MODIFIER, 153 OffsetT, 154 OffsetT> 155 RowOffsetsIteratorT; 156 157 typedef CacheModifiedInputIterator< 158 AgentSpmvPolicyT::COLUMN_INDICES_LOAD_MODIFIER, 159 OffsetT, 160 OffsetT> 161 ColumnIndicesIteratorT; 162 163 typedef CacheModifiedInputIterator< 164 AgentSpmvPolicyT::VALUES_LOAD_MODIFIER, 165 ValueT, 166 OffsetT> 167 ValueIteratorT; 168 169 typedef CacheModifiedInputIterator< 170 AgentSpmvPolicyT::VECTOR_VALUES_LOAD_MODIFIER, 171 ValueT, 172 OffsetT> 173 VectorValueIteratorT; 174 175 // Tuple type for scanning (pairs accumulated segment-value with segment-index) 176 typedef KeyValuePair<OffsetT, ValueT> KeyValuePairT; 177 178 // Reduce-value-by-segment scan operator 179 typedef ReduceByKeyOp<cub::Sum> ReduceBySegmentOpT; 180 181 // BlockReduce specialization 182 typedef BlockReduce< 183 ValueT, 184 BLOCK_THREADS, 185 BLOCK_REDUCE_WARP_REDUCTIONS> 186 BlockReduceT; 187 188 // BlockScan specialization 189 typedef BlockScan< 190 KeyValuePairT, 191 BLOCK_THREADS, 192 AgentSpmvPolicyT::SCAN_ALGORITHM> 193 BlockScanT; 194 195 // BlockScan specialization 196 typedef BlockScan< 197 ValueT, 198 BLOCK_THREADS, 199 AgentSpmvPolicyT::SCAN_ALGORITHM> 200 BlockPrefixSumT; 201 202 // BlockExchange specialization 203 typedef BlockExchange< 204 ValueT, 205 BLOCK_THREADS, 206 ITEMS_PER_THREAD> 207 BlockExchangeT; 208 209 /// Merge item type (either a non-zero value or a row-end offset) 210 union MergeItem 211 { 212 // Value type to pair with index type OffsetT (NullType if loading values directly during merge) 213 typedef typename If<AgentSpmvPolicyT::DIRECT_LOAD_NONZEROS, NullType, ValueT>::Type MergeValueT; 214 215 OffsetT row_end_offset; 216 MergeValueT nonzero; 217 }; 218 219 /// Shared memory type required by this thread block 220 struct _TempStorage 221 { 222 CoordinateT tile_coords[2]; 223 224 union Aliasable 225 { 226 // Smem needed for tile of merge items 227 MergeItem merge_items[ITEMS_PER_THREAD + TILE_ITEMS + 1]; 228 229 // Smem needed for block exchange 230 typename BlockExchangeT::TempStorage exchange; 231 232 // Smem needed for block-wide reduction 233 typename BlockReduceT::TempStorage reduce; 234 235 // Smem needed for tile scanning 236 typename BlockScanT::TempStorage scan; 237 238 // Smem needed for tile prefix sum 239 typename BlockPrefixSumT::TempStorage prefix_sum; 240 241 } aliasable; 242 }; 243 244 /// Temporary storage type (unionable) 245 struct TempStorage : Uninitialized<_TempStorage> {}; 246 247 248 //--------------------------------------------------------------------- 249 // Per-thread fields 250 //--------------------------------------------------------------------- 251 252 253 _TempStorage& temp_storage; /// Reference to temp_storage 254 255 SpmvParams<ValueT, OffsetT>& spmv_params; 256 257 ValueIteratorT wd_values; ///< Wrapped pointer to the array of \p num_nonzeros values of the corresponding nonzero elements of matrix <b>A</b>. 258 RowOffsetsIteratorT wd_row_end_offsets; ///< Wrapped Pointer to the array of \p m offsets demarcating the end of every row in \p d_column_indices and \p d_values 259 ColumnIndicesIteratorT wd_column_indices; ///< Wrapped Pointer to the array of \p num_nonzeros column-indices of the corresponding nonzero elements of matrix <b>A</b>. (Indices are zero-valued.) 260 VectorValueIteratorT wd_vector_x; ///< Wrapped Pointer to the array of \p num_cols values corresponding to the dense input vector <em>x</em> 261 VectorValueIteratorT wd_vector_y; ///< Wrapped Pointer to the array of \p num_cols values corresponding to the dense input vector <em>x</em> 262 263 264 //--------------------------------------------------------------------- 265 // Interface 266 //--------------------------------------------------------------------- 267 268 /** 269 * Constructor 270 */ AgentSpmvcub::AgentSpmv271 __device__ __forceinline__ AgentSpmv( 272 TempStorage& temp_storage, ///< Reference to temp_storage 273 SpmvParams<ValueT, OffsetT>& spmv_params) ///< SpMV input parameter bundle 274 : 275 temp_storage(temp_storage.Alias()), 276 spmv_params(spmv_params), 277 wd_values(spmv_params.d_values), 278 wd_row_end_offsets(spmv_params.d_row_end_offsets), 279 wd_column_indices(spmv_params.d_column_indices), 280 wd_vector_x(spmv_params.d_vector_x), 281 wd_vector_y(spmv_params.d_vector_y) 282 {} 283 284 285 286 287 /** 288 * Consume a merge tile, specialized for direct-load of nonzeros 289 */ ConsumeTilecub::AgentSpmv290 __device__ __forceinline__ KeyValuePairT ConsumeTile( 291 int tile_idx, 292 CoordinateT tile_start_coord, 293 CoordinateT tile_end_coord, 294 Int2Type<true> is_direct_load) ///< Marker type indicating whether to load nonzeros directly during path-discovery or beforehand in batch 295 { 296 int tile_num_rows = tile_end_coord.x - tile_start_coord.x; 297 int tile_num_nonzeros = tile_end_coord.y - tile_start_coord.y; 298 OffsetT* s_tile_row_end_offsets = &temp_storage.aliasable.merge_items[0].row_end_offset; 299 300 // Gather the row end-offsets for the merge tile into shared memory 301 for (int item = threadIdx.x; item <= tile_num_rows; item += BLOCK_THREADS) 302 { 303 s_tile_row_end_offsets[item] = wd_row_end_offsets[tile_start_coord.x + item]; 304 } 305 306 CTA_SYNC(); 307 308 // Search for the thread's starting coordinate within the merge tile 309 CountingInputIterator<OffsetT> tile_nonzero_indices(tile_start_coord.y); 310 CoordinateT thread_start_coord; 311 312 MergePathSearch( 313 OffsetT(threadIdx.x * ITEMS_PER_THREAD), // Diagonal 314 s_tile_row_end_offsets, // List A 315 tile_nonzero_indices, // List B 316 tile_num_rows, 317 tile_num_nonzeros, 318 thread_start_coord); 319 320 CTA_SYNC(); // Perf-sync 321 322 // Compute the thread's merge path segment 323 CoordinateT thread_current_coord = thread_start_coord; 324 KeyValuePairT scan_segment[ITEMS_PER_THREAD]; 325 326 ValueT running_total = 0.0; 327 328 #pragma unroll 329 for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) 330 { 331 OffsetT nonzero_idx = CUB_MIN(tile_nonzero_indices[thread_current_coord.y], spmv_params.num_nonzeros - 1); 332 OffsetT column_idx = wd_column_indices[nonzero_idx]; 333 ValueT value = wd_values[nonzero_idx]; 334 335 ValueT vector_value = spmv_params.t_vector_x[column_idx]; 336 #if (CUB_PTX_ARCH >= 350) 337 vector_value = wd_vector_x[column_idx]; 338 #endif 339 ValueT nonzero = value * vector_value; 340 341 OffsetT row_end_offset = s_tile_row_end_offsets[thread_current_coord.x]; 342 343 if (tile_nonzero_indices[thread_current_coord.y] < row_end_offset) 344 { 345 // Move down (accumulate) 346 running_total += nonzero; 347 scan_segment[ITEM].value = running_total; 348 scan_segment[ITEM].key = tile_num_rows; 349 ++thread_current_coord.y; 350 } 351 else 352 { 353 // Move right (reset) 354 scan_segment[ITEM].value = running_total; 355 scan_segment[ITEM].key = thread_current_coord.x; 356 running_total = 0.0; 357 ++thread_current_coord.x; 358 } 359 } 360 361 CTA_SYNC(); 362 363 // Block-wide reduce-value-by-segment 364 KeyValuePairT tile_carry; 365 ReduceBySegmentOpT scan_op; 366 KeyValuePairT scan_item; 367 368 scan_item.value = running_total; 369 scan_item.key = thread_current_coord.x; 370 371 BlockScanT(temp_storage.aliasable.scan).ExclusiveScan(scan_item, scan_item, scan_op, tile_carry); 372 373 if (tile_num_rows > 0) 374 { 375 if (threadIdx.x == 0) 376 scan_item.key = -1; 377 378 // Direct scatter 379 #pragma unroll 380 for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) 381 { 382 if (scan_segment[ITEM].key < tile_num_rows) 383 { 384 if (scan_item.key == scan_segment[ITEM].key) 385 scan_segment[ITEM].value = scan_item.value + scan_segment[ITEM].value; 386 387 if (HAS_ALPHA) 388 { 389 scan_segment[ITEM].value *= spmv_params.alpha; 390 } 391 392 if (HAS_BETA) 393 { 394 // Update the output vector element 395 ValueT addend = spmv_params.beta * wd_vector_y[tile_start_coord.x + scan_segment[ITEM].key]; 396 scan_segment[ITEM].value += addend; 397 } 398 399 // Set the output vector element 400 spmv_params.d_vector_y[tile_start_coord.x + scan_segment[ITEM].key] = scan_segment[ITEM].value; 401 } 402 } 403 } 404 405 // Return the tile's running carry-out 406 return tile_carry; 407 } 408 409 410 411 /** 412 * Consume a merge tile, specialized for indirect load of nonzeros 413 */ ConsumeTilecub::AgentSpmv414 __device__ __forceinline__ KeyValuePairT ConsumeTile( 415 int tile_idx, 416 CoordinateT tile_start_coord, 417 CoordinateT tile_end_coord, 418 Int2Type<false> is_direct_load) ///< Marker type indicating whether to load nonzeros directly during path-discovery or beforehand in batch 419 { 420 int tile_num_rows = tile_end_coord.x - tile_start_coord.x; 421 int tile_num_nonzeros = tile_end_coord.y - tile_start_coord.y; 422 423 #if (CUB_PTX_ARCH >= 520) 424 425 OffsetT* s_tile_row_end_offsets = &temp_storage.aliasable.merge_items[0].row_end_offset; 426 ValueT* s_tile_nonzeros = &temp_storage.aliasable.merge_items[tile_num_rows + ITEMS_PER_THREAD].nonzero; 427 428 // Gather the nonzeros for the merge tile into shared memory 429 #pragma unroll 430 for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) 431 { 432 int nonzero_idx = threadIdx.x + (ITEM * BLOCK_THREADS); 433 434 ValueIteratorT a = wd_values + tile_start_coord.y + nonzero_idx; 435 ColumnIndicesIteratorT ci = wd_column_indices + tile_start_coord.y + nonzero_idx; 436 ValueT* s = s_tile_nonzeros + nonzero_idx; 437 438 if (nonzero_idx < tile_num_nonzeros) 439 { 440 441 OffsetT column_idx = *ci; 442 ValueT value = *a; 443 444 ValueT vector_value = spmv_params.t_vector_x[column_idx]; 445 vector_value = wd_vector_x[column_idx]; 446 447 ValueT nonzero = value * vector_value; 448 449 *s = nonzero; 450 } 451 } 452 453 454 #else 455 456 OffsetT* s_tile_row_end_offsets = &temp_storage.aliasable.merge_items[0].row_end_offset; 457 ValueT* s_tile_nonzeros = &temp_storage.aliasable.merge_items[tile_num_rows + ITEMS_PER_THREAD].nonzero; 458 459 // Gather the nonzeros for the merge tile into shared memory 460 if (tile_num_nonzeros > 0) 461 { 462 #pragma unroll 463 for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) 464 { 465 int nonzero_idx = threadIdx.x + (ITEM * BLOCK_THREADS); 466 nonzero_idx = CUB_MIN(nonzero_idx, tile_num_nonzeros - 1); 467 468 OffsetT column_idx = wd_column_indices[tile_start_coord.y + nonzero_idx]; 469 ValueT value = wd_values[tile_start_coord.y + nonzero_idx]; 470 471 ValueT vector_value = spmv_params.t_vector_x[column_idx]; 472 #if (CUB_PTX_ARCH >= 350) 473 vector_value = wd_vector_x[column_idx]; 474 #endif 475 ValueT nonzero = value * vector_value; 476 477 s_tile_nonzeros[nonzero_idx] = nonzero; 478 } 479 } 480 481 #endif 482 483 // Gather the row end-offsets for the merge tile into shared memory 484 #pragma unroll 1 485 for (int item = threadIdx.x; item <= tile_num_rows; item += BLOCK_THREADS) 486 { 487 s_tile_row_end_offsets[item] = wd_row_end_offsets[tile_start_coord.x + item]; 488 } 489 490 CTA_SYNC(); 491 492 // Search for the thread's starting coordinate within the merge tile 493 CountingInputIterator<OffsetT> tile_nonzero_indices(tile_start_coord.y); 494 CoordinateT thread_start_coord; 495 496 MergePathSearch( 497 OffsetT(threadIdx.x * ITEMS_PER_THREAD), // Diagonal 498 s_tile_row_end_offsets, // List A 499 tile_nonzero_indices, // List B 500 tile_num_rows, 501 tile_num_nonzeros, 502 thread_start_coord); 503 504 CTA_SYNC(); // Perf-sync 505 506 // Compute the thread's merge path segment 507 CoordinateT thread_current_coord = thread_start_coord; 508 KeyValuePairT scan_segment[ITEMS_PER_THREAD]; 509 ValueT running_total = 0.0; 510 511 OffsetT row_end_offset = s_tile_row_end_offsets[thread_current_coord.x]; 512 ValueT nonzero = s_tile_nonzeros[thread_current_coord.y]; 513 514 #pragma unroll 515 for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) 516 { 517 if (tile_nonzero_indices[thread_current_coord.y] < row_end_offset) 518 { 519 // Move down (accumulate) 520 scan_segment[ITEM].value = nonzero; 521 running_total += nonzero; 522 ++thread_current_coord.y; 523 nonzero = s_tile_nonzeros[thread_current_coord.y]; 524 } 525 else 526 { 527 // Move right (reset) 528 scan_segment[ITEM].value = 0.0; 529 running_total = 0.0; 530 ++thread_current_coord.x; 531 row_end_offset = s_tile_row_end_offsets[thread_current_coord.x]; 532 } 533 534 scan_segment[ITEM].key = thread_current_coord.x; 535 } 536 537 CTA_SYNC(); 538 539 // Block-wide reduce-value-by-segment 540 KeyValuePairT tile_carry; 541 ReduceBySegmentOpT scan_op; 542 KeyValuePairT scan_item; 543 544 scan_item.value = running_total; 545 scan_item.key = thread_current_coord.x; 546 547 BlockScanT(temp_storage.aliasable.scan).ExclusiveScan(scan_item, scan_item, scan_op, tile_carry); 548 549 if (threadIdx.x == 0) 550 { 551 scan_item.key = thread_start_coord.x; 552 scan_item.value = 0.0; 553 } 554 555 if (tile_num_rows > 0) 556 { 557 558 CTA_SYNC(); 559 560 // Scan downsweep and scatter 561 ValueT* s_partials = &temp_storage.aliasable.merge_items[0].nonzero; 562 563 if (scan_item.key != scan_segment[0].key) 564 { 565 s_partials[scan_item.key] = scan_item.value; 566 } 567 else 568 { 569 scan_segment[0].value += scan_item.value; 570 } 571 572 #pragma unroll 573 for (int ITEM = 1; ITEM < ITEMS_PER_THREAD; ++ITEM) 574 { 575 if (scan_segment[ITEM - 1].key != scan_segment[ITEM].key) 576 { 577 s_partials[scan_segment[ITEM - 1].key] = scan_segment[ITEM - 1].value; 578 } 579 else 580 { 581 scan_segment[ITEM].value += scan_segment[ITEM - 1].value; 582 } 583 } 584 585 CTA_SYNC(); 586 587 #pragma unroll 1 588 for (int item = threadIdx.x; item < tile_num_rows; item += BLOCK_THREADS) 589 { 590 spmv_params.d_vector_y[tile_start_coord.x + item] = s_partials[item]; 591 } 592 } 593 594 // Return the tile's running carry-out 595 return tile_carry; 596 } 597 598 599 /** 600 * Consume input tile 601 */ ConsumeTilecub::AgentSpmv602 __device__ __forceinline__ void ConsumeTile( 603 CoordinateT* d_tile_coordinates, ///< [in] Pointer to the temporary array of tile starting coordinates 604 KeyValuePairT* d_tile_carry_pairs, ///< [out] Pointer to the temporary array carry-out dot product row-ids, one per block 605 int num_merge_tiles) ///< [in] Number of merge tiles 606 { 607 int tile_idx = (blockIdx.x * gridDim.y) + blockIdx.y; // Current tile index 608 609 if (tile_idx >= num_merge_tiles) 610 return; 611 612 // Read our starting coordinates 613 if (threadIdx.x < 2) 614 { 615 if (d_tile_coordinates == NULL) 616 { 617 // Search our starting coordinates 618 OffsetT diagonal = (tile_idx + threadIdx.x) * TILE_ITEMS; 619 CoordinateT tile_coord; 620 CountingInputIterator<OffsetT> nonzero_indices(0); 621 622 // Search the merge path 623 MergePathSearch( 624 diagonal, 625 RowOffsetsSearchIteratorT(spmv_params.d_row_end_offsets), 626 nonzero_indices, 627 spmv_params.num_rows, 628 spmv_params.num_nonzeros, 629 tile_coord); 630 631 temp_storage.tile_coords[threadIdx.x] = tile_coord; 632 } 633 else 634 { 635 temp_storage.tile_coords[threadIdx.x] = d_tile_coordinates[tile_idx + threadIdx.x]; 636 } 637 } 638 639 CTA_SYNC(); 640 641 CoordinateT tile_start_coord = temp_storage.tile_coords[0]; 642 CoordinateT tile_end_coord = temp_storage.tile_coords[1]; 643 644 // Consume multi-segment tile 645 KeyValuePairT tile_carry = ConsumeTile( 646 tile_idx, 647 tile_start_coord, 648 tile_end_coord, 649 Int2Type<AgentSpmvPolicyT::DIRECT_LOAD_NONZEROS>()); 650 651 // Output the tile's carry-out 652 if (threadIdx.x == 0) 653 { 654 if (HAS_ALPHA) 655 tile_carry.value *= spmv_params.alpha; 656 657 tile_carry.key += tile_start_coord.x; 658 d_tile_carry_pairs[tile_idx] = tile_carry; 659 } 660 } 661 662 663 }; 664 665 666 667 668 } // CUB namespace 669 CUB_NS_POSTFIX // Optional outer namespace(s) 670 671