1 /****************************************************************************** 2 * Copyright (c) 2011, Duane Merrill. All rights reserved. 3 * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions are met: 7 * * Redistributions of source code must retain the above copyright 8 * notice, this list of conditions and the following disclaimer. 9 * * Redistributions in binary form must reproduce the above copyright 10 * notice, this list of conditions and the following disclaimer in the 11 * documentation and/or other materials provided with the distribution. 12 * * Neither the name of the NVIDIA CORPORATION nor the 13 * names of its contributors may be used to endorse or promote products 14 * derived from this software without specific prior written permission. 15 * 16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 17 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 18 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 20 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 21 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 22 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 23 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 * 27 ******************************************************************************/ 28 29 /** 30 * \file 31 * AgentRadixSortDownsweep implements a stateful abstraction of CUDA thread blocks for participating in device-wide radix sort downsweep . 32 */ 33 34 35 #pragma once 36 37 #include <stdint.h> 38 39 #include "../thread/thread_load.cuh" 40 #include "../block/block_load.cuh" 41 #include "../block/block_store.cuh" 42 #include "../block/block_radix_rank.cuh" 43 #include "../block/block_exchange.cuh" 44 #include "../util_type.cuh" 45 #include "../iterator/cache_modified_input_iterator.cuh" 46 #include "../util_namespace.cuh" 47 48 /// Optional outer namespace(s) 49 CUB_NS_PREFIX 50 51 /// CUB namespace 52 namespace cub { 53 54 55 /****************************************************************************** 56 * Tuning policy types 57 ******************************************************************************/ 58 59 /** 60 * Radix ranking algorithm 61 */ 62 enum RadixRankAlgorithm 63 { 64 RADIX_RANK_BASIC, 65 RADIX_RANK_MEMOIZE, 66 RADIX_RANK_MATCH 67 }; 68 69 /** 70 * Parameterizable tuning policy type for AgentRadixSortDownsweep 71 */ 72 template < 73 int _BLOCK_THREADS, ///< Threads per thread block 74 int _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) 75 BlockLoadAlgorithm _LOAD_ALGORITHM, ///< The BlockLoad algorithm to use 76 CacheLoadModifier _LOAD_MODIFIER, ///< Cache load modifier for reading keys (and values) 77 RadixRankAlgorithm _RANK_ALGORITHM, ///< The radix ranking algorithm to use 78 BlockScanAlgorithm _SCAN_ALGORITHM, ///< The block scan algorithm to use 79 int _RADIX_BITS> ///< The number of radix bits, i.e., log2(bins) 80 struct AgentRadixSortDownsweepPolicy 81 { 82 enum 83 { 84 BLOCK_THREADS = _BLOCK_THREADS, ///< Threads per thread block 85 ITEMS_PER_THREAD = _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) 86 RADIX_BITS = _RADIX_BITS, ///< The number of radix bits, i.e., log2(bins) 87 }; 88 89 static const BlockLoadAlgorithm LOAD_ALGORITHM = _LOAD_ALGORITHM; ///< The BlockLoad algorithm to use 90 static const CacheLoadModifier LOAD_MODIFIER = _LOAD_MODIFIER; ///< Cache load modifier for reading keys (and values) 91 static const RadixRankAlgorithm RANK_ALGORITHM = _RANK_ALGORITHM; ///< The radix ranking algorithm to use 92 static const BlockScanAlgorithm SCAN_ALGORITHM = _SCAN_ALGORITHM; ///< The BlockScan algorithm to use 93 }; 94 95 96 /****************************************************************************** 97 * Thread block abstractions 98 ******************************************************************************/ 99 100 101 102 103 104 /** 105 * \brief AgentRadixSortDownsweep implements a stateful abstraction of CUDA thread blocks for participating in device-wide radix sort downsweep . 106 */ 107 template < 108 typename AgentRadixSortDownsweepPolicy, ///< Parameterized AgentRadixSortDownsweepPolicy tuning policy type 109 bool IS_DESCENDING, ///< Whether or not the sorted-order is high-to-low 110 typename KeyT, ///< KeyT type 111 typename ValueT, ///< ValueT type 112 typename OffsetT> ///< Signed integer type for global offsets 113 struct AgentRadixSortDownsweep 114 { 115 //--------------------------------------------------------------------- 116 // Type definitions and constants 117 //--------------------------------------------------------------------- 118 119 // Appropriate unsigned-bits representation of KeyT 120 typedef typename Traits<KeyT>::UnsignedBits UnsignedBits; 121 122 static const UnsignedBits LOWEST_KEY = Traits<KeyT>::LOWEST_KEY; 123 static const UnsignedBits MAX_KEY = Traits<KeyT>::MAX_KEY; 124 125 static const BlockLoadAlgorithm LOAD_ALGORITHM = AgentRadixSortDownsweepPolicy::LOAD_ALGORITHM; 126 static const CacheLoadModifier LOAD_MODIFIER = AgentRadixSortDownsweepPolicy::LOAD_MODIFIER; 127 static const RadixRankAlgorithm RANK_ALGORITHM = AgentRadixSortDownsweepPolicy::RANK_ALGORITHM; 128 static const BlockScanAlgorithm SCAN_ALGORITHM = AgentRadixSortDownsweepPolicy::SCAN_ALGORITHM; 129 130 enum 131 { 132 BLOCK_THREADS = AgentRadixSortDownsweepPolicy::BLOCK_THREADS, 133 ITEMS_PER_THREAD = AgentRadixSortDownsweepPolicy::ITEMS_PER_THREAD, 134 RADIX_BITS = AgentRadixSortDownsweepPolicy::RADIX_BITS, 135 TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD, 136 137 RADIX_DIGITS = 1 << RADIX_BITS, 138 KEYS_ONLY = Equals<ValueT, NullType>::VALUE, 139 }; 140 141 // Input iterator wrapper type (for applying cache modifier)s 142 typedef CacheModifiedInputIterator<LOAD_MODIFIER, UnsignedBits, OffsetT> KeysItr; 143 typedef CacheModifiedInputIterator<LOAD_MODIFIER, ValueT, OffsetT> ValuesItr; 144 145 // Radix ranking type to use 146 typedef typename If<(RANK_ALGORITHM == RADIX_RANK_BASIC), 147 BlockRadixRank<BLOCK_THREADS, RADIX_BITS, IS_DESCENDING, false, SCAN_ALGORITHM>, 148 typename If<(RANK_ALGORITHM == RADIX_RANK_MEMOIZE), 149 BlockRadixRank<BLOCK_THREADS, RADIX_BITS, IS_DESCENDING, true, SCAN_ALGORITHM>, 150 BlockRadixRankMatch<BLOCK_THREADS, RADIX_BITS, IS_DESCENDING, SCAN_ALGORITHM> 151 >::Type 152 >::Type BlockRadixRankT; 153 154 enum 155 { 156 /// Number of bin-starting offsets tracked per thread 157 BINS_TRACKED_PER_THREAD = BlockRadixRankT::BINS_TRACKED_PER_THREAD 158 }; 159 160 // BlockLoad type (keys) 161 typedef BlockLoad< 162 UnsignedBits, 163 BLOCK_THREADS, 164 ITEMS_PER_THREAD, 165 LOAD_ALGORITHM> BlockLoadKeysT; 166 167 // BlockLoad type (values) 168 typedef BlockLoad< 169 ValueT, 170 BLOCK_THREADS, 171 ITEMS_PER_THREAD, 172 LOAD_ALGORITHM> BlockLoadValuesT; 173 174 // Value exchange array type 175 typedef ValueT ValueExchangeT[TILE_ITEMS]; 176 177 /** 178 * Shared memory storage layout 179 */ 180 union __align__(16) _TempStorage 181 { 182 typename BlockLoadKeysT::TempStorage load_keys; 183 typename BlockLoadValuesT::TempStorage load_values; 184 typename BlockRadixRankT::TempStorage radix_rank; 185 186 struct 187 { 188 UnsignedBits exchange_keys[TILE_ITEMS]; 189 OffsetT relative_bin_offsets[RADIX_DIGITS]; 190 }; 191 192 Uninitialized<ValueExchangeT> exchange_values; 193 194 OffsetT exclusive_digit_prefix[RADIX_DIGITS]; 195 }; 196 197 198 /// Alias wrapper allowing storage to be unioned 199 struct TempStorage : Uninitialized<_TempStorage> {}; 200 201 202 //--------------------------------------------------------------------- 203 // Thread fields 204 //--------------------------------------------------------------------- 205 206 // Shared storage for this CTA 207 _TempStorage &temp_storage; 208 209 // Input and output device pointers 210 KeysItr d_keys_in; 211 ValuesItr d_values_in; 212 UnsignedBits *d_keys_out; 213 ValueT *d_values_out; 214 215 // The global scatter base offset for each digit (valid in the first RADIX_DIGITS threads) 216 OffsetT bin_offset[BINS_TRACKED_PER_THREAD]; 217 218 // The least-significant bit position of the current digit to extract 219 int current_bit; 220 221 // Number of bits in current digit 222 int num_bits; 223 224 // Whether to short-cirucit 225 int short_circuit; 226 227 //--------------------------------------------------------------------- 228 // Utility methods 229 //--------------------------------------------------------------------- 230 231 232 /** 233 * Scatter ranked keys through shared memory, then to device-accessible memory 234 */ 235 template <bool FULL_TILE> ScatterKeyscub::AgentRadixSortDownsweep236 __device__ __forceinline__ void ScatterKeys( 237 UnsignedBits (&twiddled_keys)[ITEMS_PER_THREAD], 238 OffsetT (&relative_bin_offsets)[ITEMS_PER_THREAD], 239 int (&ranks)[ITEMS_PER_THREAD], 240 OffsetT valid_items) 241 { 242 #pragma unroll 243 for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) 244 { 245 temp_storage.exchange_keys[ranks[ITEM]] = twiddled_keys[ITEM]; 246 } 247 248 CTA_SYNC(); 249 250 #pragma unroll 251 for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) 252 { 253 UnsignedBits key = temp_storage.exchange_keys[threadIdx.x + (ITEM * BLOCK_THREADS)]; 254 UnsignedBits digit = BFE(key, current_bit, num_bits); 255 relative_bin_offsets[ITEM] = temp_storage.relative_bin_offsets[digit]; 256 257 // Un-twiddle 258 key = Traits<KeyT>::TwiddleOut(key); 259 260 if (FULL_TILE || 261 (static_cast<OffsetT>(threadIdx.x + (ITEM * BLOCK_THREADS)) < valid_items)) 262 { 263 d_keys_out[relative_bin_offsets[ITEM] + threadIdx.x + (ITEM * BLOCK_THREADS)] = key; 264 } 265 } 266 } 267 268 269 /** 270 * Scatter ranked values through shared memory, then to device-accessible memory 271 */ 272 template <bool FULL_TILE> ScatterValuescub::AgentRadixSortDownsweep273 __device__ __forceinline__ void ScatterValues( 274 ValueT (&values)[ITEMS_PER_THREAD], 275 OffsetT (&relative_bin_offsets)[ITEMS_PER_THREAD], 276 int (&ranks)[ITEMS_PER_THREAD], 277 OffsetT valid_items) 278 { 279 CTA_SYNC(); 280 281 ValueExchangeT &exchange_values = temp_storage.exchange_values.Alias(); 282 283 #pragma unroll 284 for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) 285 { 286 exchange_values[ranks[ITEM]] = values[ITEM]; 287 } 288 289 CTA_SYNC(); 290 291 #pragma unroll 292 for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) 293 { 294 ValueT value = exchange_values[threadIdx.x + (ITEM * BLOCK_THREADS)]; 295 296 if (FULL_TILE || 297 (static_cast<OffsetT>(threadIdx.x + (ITEM * BLOCK_THREADS)) < valid_items)) 298 { 299 d_values_out[relative_bin_offsets[ITEM] + threadIdx.x + (ITEM * BLOCK_THREADS)] = value; 300 } 301 } 302 } 303 304 /** 305 * Load a tile of keys (specialized for full tile, any ranking algorithm) 306 */ 307 template <int _RANK_ALGORITHM> LoadKeyscub::AgentRadixSortDownsweep308 __device__ __forceinline__ void LoadKeys( 309 UnsignedBits (&keys)[ITEMS_PER_THREAD], 310 OffsetT block_offset, 311 OffsetT valid_items, 312 UnsignedBits oob_item, 313 Int2Type<true> is_full_tile, 314 Int2Type<_RANK_ALGORITHM> rank_algorithm) 315 { 316 BlockLoadKeysT(temp_storage.load_keys).Load( 317 d_keys_in + block_offset, keys); 318 319 CTA_SYNC(); 320 } 321 322 323 /** 324 * Load a tile of keys (specialized for partial tile, any ranking algorithm) 325 */ 326 template <int _RANK_ALGORITHM> LoadKeyscub::AgentRadixSortDownsweep327 __device__ __forceinline__ void LoadKeys( 328 UnsignedBits (&keys)[ITEMS_PER_THREAD], 329 OffsetT block_offset, 330 OffsetT valid_items, 331 UnsignedBits oob_item, 332 Int2Type<false> is_full_tile, 333 Int2Type<_RANK_ALGORITHM> rank_algorithm) 334 { 335 // Register pressure work-around: moving valid_items through shfl prevents compiler 336 // from reusing guards/addressing from prior guarded loads 337 valid_items = ShuffleIndex<CUB_PTX_WARP_THREADS>(valid_items, 0, 0xffffffff); 338 339 BlockLoadKeysT(temp_storage.load_keys).Load( 340 d_keys_in + block_offset, keys, valid_items, oob_item); 341 342 CTA_SYNC(); 343 } 344 345 346 /** 347 * Load a tile of keys (specialized for full tile, match ranking algorithm) 348 */ LoadKeyscub::AgentRadixSortDownsweep349 __device__ __forceinline__ void LoadKeys( 350 UnsignedBits (&keys)[ITEMS_PER_THREAD], 351 OffsetT block_offset, 352 OffsetT valid_items, 353 UnsignedBits oob_item, 354 Int2Type<true> is_full_tile, 355 Int2Type<RADIX_RANK_MATCH> rank_algorithm) 356 { 357 LoadDirectWarpStriped(threadIdx.x, d_keys_in + block_offset, keys); 358 } 359 360 361 /** 362 * Load a tile of keys (specialized for partial tile, match ranking algorithm) 363 */ LoadKeyscub::AgentRadixSortDownsweep364 __device__ __forceinline__ void LoadKeys( 365 UnsignedBits (&keys)[ITEMS_PER_THREAD], 366 OffsetT block_offset, 367 OffsetT valid_items, 368 UnsignedBits oob_item, 369 Int2Type<false> is_full_tile, 370 Int2Type<RADIX_RANK_MATCH> rank_algorithm) 371 { 372 // Register pressure work-around: moving valid_items through shfl prevents compiler 373 // from reusing guards/addressing from prior guarded loads 374 valid_items = ShuffleIndex<CUB_PTX_WARP_THREADS>(valid_items, 0, 0xffffffff); 375 376 LoadDirectWarpStriped(threadIdx.x, d_keys_in + block_offset, keys, valid_items, oob_item); 377 } 378 379 380 /** 381 * Load a tile of values (specialized for full tile, any ranking algorithm) 382 */ 383 template <int _RANK_ALGORITHM> LoadValuescub::AgentRadixSortDownsweep384 __device__ __forceinline__ void LoadValues( 385 ValueT (&values)[ITEMS_PER_THREAD], 386 OffsetT block_offset, 387 OffsetT valid_items, 388 Int2Type<true> is_full_tile, 389 Int2Type<_RANK_ALGORITHM> rank_algorithm) 390 { 391 BlockLoadValuesT(temp_storage.load_values).Load( 392 d_values_in + block_offset, values); 393 394 CTA_SYNC(); 395 } 396 397 398 /** 399 * Load a tile of values (specialized for partial tile, any ranking algorithm) 400 */ 401 template <int _RANK_ALGORITHM> LoadValuescub::AgentRadixSortDownsweep402 __device__ __forceinline__ void LoadValues( 403 ValueT (&values)[ITEMS_PER_THREAD], 404 OffsetT block_offset, 405 OffsetT valid_items, 406 Int2Type<false> is_full_tile, 407 Int2Type<_RANK_ALGORITHM> rank_algorithm) 408 { 409 // Register pressure work-around: moving valid_items through shfl prevents compiler 410 // from reusing guards/addressing from prior guarded loads 411 valid_items = ShuffleIndex<CUB_PTX_WARP_THREADS>(valid_items, 0, 0xffffffff); 412 413 BlockLoadValuesT(temp_storage.load_values).Load( 414 d_values_in + block_offset, values, valid_items); 415 416 CTA_SYNC(); 417 } 418 419 420 /** 421 * Load a tile of items (specialized for full tile, match ranking algorithm) 422 */ LoadValuescub::AgentRadixSortDownsweep423 __device__ __forceinline__ void LoadValues( 424 ValueT (&values)[ITEMS_PER_THREAD], 425 OffsetT block_offset, 426 OffsetT valid_items, 427 Int2Type<true> is_full_tile, 428 Int2Type<RADIX_RANK_MATCH> rank_algorithm) 429 { 430 LoadDirectWarpStriped(threadIdx.x, d_values_in + block_offset, values); 431 } 432 433 434 /** 435 * Load a tile of items (specialized for partial tile, match ranking algorithm) 436 */ LoadValuescub::AgentRadixSortDownsweep437 __device__ __forceinline__ void LoadValues( 438 ValueT (&values)[ITEMS_PER_THREAD], 439 OffsetT block_offset, 440 OffsetT valid_items, 441 Int2Type<false> is_full_tile, 442 Int2Type<RADIX_RANK_MATCH> rank_algorithm) 443 { 444 // Register pressure work-around: moving valid_items through shfl prevents compiler 445 // from reusing guards/addressing from prior guarded loads 446 valid_items = ShuffleIndex<CUB_PTX_WARP_THREADS>(valid_items, 0, 0xffffffff); 447 448 LoadDirectWarpStriped(threadIdx.x, d_values_in + block_offset, values, valid_items); 449 } 450 451 452 /** 453 * Truck along associated values 454 */ 455 template <bool FULL_TILE> GatherScatterValuescub::AgentRadixSortDownsweep456 __device__ __forceinline__ void GatherScatterValues( 457 OffsetT (&relative_bin_offsets)[ITEMS_PER_THREAD], 458 int (&ranks)[ITEMS_PER_THREAD], 459 OffsetT block_offset, 460 OffsetT valid_items, 461 Int2Type<false> /*is_keys_only*/) 462 { 463 ValueT values[ITEMS_PER_THREAD]; 464 465 CTA_SYNC(); 466 467 LoadValues( 468 values, 469 block_offset, 470 valid_items, 471 Int2Type<FULL_TILE>(), 472 Int2Type<RANK_ALGORITHM>()); 473 474 ScatterValues<FULL_TILE>( 475 values, 476 relative_bin_offsets, 477 ranks, 478 valid_items); 479 } 480 481 482 /** 483 * Truck along associated values (specialized for key-only sorting) 484 */ 485 template <bool FULL_TILE> GatherScatterValuescub::AgentRadixSortDownsweep486 __device__ __forceinline__ void GatherScatterValues( 487 OffsetT (&/*relative_bin_offsets*/)[ITEMS_PER_THREAD], 488 int (&/*ranks*/)[ITEMS_PER_THREAD], 489 OffsetT /*block_offset*/, 490 OffsetT /*valid_items*/, 491 Int2Type<true> /*is_keys_only*/) 492 {} 493 494 495 /** 496 * Process tile 497 */ 498 template <bool FULL_TILE> ProcessTilecub::AgentRadixSortDownsweep499 __device__ __forceinline__ void ProcessTile( 500 OffsetT block_offset, 501 const OffsetT &valid_items = TILE_ITEMS) 502 { 503 UnsignedBits keys[ITEMS_PER_THREAD]; 504 int ranks[ITEMS_PER_THREAD]; 505 OffsetT relative_bin_offsets[ITEMS_PER_THREAD]; 506 507 // Assign default (min/max) value to all keys 508 UnsignedBits default_key = (IS_DESCENDING) ? LOWEST_KEY : MAX_KEY; 509 510 // Load tile of keys 511 LoadKeys( 512 keys, 513 block_offset, 514 valid_items, 515 default_key, 516 Int2Type<FULL_TILE>(), 517 Int2Type<RANK_ALGORITHM>()); 518 519 // Twiddle key bits if necessary 520 #pragma unroll 521 for (int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++) 522 { 523 keys[KEY] = Traits<KeyT>::TwiddleIn(keys[KEY]); 524 } 525 526 // Rank the twiddled keys 527 int exclusive_digit_prefix[BINS_TRACKED_PER_THREAD]; 528 BlockRadixRankT(temp_storage.radix_rank).RankKeys( 529 keys, 530 ranks, 531 current_bit, 532 num_bits, 533 exclusive_digit_prefix); 534 535 CTA_SYNC(); 536 537 // Share exclusive digit prefix 538 #pragma unroll 539 for (int track = 0; track < BINS_TRACKED_PER_THREAD; ++track) 540 { 541 int bin_idx = (threadIdx.x * BINS_TRACKED_PER_THREAD) + track; 542 if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS)) 543 { 544 // Store exclusive prefix 545 temp_storage.exclusive_digit_prefix[bin_idx] = 546 exclusive_digit_prefix[track]; 547 } 548 } 549 550 CTA_SYNC(); 551 552 // Get inclusive digit prefix 553 int inclusive_digit_prefix[BINS_TRACKED_PER_THREAD]; 554 555 #pragma unroll 556 for (int track = 0; track < BINS_TRACKED_PER_THREAD; ++track) 557 { 558 int bin_idx = (threadIdx.x * BINS_TRACKED_PER_THREAD) + track; 559 if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS)) 560 { 561 if (IS_DESCENDING) 562 { 563 // Get inclusive digit prefix from exclusive prefix (higher bins come first) 564 inclusive_digit_prefix[track] = (bin_idx == 0) ? 565 (BLOCK_THREADS * ITEMS_PER_THREAD) : 566 temp_storage.exclusive_digit_prefix[bin_idx - 1]; 567 } 568 else 569 { 570 // Get inclusive digit prefix from exclusive prefix (lower bins come first) 571 inclusive_digit_prefix[track] = (bin_idx == RADIX_DIGITS - 1) ? 572 (BLOCK_THREADS * ITEMS_PER_THREAD) : 573 temp_storage.exclusive_digit_prefix[bin_idx + 1]; 574 } 575 } 576 } 577 578 CTA_SYNC(); 579 580 // Update global scatter base offsets for each digit 581 #pragma unroll 582 for (int track = 0; track < BINS_TRACKED_PER_THREAD; ++track) 583 { 584 int bin_idx = (threadIdx.x * BINS_TRACKED_PER_THREAD) + track; 585 if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS)) 586 { 587 bin_offset[track] -= exclusive_digit_prefix[track]; 588 temp_storage.relative_bin_offsets[bin_idx] = bin_offset[track]; 589 bin_offset[track] += inclusive_digit_prefix[track]; 590 } 591 } 592 593 CTA_SYNC(); 594 595 // Scatter keys 596 ScatterKeys<FULL_TILE>(keys, relative_bin_offsets, ranks, valid_items); 597 598 // Gather/scatter values 599 GatherScatterValues<FULL_TILE>(relative_bin_offsets , ranks, block_offset, valid_items, Int2Type<KEYS_ONLY>()); 600 } 601 602 //--------------------------------------------------------------------- 603 // Copy shortcut 604 //--------------------------------------------------------------------- 605 606 /** 607 * Copy tiles within the range of input 608 */ 609 template < 610 typename InputIteratorT, 611 typename T> Copycub::AgentRadixSortDownsweep612 __device__ __forceinline__ void Copy( 613 InputIteratorT d_in, 614 T *d_out, 615 OffsetT block_offset, 616 OffsetT block_end) 617 { 618 // Simply copy the input 619 while (block_offset + TILE_ITEMS <= block_end) 620 { 621 T items[ITEMS_PER_THREAD]; 622 623 LoadDirectStriped<BLOCK_THREADS>(threadIdx.x, d_in + block_offset, items); 624 CTA_SYNC(); 625 StoreDirectStriped<BLOCK_THREADS>(threadIdx.x, d_out + block_offset, items); 626 627 block_offset += TILE_ITEMS; 628 } 629 630 // Clean up last partial tile with guarded-I/O 631 if (block_offset < block_end) 632 { 633 OffsetT valid_items = block_end - block_offset; 634 635 T items[ITEMS_PER_THREAD]; 636 637 LoadDirectStriped<BLOCK_THREADS>(threadIdx.x, d_in + block_offset, items, valid_items); 638 CTA_SYNC(); 639 StoreDirectStriped<BLOCK_THREADS>(threadIdx.x, d_out + block_offset, items, valid_items); 640 } 641 } 642 643 644 /** 645 * Copy tiles within the range of input (specialized for NullType) 646 */ 647 template <typename InputIteratorT> Copycub::AgentRadixSortDownsweep648 __device__ __forceinline__ void Copy( 649 InputIteratorT /*d_in*/, 650 NullType * /*d_out*/, 651 OffsetT /*block_offset*/, 652 OffsetT /*block_end*/) 653 {} 654 655 656 //--------------------------------------------------------------------- 657 // Interface 658 //--------------------------------------------------------------------- 659 660 /** 661 * Constructor 662 */ AgentRadixSortDownsweepcub::AgentRadixSortDownsweep663 __device__ __forceinline__ AgentRadixSortDownsweep( 664 TempStorage &temp_storage, 665 OffsetT (&bin_offset)[BINS_TRACKED_PER_THREAD], 666 OffsetT num_items, 667 const KeyT *d_keys_in, 668 KeyT *d_keys_out, 669 const ValueT *d_values_in, 670 ValueT *d_values_out, 671 int current_bit, 672 int num_bits) 673 : 674 temp_storage(temp_storage.Alias()), 675 d_keys_in(reinterpret_cast<const UnsignedBits*>(d_keys_in)), 676 d_values_in(d_values_in), 677 d_keys_out(reinterpret_cast<UnsignedBits*>(d_keys_out)), 678 d_values_out(d_values_out), 679 current_bit(current_bit), 680 num_bits(num_bits), 681 short_circuit(1) 682 { 683 #pragma unroll 684 for (int track = 0; track < BINS_TRACKED_PER_THREAD; ++track) 685 { 686 this->bin_offset[track] = bin_offset[track]; 687 688 int bin_idx = (threadIdx.x * BINS_TRACKED_PER_THREAD) + track; 689 if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS)) 690 { 691 // Short circuit if the histogram has only bin counts of only zeros or problem-size 692 short_circuit = short_circuit && ((bin_offset[track] == 0) || (bin_offset[track] == num_items)); 693 } 694 } 695 696 short_circuit = CTA_SYNC_AND(short_circuit); 697 } 698 699 700 /** 701 * Constructor 702 */ AgentRadixSortDownsweepcub::AgentRadixSortDownsweep703 __device__ __forceinline__ AgentRadixSortDownsweep( 704 TempStorage &temp_storage, 705 OffsetT num_items, 706 OffsetT *d_spine, 707 const KeyT *d_keys_in, 708 KeyT *d_keys_out, 709 const ValueT *d_values_in, 710 ValueT *d_values_out, 711 int current_bit, 712 int num_bits) 713 : 714 temp_storage(temp_storage.Alias()), 715 d_keys_in(reinterpret_cast<const UnsignedBits*>(d_keys_in)), 716 d_values_in(d_values_in), 717 d_keys_out(reinterpret_cast<UnsignedBits*>(d_keys_out)), 718 d_values_out(d_values_out), 719 current_bit(current_bit), 720 num_bits(num_bits), 721 short_circuit(1) 722 { 723 #pragma unroll 724 for (int track = 0; track < BINS_TRACKED_PER_THREAD; ++track) 725 { 726 int bin_idx = (threadIdx.x * BINS_TRACKED_PER_THREAD) + track; 727 728 // Load digit bin offsets (each of the first RADIX_DIGITS threads will load an offset for that digit) 729 if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS)) 730 { 731 if (IS_DESCENDING) 732 bin_idx = RADIX_DIGITS - bin_idx - 1; 733 734 // Short circuit if the first block's histogram has only bin counts of only zeros or problem-size 735 OffsetT first_block_bin_offset = d_spine[gridDim.x * bin_idx]; 736 short_circuit = short_circuit && ((first_block_bin_offset == 0) || (first_block_bin_offset == num_items)); 737 738 // Load my block's bin offset for my bin 739 bin_offset[track] = d_spine[(gridDim.x * bin_idx) + blockIdx.x]; 740 } 741 } 742 743 short_circuit = CTA_SYNC_AND(short_circuit); 744 } 745 746 747 /** 748 * Distribute keys from a segment of input tiles. 749 */ ProcessRegioncub::AgentRadixSortDownsweep750 __device__ __forceinline__ void ProcessRegion( 751 OffsetT block_offset, 752 OffsetT block_end) 753 { 754 if (short_circuit) 755 { 756 // Copy keys 757 Copy(d_keys_in, d_keys_out, block_offset, block_end); 758 759 // Copy values 760 Copy(d_values_in, d_values_out, block_offset, block_end); 761 } 762 else 763 { 764 // Process full tiles of tile_items 765 #pragma unroll 1 766 while (block_offset + TILE_ITEMS <= block_end) 767 { 768 ProcessTile<true>(block_offset); 769 block_offset += TILE_ITEMS; 770 771 CTA_SYNC(); 772 } 773 774 // Clean up last partial tile with guarded-I/O 775 if (block_offset < block_end) 776 { 777 ProcessTile<false>(block_offset, block_end - block_offset); 778 } 779 780 } 781 } 782 783 }; 784 785 786 787 } // CUB namespace 788 CUB_NS_POSTFIX // Optional outer namespace(s) 789 790