1 /****************************************************************************** 2 * Copyright (c) 2011, Duane Merrill. All rights reserved. 3 * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions are met: 7 * * Redistributions of source code must retain the above copyright 8 * notice, this list of conditions and the following disclaimer. 9 * * Redistributions in binary form must reproduce the above copyright 10 * notice, this list of conditions and the following disclaimer in the 11 * documentation and/or other materials provided with the distribution. 12 * * Neither the name of the NVIDIA CORPORATION nor the 13 * names of its contributors may be used to endorse or promote products 14 * derived from this software without specific prior written permission. 15 * 16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 17 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 18 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 20 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 21 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 22 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 23 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 * 27 ******************************************************************************/ 28 29 /** 30 * \file 31 * cub::BlockRadixRank provides operations for ranking unsigned integer types within a CUDA thread block 32 */ 33 34 #pragma once 35 36 #include <stdint.h> 37 38 #include "../thread/thread_reduce.cuh" 39 #include "../thread/thread_scan.cuh" 40 #include "../block/block_scan.cuh" 41 #include "../util_ptx.cuh" 42 #include "../util_arch.cuh" 43 #include "../util_type.cuh" 44 #include "../util_namespace.cuh" 45 46 47 /// Optional outer namespace(s) 48 CUB_NS_PREFIX 49 50 /// CUB namespace 51 namespace cub { 52 53 /** 54 * \brief BlockRadixRank provides operations for ranking unsigned integer types within a CUDA thread block. 55 * \ingroup BlockModule 56 * 57 * \tparam BLOCK_DIM_X The thread block length in threads along the X dimension 58 * \tparam RADIX_BITS The number of radix bits per digit place 59 * \tparam IS_DESCENDING Whether or not the sorted-order is high-to-low 60 * \tparam MEMOIZE_OUTER_SCAN <b>[optional]</b> Whether or not to buffer outer raking scan partials to incur fewer shared memory reads at the expense of higher register pressure (default: true for architectures SM35 and newer, false otherwise). See BlockScanAlgorithm::BLOCK_SCAN_RAKING_MEMOIZE for more details. 61 * \tparam INNER_SCAN_ALGORITHM <b>[optional]</b> The cub::BlockScanAlgorithm algorithm to use (default: cub::BLOCK_SCAN_WARP_SCANS) 62 * \tparam SMEM_CONFIG <b>[optional]</b> Shared memory bank mode (default: \p cudaSharedMemBankSizeFourByte) 63 * \tparam BLOCK_DIM_Y <b>[optional]</b> The thread block length in threads along the Y dimension (default: 1) 64 * \tparam BLOCK_DIM_Z <b>[optional]</b> The thread block length in threads along the Z dimension (default: 1) 65 * \tparam PTX_ARCH <b>[optional]</b> \ptxversion 66 * 67 * \par Overview 68 * Blah... 69 * - Keys must be in a form suitable for radix ranking (i.e., unsigned bits). 70 * - \blocked 71 * 72 * \par Performance Considerations 73 * - \granularity 74 * 75 * \par Examples 76 * \par 77 * - <b>Example 1:</b> Simple radix rank of 32-bit integer keys 78 * \code 79 * #include <cub/cub.cuh> 80 * 81 * template <int BLOCK_THREADS> 82 * __global__ void ExampleKernel(...) 83 * { 84 * 85 * \endcode 86 */ 87 template < 88 int BLOCK_DIM_X, 89 int RADIX_BITS, 90 bool IS_DESCENDING, 91 bool MEMOIZE_OUTER_SCAN = (CUB_PTX_ARCH >= 350) ? true : false, 92 BlockScanAlgorithm INNER_SCAN_ALGORITHM = BLOCK_SCAN_WARP_SCANS, 93 cudaSharedMemConfig SMEM_CONFIG = cudaSharedMemBankSizeFourByte, 94 int BLOCK_DIM_Y = 1, 95 int BLOCK_DIM_Z = 1, 96 int PTX_ARCH = CUB_PTX_ARCH> 97 class BlockRadixRank 98 { 99 private: 100 101 /****************************************************************************** 102 * Type definitions and constants 103 ******************************************************************************/ 104 105 // Integer type for digit counters (to be packed into words of type PackedCounters) 106 typedef unsigned short DigitCounter; 107 108 // Integer type for packing DigitCounters into columns of shared memory banks 109 typedef typename If<(SMEM_CONFIG == cudaSharedMemBankSizeEightByte), 110 unsigned long long, 111 unsigned int>::Type PackedCounter; 112 113 enum 114 { 115 // The thread block size in threads 116 BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z, 117 118 RADIX_DIGITS = 1 << RADIX_BITS, 119 120 LOG_WARP_THREADS = CUB_LOG_WARP_THREADS(PTX_ARCH), 121 WARP_THREADS = 1 << LOG_WARP_THREADS, 122 WARPS = (BLOCK_THREADS + WARP_THREADS - 1) / WARP_THREADS, 123 124 BYTES_PER_COUNTER = sizeof(DigitCounter), 125 LOG_BYTES_PER_COUNTER = Log2<BYTES_PER_COUNTER>::VALUE, 126 127 PACKING_RATIO = sizeof(PackedCounter) / sizeof(DigitCounter), 128 LOG_PACKING_RATIO = Log2<PACKING_RATIO>::VALUE, 129 130 LOG_COUNTER_LANES = CUB_MAX((RADIX_BITS - LOG_PACKING_RATIO), 0), // Always at least one lane 131 COUNTER_LANES = 1 << LOG_COUNTER_LANES, 132 133 // The number of packed counters per thread (plus one for padding) 134 PADDED_COUNTER_LANES = COUNTER_LANES + 1, 135 RAKING_SEGMENT = PADDED_COUNTER_LANES, 136 }; 137 138 public: 139 140 enum 141 { 142 /// Number of bin-starting offsets tracked per thread 143 BINS_TRACKED_PER_THREAD = CUB_MAX(1, (RADIX_DIGITS + BLOCK_THREADS - 1) / BLOCK_THREADS), 144 }; 145 146 private: 147 148 149 /// BlockScan type 150 typedef BlockScan< 151 PackedCounter, 152 BLOCK_DIM_X, 153 INNER_SCAN_ALGORITHM, 154 BLOCK_DIM_Y, 155 BLOCK_DIM_Z, 156 PTX_ARCH> 157 BlockScan; 158 159 160 /// Shared memory storage layout type for BlockRadixRank 161 struct __align__(16) _TempStorage 162 { 163 union Aliasable 164 { 165 DigitCounter digit_counters[PADDED_COUNTER_LANES][BLOCK_THREADS][PACKING_RATIO]; 166 PackedCounter raking_grid[BLOCK_THREADS][RAKING_SEGMENT]; 167 168 } aliasable; 169 170 // Storage for scanning local ranks 171 typename BlockScan::TempStorage block_scan; 172 }; 173 174 175 /****************************************************************************** 176 * Thread fields 177 ******************************************************************************/ 178 179 /// Shared storage reference 180 _TempStorage &temp_storage; 181 182 /// Linear thread-id 183 unsigned int linear_tid; 184 185 /// Copy of raking segment, promoted to registers 186 PackedCounter cached_segment[RAKING_SEGMENT]; 187 188 189 /****************************************************************************** 190 * Utility methods 191 ******************************************************************************/ 192 193 /** 194 * Internal storage allocator 195 */ PrivateStorage()196 __device__ __forceinline__ _TempStorage& PrivateStorage() 197 { 198 __shared__ _TempStorage private_storage; 199 return private_storage; 200 } 201 202 203 /** 204 * Performs upsweep raking reduction, returning the aggregate 205 */ Upsweep()206 __device__ __forceinline__ PackedCounter Upsweep() 207 { 208 PackedCounter *smem_raking_ptr = temp_storage.aliasable.raking_grid[linear_tid]; 209 PackedCounter *raking_ptr; 210 211 if (MEMOIZE_OUTER_SCAN) 212 { 213 // Copy data into registers 214 #pragma unroll 215 for (int i = 0; i < RAKING_SEGMENT; i++) 216 { 217 cached_segment[i] = smem_raking_ptr[i]; 218 } 219 raking_ptr = cached_segment; 220 } 221 else 222 { 223 raking_ptr = smem_raking_ptr; 224 } 225 226 return internal::ThreadReduce<RAKING_SEGMENT>(raking_ptr, Sum()); 227 } 228 229 230 /// Performs exclusive downsweep raking scan ExclusiveDownsweep(PackedCounter raking_partial)231 __device__ __forceinline__ void ExclusiveDownsweep( 232 PackedCounter raking_partial) 233 { 234 PackedCounter *smem_raking_ptr = temp_storage.aliasable.raking_grid[linear_tid]; 235 236 PackedCounter *raking_ptr = (MEMOIZE_OUTER_SCAN) ? 237 cached_segment : 238 smem_raking_ptr; 239 240 // Exclusive raking downsweep scan 241 internal::ThreadScanExclusive<RAKING_SEGMENT>(raking_ptr, raking_ptr, Sum(), raking_partial); 242 243 if (MEMOIZE_OUTER_SCAN) 244 { 245 // Copy data back to smem 246 #pragma unroll 247 for (int i = 0; i < RAKING_SEGMENT; i++) 248 { 249 smem_raking_ptr[i] = cached_segment[i]; 250 } 251 } 252 } 253 254 255 /** 256 * Reset shared memory digit counters 257 */ ResetCounters()258 __device__ __forceinline__ void ResetCounters() 259 { 260 // Reset shared memory digit counters 261 #pragma unroll 262 for (int LANE = 0; LANE < PADDED_COUNTER_LANES; LANE++) 263 { 264 *((PackedCounter*) temp_storage.aliasable.digit_counters[LANE][linear_tid]) = 0; 265 } 266 } 267 268 269 /** 270 * Block-scan prefix callback 271 */ 272 struct PrefixCallBack 273 { operator ()cub::BlockRadixRank::PrefixCallBack274 __device__ __forceinline__ PackedCounter operator()(PackedCounter block_aggregate) 275 { 276 PackedCounter block_prefix = 0; 277 278 // Propagate totals in packed fields 279 #pragma unroll 280 for (int PACKED = 1; PACKED < PACKING_RATIO; PACKED++) 281 { 282 block_prefix += block_aggregate << (sizeof(DigitCounter) * 8 * PACKED); 283 } 284 285 return block_prefix; 286 } 287 }; 288 289 290 /** 291 * Scan shared memory digit counters. 292 */ ScanCounters()293 __device__ __forceinline__ void ScanCounters() 294 { 295 // Upsweep scan 296 PackedCounter raking_partial = Upsweep(); 297 298 // Compute exclusive sum 299 PackedCounter exclusive_partial; 300 PrefixCallBack prefix_call_back; 301 BlockScan(temp_storage.block_scan).ExclusiveSum(raking_partial, exclusive_partial, prefix_call_back); 302 303 // Downsweep scan with exclusive partial 304 ExclusiveDownsweep(exclusive_partial); 305 } 306 307 public: 308 309 /// \smemstorage{BlockScan} 310 struct TempStorage : Uninitialized<_TempStorage> {}; 311 312 313 /******************************************************************//** 314 * \name Collective constructors 315 *********************************************************************/ 316 //@{ 317 318 /** 319 * \brief Collective constructor using a private static allocation of shared memory as temporary storage. 320 */ BlockRadixRank()321 __device__ __forceinline__ BlockRadixRank() 322 : 323 temp_storage(PrivateStorage()), 324 linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) 325 {} 326 327 328 /** 329 * \brief Collective constructor using the specified memory allocation as temporary storage. 330 */ BlockRadixRank(TempStorage & temp_storage)331 __device__ __forceinline__ BlockRadixRank( 332 TempStorage &temp_storage) ///< [in] Reference to memory allocation having layout type TempStorage 333 : 334 temp_storage(temp_storage.Alias()), 335 linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) 336 {} 337 338 339 //@} end member group 340 /******************************************************************//** 341 * \name Raking 342 *********************************************************************/ 343 //@{ 344 345 /** 346 * \brief Rank keys. 347 */ 348 template < 349 typename UnsignedBits, 350 int KEYS_PER_THREAD> RankKeys(UnsignedBits (& keys)[KEYS_PER_THREAD],int (& ranks)[KEYS_PER_THREAD],int current_bit,int num_bits)351 __device__ __forceinline__ void RankKeys( 352 UnsignedBits (&keys)[KEYS_PER_THREAD], ///< [in] Keys for this tile 353 int (&ranks)[KEYS_PER_THREAD], ///< [out] For each key, the local rank within the tile 354 int current_bit, ///< [in] The least-significant bit position of the current digit to extract 355 int num_bits) ///< [in] The number of bits in the current digit 356 { 357 DigitCounter thread_prefixes[KEYS_PER_THREAD]; // For each key, the count of previous keys in this tile having the same digit 358 DigitCounter* digit_counters[KEYS_PER_THREAD]; // For each key, the byte-offset of its corresponding digit counter in smem 359 360 // Reset shared memory digit counters 361 ResetCounters(); 362 363 #pragma unroll 364 for (int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM) 365 { 366 // Get digit 367 unsigned int digit = BFE(keys[ITEM], current_bit, num_bits); 368 369 // Get sub-counter 370 unsigned int sub_counter = digit >> LOG_COUNTER_LANES; 371 372 // Get counter lane 373 unsigned int counter_lane = digit & (COUNTER_LANES - 1); 374 375 if (IS_DESCENDING) 376 { 377 sub_counter = PACKING_RATIO - 1 - sub_counter; 378 counter_lane = COUNTER_LANES - 1 - counter_lane; 379 } 380 381 // Pointer to smem digit counter 382 digit_counters[ITEM] = &temp_storage.aliasable.digit_counters[counter_lane][linear_tid][sub_counter]; 383 384 // Load thread-exclusive prefix 385 thread_prefixes[ITEM] = *digit_counters[ITEM]; 386 387 // Store inclusive prefix 388 *digit_counters[ITEM] = thread_prefixes[ITEM] + 1; 389 } 390 391 CTA_SYNC(); 392 393 // Scan shared memory counters 394 ScanCounters(); 395 396 CTA_SYNC(); 397 398 // Extract the local ranks of each key 399 for (int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM) 400 { 401 // Add in thread block exclusive prefix 402 ranks[ITEM] = thread_prefixes[ITEM] + *digit_counters[ITEM]; 403 } 404 } 405 406 407 /** 408 * \brief Rank keys. For the lower \p RADIX_DIGITS threads, digit counts for each digit are provided for the corresponding thread. 409 */ 410 template < 411 typename UnsignedBits, 412 int KEYS_PER_THREAD> RankKeys(UnsignedBits (& keys)[KEYS_PER_THREAD],int (& ranks)[KEYS_PER_THREAD],int current_bit,int num_bits,int (& exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD])413 __device__ __forceinline__ void RankKeys( 414 UnsignedBits (&keys)[KEYS_PER_THREAD], ///< [in] Keys for this tile 415 int (&ranks)[KEYS_PER_THREAD], ///< [out] For each key, the local rank within the tile (out parameter) 416 int current_bit, ///< [in] The least-significant bit position of the current digit to extract 417 int num_bits, ///< [in] The number of bits in the current digit 418 int (&exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD]) ///< [out] The exclusive prefix sum for the digits [(threadIdx.x * BINS_TRACKED_PER_THREAD) ... (threadIdx.x * BINS_TRACKED_PER_THREAD) + BINS_TRACKED_PER_THREAD - 1] 419 { 420 // Rank keys 421 RankKeys(keys, ranks, current_bit, num_bits); 422 423 // Get the inclusive and exclusive digit totals corresponding to the calling thread. 424 #pragma unroll 425 for (int track = 0; track < BINS_TRACKED_PER_THREAD; ++track) 426 { 427 int bin_idx = (linear_tid * BINS_TRACKED_PER_THREAD) + track; 428 429 if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS)) 430 { 431 if (IS_DESCENDING) 432 bin_idx = RADIX_DIGITS - bin_idx - 1; 433 434 // Obtain ex/inclusive digit counts. (Unfortunately these all reside in the 435 // first counter column, resulting in unavoidable bank conflicts.) 436 unsigned int counter_lane = (bin_idx & (COUNTER_LANES - 1)); 437 unsigned int sub_counter = bin_idx >> (LOG_COUNTER_LANES); 438 439 exclusive_digit_prefix[track] = temp_storage.aliasable.digit_counters[counter_lane][0][sub_counter]; 440 } 441 } 442 } 443 }; 444 445 446 447 448 449 /** 450 * Radix-rank using match.any 451 */ 452 template < 453 int BLOCK_DIM_X, 454 int RADIX_BITS, 455 bool IS_DESCENDING, 456 BlockScanAlgorithm INNER_SCAN_ALGORITHM = BLOCK_SCAN_WARP_SCANS, 457 int BLOCK_DIM_Y = 1, 458 int BLOCK_DIM_Z = 1, 459 int PTX_ARCH = CUB_PTX_ARCH> 460 class BlockRadixRankMatch 461 { 462 private: 463 464 /****************************************************************************** 465 * Type definitions and constants 466 ******************************************************************************/ 467 468 typedef int32_t RankT; 469 typedef int32_t DigitCounterT; 470 471 enum 472 { 473 // The thread block size in threads 474 BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z, 475 476 RADIX_DIGITS = 1 << RADIX_BITS, 477 478 LOG_WARP_THREADS = CUB_LOG_WARP_THREADS(PTX_ARCH), 479 WARP_THREADS = 1 << LOG_WARP_THREADS, 480 WARPS = (BLOCK_THREADS + WARP_THREADS - 1) / WARP_THREADS, 481 482 PADDED_WARPS = ((WARPS & 0x1) == 0) ? 483 WARPS + 1 : 484 WARPS, 485 486 COUNTERS = PADDED_WARPS * RADIX_DIGITS, 487 RAKING_SEGMENT = (COUNTERS + BLOCK_THREADS - 1) / BLOCK_THREADS, 488 PADDED_RAKING_SEGMENT = ((RAKING_SEGMENT & 0x1) == 0) ? 489 RAKING_SEGMENT + 1 : 490 RAKING_SEGMENT, 491 }; 492 493 public: 494 495 enum 496 { 497 /// Number of bin-starting offsets tracked per thread 498 BINS_TRACKED_PER_THREAD = CUB_MAX(1, (RADIX_DIGITS + BLOCK_THREADS - 1) / BLOCK_THREADS), 499 }; 500 501 private: 502 503 /// BlockScan type 504 typedef BlockScan< 505 DigitCounterT, 506 BLOCK_THREADS, 507 INNER_SCAN_ALGORITHM, 508 BLOCK_DIM_Y, 509 BLOCK_DIM_Z, 510 PTX_ARCH> 511 BlockScanT; 512 513 514 /// Shared memory storage layout type for BlockRadixRank 515 struct __align__(16) _TempStorage 516 { 517 typename BlockScanT::TempStorage block_scan; 518 519 union __align__(16) Aliasable 520 { 521 volatile DigitCounterT warp_digit_counters[RADIX_DIGITS][PADDED_WARPS]; 522 DigitCounterT raking_grid[BLOCK_THREADS][PADDED_RAKING_SEGMENT]; 523 524 } aliasable; 525 }; 526 527 528 /****************************************************************************** 529 * Thread fields 530 ******************************************************************************/ 531 532 /// Shared storage reference 533 _TempStorage &temp_storage; 534 535 /// Linear thread-id 536 unsigned int linear_tid; 537 538 539 540 public: 541 542 /// \smemstorage{BlockScan} 543 struct TempStorage : Uninitialized<_TempStorage> {}; 544 545 546 /******************************************************************//** 547 * \name Collective constructors 548 *********************************************************************/ 549 //@{ 550 551 552 /** 553 * \brief Collective constructor using the specified memory allocation as temporary storage. 554 */ BlockRadixRankMatch(TempStorage & temp_storage)555 __device__ __forceinline__ BlockRadixRankMatch( 556 TempStorage &temp_storage) ///< [in] Reference to memory allocation having layout type TempStorage 557 : 558 temp_storage(temp_storage.Alias()), 559 linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) 560 {} 561 562 563 //@} end member group 564 /******************************************************************//** 565 * \name Raking 566 *********************************************************************/ 567 //@{ 568 569 /** 570 * \brief Rank keys. 571 */ 572 template < 573 typename UnsignedBits, 574 int KEYS_PER_THREAD> RankKeys(UnsignedBits (& keys)[KEYS_PER_THREAD],int (& ranks)[KEYS_PER_THREAD],int current_bit,int num_bits)575 __device__ __forceinline__ void RankKeys( 576 UnsignedBits (&keys)[KEYS_PER_THREAD], ///< [in] Keys for this tile 577 int (&ranks)[KEYS_PER_THREAD], ///< [out] For each key, the local rank within the tile 578 int current_bit, ///< [in] The least-significant bit position of the current digit to extract 579 int num_bits) ///< [in] The number of bits in the current digit 580 { 581 // Initialize shared digit counters 582 583 #pragma unroll 584 for (int ITEM = 0; ITEM < PADDED_RAKING_SEGMENT; ++ITEM) 585 temp_storage.aliasable.raking_grid[linear_tid][ITEM] = 0; 586 587 CTA_SYNC(); 588 589 // Each warp will strip-mine its section of input, one strip at a time 590 591 volatile DigitCounterT *digit_counters[KEYS_PER_THREAD]; 592 uint32_t warp_id = linear_tid >> LOG_WARP_THREADS; 593 uint32_t lane_mask_lt = LaneMaskLt(); 594 595 #pragma unroll 596 for (int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM) 597 { 598 // My digit 599 uint32_t digit = BFE(keys[ITEM], current_bit, num_bits); 600 601 if (IS_DESCENDING) 602 digit = RADIX_DIGITS - digit - 1; 603 604 // Mask of peers who have same digit as me 605 uint32_t peer_mask = MatchAny<RADIX_BITS>(digit); 606 607 // Pointer to smem digit counter for this key 608 digit_counters[ITEM] = &temp_storage.aliasable.warp_digit_counters[digit][warp_id]; 609 610 // Number of occurrences in previous strips 611 DigitCounterT warp_digit_prefix = *digit_counters[ITEM]; 612 613 // Warp-sync 614 WARP_SYNC(0xFFFFFFFF); 615 616 // Number of peers having same digit as me 617 int32_t digit_count = __popc(peer_mask); 618 619 // Number of lower-ranked peers having same digit seen so far 620 int32_t peer_digit_prefix = __popc(peer_mask & lane_mask_lt); 621 622 if (peer_digit_prefix == 0) 623 { 624 // First thread for each digit updates the shared warp counter 625 *digit_counters[ITEM] = DigitCounterT(warp_digit_prefix + digit_count); 626 } 627 628 // Warp-sync 629 WARP_SYNC(0xFFFFFFFF); 630 631 // Number of prior keys having same digit 632 ranks[ITEM] = warp_digit_prefix + DigitCounterT(peer_digit_prefix); 633 } 634 635 CTA_SYNC(); 636 637 // Scan warp counters 638 639 DigitCounterT scan_counters[PADDED_RAKING_SEGMENT]; 640 641 #pragma unroll 642 for (int ITEM = 0; ITEM < PADDED_RAKING_SEGMENT; ++ITEM) 643 scan_counters[ITEM] = temp_storage.aliasable.raking_grid[linear_tid][ITEM]; 644 645 BlockScanT(temp_storage.block_scan).ExclusiveSum(scan_counters, scan_counters); 646 647 #pragma unroll 648 for (int ITEM = 0; ITEM < PADDED_RAKING_SEGMENT; ++ITEM) 649 temp_storage.aliasable.raking_grid[linear_tid][ITEM] = scan_counters[ITEM]; 650 651 CTA_SYNC(); 652 653 // Seed ranks with counter values from previous warps 654 #pragma unroll 655 for (int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM) 656 ranks[ITEM] += *digit_counters[ITEM]; 657 } 658 659 660 /** 661 * \brief Rank keys. For the lower \p RADIX_DIGITS threads, digit counts for each digit are provided for the corresponding thread. 662 */ 663 template < 664 typename UnsignedBits, 665 int KEYS_PER_THREAD> RankKeys(UnsignedBits (& keys)[KEYS_PER_THREAD],int (& ranks)[KEYS_PER_THREAD],int current_bit,int num_bits,int (& exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD])666 __device__ __forceinline__ void RankKeys( 667 UnsignedBits (&keys)[KEYS_PER_THREAD], ///< [in] Keys for this tile 668 int (&ranks)[KEYS_PER_THREAD], ///< [out] For each key, the local rank within the tile (out parameter) 669 int current_bit, ///< [in] The least-significant bit position of the current digit to extract 670 int num_bits, ///< [in] The number of bits in the current digit 671 int (&exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD]) ///< [out] The exclusive prefix sum for the digits [(threadIdx.x * BINS_TRACKED_PER_THREAD) ... (threadIdx.x * BINS_TRACKED_PER_THREAD) + BINS_TRACKED_PER_THREAD - 1] 672 { 673 RankKeys(keys, ranks, current_bit, num_bits); 674 675 // Get exclusive count for each digit 676 #pragma unroll 677 for (int track = 0; track < BINS_TRACKED_PER_THREAD; ++track) 678 { 679 int bin_idx = (linear_tid * BINS_TRACKED_PER_THREAD) + track; 680 681 if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS)) 682 { 683 if (IS_DESCENDING) 684 bin_idx = RADIX_DIGITS - bin_idx - 1; 685 686 exclusive_digit_prefix[track] = temp_storage.aliasable.warp_digit_counters[bin_idx][0]; 687 } 688 } 689 } 690 }; 691 692 693 } // CUB namespace 694 CUB_NS_POSTFIX // Optional outer namespace(s) 695 696 697