1 /****************************************************************************** 2 * Copyright (c) 2011, Duane Merrill. All rights reserved. 3 * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions are met: 7 * * Redistributions of source code must retain the above copyright 8 * notice, this list of conditions and the following disclaimer. 9 * * Redistributions in binary form must reproduce the above copyright 10 * notice, this list of conditions and the following disclaimer in the 11 * documentation and/or other materials provided with the distribution. 12 * * Neither the name of the NVIDIA CORPORATION nor the 13 * names of its contributors may be used to endorse or promote products 14 * derived from this software without specific prior written permission. 15 * 16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 17 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 18 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 20 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 21 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 22 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 23 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 * 27 ******************************************************************************/ 28 29 /** 30 * \file 31 * AgentRadixSortUpsweep implements a stateful abstraction of CUDA thread blocks for participating in device-wide radix sort upsweep . 32 */ 33 34 #pragma once 35 36 #include "../thread/thread_reduce.cuh" 37 #include "../thread/thread_load.cuh" 38 #include "../warp/warp_reduce.cuh" 39 #include "../block/block_load.cuh" 40 #include "../util_type.cuh" 41 #include "../iterator/cache_modified_input_iterator.cuh" 42 #include "../util_namespace.cuh" 43 44 /// Optional outer namespace(s) 45 CUB_NS_PREFIX 46 47 /// CUB namespace 48 namespace cub { 49 50 /****************************************************************************** 51 * Tuning policy types 52 ******************************************************************************/ 53 54 /** 55 * Parameterizable tuning policy type for AgentRadixSortUpsweep 56 */ 57 template < 58 int _BLOCK_THREADS, ///< Threads per thread block 59 int _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) 60 CacheLoadModifier _LOAD_MODIFIER, ///< Cache load modifier for reading keys 61 int _RADIX_BITS> ///< The number of radix bits, i.e., log2(bins) 62 struct AgentRadixSortUpsweepPolicy 63 { 64 enum 65 { 66 BLOCK_THREADS = _BLOCK_THREADS, ///< Threads per thread block 67 ITEMS_PER_THREAD = _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) 68 RADIX_BITS = _RADIX_BITS, ///< The number of radix bits, i.e., log2(bins) 69 }; 70 71 static const CacheLoadModifier LOAD_MODIFIER = _LOAD_MODIFIER; ///< Cache load modifier for reading keys 72 }; 73 74 75 /****************************************************************************** 76 * Thread block abstractions 77 ******************************************************************************/ 78 79 /** 80 * \brief AgentRadixSortUpsweep implements a stateful abstraction of CUDA thread blocks for participating in device-wide radix sort upsweep . 81 */ 82 template < 83 typename AgentRadixSortUpsweepPolicy, ///< Parameterized AgentRadixSortUpsweepPolicy tuning policy type 84 typename KeyT, ///< KeyT type 85 typename OffsetT> ///< Signed integer type for global offsets 86 struct AgentRadixSortUpsweep 87 { 88 89 //--------------------------------------------------------------------- 90 // Type definitions and constants 91 //--------------------------------------------------------------------- 92 93 typedef typename Traits<KeyT>::UnsignedBits UnsignedBits; 94 95 // Integer type for digit counters (to be packed into words of PackedCounters) 96 typedef unsigned char DigitCounter; 97 98 // Integer type for packing DigitCounters into columns of shared memory banks 99 typedef unsigned int PackedCounter; 100 101 static const CacheLoadModifier LOAD_MODIFIER = AgentRadixSortUpsweepPolicy::LOAD_MODIFIER; 102 103 enum 104 { 105 RADIX_BITS = AgentRadixSortUpsweepPolicy::RADIX_BITS, 106 BLOCK_THREADS = AgentRadixSortUpsweepPolicy::BLOCK_THREADS, 107 KEYS_PER_THREAD = AgentRadixSortUpsweepPolicy::ITEMS_PER_THREAD, 108 109 RADIX_DIGITS = 1 << RADIX_BITS, 110 111 LOG_WARP_THREADS = CUB_PTX_LOG_WARP_THREADS, 112 WARP_THREADS = 1 << LOG_WARP_THREADS, 113 WARPS = (BLOCK_THREADS + WARP_THREADS - 1) / WARP_THREADS, 114 115 TILE_ITEMS = BLOCK_THREADS * KEYS_PER_THREAD, 116 117 BYTES_PER_COUNTER = sizeof(DigitCounter), 118 LOG_BYTES_PER_COUNTER = Log2<BYTES_PER_COUNTER>::VALUE, 119 120 PACKING_RATIO = sizeof(PackedCounter) / sizeof(DigitCounter), 121 LOG_PACKING_RATIO = Log2<PACKING_RATIO>::VALUE, 122 123 LOG_COUNTER_LANES = CUB_MAX(0, RADIX_BITS - LOG_PACKING_RATIO), 124 COUNTER_LANES = 1 << LOG_COUNTER_LANES, 125 126 // To prevent counter overflow, we must periodically unpack and aggregate the 127 // digit counters back into registers. Each counter lane is assigned to a 128 // warp for aggregation. 129 130 LANES_PER_WARP = CUB_MAX(1, (COUNTER_LANES + WARPS - 1) / WARPS), 131 132 // Unroll tiles in batches without risk of counter overflow 133 UNROLL_COUNT = CUB_MIN(64, 255 / KEYS_PER_THREAD), 134 UNROLLED_ELEMENTS = UNROLL_COUNT * TILE_ITEMS, 135 }; 136 137 138 // Input iterator wrapper type (for applying cache modifier)s 139 typedef CacheModifiedInputIterator<LOAD_MODIFIER, UnsignedBits, OffsetT> KeysItr; 140 141 /** 142 * Shared memory storage layout 143 */ 144 union __align__(16) _TempStorage 145 { 146 DigitCounter thread_counters[COUNTER_LANES][BLOCK_THREADS][PACKING_RATIO]; 147 PackedCounter packed_thread_counters[COUNTER_LANES][BLOCK_THREADS]; 148 OffsetT block_counters[WARP_THREADS][RADIX_DIGITS]; 149 }; 150 151 152 /// Alias wrapper allowing storage to be unioned 153 struct TempStorage : Uninitialized<_TempStorage> {}; 154 155 156 //--------------------------------------------------------------------- 157 // Thread fields (aggregate state bundle) 158 //--------------------------------------------------------------------- 159 160 // Shared storage for this CTA 161 _TempStorage &temp_storage; 162 163 // Thread-local counters for periodically aggregating composite-counter lanes 164 OffsetT local_counts[LANES_PER_WARP][PACKING_RATIO]; 165 166 // Input and output device pointers 167 KeysItr d_keys_in; 168 169 // The least-significant bit position of the current digit to extract 170 int current_bit; 171 172 // Number of bits in current digit 173 int num_bits; 174 175 176 177 //--------------------------------------------------------------------- 178 // Helper structure for templated iteration 179 //--------------------------------------------------------------------- 180 181 // Iterate 182 template <int COUNT, int MAX> 183 struct Iterate 184 { 185 // BucketKeys BucketKeyscub::AgentRadixSortUpsweep::Iterate186 static __device__ __forceinline__ void BucketKeys( 187 AgentRadixSortUpsweep &cta, 188 UnsignedBits keys[KEYS_PER_THREAD]) 189 { 190 cta.Bucket(keys[COUNT]); 191 192 // Next 193 Iterate<COUNT + 1, MAX>::BucketKeys(cta, keys); 194 } 195 }; 196 197 // Terminate 198 template <int MAX> 199 struct Iterate<MAX, MAX> 200 { 201 // BucketKeys BucketKeyscub::AgentRadixSortUpsweep::Iterate202 static __device__ __forceinline__ void BucketKeys(AgentRadixSortUpsweep &/*cta*/, UnsignedBits /*keys*/[KEYS_PER_THREAD]) {} 203 }; 204 205 206 //--------------------------------------------------------------------- 207 // Utility methods 208 //--------------------------------------------------------------------- 209 210 /** 211 * Decode a key and increment corresponding smem digit counter 212 */ Bucketcub::AgentRadixSortUpsweep213 __device__ __forceinline__ void Bucket(UnsignedBits key) 214 { 215 // Perform transform op 216 UnsignedBits converted_key = Traits<KeyT>::TwiddleIn(key); 217 218 // Extract current digit bits 219 UnsignedBits digit = BFE(converted_key, current_bit, num_bits); 220 221 // Get sub-counter offset 222 UnsignedBits sub_counter = digit & (PACKING_RATIO - 1); 223 224 // Get row offset 225 UnsignedBits row_offset = digit >> LOG_PACKING_RATIO; 226 227 // Increment counter 228 temp_storage.thread_counters[row_offset][threadIdx.x][sub_counter]++; 229 } 230 231 232 /** 233 * Reset composite counters 234 */ ResetDigitCounterscub::AgentRadixSortUpsweep235 __device__ __forceinline__ void ResetDigitCounters() 236 { 237 #pragma unroll 238 for (int LANE = 0; LANE < COUNTER_LANES; LANE++) 239 { 240 temp_storage.packed_thread_counters[LANE][threadIdx.x] = 0; 241 } 242 } 243 244 245 /** 246 * Reset the unpacked counters in each thread 247 */ ResetUnpackedCounterscub::AgentRadixSortUpsweep248 __device__ __forceinline__ void ResetUnpackedCounters() 249 { 250 #pragma unroll 251 for (int LANE = 0; LANE < LANES_PER_WARP; LANE++) 252 { 253 #pragma unroll 254 for (int UNPACKED_COUNTER = 0; UNPACKED_COUNTER < PACKING_RATIO; UNPACKED_COUNTER++) 255 { 256 local_counts[LANE][UNPACKED_COUNTER] = 0; 257 } 258 } 259 } 260 261 262 /** 263 * Extracts and aggregates the digit counters for each counter lane 264 * owned by this warp 265 */ UnpackDigitCountscub::AgentRadixSortUpsweep266 __device__ __forceinline__ void UnpackDigitCounts() 267 { 268 unsigned int warp_id = threadIdx.x >> LOG_WARP_THREADS; 269 unsigned int warp_tid = LaneId(); 270 271 #pragma unroll 272 for (int LANE = 0; LANE < LANES_PER_WARP; LANE++) 273 { 274 const int counter_lane = (LANE * WARPS) + warp_id; 275 if (counter_lane < COUNTER_LANES) 276 { 277 #pragma unroll 278 for (int PACKED_COUNTER = 0; PACKED_COUNTER < BLOCK_THREADS; PACKED_COUNTER += WARP_THREADS) 279 { 280 #pragma unroll 281 for (int UNPACKED_COUNTER = 0; UNPACKED_COUNTER < PACKING_RATIO; UNPACKED_COUNTER++) 282 { 283 OffsetT counter = temp_storage.thread_counters[counter_lane][warp_tid + PACKED_COUNTER][UNPACKED_COUNTER]; 284 local_counts[LANE][UNPACKED_COUNTER] += counter; 285 } 286 } 287 } 288 } 289 } 290 291 292 /** 293 * Processes a single, full tile 294 */ ProcessFullTilecub::AgentRadixSortUpsweep295 __device__ __forceinline__ void ProcessFullTile(OffsetT block_offset) 296 { 297 // Tile of keys 298 UnsignedBits keys[KEYS_PER_THREAD]; 299 300 LoadDirectStriped<BLOCK_THREADS>(threadIdx.x, d_keys_in + block_offset, keys); 301 302 // Prevent hoisting 303 CTA_SYNC(); 304 305 // Bucket tile of keys 306 Iterate<0, KEYS_PER_THREAD>::BucketKeys(*this, keys); 307 } 308 309 310 /** 311 * Processes a single load (may have some threads masked off) 312 */ ProcessPartialTilecub::AgentRadixSortUpsweep313 __device__ __forceinline__ void ProcessPartialTile( 314 OffsetT block_offset, 315 const OffsetT &block_end) 316 { 317 // Process partial tile if necessary using single loads 318 block_offset += threadIdx.x; 319 while (block_offset < block_end) 320 { 321 // Load and bucket key 322 UnsignedBits key = d_keys_in[block_offset]; 323 Bucket(key); 324 block_offset += BLOCK_THREADS; 325 } 326 } 327 328 329 //--------------------------------------------------------------------- 330 // Interface 331 //--------------------------------------------------------------------- 332 333 /** 334 * Constructor 335 */ AgentRadixSortUpsweepcub::AgentRadixSortUpsweep336 __device__ __forceinline__ AgentRadixSortUpsweep( 337 TempStorage &temp_storage, 338 const KeyT *d_keys_in, 339 int current_bit, 340 int num_bits) 341 : 342 temp_storage(temp_storage.Alias()), 343 d_keys_in(reinterpret_cast<const UnsignedBits*>(d_keys_in)), 344 current_bit(current_bit), 345 num_bits(num_bits) 346 {} 347 348 349 /** 350 * Compute radix digit histograms from a segment of input tiles. 351 */ ProcessRegioncub::AgentRadixSortUpsweep352 __device__ __forceinline__ void ProcessRegion( 353 OffsetT block_offset, 354 const OffsetT &block_end) 355 { 356 // Reset digit counters in smem and unpacked counters in registers 357 ResetDigitCounters(); 358 ResetUnpackedCounters(); 359 360 // Unroll batches of full tiles 361 while (block_offset + UNROLLED_ELEMENTS <= block_end) 362 { 363 for (int i = 0; i < UNROLL_COUNT; ++i) 364 { 365 ProcessFullTile(block_offset); 366 block_offset += TILE_ITEMS; 367 } 368 369 CTA_SYNC(); 370 371 // Aggregate back into local_count registers to prevent overflow 372 UnpackDigitCounts(); 373 374 CTA_SYNC(); 375 376 // Reset composite counters in lanes 377 ResetDigitCounters(); 378 } 379 380 // Unroll single full tiles 381 while (block_offset + TILE_ITEMS <= block_end) 382 { 383 ProcessFullTile(block_offset); 384 block_offset += TILE_ITEMS; 385 } 386 387 // Process partial tile if necessary 388 ProcessPartialTile( 389 block_offset, 390 block_end); 391 392 CTA_SYNC(); 393 394 // Aggregate back into local_count registers 395 UnpackDigitCounts(); 396 } 397 398 399 /** 400 * Extract counts (saving them to the external array) 401 */ 402 template <bool IS_DESCENDING> ExtractCountscub::AgentRadixSortUpsweep403 __device__ __forceinline__ void ExtractCounts( 404 OffsetT *counters, 405 int bin_stride = 1, 406 int bin_offset = 0) 407 { 408 unsigned int warp_id = threadIdx.x >> LOG_WARP_THREADS; 409 unsigned int warp_tid = LaneId(); 410 411 // Place unpacked digit counters in shared memory 412 #pragma unroll 413 for (int LANE = 0; LANE < LANES_PER_WARP; LANE++) 414 { 415 int counter_lane = (LANE * WARPS) + warp_id; 416 if (counter_lane < COUNTER_LANES) 417 { 418 int digit_row = counter_lane << LOG_PACKING_RATIO; 419 420 #pragma unroll 421 for (int UNPACKED_COUNTER = 0; UNPACKED_COUNTER < PACKING_RATIO; UNPACKED_COUNTER++) 422 { 423 int bin_idx = digit_row + UNPACKED_COUNTER; 424 425 temp_storage.block_counters[warp_tid][bin_idx] = 426 local_counts[LANE][UNPACKED_COUNTER]; 427 } 428 } 429 } 430 431 CTA_SYNC(); 432 433 // Rake-reduce bin_count reductions 434 435 // Whole blocks 436 #pragma unroll 437 for (int BIN_BASE = RADIX_DIGITS % BLOCK_THREADS; 438 (BIN_BASE + BLOCK_THREADS) <= RADIX_DIGITS; 439 BIN_BASE += BLOCK_THREADS) 440 { 441 int bin_idx = BIN_BASE + threadIdx.x; 442 443 OffsetT bin_count = 0; 444 #pragma unroll 445 for (int i = 0; i < WARP_THREADS; ++i) 446 bin_count += temp_storage.block_counters[i][bin_idx]; 447 448 if (IS_DESCENDING) 449 bin_idx = RADIX_DIGITS - bin_idx - 1; 450 451 counters[(bin_stride * bin_idx) + bin_offset] = bin_count; 452 } 453 454 // Remainder 455 if ((RADIX_DIGITS % BLOCK_THREADS != 0) && (threadIdx.x < RADIX_DIGITS)) 456 { 457 int bin_idx = threadIdx.x; 458 459 OffsetT bin_count = 0; 460 #pragma unroll 461 for (int i = 0; i < WARP_THREADS; ++i) 462 bin_count += temp_storage.block_counters[i][bin_idx]; 463 464 if (IS_DESCENDING) 465 bin_idx = RADIX_DIGITS - bin_idx - 1; 466 467 counters[(bin_stride * bin_idx) + bin_offset] = bin_count; 468 } 469 } 470 471 472 /** 473 * Extract counts 474 */ 475 template <int BINS_TRACKED_PER_THREAD> ExtractCountscub::AgentRadixSortUpsweep476 __device__ __forceinline__ void ExtractCounts( 477 OffsetT (&bin_count)[BINS_TRACKED_PER_THREAD]) ///< [out] The exclusive prefix sum for the digits [(threadIdx.x * BINS_TRACKED_PER_THREAD) ... (threadIdx.x * BINS_TRACKED_PER_THREAD) + BINS_TRACKED_PER_THREAD - 1] 478 { 479 unsigned int warp_id = threadIdx.x >> LOG_WARP_THREADS; 480 unsigned int warp_tid = LaneId(); 481 482 // Place unpacked digit counters in shared memory 483 #pragma unroll 484 for (int LANE = 0; LANE < LANES_PER_WARP; LANE++) 485 { 486 int counter_lane = (LANE * WARPS) + warp_id; 487 if (counter_lane < COUNTER_LANES) 488 { 489 int digit_row = counter_lane << LOG_PACKING_RATIO; 490 491 #pragma unroll 492 for (int UNPACKED_COUNTER = 0; UNPACKED_COUNTER < PACKING_RATIO; UNPACKED_COUNTER++) 493 { 494 int bin_idx = digit_row + UNPACKED_COUNTER; 495 496 temp_storage.block_counters[warp_tid][bin_idx] = 497 local_counts[LANE][UNPACKED_COUNTER]; 498 } 499 } 500 } 501 502 CTA_SYNC(); 503 504 // Rake-reduce bin_count reductions 505 #pragma unroll 506 for (int track = 0; track < BINS_TRACKED_PER_THREAD; ++track) 507 { 508 int bin_idx = (threadIdx.x * BINS_TRACKED_PER_THREAD) + track; 509 510 if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS)) 511 { 512 bin_count[track] = 0; 513 514 #pragma unroll 515 for (int i = 0; i < WARP_THREADS; ++i) 516 bin_count[track] += temp_storage.block_counters[i][bin_idx]; 517 } 518 } 519 } 520 521 }; 522 523 524 } // CUB namespace 525 CUB_NS_POSTFIX // Optional outer namespace(s) 526 527